/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.colgroup;

import org.apache.commons.lang3.NotImplementedException;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.AColGroupValue;
import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC;
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
import org.apache.sysds.runtime.compress.colgroup.ColGroupRLE;
import org.apache.sysds.runtime.compress.colgroup.ColGroupSDCSingleZeros;
import org.apache.sysds.runtime.compress.colgroup.ColGroupSDCZeros;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils;
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictLibMatrixMult;
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;

public abstract class APreAgg
extends AColGroupValue {
    private static final long serialVersionUID = 3250955207277128281L;
    private static boolean loggedWarningForDirect = false;

    protected APreAgg(IColIndex colIndices, IDictionary dict, int[] cachedCounts) {
        super(colIndices, dict, cachedCounts);
    }

    @Override
    public final void tsmmAColGroup(AColGroup other, MatrixBlock result) {
        if (other instanceof ColGroupEmpty) {
            return;
        }
        if (other instanceof APreAgg) {
            this.tsmmAPreAgg((APreAgg)other, result);
        } else if (other instanceof ColGroupUncompressed) {
            this.tsmmColGroupUncompressed((ColGroupUncompressed)other, result);
        } else {
            throw new DMLCompressionException("Unsupported column group type " + other.getClass().getSimpleName());
        }
    }

    @Override
    public final void leftMultByAColGroup(AColGroup lhs, MatrixBlock result, int nRows) {
        if (lhs instanceof APreAgg) {
            this.leftMultByColGroupValue((APreAgg)lhs, result);
        } else if (lhs instanceof ColGroupUncompressed) {
            this.leftMultByUncompressedColGroup((ColGroupUncompressed)lhs, result);
        } else {
            throw new DMLCompressionException("Not supported left multiplication with A ColGroup of type: " + lhs.getClass().getSimpleName());
        }
    }

    public final IDictionary preAggregateThatIndexStructure(APreAgg that) {
        long outputLength = (long)that._colIndexes.size() * (long)this.getNumValues();
        if (outputLength > Integer.MAX_VALUE) {
            throw new NotImplementedException("Not supported pre aggregate of above integer length");
        }
        if (outputLength <= 0L) {
            return null;
        }
        Dictionary ret = Dictionary.createNoCheck(new double[(int)outputLength]);
        if (that instanceof ColGroupDDC) {
            this.preAggregateThatDDCStructure((ColGroupDDC)that, ret);
        } else if (that instanceof ColGroupSDCSingleZeros) {
            this.preAggregateThatSDCSingleZerosStructure((ColGroupSDCSingleZeros)that, ret);
        } else if (that instanceof ColGroupSDCZeros) {
            this.preAggregateThatSDCZerosStructure((ColGroupSDCZeros)that, ret);
        } else if (that instanceof ColGroupRLE) {
            this.preAggregateThatRLEStructure((ColGroupRLE)that, ret);
        } else {
            throw new DMLRuntimeException("Not supported pre aggregate using index structure of :" + that.getClass().getSimpleName() + " in " + this.getClass().getSimpleName());
        }
        return ret.getMBDict(that._colIndexes.size());
    }

    public final void preAggregate(MatrixBlock m, double[] preAgg, int rl, int ru) {
        if (m.isInSparseFormat()) {
            this.preAggregateSparse(m.getSparseBlock(), preAgg, rl, ru);
        } else {
            this.preAggregateDense(m, preAgg, rl, ru, 0, m.getNumColumns());
        }
    }

    public abstract void preAggregateDense(MatrixBlock var1, double[] var2, int var3, int var4, int var5, int var6);

    public abstract void preAggregateSparse(SparseBlock var1, double[] var2, int var3, int var4);

    protected abstract void preAggregateThatDDCStructure(ColGroupDDC var1, Dictionary var2);

    protected abstract void preAggregateThatSDCZerosStructure(ColGroupSDCZeros var1, Dictionary var2);

    protected abstract void preAggregateThatSDCSingleZerosStructure(ColGroupSDCSingleZeros var1, Dictionary var2);

    protected abstract void preAggregateThatRLEStructure(ColGroupRLE var1, Dictionary var2);

    public int getPreAggregateSize() {
        return this.getNumValues();
    }

    private void tsmmAPreAgg(APreAgg lg, MatrixBlock result) {
        IColIndex rightIdx = this._colIndexes;
        IColIndex leftIdx = lg._colIndexes;
        if (this.sameIndexStructure(lg)) {
            DictLibMatrixMult.TSMMToUpperTriangleScaling(lg._dict, this._dict, leftIdx, rightIdx, this.getCounts(), result);
        } else {
            boolean left = this.shouldPreAggregateLeft(lg);
            if (!loggedWarningForDirect && this.shouldDirectMultiply(lg, leftIdx.size(), rightIdx.size(), left)) {
                loggedWarningForDirect = true;
                LOG.warn((Object)("Not implemented direct tsmm colgroup: " + lg.getClass().getSimpleName() + " %*% " + this.getClass().getSimpleName()));
            }
            if (left) {
                IDictionary lpa = this.preAggregateThatIndexStructure(lg);
                if (lpa != null) {
                    DictLibMatrixMult.TSMMToUpperTriangle(lpa, this._dict, leftIdx, rightIdx, result);
                }
            } else {
                IDictionary rpa = lg.preAggregateThatIndexStructure(this);
                if (rpa != null) {
                    DictLibMatrixMult.TSMMToUpperTriangle(lg._dict, rpa, leftIdx, rightIdx, result);
                }
            }
        }
    }

    private boolean shouldDirectMultiply(APreAgg lg, int nColL, int nColR, boolean leftPreAgg) {
        int lMRows = lg.numRowsToMultiply();
        int rMRows = this.numRowsToMultiply();
        long commonDim = Math.min(lMRows, rMRows);
        long directFLOPS = commonDim * (long)nColL * (long)nColR * 2L;
        long preAggFLOPS = 0L;
        if (leftPreAgg) {
            int nVal = this.getNumValues();
            preAggFLOPS += (long)(nColL * nVal);
            preAggFLOPS += (long)nColL * commonDim;
            preAggFLOPS += (long)(nColR * nColL * nVal);
        } else {
            int nVal = lg.getNumValues();
            preAggFLOPS += (long)(nColR * nVal);
            preAggFLOPS += (long)nColR * commonDim;
            preAggFLOPS += (long)(nColR * nColL * nVal);
        }
        return directFLOPS < preAggFLOPS;
    }

    private void leftMultByColGroupValue(APreAgg lhs, MatrixBlock result) {
        IColIndex rightIdx = this._colIndexes;
        IColIndex leftIdx = lhs._colIndexes;
        IDictionary rDict = this._dict;
        IDictionary lDict = lhs._dict;
        boolean sameIdx = this.sameIndexStructure(lhs);
        if (sameIdx && rDict == lDict) {
            DictLibMatrixMult.TSMMDictionaryWithScaling(rDict, this.getCounts(), leftIdx, rightIdx, result);
        } else if (sameIdx) {
            DictLibMatrixMult.MMDictsWithScaling(lDict, rDict, leftIdx, rightIdx, result, this.getCounts());
        } else if (this.shouldPreAggregateLeft(lhs)) {
            IDictionary lhsPA = lhs.preAggregateThatIndexStructure(this);
            if (lhsPA != null) {
                DictLibMatrixMult.MMDicts(lDict, lhsPA, leftIdx, rightIdx, result);
            }
        } else {
            IDictionary rhsPA = this.preAggregateThatIndexStructure(lhs);
            if (rhsPA != null) {
                DictLibMatrixMult.MMDicts(rhsPA, rDict, leftIdx, rightIdx, result);
            }
        }
    }

    private void leftMultByUncompressedColGroup(ColGroupUncompressed lhs, MatrixBlock result) {
        if (lhs.getNumCols() != 1) {
            LOG.warn((Object)"Transpose of uncompressed to fit to template need t(a) %*% b");
        }
        MatrixBlock tmp = LibMatrixReorg.transpose(lhs.getData(), InfrastructureAnalyzer.getLocalParallelism());
        int numVals = this.getNumValues();
        MatrixBlock preAgg = new MatrixBlock(tmp.getNumRows(), numVals, false);
        preAgg.allocateDenseBlock();
        this.preAggregate(tmp, preAgg.getDenseBlockValues(), 0, tmp.getNumRows());
        preAgg.recomputeNonZeros();
        MatrixBlock tmpRes = new MatrixBlock(preAgg.getNumRows(), this._colIndexes.size(), false);
        MatrixBlock dictM = this._dict.getMBDict(this.getNumCols()).getMatrixBlock();
        if (dictM != null) {
            LibMatrixMult.matrixMult(preAgg, dictM, tmpRes);
            this.addMatrixToResult(tmpRes, result, lhs._colIndexes);
        }
    }

    private void addMatrixToResult(MatrixBlock tmp, MatrixBlock result, IColIndex rowIndexes) {
        if (tmp.isEmpty()) {
            return;
        }
        double[] retV = result.getDenseBlockValues();
        int nColRet = result.getNumColumns();
        if (tmp.isInSparseFormat()) {
            SparseBlock sb = tmp.getSparseBlock();
            for (int row = 0; row < rowIndexes.size(); ++row) {
                if (sb.isEmpty(row)) continue;
                int apos = sb.pos(row);
                int alen = sb.size(row);
                int[] aix = sb.indexes(row);
                double[] avals = sb.values(row);
                int offR = rowIndexes.get(row) * nColRet;
                for (int i = apos; i < apos + alen; ++i) {
                    int n = offR + this._colIndexes.get(aix[i]);
                    retV[n] = retV[n] + avals[i];
                }
            }
        } else {
            double[] tmpV = tmp.getDenseBlockValues();
            int nCol = this._colIndexes.size();
            int row = 0;
            int offT = 0;
            while (row < rowIndexes.size()) {
                int offR = rowIndexes.get(row) * nColRet;
                for (int col = 0; col < nCol; ++col) {
                    int n = offR + this._colIndexes.get(col);
                    retV[n] = retV[n] + tmpV[offT + col];
                }
                ++row;
                offT += nCol;
            }
        }
    }

    private void tsmmColGroupUncompressed(ColGroupUncompressed other, MatrixBlock result) {
        LOG.warn((Object)"Inefficient multiplication with uncompressed column group");
        int nCols = result.getNumColumns();
        MatrixBlock otherMBT = LibMatrixReorg.transpose(other.getData());
        int nRows = otherMBT.getNumRows();
        MatrixBlock tmp = new MatrixBlock(nRows, nCols, false);
        tmp.allocateDenseBlock();
        this.leftMultByMatrixNoPreAgg(otherMBT, tmp, 0, nRows, 0, otherMBT.getNumColumns());
        double[] r = tmp.getDenseBlockValues();
        double[] resV = result.getDenseBlockValues();
        int otLen = other._colIndexes.size();
        int thisLen = this._colIndexes.size();
        for (int i = 0; i < otLen; ++i) {
            int oid = other._colIndexes.get(i);
            int offR = i * nCols;
            for (int j = 0; j < thisLen; ++j) {
                DictLibMatrixMult.addToUpperTriangle(nCols, oid, this._colIndexes.get(j), resV, r[offR + this._colIndexes.get(j)]);
            }
        }
    }

    private boolean shouldPreAggregateLeft(APreAgg lhs) {
        double costLeftDense;
        int nvL = lhs.getNumValues();
        int nvR = this.getNumValues();
        int lCol = lhs._colIndexes.size();
        int rCol = this._colIndexes.size();
        double costRightDense = nvR * rCol;
        return costRightDense < (costLeftDense = (double)(nvL * lCol));
    }

    public void mmWithDictionary(MatrixBlock preAgg, MatrixBlock tmpRes, MatrixBlock ret, int k, int rl, int ru) {
        MatrixBlock preAggCopy = new MatrixBlock();
        preAggCopy.copy(preAgg);
        MatrixBlock tmpResCopy = new MatrixBlock();
        tmpResCopy.copy(tmpRes);
        MatrixBlock dict = this.getDictionary().getMBDict(this._colIndexes.size()).getMatrixBlock();
        if (dict != null) {
            LibMatrixMult.matrixMult(preAggCopy, dict, tmpResCopy, k);
            ColGroupUtils.addMatrixToResult(tmpResCopy, ret, this._colIndexes, rl, ru);
        }
    }

    protected abstract int numRowsToMultiply();
}

