/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.lib;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupValue;
import org.apache.sysds.runtime.compress.lib.CLALibUtils;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.util.CommonThreadPool;

public class CLALibLeftMultBy {
    private static final Log LOG = LogFactory.getLog((String)CLALibLeftMultBy.class.getName());

    public static MatrixBlock leftMultByMatrixTransposed(CompressedMatrixBlock right, MatrixBlock left, MatrixBlock ret, int k) {
        if (left.isEmpty()) {
            return ret;
        }
        MatrixBlock transposed = new MatrixBlock(left.getNumColumns(), left.getNumRows(), false);
        LibMatrixReorg.transpose(left, transposed, k);
        ret = CLALibLeftMultBy.leftMultByMatrix(right, transposed, ret, k);
        ret.recomputeNonZeros();
        return ret;
    }

    public static MatrixBlock leftMultByMatrixTransposed(CompressedMatrixBlock right, CompressedMatrixBlock left, MatrixBlock ret, int k) {
        ret = CLALibLeftMultBy.prepareReturnMatrix(right, left, ret, true);
        CLALibLeftMultBy.leftMultByCompressedTransposedMatrix(right, left, ret, k);
        ret.recomputeNonZeros();
        return ret;
    }

    public static MatrixBlock leftMultByMatrix(CompressedMatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k) {
        ret = CLALibLeftMultBy.prepareReturnMatrix(m1, m2, ret, false);
        if (m2.isEmpty()) {
            return ret;
        }
        LOG.trace((Object)"LeftMultByMatrix Execution");
        ret = CLALibLeftMultBy.leftMultByMatrix(m1.getColGroups(), m2, ret, k, m1.isOverlapping());
        ret.recomputeNonZeros();
        return ret;
    }

    public static void leftMultByTransposeSelf(CompressedMatrixBlock cmb, MatrixBlock result, int k) {
        List<AColGroup> groups = cmb.getColGroups();
        int numColumns = cmb.getNumColumns();
        int numRows = cmb.getNumRows();
        boolean containsSDC = CLALibUtils.containsSDCOrConst(groups);
        double[] constV = containsSDC ? new double[numColumns] : null;
        List<AColGroup> filteredGroups = CLALibUtils.filterGroups(groups, constV);
        double[] filteredColSum = CLALibLeftMultBy.getColSum(filteredGroups, numColumns, numRows, containsSDC);
        CLALibLeftMultBy.tsmmColGroups(filteredGroups, result, numRows);
        double[] retV = result.getDenseBlockValues();
        if (constV != null) {
            CLALibLeftMultBy.outerProductUpperTriangle(constV, filteredColSum, retV);
            for (int i = 0; i < filteredColSum.length; ++i) {
                int n = i;
                filteredColSum[n] = filteredColSum[n] + constV[i] * (double)numRows;
            }
            CLALibLeftMultBy.outerProductUpperTriangle(filteredColSum, constV, retV);
        }
        long nnz = LibMatrixMult.copyUpperToLowerTriangle(result);
        result.setNonZeros(nnz);
        result.examSparsity();
    }

    private static MatrixBlock prepareReturnMatrix(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, boolean doTranspose) {
        int numRowsOutput = doTranspose ? m2.getNumColumns() : m2.getNumRows();
        int numColumnsOutput = m1.getNumColumns();
        if (ret == null) {
            ret = new MatrixBlock(numRowsOutput, numColumnsOutput, false, numRowsOutput * numColumnsOutput);
        } else if (ret.getNumColumns() != numColumnsOutput || ret.getNumRows() != numRowsOutput || !ret.isAllocated()) {
            ret.reset(numRowsOutput, numColumnsOutput, false, numRowsOutput * numColumnsOutput);
        }
        ret.allocateDenseBlock();
        return ret;
    }

    private static MatrixBlock leftMultByCompressedTransposedMatrix(CompressedMatrixBlock right, CompressedMatrixBlock left, MatrixBlock ret, int k) {
        int sd = right.getNumRows();
        int cr = right.getNumColumns();
        int rl = left.getNumColumns();
        List<AColGroup> rightCG = right.getColGroups();
        List<AColGroup> leftCG = left.getColGroups();
        boolean containsRight = CLALibUtils.containsSDCOrConst(rightCG);
        double[] cR = containsRight ? new double[cr] : null;
        List<AColGroup> fRight = CLALibUtils.filterGroups(rightCG, cR);
        boolean containsLeft = CLALibUtils.containsSDCOrConst(leftCG);
        double[] cL = containsLeft ? new double[rl] : null;
        List<AColGroup> fLeft = CLALibUtils.filterGroups(leftCG, cL);
        double[] fRightSum = CLALibLeftMultBy.getColSum(fRight, cr, sd, containsLeft);
        double[] fLeftSum = CLALibLeftMultBy.getColSum(fLeft, rl, sd, containsRight);
        for (int i = 0; i < fRight.size(); ++i) {
            for (int j = 0; j < fLeft.size(); ++j) {
                fRight.get(i).leftMultByAColGroup(fLeft.get(j), ret);
            }
        }
        double[] retV = ret.getDenseBlockValues();
        if (containsLeft) {
            CLALibLeftMultBy.outerProduct(cL, fRightSum, retV);
        }
        if (containsRight) {
            CLALibLeftMultBy.outerProduct(cR, fLeftSum, retV);
        }
        return ret;
    }

    private static void tsmmColGroups(List<AColGroup> filteredGroups, MatrixBlock ret, int nRows) {
        for (int i = 0; i < filteredGroups.size(); ++i) {
            AColGroup g = filteredGroups.get(i);
            g.tsmm(ret, nRows);
            for (int j = i + 1; j < filteredGroups.size(); ++j) {
                AColGroup h = filteredGroups.get(j);
                g.tsmmAColGroup(h, ret);
            }
        }
    }

    private static void outerProductUpperTriangle(double[] leftRowSum, double[] rightColumnSum, double[] result) {
        for (int row = 0; row < leftRowSum.length; ++row) {
            int offOut = rightColumnSum.length * row;
            double vLeft = leftRowSum[row];
            for (int col = row; col < rightColumnSum.length; ++col) {
                int n = offOut + col;
                result[n] = result[n] + vLeft * rightColumnSum[col];
            }
        }
    }

    private static MatrixBlock leftMultByMatrix(List<AColGroup> colGroups, MatrixBlock that, MatrixBlock ret, int k, boolean overlapping) {
        if (that.isEmpty()) {
            ret.setNonZeros(0L);
            return ret;
        }
        int numColumnsOut = ret.getNumColumns();
        boolean containsSDC = CLALibUtils.containsSDCOrConst(colGroups);
        int lr = that.getNumRows();
        double[] constV = containsSDC ? new double[numColumnsOut] : null;
        List<AColGroup> filteredGroups = CLALibUtils.filterGroups(colGroups, constV);
        if (colGroups == filteredGroups) {
            constV = null;
        }
        double[] rowSums = !filteredGroups.isEmpty() ? (k == 1 ? CLALibLeftMultBy.leftMultByMatrixPrimitive(filteredGroups, that, ret, 0, lr, containsSDC ? new double[lr] : null) : CLALibLeftMultBy.leftMultByMatrixParallel(filteredGroups, that, ret, containsSDC, overlapping, k)) : (constV != null ? that.rowSum(k).getDenseBlockValues() : null);
        if (rowSums != null && constV != null) {
            ret.sparseToDense();
            CLALibLeftMultBy.outerProduct(rowSums, constV, ret.getDenseBlockValues());
        }
        ret.recomputeNonZeros();
        return ret;
    }

    private static double[] leftMultByMatrixParallel(List<AColGroup> filteredGroups, MatrixBlock that, MatrixBlock ret, boolean calculateRowSums, boolean overlapping, int k) {
        LOG.debug((Object)"Parallel left matrix multiplication");
        try {
            ExecutorService pool = CommonThreadPool.get(k);
            ArrayList<LeftMatrixColGroupMultTask> tasks = new ArrayList<LeftMatrixColGroupMultTask>();
            int rl = that.getNumRows();
            int rowBlockSize = rl <= k ? 1 : Math.min(Math.max(rl / k * 2, 1), 8);
            double[] rowSums = calculateRowSums ? new double[rl] : null;
            int numberSplits = Math.max(k / (rl / rowBlockSize), 1);
            if (numberSplits == 1) {
                for (int blo = 0; blo < rl; blo += rowBlockSize) {
                    tasks.add(new LeftMatrixColGroupMultTask(filteredGroups, that, ret, blo, Math.min(blo + rowBlockSize, rl), rowSums));
                }
                for (Future future : pool.invokeAll(tasks)) {
                    future.get();
                }
            } else {
                List<List<AColGroup>> split = CLALibLeftMultBy.split(filteredGroups, numberSplits);
                boolean useTmp = overlapping && filteredGroups.size() > 1;
                for (int blo = 0; blo < rl; blo += rowBlockSize) {
                    int start = blo;
                    int end = Math.min(blo + rowBlockSize, rl);
                    for (int i = 0; i < split.size(); ++i) {
                        MatrixBlock tmpRet;
                        List<AColGroup> gr = split.get(i);
                        MatrixBlock matrixBlock = tmpRet = useTmp ? new MatrixBlock(rl, ret.getNumColumns(), false) : ret;
                        if (tmpRet.getDenseBlock() == null) {
                            tmpRet.allocateDenseBlock();
                        }
                        if (i == 0) {
                            tasks.add(new LeftMatrixColGroupMultTask(gr, that, tmpRet, start, end, rowSums));
                            continue;
                        }
                        tasks.add(new LeftMatrixColGroupMultTask(gr, that, tmpRet, start, end, null));
                    }
                }
                if (useTmp) {
                    BinaryOperator op = new BinaryOperator(Plus.getPlusFnObject());
                    for (Future future : pool.invokeAll(tasks)) {
                        ret.binaryOperationsInPlace(op, (MatrixValue)future.get());
                    }
                } else {
                    for (Future future : pool.invokeAll(tasks)) {
                        future.get();
                    }
                }
            }
            pool.shutdown();
            return rowSums;
        }
        catch (InterruptedException | ExecutionException e) {
            throw new DMLRuntimeException(e);
        }
    }

    private static List<List<AColGroup>> split(List<AColGroup> groups, int splits) {
        Collections.sort(groups, Comparator.comparing(AColGroup::getNumValues).reversed());
        ArrayList<List<AColGroup>> ret = new ArrayList<List<AColGroup>>();
        for (int i = 0; i < splits; ++i) {
            ret.add(new ArrayList());
        }
        for (int j = 0; j < groups.size(); ++j) {
            ((List)ret.get(j % splits)).add(groups.get(j));
        }
        return ret;
    }

    private static void outerProduct(double[] leftRowSum, double[] rightColumnSum, double[] result) {
        for (int row = 0; row < leftRowSum.length; ++row) {
            int offOut = rightColumnSum.length * row;
            double vLeft = leftRowSum[row];
            for (int col = 0; col < rightColumnSum.length; ++col) {
                int n = offOut + col;
                result[n] = result[n] + vLeft * rightColumnSum[col];
            }
        }
    }

    private static double[] leftMultByMatrixPrimitive(List<AColGroup> colGroups, MatrixBlock that, MatrixBlock ret, int rl, int ru, double[] rowSums) {
        if (that.isInSparseFormat()) {
            CLALibLeftMultBy.leftMultByMatrixPrimitiveSparse(colGroups, that, ret, rl, ru, rowSums);
        } else {
            CLALibLeftMultBy.leftMultByMatrixPrimitiveDense(colGroups, that, ret, rl, ru, rowSums);
        }
        ret.setNonZeros(ret.getNumRows() * ret.getNumColumns());
        return rowSums;
    }

    private static void leftMultByMatrixPrimitiveSparse(List<AColGroup> colGroups, MatrixBlock that, MatrixBlock ret, int rl, int ru, double[] rowSum) {
        for (int i = rl; i < ru; ++i) {
            SparseBlock sb;
            for (int j = 0; j < colGroups.size(); ++j) {
                colGroups.get(j).leftMultByMatrix(that, ret, i, i + 1);
            }
            if (rowSum == null || (sb = that.getSparseBlock()).isEmpty(i)) continue;
            int apos = sb.pos(i);
            int alen = sb.size(i) + apos;
            double[] aval = sb.values(i);
            for (int j = apos; j < alen; ++j) {
                int n = i;
                rowSum[n] = rowSum[n] + aval[j];
            }
        }
    }

    private static void leftMultByMatrixPrimitiveDense(List<AColGroup> colGroups, MatrixBlock that, MatrixBlock ret, int rl, int ru, double[] rowSum) {
        int numColsOut = ret.getNumColumns();
        List<ColGroupValue> ColGroupValues = CLALibLeftMultBy.preFilterAndMultiply(colGroups, that, ret, rl, ru);
        boolean rowBlockSize = true;
        int colGroupBlocking = ColGroupValues.size() % 16 < 4 ? 20 : 16;
        MatrixBlock[] preAgg = CLALibLeftMultBy.populatePreAggregate(colGroupBlocking);
        MatrixBlock tmpRes = new MatrixBlock(1, numColsOut, false);
        int lc = that.getNumColumns();
        for (int g = 0; g < ColGroupValues.size(); g += colGroupBlocking) {
            int gEnd = Math.min(g + colGroupBlocking, ColGroupValues.size());
            for (int j = g; j < gEnd && j < ColGroupValues.size(); ++j) {
                ColGroupValue cg = ColGroupValues.get(j);
                int nVals = cg.getNumValues();
                preAgg[j % colGroupBlocking].reset(1, nVals, false);
            }
            int colBlockSize = 32000;
            for (int h = rl; h < ru; ++h) {
                int rowUpper = Math.min(h + 1, ru);
                for (int i = 0; i < lc; i += colBlockSize) {
                    int colUpper = Math.min(i + colBlockSize, lc);
                    for (int j = g; j < gEnd && j < ColGroupValues.size(); ++j) {
                        ColGroupValues.get(j).preAggregateDense(that, preAgg[j % colGroupBlocking], h, rowUpper, i, colUpper);
                    }
                    if (rowSum == null) continue;
                    double[] thatV = that.getDenseBlockValues();
                    for (int r = h; r < rowUpper; ++r) {
                        int rowOff = r * lc;
                        for (int c = rowOff + i; c < rowOff + colUpper; ++c) {
                            int n = r;
                            rowSum[n] = rowSum[n] + thatV[c];
                        }
                    }
                }
                for (int j = g; j < gEnd && j < ColGroupValues.size(); ++j) {
                    ColGroupValue vj = ColGroupValues.get(j);
                    MatrixBlock preAggJ = preAgg[j % colGroupBlocking];
                    preAggJ.recomputeNonZeros();
                    tmpRes.reset(1, vj.getNumCols(), false);
                    MatrixBlock tmp = vj.leftMultByPreAggregateMatrix(preAggJ, tmpRes);
                    vj.addMatrixToResult(tmp, ret, h, Math.min(h + 1, ru));
                    preAggJ.reset();
                }
            }
        }
        if (ColGroupValues.size() == 0 && rowSum != null) {
            double[] thatV = that.getDenseBlockValues();
            for (int r = rl; r < ru; ++r) {
                int rowOff = r * lc;
                int thatOffEnd = rowOff + lc;
                for (int c = rowOff; c < thatOffEnd; ++c) {
                    int n = r;
                    rowSum[n] = rowSum[n] + thatV[c];
                }
            }
        }
    }

    private static MatrixBlock[] populatePreAggregate(int colGroupBlocking) {
        MatrixBlock[] preAgg = new MatrixBlock[colGroupBlocking];
        for (int j = 0; j < colGroupBlocking; ++j) {
            MatrixBlock m = new MatrixBlock(1, 1, false);
            m.allocateDenseBlock();
            preAgg[j] = m;
        }
        return preAgg;
    }

    private static List<ColGroupValue> preFilterAndMultiply(List<AColGroup> colGroups, MatrixBlock that, MatrixBlock ret, int rl, int ru) {
        ArrayList<ColGroupValue> ColGroupValues = new ArrayList<ColGroupValue>(colGroups.size());
        for (int j = 0; j < colGroups.size(); ++j) {
            AColGroup a = colGroups.get(j);
            if (a instanceof ColGroupValue) {
                ColGroupValues.add((ColGroupValue)a);
                continue;
            }
            a.leftMultByMatrix(that, ret, rl, ru);
        }
        Collections.sort(ColGroupValues, Comparator.comparing(AColGroup::getNumValues).reversed());
        return ColGroupValues;
    }

    private static double[] getColSum(List<AColGroup> groups, int nCols, int nRows, boolean returnNull) {
        return returnNull ? AColGroup.colSum(groups, new double[nCols], nRows) : null;
    }

    private static class LeftMatrixColGroupMultTask
    implements Callable<MatrixBlock> {
        private final List<AColGroup> _groups;
        private final MatrixBlock _that;
        private final MatrixBlock _ret;
        private final int _rl;
        private final int _ru;
        private final double[] _rowSums;

        protected LeftMatrixColGroupMultTask(List<AColGroup> groups, MatrixBlock that, MatrixBlock ret, int rl, int ru, double[] rowSums) {
            this._groups = groups;
            this._that = that;
            this._ret = ret;
            this._rl = rl;
            this._ru = ru;
            this._rowSums = rowSums;
        }

        @Override
        public MatrixBlock call() {
            try {
                CLALibLeftMultBy.leftMultByMatrixPrimitive(this._groups, this._that, this._ret, this._rl, this._ru, this._rowSums);
            }
            catch (Exception e) {
                e.printStackTrace();
                throw new DMLRuntimeException(e);
            }
            return this._ret;
        }
    }
}

