/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.util.hnsw;

import java.io.IOException;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.SparseFixedBitSet;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.apache.lucene.util.hnsw.RandomVectorScorer;

public class FilteredHnswGraphSearcher
extends HnswGraphSearcher {
    private static final float EXPANDED_EXPLORATION_LAMBDA = 0.1f;
    private final int maxExplorationMultiplier;
    private final int minToScore;

    private FilteredHnswGraphSearcher(NeighborQueue candidates, BitSet visited, int filterSize, HnswGraph graph) {
        super(candidates, visited);
        assert (graph.maxConn() > 0) : "graph must have known max connections";
        float filterRatio = (float)filterSize / (float)graph.size();
        this.maxExplorationMultiplier = (int)Math.round(Math.min((double)(1.0f / filterRatio), (double)graph.maxConn() / 2.0));
        this.minToScore = (int)Math.round(Math.min(Math.max(0.0, 1.0 / (double)filterRatio - 2.0 * (double)graph.maxConn()), (double)graph.maxConn()));
    }

    public static FilteredHnswGraphSearcher create(int k, HnswGraph graph, int filterSize, Bits acceptOrds) {
        if (acceptOrds == null) {
            throw new IllegalArgumentException("acceptOrds must not be null to used filtered search");
        }
        if (filterSize <= 0 || filterSize >= FilteredHnswGraphSearcher.getGraphSize(graph)) {
            throw new IllegalArgumentException("filterSize must be > 0 and < graph size");
        }
        return new FilteredHnswGraphSearcher(new NeighborQueue(k, true), FilteredHnswGraphSearcher.bitSet(filterSize, FilteredHnswGraphSearcher.getGraphSize(graph), k), filterSize, graph);
    }

    private static BitSet bitSet(long filterSize, int graphSize, int topk) {
        float percentFiltered = (float)filterSize / (float)graphSize;
        assert (percentFiltered > 0.0f && percentFiltered < 1.0f);
        double totalOps = Math.log(graphSize) * (double)topk;
        int approximateVisitation = (int)(totalOps / (double)percentFiltered);
        return FilteredHnswGraphSearcher.bitSet(approximateVisitation, graphSize);
    }

    private static BitSet bitSet(int expectedBits, int totalBits) {
        if (expectedBits < totalBits >>> 7) {
            return new SparseFixedBitSet(totalBits);
        }
        return new FixedBitSet(totalBits);
    }

    @Override
    void searchLevel(KnnCollector results, RandomVectorScorer scorer, int level, int[] eps, HnswGraph graph, Bits acceptOrds) throws IOException {
        float topCandidateSimilarity;
        assert (level == 0) : "Filtered search only works on the base level";
        int size = FilteredHnswGraphSearcher.getGraphSize(graph);
        this.prepareScratchState();
        for (int ep : eps) {
            if (this.visited.getAndSet(ep)) continue;
            if (results.earlyTerminated()) {
                return;
            }
            float score = scorer.score(ep);
            results.incVisitedCount(1);
            this.candidates.add(ep, score);
            if (!acceptOrds.get(ep)) continue;
            results.collect(ep, score);
        }
        IntArrayQueue toScore = new IntArrayQueue(graph.maxConn() * 2 * this.maxExplorationMultiplier);
        IntArrayQueue toExplore = new IntArrayQueue(graph.maxConn() * 2 * this.maxExplorationMultiplier);
        float minAcceptedSimilarity = Math.nextUp(results.minCompetitiveSimilarity());
        while (this.candidates.size() > 0 && !results.earlyTerminated() && !(minAcceptedSimilarity > (topCandidateSimilarity = this.candidates.topScore()))) {
            int toScoreOrd;
            int friendOrd;
            int topCandidateNode = this.candidates.pop();
            graph.seek(level, topCandidateNode);
            int neighborCount = graph.neighborCount();
            toScore.clear();
            toExplore.clear();
            while ((friendOrd = graph.nextNeighbor()) != Integer.MAX_VALUE && !toScore.isFull()) {
                assert (friendOrd < size) : "friendOrd=" + friendOrd + "; size=" + size;
                if (this.visited.getAndSet(friendOrd)) continue;
                if (acceptOrds.get(friendOrd)) {
                    toScore.add(friendOrd);
                    continue;
                }
                toExplore.add(friendOrd);
            }
            float filteredAmount = (float)toExplore.count() / (float)neighborCount;
            int maxToScoreCount = (int)((float)neighborCount * Math.min((float)this.maxExplorationMultiplier, 1.0f / (1.0f - filteredAmount)));
            int maxAdditionalToExploreCount = toExplore.capacity() - 1;
            int totalExplored = toScore.count() + toExplore.count();
            if (toScore.count() < maxToScoreCount && filteredAmount > 0.1f) {
                int exploreFriend;
                while ((exploreFriend = toExplore.poll()) != Integer.MAX_VALUE && totalExplored < maxAdditionalToExploreCount && toScore.count() < maxToScoreCount) {
                    int friendOfAFriendOrd;
                    this.graphSeek(graph, level, exploreFriend);
                    while ((friendOfAFriendOrd = graph.nextNeighbor()) != Integer.MAX_VALUE && toScore.count() < maxToScoreCount) {
                        if (this.visited.getAndSet(friendOfAFriendOrd)) continue;
                        ++totalExplored;
                        if (acceptOrds.get(friendOfAFriendOrd)) {
                            toScore.add(friendOfAFriendOrd);
                            continue;
                        }
                        if (totalExplored >= maxAdditionalToExploreCount || toScore.count() >= this.minToScore) continue;
                        toExplore.add(friendOfAFriendOrd);
                    }
                }
            }
            while ((toScoreOrd = toScore.poll()) != Integer.MAX_VALUE) {
                float friendSimilarity = scorer.score(toScoreOrd);
                results.incVisitedCount(1);
                if (!(friendSimilarity > minAcceptedSimilarity)) continue;
                this.candidates.add(toScoreOrd, friendSimilarity);
                if (!results.collect(toScoreOrd, friendSimilarity)) continue;
                minAcceptedSimilarity = Math.nextUp(results.minCompetitiveSimilarity());
            }
            if (results.getSearchStrategy() == null) continue;
            results.getSearchStrategy().nextVectorsBlock();
        }
    }

    private void prepareScratchState() {
        this.candidates.clear();
        this.visited.clear();
    }

    private static class IntArrayQueue {
        private int[] nodes;
        private int upto;
        private int size;

        IntArrayQueue(int capacity) {
            this.nodes = new int[capacity];
        }

        int capacity() {
            return this.nodes.length;
        }

        int count() {
            return this.size - this.upto;
        }

        void add(int node) {
            if (this.isFull()) {
                throw new UnsupportedOperationException("Initial capacity should remain unchanged");
            }
            this.nodes[this.size++] = node;
        }

        boolean isFull() {
            return this.size == this.nodes.length;
        }

        int poll() {
            if (this.upto == this.size) {
                return Integer.MAX_VALUE;
            }
            return this.nodes[this.upto++];
        }

        void clear() {
            this.upto = 0;
            this.size = 0;
        }
    }
}

