/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.search.aggregations.metrics;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.function.BiConsumer;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.BulkScorer;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.DisiPriorityQueue;
import org.apache.lucene.search.DisiWrapper;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.RamUsageEstimator;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.common.Nullable;
import org.opensearch.common.hash.MurmurHash3;
import org.opensearch.common.lease.Releasable;
import org.opensearch.common.lease.Releasables;
import org.opensearch.common.util.BigArrays;
import org.opensearch.common.util.BitArray;
import org.opensearch.common.util.BitMixer;
import org.opensearch.common.util.LongArray;
import org.opensearch.common.util.ObjectArray;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.index.fielddata.SortedBinaryDocValues;
import org.opensearch.index.fielddata.SortedNumericDoubleValues;
import org.opensearch.search.SearchService;
import org.opensearch.search.aggregations.Aggregator;
import org.opensearch.search.aggregations.InternalAggregation;
import org.opensearch.search.aggregations.LeafBucketCollector;
import org.opensearch.search.aggregations.metrics.AbstractHyperLogLogPlusPlus;
import org.opensearch.search.aggregations.metrics.HyperLogLogPlusPlus;
import org.opensearch.search.aggregations.metrics.InternalCardinality;
import org.opensearch.search.aggregations.metrics.NumericMetricsAggregator;
import org.opensearch.search.aggregations.support.ValuesSource;
import org.opensearch.search.aggregations.support.ValuesSourceConfig;
import org.opensearch.search.internal.SearchContext;

public class CardinalityAggregator
extends NumericMetricsAggregator.SingleValue {
    private static final Logger logger = LogManager.getLogger(CardinalityAggregator.class);
    private final int precision;
    private final ValuesSource valuesSource;
    private final ValuesSourceConfig valuesSourceConfig;
    @Nullable
    private HyperLogLogPlusPlus counts;
    private Collector collector;
    private int emptyCollectorsUsed;
    private int numericCollectorsUsed;
    private int ordinalsCollectorsUsed;
    private int ordinalsCollectorsOverheadTooHigh;
    private int stringHashingCollectorsUsed;
    private int dynamicPrunedSegments;

    public CardinalityAggregator(String name, ValuesSourceConfig valuesSourceConfig, int precision, SearchContext context, Aggregator parent, Map<String, Object> metadata) throws IOException {
        super(name, context, parent, metadata);
        this.valuesSource = valuesSourceConfig.hasValues() ? valuesSourceConfig.getValuesSource() : null;
        this.precision = precision;
        this.counts = this.valuesSource == null ? null : new HyperLogLogPlusPlus(precision, context.bigArrays(), 1L);
        this.valuesSourceConfig = valuesSourceConfig;
    }

    @Override
    public ScoreMode scoreMode() {
        return this.valuesSource != null && this.valuesSource.needsScores() ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES;
    }

    private Collector pickCollector(LeafReaderContext ctx) throws IOException {
        if (this.valuesSource == null) {
            ++this.emptyCollectorsUsed;
            return new EmptyCollector();
        }
        if (this.valuesSource instanceof ValuesSource.Numeric) {
            ValuesSource.Numeric source = (ValuesSource.Numeric)this.valuesSource;
            MurmurHash3Values hashValues = source.isFloatingPoint() || source.isBigInteger() ? MurmurHash3Values.hash(source.doubleValues(ctx)) : MurmurHash3Values.hash(source.longValues(ctx));
            ++this.numericCollectorsUsed;
            return new DirectCollector(this.counts, hashValues);
        }
        Collector collector = null;
        if (this.valuesSource instanceof ValuesSource.Bytes.WithOrdinals) {
            long countsMemoryUsage;
            ValuesSource.Bytes.WithOrdinals source = (ValuesSource.Bytes.WithOrdinals)this.valuesSource;
            SortedSetDocValues ordinalValues = source.ordinalsValues(ctx);
            long maxOrd = ordinalValues.getValueCount();
            if (maxOrd == 0L) {
                ++this.emptyCollectorsUsed;
                return new EmptyCollector();
            }
            long ordinalsMemoryUsage = OrdinalsCollector.memoryOverhead(maxOrd);
            if (ordinalsMemoryUsage < (countsMemoryUsage = HyperLogLogPlusPlus.memoryUsage(this.precision)) / 4L) {
                ++this.ordinalsCollectorsUsed;
                collector = new OrdinalsCollector(this.counts, ordinalValues, this.context.bigArrays());
            } else {
                ++this.ordinalsCollectorsOverheadTooHigh;
            }
        }
        if (collector == null) {
            ++this.stringHashingCollectorsUsed;
            collector = new DirectCollector(this.counts, MurmurHash3Values.hash(this.valuesSource.bytesValues(ctx)));
        }
        if (this.canPrune(this.parent, this.subAggregators, this.valuesSourceConfig)) {
            Terms terms = ctx.reader().terms(this.valuesSourceConfig.fieldContext().field());
            if (terms == null) {
                return collector;
            }
            if (this.exceedMaxThreshold(terms)) {
                return collector;
            }
            Collector pruningCollector = this.tryWrapWithPruningCollector(collector, terms, ctx);
            if (pruningCollector == null) {
                return collector;
            }
            if (!this.tryScoreWithPruningCollector(ctx, pruningCollector)) {
                return collector;
            }
            logger.debug("Dynamic pruned segment {} of shard {}", (Object)ctx.ord, (Object)this.context.indexShard().shardId());
            ++this.dynamicPrunedSegments;
            return this.getNoOpCollector();
        }
        return collector;
    }

    private boolean canPrune(Aggregator parent, Aggregator[] subAggregators, ValuesSourceConfig valuesSourceConfig) {
        return parent == null && subAggregators.length == 0 && valuesSourceConfig.missing() == null && valuesSourceConfig.script() == null;
    }

    private boolean exceedMaxThreshold(Terms terms) throws IOException {
        if (terms.size() > (long)this.context.cardinalityAggregationPruningThreshold()) {
            logger.debug("Cannot prune because terms size {} is greater than the threshold {}", (Object)terms.size(), (Object)this.context.cardinalityAggregationPruningThreshold());
            return true;
        }
        return false;
    }

    private Collector tryWrapWithPruningCollector(Collector collector, Terms terms, LeafReaderContext ctx) {
        try {
            return new PruningCollector(collector, terms.iterator(), ctx, this.context, this.valuesSourceConfig.fieldContext().field());
        }
        catch (Exception e) {
            logger.warn("Failed to build collector for dynamic pruning.", (Throwable)e);
            return null;
        }
    }

    private boolean tryScoreWithPruningCollector(LeafReaderContext ctx, Collector pruningCollector) throws IOException {
        try {
            Weight weight = this.context.query().rewrite((IndexSearcher)this.context.searcher()).createWeight((IndexSearcher)this.context.searcher(), ScoreMode.TOP_DOCS, 1.0f);
            BulkScorer scorer = weight.bulkScorer(ctx);
            if (scorer == null) {
                return false;
            }
            Bits liveDocs = ctx.reader().getLiveDocs();
            scorer.score((LeafCollector)pruningCollector, liveDocs);
            pruningCollector.postCollect();
            Releasables.close((Releasable)pruningCollector);
        }
        catch (Exception e) {
            throw new OpenSearchStatusException("Failed when performing dynamic pruning in cardinality aggregation. You can set cluster setting [" + SearchService.CARDINALITY_AGGREGATION_PRUNING_THRESHOLD.getKey() + "] to 0 to disable.", RestStatus.INTERNAL_SERVER_ERROR, e, new Object[0]);
        }
        return true;
    }

    private Collector getNoOpCollector() {
        return new Collector(){

            public void close() {
            }

            @Override
            public void postCollect() throws IOException {
            }

            @Override
            public void collect(int doc, long owningBucketOrd) throws IOException {
                throw new CollectionTerminatedException();
            }
        };
    }

    @Override
    public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
        this.postCollectLastCollector();
        this.collector = this.pickCollector(ctx);
        return this.collector;
    }

    private void postCollectLastCollector() throws IOException {
        if (this.collector != null) {
            try {
                this.collector.postCollect();
            }
            finally {
                this.collector.close();
                this.collector = null;
            }
        }
    }

    @Override
    protected void doPostCollection() throws IOException {
        this.postCollectLastCollector();
    }

    @Override
    public double metric(long owningBucketOrd) {
        return this.counts == null ? 0.0 : (double)this.counts.cardinality(owningBucketOrd);
    }

    @Override
    public InternalAggregation buildAggregation(long owningBucketOrdinal) {
        if (this.counts == null || owningBucketOrdinal >= this.counts.maxOrd() || this.counts.cardinality(owningBucketOrdinal) == 0L) {
            return this.buildEmptyAggregation();
        }
        AbstractHyperLogLogPlusPlus copy = this.counts.clone(owningBucketOrdinal, BigArrays.NON_RECYCLING_INSTANCE);
        return new InternalCardinality(this.name, copy, this.metadata());
    }

    @Override
    public InternalAggregation buildEmptyAggregation() {
        return new InternalCardinality(this.name, null, this.metadata());
    }

    @Override
    protected void doClose() {
        Releasables.close((Releasable[])new Releasable[]{this.counts, this.collector});
    }

    @Override
    public void collectDebugInfo(BiConsumer<String, Object> add) {
        super.collectDebugInfo(add);
        add.accept("empty_collectors_used", this.emptyCollectorsUsed);
        add.accept("numeric_collectors_used", this.numericCollectorsUsed);
        add.accept("ordinals_collectors_used", this.ordinalsCollectorsUsed);
        add.accept("ordinals_collectors_overhead_too_high", this.ordinalsCollectorsOverheadTooHigh);
        add.accept("string_hashing_collectors_used", this.stringHashingCollectorsUsed);
        add.accept("dynamic_pruned_segments", this.dynamicPrunedSegments);
    }

    static abstract class MurmurHash3Values {
        MurmurHash3Values() {
        }

        public abstract boolean advanceExact(int var1) throws IOException;

        public abstract int count();

        public abstract long nextValue() throws IOException;

        public static MurmurHash3Values hash(SortedNumericDoubleValues values) {
            return new Double(values);
        }

        public static MurmurHash3Values hash(SortedNumericDocValues values) {
            return new Long(values);
        }

        public static MurmurHash3Values hash(SortedBinaryDocValues values) {
            return new Bytes(values);
        }

        private static class Bytes
        extends MurmurHash3Values {
            private final MurmurHash3.Hash128 hash = new MurmurHash3.Hash128();
            private final SortedBinaryDocValues values;

            Bytes(SortedBinaryDocValues values) {
                this.values = values;
            }

            @Override
            public boolean advanceExact(int docId) throws IOException {
                return this.values.advanceExact(docId);
            }

            @Override
            public int count() {
                return this.values.docValueCount();
            }

            @Override
            public long nextValue() throws IOException {
                BytesRef bytes = this.values.nextValue();
                MurmurHash3.hash128(bytes.bytes, bytes.offset, bytes.length, 0L, this.hash);
                return this.hash.h1;
            }
        }

        private static class Double
        extends MurmurHash3Values {
            private final SortedNumericDoubleValues values;

            Double(SortedNumericDoubleValues values) {
                this.values = values;
            }

            @Override
            public boolean advanceExact(int docId) throws IOException {
                return this.values.advanceExact(docId);
            }

            @Override
            public int count() {
                return this.values.docValueCount();
            }

            @Override
            public long nextValue() throws IOException {
                return BitMixer.mix64((long)java.lang.Double.doubleToLongBits(this.values.nextValue()));
            }
        }

        private static class Long
        extends MurmurHash3Values {
            private final SortedNumericDocValues values;

            Long(SortedNumericDocValues values) {
                this.values = values;
            }

            @Override
            public boolean advanceExact(int docId) throws IOException {
                return this.values.advanceExact(docId);
            }

            @Override
            public int count() {
                return this.values.docValueCount();
            }

            @Override
            public long nextValue() throws IOException {
                return BitMixer.mix64((long)this.values.nextValue());
            }
        }
    }

    private static class OrdinalsCollector
    extends Collector {
        private static final long SHALLOW_FIXEDBITSET_SIZE = RamUsageEstimator.shallowSizeOfInstance(FixedBitSet.class);
        private final BigArrays bigArrays;
        private final SortedSetDocValues values;
        private final int maxOrd;
        private final HyperLogLogPlusPlus counts;
        private ObjectArray<BitArray> visitedOrds;

        public static long memoryOverhead(long maxOrd) {
            return (long)RamUsageEstimator.NUM_BYTES_OBJECT_REF + SHALLOW_FIXEDBITSET_SIZE + (maxOrd + 7L) / 8L;
        }

        OrdinalsCollector(HyperLogLogPlusPlus counts, SortedSetDocValues values, BigArrays bigArrays) {
            if (values.getValueCount() > Integer.MAX_VALUE) {
                throw new IllegalArgumentException();
            }
            this.maxOrd = (int)values.getValueCount();
            this.bigArrays = bigArrays;
            this.counts = counts;
            this.values = values;
            this.visitedOrds = bigArrays.newObjectArray(1L);
        }

        @Override
        public void collect(int doc, long bucketOrd) throws IOException {
            this.visitedOrds = this.bigArrays.grow(this.visitedOrds, bucketOrd + 1L);
            BitArray bits = this.visitedOrds.get(bucketOrd);
            if (bits == null) {
                bits = new BitArray(this.maxOrd, this.bigArrays);
                this.visitedOrds.set(bucketOrd, bits);
            }
            if (this.values.advanceExact(doc)) {
                long ord = this.values.nextOrd();
                while (ord != -1L) {
                    bits.set((int)ord);
                    ord = this.values.nextOrd();
                }
            }
        }

        @Override
        public void postCollect() throws IOException {
            try (BitArray allVisitedOrds = new BitArray(this.maxOrd, this.bigArrays);){
                for (long bucket = this.visitedOrds.size() - 1L; bucket >= 0L; --bucket) {
                    BitArray bits = this.visitedOrds.get(bucket);
                    if (bits == null) continue;
                    allVisitedOrds.or(bits);
                }
                try (LongArray hashes = this.bigArrays.newLongArray(this.maxOrd, false);){
                    MurmurHash3.Hash128 hash = new MurmurHash3.Hash128();
                    long ord = allVisitedOrds.nextSetBit(0L);
                    while (ord < Long.MAX_VALUE) {
                        BytesRef value = this.values.lookupOrd(ord);
                        MurmurHash3.hash128(value.bytes, value.offset, value.length, 0L, hash);
                        hashes.set(ord, hash.h1);
                        ord = ord + 1L < (long)this.maxOrd ? allVisitedOrds.nextSetBit(ord + 1L) : Long.MAX_VALUE;
                    }
                    for (long bucket = this.visitedOrds.size() - 1L; bucket >= 0L; --bucket) {
                        BitArray bits = this.visitedOrds.get(bucket);
                        if (bits == null) continue;
                        long ord2 = bits.nextSetBit(0L);
                        while (ord2 < Long.MAX_VALUE) {
                            this.counts.collect(bucket, hashes.get(ord2));
                            ord2 = ord2 + 1L < (long)this.maxOrd ? bits.nextSetBit(ord2 + 1L) : Long.MAX_VALUE;
                        }
                    }
                }
            }
        }

        public void close() {
            int i = 0;
            while ((long)i < this.visitedOrds.size()) {
                Releasables.close((Releasable)this.visitedOrds.get(i));
                ++i;
            }
            Releasables.close(this.visitedOrds);
        }
    }

    private static class DirectCollector
    extends Collector {
        private final MurmurHash3Values hashes;
        private final HyperLogLogPlusPlus counts;

        DirectCollector(HyperLogLogPlusPlus counts, MurmurHash3Values values) {
            this.counts = counts;
            this.hashes = values;
        }

        @Override
        public void collect(int doc, long bucketOrd) throws IOException {
            if (this.hashes.advanceExact(doc)) {
                int valueCount = this.hashes.count();
                for (int i = 0; i < valueCount; ++i) {
                    this.counts.collect(bucketOrd, this.hashes.nextValue());
                }
            }
        }

        @Override
        public void postCollect() {
        }

        public void close() {
        }
    }

    private static class EmptyCollector
    extends Collector {
        private EmptyCollector() {
        }

        @Override
        public void collect(int doc, long bucketOrd) {
        }

        @Override
        public void postCollect() {
        }

        public void close() {
        }
    }

    private static class DisjunctionDISI
    extends DocIdSetIterator {
        private final DisiPriorityQueue queue;
        private int slowDocId = -1;

        public DisjunctionDISI(DisiPriorityQueue queue) {
            this.queue = queue;
        }

        public int docID() {
            return this.slowDocId;
        }

        public int advance(int target) throws IOException {
            DisiWrapper top = this.queue.top();
            if (top == null) {
                this.slowDocId = Integer.MAX_VALUE;
                return Integer.MAX_VALUE;
            }
            if (top.doc >= target) {
                this.slowDocId = top.doc;
                return top.doc;
            }
            do {
                top.doc = top.approximation.advance(target);
                top = this.queue.updateTop();
            } while (top.doc < target);
            this.slowDocId = this.queue.size() == 0 ? Integer.MAX_VALUE : this.queue.top().doc;
            return this.slowDocId;
        }

        public int nextDoc() {
            throw new UnsupportedOperationException();
        }

        public long cost() {
            throw new UnsupportedOperationException();
        }
    }

    private static class PruningCollector
    extends Collector {
        private final Collector delegate;
        private final DisiPriorityQueue queue;
        private final DocIdSetIterator competitiveIterator;

        PruningCollector(Collector delegate, TermsEnum terms, LeafReaderContext ctx, SearchContext context, String field) throws IOException {
            this.delegate = delegate;
            HashMap<BytesRef, Scorer> postingMap = new HashMap<BytesRef, Scorer>();
            while (terms.next() != null) {
                BytesRef term = terms.term();
                TermQuery termQuery = new TermQuery(new Term(field, term));
                Weight subWeight = termQuery.createWeight((IndexSearcher)context.searcher(), ScoreMode.COMPLETE_NO_SCORES, 1.0f);
                Scorer scorer = subWeight.scorer(ctx);
                postingMap.put(term, scorer);
            }
            this.queue = new DisiPriorityQueue(postingMap.size());
            for (Scorer scorer : postingMap.values()) {
                this.queue.add(new DisiWrapper(scorer));
            }
            this.competitiveIterator = new DisjunctionDISI(this.queue);
        }

        public void close() {
            this.delegate.close();
        }

        @Override
        public void collect(int doc, long owningBucketOrd) throws IOException {
            this.delegate.collect(doc, owningBucketOrd);
            this.prune(doc);
        }

        private void prune(int doc) {
            DisiWrapper top = this.queue.top();
            int curTopDoc = top.doc;
            if (curTopDoc == doc) {
                do {
                    this.queue.pop();
                    top = this.queue.updateTop();
                } while (this.queue.size() > 1 && top.doc == curTopDoc);
            }
        }

        public DocIdSetIterator competitiveIterator() {
            return this.competitiveIterator;
        }

        @Override
        public void postCollect() throws IOException {
            this.delegate.postCollect();
        }
    }

    private static abstract class Collector
    extends LeafBucketCollector
    implements Releasable {
        private Collector() {
        }

        public abstract void postCollect() throws IOException;
    }
}

