/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.transform.encode;

import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.concurrent.Callable;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.runtime.compress.estim.sample.SampleEstimatorFactory;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.data.SparseBlockCSR;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.transform.encode.ColumnEncoder;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderRecode;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.utils.stats.TransformStatistics;

public class ColumnEncoderBagOfWords
extends ColumnEncoder {
    public static int NUM_SAMPLES_MAP_ESTIMATION = 16000;
    private Map<Object, Integer> _tokenDictionary;
    private HashSet<Object> _tokenDictionaryPart = null;
    protected String _seperatorRegex = "\\s+";
    protected boolean _caseSensitive = false;
    protected int[] _nnzPerRow;
    protected long _nnz = 0L;
    protected long[] _nnzPartials;
    protected int _defaultNnzCapacity = 64;
    protected double _avgNnzPerRow = 1.0;

    protected ColumnEncoderBagOfWords(int colID) {
        super(colID);
    }

    public ColumnEncoderBagOfWords() {
        super(-1);
    }

    public ColumnEncoderBagOfWords(ColumnEncoderBagOfWords enc) {
        super(enc._colID);
        this._nnzPerRow = enc._nnzPerRow != null ? (int[])enc._nnzPerRow.clone() : null;
        this._tokenDictionary = enc._tokenDictionary;
        this._seperatorRegex = enc._seperatorRegex;
        this._caseSensitive = enc._caseSensitive;
    }

    public void setTokenDictionary(HashMap<Object, Integer> dict) {
        this._tokenDictionary = dict;
    }

    public Map<Object, Integer> getTokenDictionary() {
        return this._tokenDictionary;
    }

    protected void initNnzPartials(int rows, int numBlocks) {
        this._nnzPerRow = new int[rows];
        this._nnzPartials = new long[numBlocks];
    }

    public double computeNnzEstimate(CacheBlock<?> in, int[] sampleIndices) {
        int max_index = Math.min(NUM_SAMPLES_MAP_ESTIMATION, sampleIndices.length);
        int nnz = 0;
        for (int i = 0; i < max_index; ++i) {
            int sind = sampleIndices[i];
            String current = in.getString(sind, this._colID - 1);
            if (current == null) continue;
            for (String token : ColumnEncoderBagOfWords.tokenize(current, this._caseSensitive, this._seperatorRegex)) {
                if (token.isEmpty() || !this._tokenDictionary.containsKey(token)) continue;
                ++nnz;
            }
        }
        return (double)nnz / (double)max_index;
    }

    @Override
    public void computeMapSizeEstimate(CacheBlock<?> in, int[] sampleIndices) {
        HashMap<String, Integer> distinctFreq = new HashMap<String, Integer>();
        long totSize = 0L;
        int max_index = Math.min(NUM_SAMPLES_MAP_ESTIMATION, sampleIndices.length / 3);
        int numTokensSample = 0;
        int[] nnzPerRow = new int[max_index];
        for (int i = 0; i < max_index; ++i) {
            int sind = sampleIndices[i];
            String current = in.getString(sind, this._colID - 1);
            HashSet<String> tokenSetRow = new HashSet<String>();
            if (current != null) {
                for (String token : ColumnEncoderBagOfWords.tokenize(current, this._caseSensitive, this._seperatorRegex)) {
                    if (token.isEmpty()) continue;
                    tokenSetRow.add(token);
                    if (distinctFreq.containsKey(token)) {
                        distinctFreq.put(token, (Integer)distinctFreq.get(token) + 1);
                    } else {
                        distinctFreq.put(token, 1);
                        totSize += (long)token.length() * 2L + 16L;
                    }
                    ++numTokensSample;
                }
            }
            nnzPerRow[i] = tokenSetRow.size();
        }
        Arrays.sort(nnzPerRow);
        this._avgNnzPerRow = (double)Arrays.stream(nnzPerRow).sum() / (double)nnzPerRow.length;
        this._defaultNnzCapacity = (int)Math.max((double)nnzPerRow[(int)((double)nnzPerRow.length * 0.75)] / 0.9, 64.0);
        double avgSentenceLength = (double)numTokensSample * 1.2 / (double)max_index;
        int[] freq = distinctFreq.values().stream().mapToInt(v -> v).toArray();
        this._estNumDistincts = SampleEstimatorFactory.distinctCount(freq, (int)(avgSentenceLength * (double)in.getNumRows()), numTokensSample, SampleEstimatorFactory.EstimationType.HassAndStokes);
        this._estNumDistincts = (int)((double)this._estNumDistincts * 1.2);
        long avgKeySize = totSize / (long)distinctFreq.size();
        long valSize = 16L;
        this._avgEntrySize = avgKeySize + valSize;
        this._estMetaSize = (long)this._estNumDistincts * this._avgEntrySize;
    }

    public void computeNnzPerRow(CacheBlock<?> in, int start, int end) {
        for (int i = start; i < end; ++i) {
            String current = in.getString(i, this._colID - 1);
            HashSet<String> distinctTokens = new HashSet<String>();
            if (current != null) {
                for (String token : ColumnEncoderBagOfWords.tokenize(current, this._caseSensitive, this._seperatorRegex)) {
                    if (token.isEmpty() || !this._tokenDictionary.containsKey(token)) continue;
                    distinctTokens.add(token);
                }
            }
            this._nnzPerRow[i] = distinctTokens.size();
        }
    }

    public static String[] tokenize(String current, boolean caseSensitive, String seperatorRegex) {
        StringBuilder finalString = new StringBuilder();
        for (char c : current.toCharArray()) {
            if (Character.isLetter(c)) {
                finalString.append(caseSensitive ? c : Character.toLowerCase(c));
                continue;
            }
            finalString.append(' ');
        }
        return finalString.toString().split(seperatorRegex);
    }

    @Override
    public int getDomainSize() {
        return this._tokenDictionary.size();
    }

    @Override
    protected double getCode(CacheBlock<?> in, int row) {
        throw new NotImplementedException();
    }

    @Override
    protected double[] getCodeCol(CacheBlock<?> in, int startInd, int rowEnd, double[] tmp) {
        throw new NotImplementedException();
    }

    @Override
    protected ColumnEncoder.TransformType getTransformType() {
        return ColumnEncoder.TransformType.BAG_OF_WORDS;
    }

    @Override
    public Callable<Object> getBuildTask(CacheBlock<?> in) {
        return new ColumnBagOfWordsBuildTask(this, in);
    }

    @Override
    public void build(CacheBlock<?> in) {
        long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        this._tokenDictionary = new HashMap<Object, Integer>(this._estNumDistincts);
        int i = 1;
        this._nnz = 0L;
        this._nnzPerRow = new int[in.getNumRows()];
        for (int r = 0; r < in.getNumRows(); ++r) {
            HashSet<String> tokenSetPerRow = new HashSet<String>(this._defaultNnzCapacity);
            String current = in.getString(r, this._colID - 1);
            if (current != null) {
                for (String token : ColumnEncoderBagOfWords.tokenize(current, this._caseSensitive, this._seperatorRegex)) {
                    if (token.isEmpty()) continue;
                    tokenSetPerRow.add(token);
                    if (this._tokenDictionary.containsKey(token)) continue;
                    this._tokenDictionary.put(token, i++);
                }
            }
            this._nnzPerRow[r] = tokenSetPerRow.size();
            this._nnz += (long)tokenSetPerRow.size();
        }
        if (DMLScript.STATISTICS) {
            TransformStatistics.incBagOfWordsBuildTime(System.nanoTime() - t0);
        }
    }

    @Override
    public Callable<Object> getPartialBuildTask(CacheBlock<?> in, int startRow, int blockSize, HashMap<Integer, Object> ret, int pos) {
        return new BowPartialBuildTask(in, this._colID, startRow, blockSize, ret, this._nnzPerRow, this._caseSensitive, this._seperatorRegex, this._nnzPartials, pos);
    }

    @Override
    public Callable<Object> getPartialMergeBuildTask(HashMap<Integer, ?> ret) {
        this._tokenDictionary = new HashMap<Object, Integer>(this._estNumDistincts);
        return new BowMergePartialBuildTask(this, ret);
    }

    @Override
    public void prepareBuildPartial() {
        if (this._tokenDictionaryPart == null) {
            this._tokenDictionaryPart = new HashSet();
        }
    }

    public HashSet<Object> getPartialTokenDictionary() {
        return this._tokenDictionaryPart;
    }

    @Override
    public void buildPartial(FrameBlock in) {
        if (!this.isApplicable()) {
            return;
        }
        for (int r = 0; r < in.getNumRows(); ++r) {
            String current = in.getString(r, this._colID - 1);
            if (current == null) continue;
            for (String token : ColumnEncoderBagOfWords.tokenize(current, this._caseSensitive, this._seperatorRegex)) {
                if (token.isEmpty()) continue;
                this._tokenDictionaryPart.add(token);
            }
        }
    }

    @Override
    protected void applySparse(CacheBlock<?> in, MatrixBlock out, int outputCol, int rowStart, int blk) {
        boolean mcsr = MatrixBlock.DEFAULT_SPARSEBLOCK == SparseBlock.Type.MCSR;
        mcsr = false;
        ArrayList<Integer> sparseRowsWZeros = new ArrayList<Integer>();
        for (int r = rowStart; r < UtilFunctions.getEndIndex(in.getNumRows(), rowStart, blk); ++r) {
            if (mcsr) {
                throw new NotImplementedException();
            }
            HashMap<String, Integer> counter = this.countTokenAppearances(in, r);
            if (counter.isEmpty()) {
                sparseRowsWZeros.add(r);
                continue;
            }
            SparseBlockCSR csrblock = (SparseBlockCSR)out.getSparseBlock();
            int[] rptr = csrblock.rowPointers();
            Pair[] columnValuePairs = new Pair[this._nnzPerRow[r]];
            int i = 0;
            for (Map.Entry<String, Integer> entry : counter.entrySet()) {
                String token = entry.getKey();
                columnValuePairs[i] = new Pair(outputCol + this._tokenDictionary.getOrDefault(token, 0) - 1, entry.getValue());
                i += this._tokenDictionary.containsKey(token) ? 1 : 0;
            }
            if (columnValuePairs.length >= 128) {
                Arrays.sort(columnValuePairs, Comparator.comparingInt(pair -> pair.key));
            } else {
                ColumnEncoderBagOfWords.insertionSort(columnValuePairs);
            }
            for (i = 0; i < columnValuePairs.length; ++i) {
                int index = this.sparseRowPointerOffset != null ? this.sparseRowPointerOffset[r] - 1 + i : i;
                csrblock.indexes()[index += rptr[r] + this._colID - 1] = columnValuePairs[i].key;
                csrblock.values()[index] = columnValuePairs[i].value;
            }
        }
        if (!sparseRowsWZeros.isEmpty()) {
            this.addSparseRowsWZeros(sparseRowsWZeros);
        }
    }

    private static void insertionSort(Pair[] arr) {
        for (int i = 1; i < arr.length; ++i) {
            Pair current = arr[i];
            for (int j = i - 1; j >= 0 && arr[j].key > current.key; --j) {
                arr[j + 1] = arr[j];
            }
            arr[j + 1] = current;
        }
    }

    @Override
    protected void applyDense(CacheBlock<?> in, MatrixBlock out, int outputCol, int rowStart, int blk) {
        for (int r = rowStart; r < Math.max(in.getNumRows(), rowStart + blk); ++r) {
            HashMap<String, Integer> counter = this.countTokenAppearances(in, r);
            for (Map.Entry<String, Integer> entry : counter.entrySet()) {
                out.set(r, outputCol + this._tokenDictionary.get(entry.getKey()) - 1, entry.getValue().intValue());
            }
        }
    }

    private HashMap<String, Integer> countTokenAppearances(CacheBlock<?> in, int r) {
        String current = in.getString(r, this._colID - 1);
        HashMap<String, Integer> counter = new HashMap<String, Integer>();
        if (current != null) {
            for (String token : ColumnEncoderBagOfWords.tokenize(current, this._caseSensitive, this._seperatorRegex)) {
                if (token.isEmpty() || !this._tokenDictionary.containsKey(token)) continue;
                counter.put(token, counter.getOrDefault(token, 0) + 1);
            }
        }
        return counter;
    }

    @Override
    public void allocateMetaData(FrameBlock meta) {
        meta.ensureAllocatedColumns(this.getDomainSize());
    }

    @Override
    public FrameBlock getMetaData(FrameBlock out) {
        int rowID = 0;
        StringBuilder sb = new StringBuilder();
        for (Map.Entry<Object, Integer> e : this._tokenDictionary.entrySet()) {
            out.set(rowID++, this._colID - 1, ColumnEncoderRecode.constructRecodeMapEntry(e.getKey(), e.getValue(), sb));
        }
        return out;
    }

    @Override
    public void initMetaData(FrameBlock meta) {
        if (meta != null && meta.getNumRows() > 0) {
            this._tokenDictionary = meta.getRecodeMap(this._colID - 1);
        }
    }

    @Override
    public void writeExternal(ObjectOutput out) throws IOException {
        super.writeExternal(out);
        out.writeInt(this._tokenDictionary == null ? 0 : this._tokenDictionary.size());
        if (this._tokenDictionary != null) {
            for (Map.Entry<Object, Integer> e : this._tokenDictionary.entrySet()) {
                System.out.println(e);
                out.writeUTF((String)e.getKey());
                out.writeInt(e.getValue());
            }
        }
    }

    @Override
    public void readExternal(ObjectInput in) throws IOException {
        super.readExternal(in);
        int size = in.readInt();
        this._tokenDictionary = new HashMap<Object, Integer>(size * 4 / 3);
        for (int j = 0; j < size; ++j) {
            String key = in.readUTF();
            Integer value = in.readInt();
            this._tokenDictionary.put(key, value);
        }
    }

    private static class ColumnBagOfWordsBuildTask
    implements Callable<Object> {
        private final ColumnEncoderBagOfWords _encoder;
        private final CacheBlock<?> _input;

        protected ColumnBagOfWordsBuildTask(ColumnEncoderBagOfWords encoder, CacheBlock<?> input) {
            this._encoder = encoder;
            this._input = input;
        }

        @Override
        public Void call() {
            this._encoder.build(this._input);
            return null;
        }

        public String toString() {
            return this.getClass().getSimpleName() + "<ColId: " + this._encoder._colID + ">";
        }
    }

    private static class BowMergePartialBuildTask
    implements Callable<Object> {
        private final HashMap<Integer, ?> _partialMaps;
        private final ColumnEncoderBagOfWords _encoder;

        private BowMergePartialBuildTask(ColumnEncoderBagOfWords encoderRecode, HashMap<Integer, ?> partialMaps) {
            this._partialMaps = partialMaps;
            this._encoder = encoderRecode;
        }

        @Override
        public Object call() {
            long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0L;
            Map<Object, Integer> tokenDictionary = this._encoder._tokenDictionary;
            for (Object tokenSet : this._partialMaps.values()) {
                ((HashSet)tokenSet).forEach(token -> {
                    if (!tokenDictionary.containsKey(token)) {
                        tokenDictionary.put(token, tokenDictionary.size() + 1);
                    }
                });
            }
            for (Object nnzPartial : (Object)this._encoder._nnzPartials) {
                this._encoder._nnz += nnzPartial;
            }
            if (DMLScript.STATISTICS) {
                TransformStatistics.incBagOfWordsBuildTime(System.nanoTime() - t0);
            }
            return null;
        }

        public String toString() {
            return this.getClass().getSimpleName() + "<ColId: " + this._encoder._colID + ">";
        }
    }

    private static class BowPartialBuildTask
    implements Callable<Object> {
        private final CacheBlock<?> _input;
        private final int _blockSize;
        private final int _startRow;
        private final int _colID;
        private final boolean _caseSensitive;
        private final String _seperator;
        private final HashMap<Integer, Object> _partialMaps;
        private final int[] _nnzPerRow;
        private final long[] _nnzPartials;
        private final int _pos;

        protected BowPartialBuildTask(CacheBlock<?> input, int colID, int startRow, int blocksize, HashMap<Integer, Object> partialMaps, int[] nnzPerRow, boolean caseSensitive, String seperator, long[] nnzPartials, int pos) {
            this._input = input;
            this._blockSize = blocksize;
            this._colID = colID;
            this._startRow = startRow;
            this._partialMaps = partialMaps;
            this._caseSensitive = caseSensitive;
            this._seperator = seperator;
            this._nnzPerRow = nnzPerRow;
            this._nnzPartials = nnzPartials;
            this._pos = pos;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public Object call() {
            long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0L;
            int endRow = UtilFunctions.getEndIndex(this._input.getNumRows(), this._startRow, this._blockSize);
            HashSet<String> tokenSetPartial = new HashSet<String>();
            long nnzPartial = 0L;
            for (int r = this._startRow; r < endRow; ++r) {
                HashSet<String> tokenSetPerRow = new HashSet<String>(64);
                String current = this._input.getString(r, this._colID - 1);
                if (current != null) {
                    for (String token : ColumnEncoderBagOfWords.tokenize(current, this._caseSensitive, this._seperator)) {
                        if (token.isEmpty()) continue;
                        tokenSetPerRow.add(token);
                        tokenSetPartial.add(token);
                    }
                }
                this._nnzPerRow[r] = tokenSetPerRow.size();
                nnzPartial += (long)tokenSetPerRow.size();
            }
            this._nnzPartials[this._pos] = nnzPartial;
            HashMap<Integer, Object> hashMap = this._partialMaps;
            synchronized (hashMap) {
                this._partialMaps.put(this._startRow, tokenSetPartial);
            }
            if (DMLScript.STATISTICS) {
                TransformStatistics.incBagOfWordsBuildTime(System.nanoTime() - t0);
            }
            return null;
        }

        public String toString() {
            return this.getClass().getSimpleName() + "<Start row: " + this._startRow + "; Block size: " + this._blockSize + ">";
        }
    }

    static class Pair {
        int key;
        int value;

        Pair(int key, int value) {
            this.key = key;
            this.value = value;
        }
    }
}

