/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.codegen.cplan;

import java.util.Arrays;
import org.apache.commons.lang.StringUtils;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.codegen.SpoofCompiler;
import org.apache.sysds.hops.codegen.cplan.CNode;
import org.apache.sysds.hops.codegen.cplan.CNodeData;
import org.apache.sysds.hops.codegen.template.TemplateUtils;
import org.apache.sysds.runtime.util.UtilFunctions;

public class CNodeBinary
extends CNode {
    private final BinType _type;

    public CNodeBinary(CNode in1, CNode in2, BinType type) {
        if (type.isCommutative() && in1 instanceof CNodeData && in1.getDataType() == Types.DataType.SCALAR) {
            CNode tmp = in1;
            in1 = in2;
            in2 = tmp;
        }
        this._inputs.add(in1);
        this._inputs.add(in2);
        this._type = type;
        this.setOutputDims();
    }

    public BinType getType() {
        return this._type;
    }

    @Override
    public String codegen(boolean sparse, SpoofCompiler.GeneratorAPI api) {
        int j;
        if (this.isGenerated()) {
            return "";
        }
        StringBuilder sb = new StringBuilder();
        sb.append(((CNode)this._inputs.get(0)).codegen(sparse, api));
        sb.append(((CNode)this._inputs.get(1)).codegen(sparse, api));
        boolean lsparseLhs = sparse && this._inputs.get(0) instanceof CNodeData && ((CNode)this._inputs.get(0)).getVarname().startsWith("a");
        boolean lsparseRhs = sparse && this._inputs.get(1) instanceof CNodeData && ((CNode)this._inputs.get(1)).getVarname().startsWith("a");
        boolean scalarInput = ((CNode)this._inputs.get(0)).getDataType().isScalar();
        boolean scalarVector = ((CNode)this._inputs.get(0)).getDataType().isScalar() && ((CNode)this._inputs.get(1)).getDataType().isMatrix();
        boolean vectorVector = ((CNode)this._inputs.get(0)).getDataType().isMatrix() && ((CNode)this._inputs.get(1)).getDataType().isMatrix();
        String var = this.createVarname();
        String tmp = this.getLanguageTemplateClass(this, api).getTemplate(this._type, lsparseLhs, lsparseRhs, scalarVector, scalarInput, vectorVector);
        tmp = tmp.replace("%TMP%", var);
        for (j = 0; j < 2; ++j) {
            String varj = ((CNode)this._inputs.get(j)).getVarname(api);
            tmp = tmp.replace("%IN" + (j + 1) + "v%", varj + "vals");
            tmp = tmp.replace("%IN" + (j + 1) + "i%", varj + "ix");
            tmp = tmp.replace("%IN" + (j + 1) + "%", (CharSequence)(varj.startsWith("a") ? (api == SpoofCompiler.GeneratorAPI.JAVA ? varj : (((CNode)this._inputs.get(j)).getDataType() == Types.DataType.MATRIX ? varj + ".vals(0)" : varj)) : (varj.startsWith("b") ? (api == SpoofCompiler.GeneratorAPI.JAVA ? varj + ".values(rix)" : (this._type == BinType.VECT_MATRIXMULT ? varj : varj + ".vals(0)")) : (((CNode)this._inputs.get(j)).getDataType() == Types.DataType.MATRIX ? (api == SpoofCompiler.GeneratorAPI.JAVA ? varj : varj + ".vals(0)") : varj))));
            tmp = tmp.replace("%POS" + (j + 1) + "%", (CharSequence)(this._inputs.get(j) instanceof CNodeData && ((CNode)this._inputs.get(j)).getDataType().isMatrix() ? (!varj.startsWith("b") ? varj + "i" : ((TemplateUtils.isMatrix((CNode)this._inputs.get(j)) || this._type.isElementwise() && TemplateUtils.isColVector((CNode)this._inputs.get(j))) && this._type != BinType.VECT_MATRIXMULT ? varj + ".pos(rix)" : "0")) : "0"));
        }
        if (this._type == BinType.VECT_OUTERMULT_ADD || this._type == BinType.VECT_CBIND && vectorVector) {
            for (j = 0; j < 2; ++j) {
                tmp = tmp.replace("%LEN" + (j + 1) + "%", ((CNode)this._inputs.get(j)).getVectorLength(api));
            }
        } else {
            CNode mInput = this.getIntermediateInputVector();
            if (mInput != null) {
                tmp = tmp.replace("%LEN%", mInput.getVectorLength(api));
            }
        }
        sb.append(tmp);
        this._generated = true;
        return sb.toString();
    }

    private CNode getIntermediateInputVector() {
        for (int i = 0; i < 2; ++i) {
            if (!this.getInput().get(i).getDataType().isMatrix()) continue;
            return this.getInput().get(i);
        }
        return null;
    }

    public String toString() {
        switch (this._type) {
            case DOT_PRODUCT: {
                return "b(dot)";
            }
            case VECT_MATRIXMULT: {
                return "b(vmm)";
            }
            case VECT_OUTERMULT_ADD: {
                return "b(voma)";
            }
            case VECT_MULT_ADD: {
                return "b(vma)";
            }
            case VECT_DIV_ADD: {
                return "b(vda)";
            }
            case VECT_MINUS_ADD: {
                return "b(vmia)";
            }
            case VECT_PLUS_ADD: {
                return "b(vpa)";
            }
            case VECT_POW_ADD: {
                return "b(vpowa)";
            }
            case VECT_MIN_ADD: {
                return "b(vmina)";
            }
            case VECT_MAX_ADD: {
                return "b(vmaxa)";
            }
            case VECT_EQUAL_ADD: {
                return "b(veqa)";
            }
            case VECT_NOTEQUAL_ADD: {
                return "b(vneqa)";
            }
            case VECT_LESS_ADD: {
                return "b(vlta)";
            }
            case VECT_LESSEQUAL_ADD: {
                return "b(vltea)";
            }
            case VECT_GREATEREQUAL_ADD: {
                return "b(vgtea)";
            }
            case VECT_GREATER_ADD: {
                return "b(vgta)";
            }
            case VECT_CBIND_ADD: {
                return "b(vcbinda)";
            }
            case VECT_MULT_SCALAR: {
                return "b(vm)";
            }
            case VECT_DIV_SCALAR: {
                return "b(vd)";
            }
            case VECT_MINUS_SCALAR: {
                return "b(vmi)";
            }
            case VECT_PLUS_SCALAR: {
                return "b(vp)";
            }
            case VECT_XOR_SCALAR: {
                return "v(vxor)";
            }
            case VECT_POW_SCALAR: {
                return "b(vpow)";
            }
            case VECT_MIN_SCALAR: {
                return "b(vmin)";
            }
            case VECT_MAX_SCALAR: {
                return "b(vmax)";
            }
            case VECT_EQUAL_SCALAR: {
                return "b(veq)";
            }
            case VECT_NOTEQUAL_SCALAR: {
                return "b(vneq)";
            }
            case VECT_LESS_SCALAR: {
                return "b(vlt)";
            }
            case VECT_LESSEQUAL_SCALAR: {
                return "b(vlte)";
            }
            case VECT_GREATEREQUAL_SCALAR: {
                return "b(vgte)";
            }
            case VECT_GREATER_SCALAR: {
                return "b(vgt)";
            }
            case VECT_MULT: {
                return "b(v2m)";
            }
            case VECT_DIV: {
                return "b(v2d)";
            }
            case VECT_MINUS: {
                return "b(v2mi)";
            }
            case VECT_PLUS: {
                return "b(v2p)";
            }
            case VECT_XOR: {
                return "b(v2xor)";
            }
            case VECT_MIN: {
                return "b(v2min)";
            }
            case VECT_MAX: {
                return "b(v2max)";
            }
            case VECT_EQUAL: {
                return "b(v2eq)";
            }
            case VECT_NOTEQUAL: {
                return "b(v2neq)";
            }
            case VECT_LESS: {
                return "b(v2lt)";
            }
            case VECT_LESSEQUAL: {
                return "b(v2lte)";
            }
            case VECT_GREATEREQUAL: {
                return "b(v2gte)";
            }
            case VECT_GREATER: {
                return "b(v2gt)";
            }
            case VECT_CBIND: {
                return "b(cbind)";
            }
            case VECT_BIASADD: {
                return "b(vbias+)";
            }
            case VECT_BIASMULT: {
                return "b(vbias*)";
            }
            case MULT: {
                return "b(*)";
            }
            case DIV: {
                return "b(/)";
            }
            case PLUS: {
                return "b(+)";
            }
            case MINUS: {
                return "b(-)";
            }
            case POW: {
                return "b(^)";
            }
            case MODULUS: {
                return "b(%%)";
            }
            case INTDIV: {
                return "b(%/%)";
            }
            case LESS: {
                return "b(<)";
            }
            case LESSEQUAL: {
                return "b(<=)";
            }
            case GREATER: {
                return "b(>)";
            }
            case GREATEREQUAL: {
                return "b(>=)";
            }
            case EQUAL: {
                return "b(==)";
            }
            case NOTEQUAL: {
                return "b(!=)";
            }
            case OR: {
                return "b(|)";
            }
            case AND: {
                return "b(&)";
            }
            case XOR: {
                return "b(xor)";
            }
            case BITWAND: {
                return "b(bitwAnd)";
            }
            case SEQ_RIX: {
                return "b(seqr)";
            }
            case MINUS1_MULT: {
                return "b(1-*)";
            }
            case MINUS_NZ: {
                return "b(-nz)";
            }
        }
        return "b(" + this._type.name().toLowerCase() + ")";
    }

    @Override
    public void setOutputDims() {
        switch (this._type) {
            case VECT_MULT_ADD: 
            case VECT_DIV_ADD: 
            case VECT_MINUS_ADD: 
            case VECT_PLUS_ADD: 
            case VECT_POW_ADD: 
            case VECT_MIN_ADD: 
            case VECT_MAX_ADD: 
            case VECT_EQUAL_ADD: 
            case VECT_NOTEQUAL_ADD: 
            case VECT_LESS_ADD: 
            case VECT_LESSEQUAL_ADD: 
            case VECT_GREATEREQUAL_ADD: 
            case VECT_GREATER_ADD: 
            case VECT_CBIND_ADD: 
            case VECT_XOR_ADD: {
                boolean vectorScalar = ((CNode)this._inputs.get(1)).getDataType() == Types.DataType.SCALAR;
                this._rows = ((CNode)this._inputs.get((int)(vectorScalar ? 0 : 1)))._rows;
                this._cols = ((CNode)this._inputs.get((int)(vectorScalar ? 0 : 1)))._cols;
                this._dataType = Types.DataType.MATRIX;
                break;
            }
            case VECT_CBIND: {
                this._rows = ((CNode)this._inputs.get((int)0))._rows;
                this._cols = ((CNode)this._inputs.get((int)0))._cols + 1L;
                this._dataType = Types.DataType.MATRIX;
                break;
            }
            case VECT_OUTERMULT_ADD: {
                this._rows = ((CNode)this._inputs.get((int)0))._cols;
                this._cols = ((CNode)this._inputs.get((int)1))._cols;
                this._dataType = Types.DataType.MATRIX;
                break;
            }
            case VECT_MULT_SCALAR: 
            case VECT_DIV_SCALAR: 
            case VECT_MINUS_SCALAR: 
            case VECT_PLUS_SCALAR: 
            case VECT_XOR_SCALAR: 
            case VECT_POW_SCALAR: 
            case VECT_MIN_SCALAR: 
            case VECT_MAX_SCALAR: 
            case VECT_EQUAL_SCALAR: 
            case VECT_NOTEQUAL_SCALAR: 
            case VECT_LESS_SCALAR: 
            case VECT_LESSEQUAL_SCALAR: 
            case VECT_GREATEREQUAL_SCALAR: 
            case VECT_GREATER_SCALAR: 
            case VECT_MULT: 
            case VECT_DIV: 
            case VECT_MINUS: 
            case VECT_PLUS: 
            case VECT_XOR: 
            case VECT_MIN: 
            case VECT_MAX: 
            case VECT_EQUAL: 
            case VECT_NOTEQUAL: 
            case VECT_LESS: 
            case VECT_LESSEQUAL: 
            case VECT_GREATEREQUAL: 
            case VECT_GREATER: 
            case VECT_BIASADD: 
            case VECT_BIASMULT: 
            case VECT_BITWAND_SCALAR: 
            case VECT_BITWAND: {
                boolean scalarVector = ((CNode)this._inputs.get(0)).getDataType() == Types.DataType.SCALAR;
                this._rows = ((CNode)this._inputs.get((int)(scalarVector ? 1 : 0)))._rows;
                this._cols = ((CNode)this._inputs.get((int)(scalarVector ? 1 : 0)))._cols;
                this._dataType = Types.DataType.MATRIX;
                break;
            }
            case VECT_MATRIXMULT: {
                this._rows = ((CNode)this._inputs.get((int)0))._rows;
                this._cols = ((CNode)this._inputs.get((int)1))._cols;
                this._dataType = Types.DataType.MATRIX;
                break;
            }
            case DOT_PRODUCT: 
            case MULT: 
            case DIV: 
            case PLUS: 
            case MINUS: 
            case POW: 
            case MODULUS: 
            case INTDIV: 
            case LESS: 
            case LESSEQUAL: 
            case GREATER: 
            case GREATEREQUAL: 
            case EQUAL: 
            case NOTEQUAL: 
            case OR: 
            case AND: 
            case XOR: 
            case BITWAND: 
            case SEQ_RIX: 
            case MINUS1_MULT: 
            case MINUS_NZ: 
            case ROWMAXS_VECTMULT: 
            case MIN: 
            case MAX: 
            case LOG: 
            case LOG_NZ: {
                this._rows = 0L;
                this._cols = 0L;
                this._dataType = Types.DataType.SCALAR;
                break;
            }
            default: {
                throw new RuntimeException("Unknown CNodeBinary type: " + this._type);
            }
        }
    }

    @Override
    public int hashCode() {
        if (this._hash == 0) {
            this._hash = UtilFunctions.intHashCode(super.hashCode(), this._type.hashCode());
        }
        return this._hash;
    }

    @Override
    public boolean equals(Object o) {
        if (!(o instanceof CNodeBinary)) {
            return false;
        }
        CNodeBinary that = (CNodeBinary)o;
        return super.equals(that) && this._type == that._type;
    }

    @Override
    public boolean isSupported(SpoofCompiler.GeneratorAPI api) {
        boolean is_supported;
        boolean bl = is_supported = api == SpoofCompiler.GeneratorAPI.CUDA || api == SpoofCompiler.GeneratorAPI.JAVA;
        if (api == SpoofCompiler.GeneratorAPI.CUDA) {
            is_supported = !this._type.isNotSupportedBySpoofCUDA();
        }
        int i = 0;
        while (is_supported && i < this._inputs.size()) {
            CNode in = (CNode)this._inputs.get(i++);
            is_supported = in.isSupported(api);
        }
        return is_supported;
    }

    public static enum BinType {
        ROWMAXS_VECTMULT,
        DOT_PRODUCT,
        VECT_MATRIXMULT,
        VECT_OUTERMULT_ADD,
        VECT_MULT_ADD,
        VECT_DIV_ADD,
        VECT_MINUS_ADD,
        VECT_PLUS_ADD,
        VECT_POW_ADD,
        VECT_MIN_ADD,
        VECT_MAX_ADD,
        VECT_EQUAL_ADD,
        VECT_NOTEQUAL_ADD,
        VECT_LESS_ADD,
        VECT_LESSEQUAL_ADD,
        VECT_GREATER_ADD,
        VECT_GREATEREQUAL_ADD,
        VECT_CBIND_ADD,
        VECT_XOR_ADD,
        VECT_MULT_SCALAR,
        VECT_DIV_SCALAR,
        VECT_MINUS_SCALAR,
        VECT_PLUS_SCALAR,
        VECT_POW_SCALAR,
        VECT_MIN_SCALAR,
        VECT_MAX_SCALAR,
        VECT_EQUAL_SCALAR,
        VECT_NOTEQUAL_SCALAR,
        VECT_LESS_SCALAR,
        VECT_LESSEQUAL_SCALAR,
        VECT_GREATER_SCALAR,
        VECT_GREATEREQUAL_SCALAR,
        VECT_CBIND,
        VECT_XOR_SCALAR,
        VECT_BITWAND_SCALAR,
        VECT_MULT,
        VECT_DIV,
        VECT_MINUS,
        VECT_PLUS,
        VECT_MIN,
        VECT_MAX,
        VECT_EQUAL,
        VECT_NOTEQUAL,
        VECT_LESS,
        VECT_LESSEQUAL,
        VECT_GREATER,
        VECT_GREATEREQUAL,
        VECT_XOR,
        VECT_BITWAND,
        VECT_BIASADD,
        VECT_BIASMULT,
        MULT,
        DIV,
        PLUS,
        MINUS,
        MODULUS,
        INTDIV,
        LESS,
        LESSEQUAL,
        GREATER,
        GREATEREQUAL,
        EQUAL,
        NOTEQUAL,
        MIN,
        MAX,
        AND,
        OR,
        XOR,
        LOG,
        LOG_NZ,
        POW,
        BITWAND,
        SEQ_RIX,
        MINUS1_MULT,
        MINUS_NZ;


        public static boolean contains(String value) {
            return Arrays.stream(BinType.values()).anyMatch(bt -> bt.name().equals(value));
        }

        public boolean isCommutative() {
            boolean ssComm = this == EQUAL || this == NOTEQUAL || this == PLUS || this == MULT || this == MIN || this == MAX || this == OR || this == AND || this == XOR || this == BITWAND;
            boolean vsComm = this == VECT_EQUAL_SCALAR || this == VECT_NOTEQUAL_SCALAR || this == VECT_PLUS_SCALAR || this == VECT_MULT_SCALAR || this == VECT_MIN_SCALAR || this == VECT_MAX_SCALAR || this == VECT_XOR_SCALAR || this == VECT_BITWAND_SCALAR;
            boolean vvComm = this == VECT_EQUAL || this == VECT_NOTEQUAL || this == VECT_PLUS || this == VECT_MULT || this == VECT_MIN || this == VECT_MAX || this == VECT_XOR || this == VECT_BITWAND;
            return ssComm || vsComm || vvComm;
        }

        public boolean isElementwise() {
            return this != DOT_PRODUCT && this != VECT_MATRIXMULT && this != VECT_OUTERMULT_ADD;
        }

        public boolean isVectorPrimitive() {
            return this.isVectorScalarPrimitive() || this.isVectorVectorPrimitive() || this.isVectorMatrixPrimitive();
        }

        public boolean isVectorScalarPrimitive() {
            return this == VECT_DIV_SCALAR || this == VECT_MULT_SCALAR || this == VECT_MINUS_SCALAR || this == VECT_PLUS_SCALAR || this == VECT_POW_SCALAR || this == VECT_MIN_SCALAR || this == VECT_MAX_SCALAR || this == VECT_EQUAL_SCALAR || this == VECT_NOTEQUAL_SCALAR || this == VECT_LESS_SCALAR || this == VECT_LESSEQUAL_SCALAR || this == VECT_GREATER_SCALAR || this == VECT_GREATEREQUAL_SCALAR || this == VECT_CBIND || this == VECT_XOR_SCALAR || this == VECT_BITWAND_SCALAR;
        }

        public boolean isVectorVectorPrimitive() {
            return this == VECT_DIV || this == VECT_MULT || this == VECT_MINUS || this == VECT_PLUS || this == VECT_MIN || this == VECT_MAX || this == VECT_EQUAL || this == VECT_NOTEQUAL || this == VECT_LESS || this == VECT_LESSEQUAL || this == VECT_GREATER || this == VECT_GREATEREQUAL || this == VECT_XOR || this == VECT_BITWAND || this == VECT_BIASADD || this == VECT_BIASMULT;
        }

        public boolean isVectorMatrixPrimitive() {
            return this == VECT_MATRIXMULT || this == VECT_OUTERMULT_ADD;
        }

        public BinType getVectorAddPrimitive() {
            return BinType.valueOf("VECT_" + this.getVectorPrimitiveName().toUpperCase() + "_ADD");
        }

        public String getVectorPrimitiveName() {
            String[] tmp = this.name().split("_");
            return StringUtils.capitalize((String)tmp[1].toLowerCase());
        }

        public boolean isNotSupportedBySpoofCUDA() {
            return this == VECT_BIASADD || this == VECT_BIASMULT;
        }
    }
}

