/*
 * Decompiled with CFR 0.152.
 */
package org.dflib;

import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.function.Predicate;
import org.dflib.ColumnDataFrame;
import org.dflib.DataFrame;
import org.dflib.Exp;
import org.dflib.Index;
import org.dflib.IntSeries;
import org.dflib.JoinType;
import org.dflib.Series;
import org.dflib.Sorter;
import org.dflib.agg.GroupByAggregator;
import org.dflib.concat.SeriesConcat;
import org.dflib.concat.VConcat;
import org.dflib.exp.Exps;
import org.dflib.series.EmptySeries;
import org.dflib.slice.FixedColumnSetIndex;
import org.dflib.sort.GroupBySorter;
import org.dflib.sort.IntComparator;
import org.dflib.window.DenseRanker;
import org.dflib.window.Ranker;

public class GroupBy {
    private final DataFrame source;
    private final Map<Object, IntSeries> groupsIndex;
    private final IntComparator sorter;
    private final ConcurrentMap<Object, DataFrame> groupsCache;
    private final FixedColumnSetIndex columnSetIndex;

    public GroupBy(DataFrame source, Map<Object, IntSeries> groupsIndex, IntComparator sorter) {
        this(source, groupsIndex, sorter, null);
    }

    protected GroupBy(DataFrame source, Map<Object, IntSeries> groupsIndex, IntComparator sorter, FixedColumnSetIndex columnSetIndex) {
        this.source = source;
        this.groupsIndex = groupsIndex;
        this.sorter = sorter;
        this.columnSetIndex = columnSetIndex;
        this.groupsCache = new ConcurrentHashMap<Object, DataFrame>();
    }

    public int size() {
        return this.groupsIndex.size();
    }

    public DataFrame getSource() {
        return this.source;
    }

    public GroupBy cols(Predicate<String> colsPredicate) {
        Index srcIndex = this.source.getColumnsIndex();
        return new GroupBy(this.source, this.groupsIndex, this.sorter, FixedColumnSetIndex.of(srcIndex, srcIndex.positions(colsPredicate)));
    }

    public GroupBy cols(String ... cols) {
        return new GroupBy(this.source, this.groupsIndex, this.sorter, FixedColumnSetIndex.of(cols));
    }

    public GroupBy cols(int ... cols) {
        Index srcIndex = this.source.getColumnsIndex();
        return new GroupBy(this.source, this.groupsIndex, this.sorter, FixedColumnSetIndex.of(srcIndex, cols));
    }

    public GroupBy colsExcept(String ... cols) {
        return this.cols(this.source.getColumnsIndex().positionsExcept(cols));
    }

    public GroupBy colsExcept(int ... cols) {
        return this.cols(this.source.getColumnsIndex().positionsExcept(cols));
    }

    public GroupBy colsExcept(Predicate<String> colsPredicate) {
        return this.cols(colsPredicate.negate());
    }

    public Set<Object> getGroupKeys() {
        return this.groupsIndex.keySet();
    }

    public boolean hasGroup(Object key) {
        return this.groupsIndex.containsKey(key);
    }

    public IntSeries getGroupIndex(Object key) {
        return this.groupsIndex.get(key);
    }

    public DataFrame getGroup(Object key) {
        return this.groupsCache.computeIfAbsent(key, this::resolveGroup);
    }

    public IntSeries rank() {
        if (this.groupsIndex.isEmpty()) {
            return Series.ofInt(new int[0]);
        }
        if (this.sorter == null) {
            return Ranker.sameRank(this.source.height());
        }
        return new Ranker(this.sorter).rank(this.source, this.groupsIndex.values());
    }

    public IntSeries denseRank() {
        if (this.groupsIndex.isEmpty()) {
            return Series.ofInt(new int[0]);
        }
        if (this.sorter == null) {
            return Ranker.sameRank(this.source.height());
        }
        return new DenseRanker(this.sorter).rank(this.source, this.groupsIndex.values());
    }

    public <T> Series<T> shift(String column, int offset, T filler) {
        int pos = this.source.getColumnsIndex().position(column);
        return this.shift(pos, offset, filler);
    }

    public <T> Series<T> shift(int column, int offset, T filler) {
        if (this.groupsIndex.isEmpty()) {
            return new EmptySeries();
        }
        Series[] shifted = new Series[this.groupsIndex.size()];
        int i = 0;
        for (Object key : this.getGroupKeys()) {
            DataFrame group = this.getGroup(key);
            shifted[i++] = group.getColumn(column).shift(offset, filler);
        }
        IntSeries groupsIndexAll = SeriesConcat.intConcat(this.groupsIndex.values());
        Series shiftedAll = SeriesConcat.concat(shifted);
        return shiftedAll.select(groupsIndexAll.sortIndexInt());
    }

    public GroupBy head(int len) {
        if (len == 0) {
            return this;
        }
        LinkedHashMap<Object, IntSeries> trimmed = new LinkedHashMap<Object, IntSeries>((int)((double)this.groupsIndex.size() / 0.75));
        for (Map.Entry<Object, IntSeries> e : this.groupsIndex.entrySet()) {
            trimmed.put(e.getKey(), e.getValue().head(len));
        }
        return new GroupBy(this.source, trimmed, this.sorter);
    }

    public GroupBy tail(int len) {
        if (len == 0) {
            return this;
        }
        LinkedHashMap<Object, IntSeries> trimmed = new LinkedHashMap<Object, IntSeries>((int)((double)this.groupsIndex.size() / 0.75));
        for (Map.Entry<Object, IntSeries> e : this.groupsIndex.entrySet()) {
            trimmed.put(e.getKey(), e.getValue().tail(len));
        }
        return new GroupBy(this.source, trimmed, this.sorter);
    }

    public GroupBy sort(Sorter ... sorters) {
        return new GroupBySorter(this).sort(sorters);
    }

    public GroupBy sort(IntComparator sorter) {
        return new GroupBySorter(this).sort(sorter);
    }

    public GroupBy sort(String column, boolean ascending) {
        return new GroupBySorter(this).sort(column, ascending);
    }

    public GroupBy sort(int column, boolean ascending) {
        return new GroupBySorter(this).sort(column, ascending);
    }

    public GroupBy sort(String[] columns, boolean[] ascending) {
        return new GroupBySorter(this).sort(columns, ascending);
    }

    public GroupBy sort(int[] columns, boolean[] ascending) {
        return new GroupBySorter(this).sort(columns, ascending);
    }

    public DataFrame select() {
        IntSeries index = SeriesConcat.intConcat(this.groupsIndex.values());
        return this.columnSetIndex != null ? this.source.rows(index).cols(this.columnSetIndex.getIndex()).select() : this.source.rows(index).select();
    }

    public DataFrame select(Exp<?> ... exps) {
        int len = this.groupsIndex.size();
        DataFrame[] dfs = new DataFrame[len];
        int i = 0;
        if (this.columnSetIndex != null) {
            Index customIndex = this.columnSetIndex.getIndex();
            for (IntSeries gi : this.groupsIndex.values()) {
                dfs[i++] = this.source.rows(gi).cols(customIndex).select(exps);
            }
        } else {
            for (IntSeries gi : this.groupsIndex.values()) {
                dfs[i++] = this.source.rows(gi).cols().select(exps);
            }
        }
        return VConcat.concat(JoinType.inner, dfs);
    }

    public DataFrame merge(Exp<?> ... exps) {
        int len = this.groupsIndex.size();
        DataFrame[] dfs = new DataFrame[len];
        int i = 0;
        if (this.columnSetIndex != null) {
            Index customIndex = this.columnSetIndex.getIndex();
            for (IntSeries gi : this.groupsIndex.values()) {
                dfs[i++] = this.source.rows(gi).select().cols(customIndex).merge(exps);
            }
        } else {
            for (IntSeries gi : this.groupsIndex.values()) {
                dfs[i++] = this.source.rows(gi).select().cols().merge(exps);
            }
        }
        return VConcat.concat(JoinType.inner, dfs);
    }

    public DataFrame agg(Exp<?> ... aggregators) {
        Index index;
        if (this.columnSetIndex != null) {
            int w = aggregators.length;
            index = this.columnSetIndex.getIndex();
            if (w != index.size()) {
                throw new IllegalArgumentException("Can't perform 'agg': Exp[] size is different from the ColumnSet size: " + w + " vs. " + index.size());
            }
        } else {
            index = Exps.index(this.source, aggregators);
        }
        return new ColumnDataFrame(null, index, GroupByAggregator.agg(this, aggregators));
    }

    protected DataFrame resolveGroup(Object key) {
        IntSeries index = this.groupsIndex.get(key);
        if (index == null) {
            return null;
        }
        int w = this.source.width();
        Series[] data = new Series[w];
        for (int j = 0; j < w; ++j) {
            data[j] = this.source.getColumn(j).select(index);
        }
        return new ColumnDataFrame(null, this.source.getColumnsIndex(), data);
    }
}

