/*
 * Decompiled with CFR 0.152.
 */
package org.dflib.join;

import java.util.Map;
import java.util.Objects;
import java.util.function.Predicate;
import java.util.function.UnaryOperator;
import org.dflib.ColumnDataFrame;
import org.dflib.DataFrame;
import org.dflib.Exp;
import org.dflib.Hasher;
import org.dflib.Index;
import org.dflib.IntSeries;
import org.dflib.JoinType;
import org.dflib.Series;
import org.dflib.builder.ObjectAccum;
import org.dflib.join.HashJoiner;
import org.dflib.join.JoinIndex;
import org.dflib.join.JoinIndicator;
import org.dflib.join.JoinPredicate;
import org.dflib.join.NestedLoopJoiner;
import org.dflib.series.IndexedSeries;
import org.dflib.series.SingleValueSeries;

public class Join {
    private final JoinType type;
    private final DataFrame leftFrame;
    private final DataFrame rightFrame;
    private Hasher leftHasher;
    private Hasher rightHasher;
    private JoinPredicate predicate;
    private String indicatorColumn;
    private boolean userColumns;
    private UnaryOperator<JoinIndex> colSelector = UnaryOperator.identity();

    public Join(JoinType type, DataFrame leftFrame, DataFrame rightFrame) {
        this.type = type;
        this.leftFrame = leftFrame;
        this.rightFrame = rightFrame;
    }

    public Join on(int columnsIndex) {
        return this.on(columnsIndex, columnsIndex);
    }

    public Join on(int leftColumn, int rightColumn) {
        return this.on(Hasher.of(leftColumn), Hasher.of(rightColumn));
    }

    public Join on(String column) {
        return this.on(column, column);
    }

    public Join on(String leftColumn, String rightColumn) {
        return this.on(Hasher.of(leftColumn), Hasher.of(rightColumn));
    }

    public Join on(Hasher hasher) {
        return this.on(hasher, hasher);
    }

    public Join on(Hasher left, Hasher right) {
        this.leftHasher = this.combineHashers(this.leftHasher, left);
        this.rightHasher = this.combineHashers(this.rightHasher, right);
        this.predicate = null;
        return this;
    }

    public Join predicatedBy(JoinPredicate predicate) {
        this.predicate = predicate;
        this.leftHasher = null;
        this.rightHasher = null;
        return this;
    }

    public Join indicatorColumn(String name) {
        this.indicatorColumn = name;
        return this;
    }

    public Join cols(int ... columns) {
        this.colSelector = i -> i.cols(columns);
        this.userColumns = true;
        return this;
    }

    public Join cols(String ... columns) {
        this.colSelector = i -> i.cols(columns);
        this.userColumns = true;
        return this;
    }

    public Join cols(Predicate<String> labelCondition) {
        this.colSelector = i -> i.cols(labelCondition);
        this.userColumns = true;
        return this;
    }

    public Join colsExcept(int ... columns) {
        this.colSelector = i -> i.colsExcept(columns);
        this.userColumns = true;
        return this;
    }

    public Join colsExcept(String ... columns) {
        this.colSelector = i -> i.colsExcept(columns);
        this.userColumns = true;
        return this;
    }

    public Join colsExcept(Predicate<String> labelCondition) {
        this.colSelector = i -> i.colsExcept(labelCondition);
        this.userColumns = true;
        return this;
    }

    public DataFrame select() {
        JoinIndex index = (JoinIndex)this.colSelector.apply(this.defaultJoinIndex());
        IntSeries[] selectors = this.rowSelectors();
        return new ColumnDataFrame(null, index.getIndex(), this.merge(selectors[0], selectors[1], index.getPositions()));
    }

    public DataFrame selectAs(UnaryOperator<String> renamer) {
        JoinIndex index = (JoinIndex)this.colSelector.apply(this.defaultJoinIndex());
        IntSeries[] selectors = this.rowSelectors();
        return new ColumnDataFrame(null, index.getIndex().replace(renamer), this.merge(selectors[0], selectors[1], index.getPositions()));
    }

    public DataFrame selectAs(String ... newColumnNames) {
        JoinIndex index = (JoinIndex)this.colSelector.apply(this.defaultJoinIndex());
        IntSeries[] selectors = this.rowSelectors();
        return new ColumnDataFrame(null, Index.of(newColumnNames), this.merge(selectors[0], selectors[1], index.getPositions()));
    }

    public DataFrame selectAs(Map<String, String> oldToNewNames) {
        JoinIndex index = (JoinIndex)this.colSelector.apply(this.defaultJoinIndex());
        IntSeries[] selectors = this.rowSelectors();
        return new ColumnDataFrame(null, index.getIndex().replace(oldToNewNames), this.merge(selectors[0], selectors[1], index.getPositions()));
    }

    public DataFrame select(Exp<?> ... exps) {
        int w;
        JoinIndex defaultIndex = this.defaultJoinIndex();
        JoinIndex resultIndex = (JoinIndex)this.colSelector.apply(defaultIndex);
        if (this.userColumns && (w = exps.length) != resultIndex.size()) {
            throw new IllegalArgumentException("Can't perform 'select': Exp[] size is different from the ColumnSet size: " + w + " vs. " + resultIndex.size());
        }
        JoinIndex allAliasesIndex = defaultIndex.colsExpandAliases();
        IntSeries[] selectors = this.rowSelectors();
        Series<?>[] uniqueColumns = this.merge(selectors[0], selectors[1], defaultIndex.getPositions());
        ColumnDataFrame allAliasesDf = new ColumnDataFrame(null, allAliasesIndex.getIndex(), this.pick(uniqueColumns, allAliasesIndex.getPositions()));
        return this.userColumns ? allAliasesDf.cols(resultIndex.getIndex()).select(exps) : allAliasesDf.cols().select(exps);
    }

    private JoinIndex defaultJoinIndex() {
        return JoinIndex.of(this.leftFrame.getName(), this.rightFrame.getName(), this.leftFrame.getColumnsIndex(), this.rightFrame.getColumnsIndex(), this.indicatorColumn);
    }

    private Hasher combineHashers(Hasher possiblyNull, Hasher mustBeNotNull) {
        Objects.requireNonNull(mustBeNotNull);
        return possiblyNull != null ? possiblyNull.and(mustBeNotNull) : mustBeNotNull;
    }

    private IntSeries[] rowSelectors() {
        if (this.predicate != null) {
            return new NestedLoopJoiner(this.predicate, this.type).rowSelectors(this.leftFrame, this.rightFrame);
        }
        if (this.leftHasher != null && this.rightHasher != null) {
            return new HashJoiner(this.leftHasher, this.rightHasher, this.type).rowSelectors(this.leftFrame, this.rightFrame);
        }
        throw new IllegalStateException("No join condition set. Either join columns, Hashers or a predicate must be specified");
    }

    private Series<?>[] merge(IntSeries leftIndex, IntSeries rightIndex, int[] positions) {
        int llen = this.leftFrame.width();
        int rlen = this.rightFrame.width();
        int lrlen = llen + rlen;
        int len = positions.length;
        int h = leftIndex.size();
        Series[] data = new Series[len];
        for (int i = 0; i < len; ++i) {
            int si = positions[i];
            data[i] = si < llen ? new IndexedSeries(this.leftFrame.getColumn(si), leftIndex) : (si < lrlen ? new IndexedSeries(this.rightFrame.getColumn(si - llen), rightIndex) : (si == lrlen && this.indicatorColumn != null ? this.buildIndicator(leftIndex, rightIndex) : new SingleValueSeries<Object>(null, h)));
        }
        return data;
    }

    private Series<?>[] pick(Series<?>[] cols, int[] positions) {
        int len = positions.length;
        Series[] data = new Series[len];
        for (int i = 0; i < len; ++i) {
            data[i] = cols[positions[i]];
        }
        return data;
    }

    private Series<JoinIndicator> buildIndicator(IntSeries leftIndex, IntSeries rightIndex) {
        int h = leftIndex.size();
        ObjectAccum<JoinIndicator> appender = new ObjectAccum<JoinIndicator>(h);
        for (int i = 0; i < h; ++i) {
            appender.push(leftIndex.getInt(i) < 0 ? JoinIndicator.right_only : (rightIndex.getInt(i) < 0 ? JoinIndicator.left_only : JoinIndicator.both));
        }
        return appender.toSeries();
    }
}

