/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.transform.encode;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC;
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.frame.data.columns.ACompressedArray;
import org.apache.sysds.runtime.frame.data.columns.Array;
import org.apache.sysds.runtime.frame.data.columns.DDCArray;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.transform.encode.ColumnEncoder;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderBin;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderComposite;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderFeatureHash;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderRecode;
import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder;
import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.UtilFunctions;

public class CompressedEncode {
    protected static final Log LOG = LogFactory.getLog((String)CompressedEncode.class.getName());
    private final MultiColumnEncoder enc;
    private final FrameBlock in;
    private final int k;

    private CompressedEncode(MultiColumnEncoder enc, FrameBlock in, int k) {
        this.enc = enc;
        this.in = in;
        this.k = k;
    }

    public static MatrixBlock encode(MultiColumnEncoder enc, FrameBlock in, int k) throws InterruptedException, ExecutionException {
        return new CompressedEncode(enc, in, k).apply();
    }

    private MatrixBlock apply() throws InterruptedException, ExecutionException {
        List<ColumnEncoderComposite> encoders = this.enc.getColumnEncoders();
        List<AColGroup> groups = this.isParallel() ? this.multiThread(encoders) : this.singleThread(encoders);
        int cols = this.shiftGroups(groups);
        CompressedMatrixBlock mb = new CompressedMatrixBlock(this.in.getNumRows(), cols, -1L, false, groups);
        ((MatrixBlock)mb).recomputeNonZeros();
        this.logging(mb);
        return mb;
    }

    private boolean isParallel() {
        return this.k > 1 && this.enc.getEncoders().size() > 1;
    }

    private List<AColGroup> singleThread(List<ColumnEncoderComposite> encoders) {
        ArrayList<AColGroup> groups = new ArrayList<AColGroup>(encoders.size());
        for (ColumnEncoderComposite c : encoders) {
            groups.add(this.encode(c));
        }
        return groups;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private List<AColGroup> multiThread(List<ColumnEncoderComposite> encoders) throws InterruptedException, ExecutionException {
        ExecutorService pool = CommonThreadPool.get(this.k);
        try {
            ArrayList<EncodeTask> tasks = new ArrayList<EncodeTask>(encoders.size());
            for (ColumnEncoderComposite c : encoders) {
                tasks.add(new EncodeTask(c));
            }
            ArrayList<AColGroup> groups = new ArrayList<AColGroup>(encoders.size());
            for (Future t : pool.invokeAll(tasks)) {
                groups.add((AColGroup)t.get());
            }
            ArrayList<AColGroup> arrayList = groups;
            return arrayList;
        }
        finally {
            pool.shutdown();
        }
    }

    private int shiftGroups(List<AColGroup> groups) {
        int cols = groups.get(0).getColIndices().size();
        for (int i = 1; i < groups.size(); ++i) {
            groups.set(i, groups.get(i).shiftColIndices(cols));
            cols += groups.get(i).getColIndices().size();
        }
        return cols;
    }

    private AColGroup encode(ColumnEncoderComposite c) {
        if (c.isRecodeToDummy()) {
            return this.recodeToDummy(c);
        }
        if (c.isRecode()) {
            return this.recode(c);
        }
        if (c.isPassThrough()) {
            return this.passThrough(c);
        }
        if (c.isBin()) {
            return this.bin(c);
        }
        if (c.isBinToDummy()) {
            return this.binToDummy(c);
        }
        if (c.isHash()) {
            return this.hash(c);
        }
        if (c.isHashToDummy()) {
            return this.hashToDummy(c);
        }
        throw new NotImplementedException("Not supporting : " + c);
    }

    private AColGroup recodeToDummy(ColumnEncoderComposite c) {
        int colId = c._colID;
        Array<?> a = this.in.getColumn(colId - 1);
        boolean containsNull = a.containsNull();
        Map<?, Long> map = a.getRecodeMap();
        List<ColumnEncoder> r = c.getEncoders();
        r.set(0, new ColumnEncoderRecode(colId, (HashMap)map));
        int domain = map.size();
        if (containsNull && domain == 0) {
            return new ColGroupEmpty(ColIndexFactory.create(1));
        }
        IColIndex colIndexes = ColIndexFactory.create(0, domain);
        if (domain == 1 && !containsNull) {
            return ColGroupConst.create(colIndexes, new double[]{1.0});
        }
        IdentityDictionary d = new IdentityDictionary(colIndexes.size(), containsNull);
        AMapToData m = this.createMappingAMapToData(a, map, containsNull);
        return ColGroupDDC.create(colIndexes, d, m, null);
    }

    private AColGroup bin(ColumnEncoderComposite c) {
        int colId = c._colID;
        Array<?> a = this.in.getColumn(colId - 1);
        boolean containsNull = a.containsNull();
        List<ColumnEncoder> r = c.getEncoders();
        ColumnEncoderBin b = (ColumnEncoderBin)r.get(0);
        b.build(this.in);
        IColIndex colIndexes = ColIndexFactory.create(1);
        MatrixBlockDictionary d = this.createIncrementingVector(b._numBin, containsNull);
        AMapToData m = this.binEncode(a, b, containsNull);
        AColGroup ret = ColGroupDDC.create(colIndexes, d, m, null);
        return ret;
    }

    private AMapToData binEncode(Array<?> a, ColumnEncoderBin b, boolean containsNull) {
        AMapToData m = MapToFactory.create(a.size(), b._numBin + (containsNull ? 1 : 0));
        if (containsNull) {
            for (int i = 0; i < a.size(); ++i) {
                double v = a.getAsNaNDouble(i);
                try {
                    if (Double.isNaN(v)) {
                        m.set(i, b._numBin);
                        continue;
                    }
                    int idx = (int)b.getCodeIndex(v) - 1;
                    if (idx < 0) {
                        idx = 0;
                    }
                    m.set(i, idx);
                    continue;
                }
                catch (Exception e) {
                    m.set(i, (int)b.getCodeIndex(v - 1.0E-5) - 1);
                }
            }
        } else {
            for (int i = 0; i < a.size(); ++i) {
                try {
                    int idx = (int)b.getCodeIndex(a.getAsDouble(i)) - 1;
                    if (idx < 0) {
                        idx = 0;
                    }
                    m.set(i, idx);
                    continue;
                }
                catch (Exception e) {
                    int idx = (int)b.getCodeIndex(a.getAsDouble(i) - 1.0E-5) - 1;
                    m.set(i, idx);
                }
            }
        }
        return m;
    }

    private MatrixBlockDictionary createIncrementingVector(int nVals, boolean NaN) {
        MatrixBlock bins = new MatrixBlock(nVals + (NaN ? 1 : 0), 1, false);
        for (int i = 0; i < nVals; ++i) {
            bins.quickSetValue(i, 0, i + 1);
        }
        if (NaN) {
            bins.quickSetValue(nVals, 0, Double.NaN);
        }
        return MatrixBlockDictionary.create(bins);
    }

    private AColGroup binToDummy(ColumnEncoderComposite c) {
        int colId = c._colID;
        Array<?> a = this.in.getColumn(colId - 1);
        boolean containsNull = a.containsNull();
        List<ColumnEncoder> r = c.getEncoders();
        ColumnEncoderBin b = (ColumnEncoderBin)r.get(0);
        b.build(this.in);
        IColIndex colIndexes = ColIndexFactory.create(0, b._numBin);
        IdentityDictionary d = new IdentityDictionary(colIndexes.size(), containsNull);
        AMapToData m = this.binEncode(a, b, containsNull);
        AColGroup ret = ColGroupDDC.create(colIndexes, d, m, null);
        return ret;
    }

    private AColGroup recode(ColumnEncoderComposite c) {
        int colId = c._colID;
        Array<?> a = this.in.getColumn(colId - 1);
        Map<?, Long> map = a.getRecodeMap();
        boolean containsNull = a.containsNull();
        int domain = map.size();
        IColIndex colIndexes = ColIndexFactory.create(1);
        if (domain == 1) {
            return ColGroupConst.create(colIndexes, new double[]{1.0});
        }
        MatrixBlock incrementing = new MatrixBlock(domain + (containsNull ? 1 : 0), 1, false);
        for (int i = 0; i < domain; ++i) {
            incrementing.quickSetValue(i, 0, i + 1);
        }
        if (containsNull) {
            incrementing.quickSetValue(domain, 0, Double.NaN);
        }
        MatrixBlockDictionary d = MatrixBlockDictionary.create(incrementing);
        AMapToData m = this.createMappingAMapToData(a, map, containsNull);
        List<ColumnEncoder> r = c.getEncoders();
        r.set(0, new ColumnEncoderRecode(colId, (HashMap)map));
        return ColGroupDDC.create(colIndexes, d, m, null);
    }

    private AColGroup passThrough(ColumnEncoderComposite c) {
        IColIndex colIndexes = ColIndexFactory.create(1);
        int colId = c._colID;
        Array<?> a = this.in.getColumn(colId - 1);
        if (a instanceof ACompressedArray) {
            switch (a.getFrameArrayType()) {
                case DDC: {
                    DDCArray aDDC = (DDCArray)a;
                    Array dict = aDDC.getDict();
                    double[] vals = new double[dict.size()];
                    for (int i = 0; i < dict.size(); ++i) {
                        vals[i] = dict.getAsDouble(i);
                    }
                    Dictionary d = Dictionary.create(vals);
                    return ColGroupDDC.create(colIndexes, d, aDDC.getMap(), null);
                }
            }
            throw new NotImplementedException();
        }
        boolean containsNull = a.containsNull();
        HashMap map = (HashMap)a.getRecodeMap();
        int blockSz = ConfigurationManager.getDMLConfig().getIntValue("sysds.defaultblocksize");
        if (map.size() >= blockSz) {
            double[] vals = (double[])a.changeType(Types.ValueType.FP64).get();
            MatrixBlock col = new MatrixBlock(a.size(), 1, vals);
            col.recomputeNonZeros();
            return ColGroupUncompressed.create(colIndexes, col, false);
        }
        double[] vals = new double[map.size() + (containsNull ? 1 : 0)];
        if (containsNull) {
            vals[map.size()] = Double.NaN;
        }
        Types.ValueType t = a.getValueType();
        map.forEach((k, v) -> {
            vals[v.intValue() - 1] = UtilFunctions.objectToDouble(t, k);
        });
        Dictionary d = Dictionary.create(vals);
        AMapToData m = this.createMappingAMapToData(a, map, containsNull);
        return ColGroupDDC.create(colIndexes, d, m, null);
    }

    private AMapToData createMappingAMapToData(Array<?> a, Map<?, Long> map, boolean containsNull) {
        try {
            int si = map.size();
            AMapToData m = MapToFactory.create(this.in.getNumRows(), si + (containsNull ? 1 : 0));
            Array.ArrayIterator it = a.getIterator();
            if (containsNull) {
                while (it.hasNext()) {
                    Object v = it.next();
                    try {
                        if (v != null) {
                            m.set(it.getIndex(), map.get(v).intValue() - 1);
                            continue;
                        }
                        m.set(it.getIndex(), si);
                    }
                    catch (Exception e) {
                        throw new RuntimeException("failed on " + v + " " + a.getValueType(), e);
                    }
                }
            } else {
                while (it.hasNext()) {
                    Object v = it.next();
                    m.set(it.getIndex(), map.get(v).intValue() - 1);
                }
            }
            return m;
        }
        catch (Exception e) {
            throw new RuntimeException("failed constructing map: " + map, e);
        }
    }

    private AMapToData createHashMappingAMapToData(Array<?> a, int k, boolean nulls) {
        AMapToData m = MapToFactory.create(a.size(), k + (nulls ? 1 : 0));
        if (nulls) {
            for (int i = 0; i < a.size(); ++i) {
                double h = Math.abs(a.hashDouble(i));
                if (Double.isNaN(h)) {
                    m.set(i, k);
                    continue;
                }
                m.set(i, (int)h % k);
            }
        } else {
            for (int i = 0; i < a.size(); ++i) {
                double h = Math.abs(a.hashDouble(i));
                m.set(i, (int)h % k);
            }
        }
        return m;
    }

    private AColGroup hash(ColumnEncoderComposite c) {
        int colId = c._colID;
        Array<?> a = this.in.getColumn(colId - 1);
        ColumnEncoderFeatureHash CEHash = (ColumnEncoderFeatureHash)c.getEncoders().get(0);
        int domain = (int)CEHash.getK();
        boolean nulls = a.containsNull();
        IColIndex colIndexes = ColIndexFactory.create(0, 1);
        if (domain == 1 && !nulls) {
            return ColGroupConst.create(colIndexes, new double[]{1.0});
        }
        MatrixBlock incrementing = new MatrixBlock(domain + (nulls ? 1 : 0), 1, false);
        for (int i = 0; i < domain; ++i) {
            incrementing.quickSetValue(i, 0, i + 1);
        }
        if (nulls) {
            incrementing.quickSetValue(domain, 0, Double.NaN);
        }
        MatrixBlockDictionary d = MatrixBlockDictionary.create(incrementing);
        AMapToData m = this.createHashMappingAMapToData(a, domain, nulls);
        AColGroup ret = ColGroupDDC.create(colIndexes, d, m, null);
        return ret;
    }

    private AColGroup hashToDummy(ColumnEncoderComposite c) {
        int colId = c._colID;
        Array<?> a = this.in.getColumn(colId - 1);
        ColumnEncoderFeatureHash CEHash = (ColumnEncoderFeatureHash)c.getEncoders().get(0);
        int domain = (int)CEHash.getK();
        boolean nulls = a.containsNull();
        IColIndex colIndexes = ColIndexFactory.create(0, domain);
        if (domain == 1 && !nulls) {
            return ColGroupConst.create(colIndexes, new double[]{1.0});
        }
        IdentityDictionary d = new IdentityDictionary(colIndexes.size(), nulls);
        AMapToData m = this.createHashMappingAMapToData(a, domain, nulls);
        return ColGroupDDC.create(colIndexes, d, m, null);
    }

    private void logging(MatrixBlock mb) {
        if (LOG.isDebugEnabled()) {
            LOG.debug((Object)String.format("Uncompressed transform encode Dense size:   %16d", mb.estimateSizeDenseInMemory()));
            LOG.debug((Object)String.format("Uncompressed transform encode Sparse size:  %16d", mb.estimateSizeSparseInMemory()));
            LOG.debug((Object)String.format("Compressed transform encode size:           %16d", mb.estimateSizeInMemory()));
            double ratio = Math.min(mb.estimateSizeDenseInMemory(), mb.estimateSizeSparseInMemory()) / mb.estimateSizeInMemory();
            double denseRatio = mb.estimateSizeDenseInMemory() / mb.estimateSizeInMemory();
            LOG.debug((Object)String.format("Compression ratio: %10.3f", ratio));
            LOG.debug((Object)String.format("Dense ratio:       %10.3f", denseRatio));
        }
    }

    private class EncodeTask
    implements Callable<AColGroup> {
        ColumnEncoderComposite c;

        protected EncodeTask(ColumnEncoderComposite c) {
            this.c = c;
        }

        @Override
        public AColGroup call() throws Exception {
            return CompressedEncode.this.encode(this.c);
        }
    }
}

