/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.privacy.propagation;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.AggUnaryOp;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.ReorgOp;
import org.apache.sysds.hops.TernaryOp;
import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.cp.AggregateBinaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.AppendCPInstruction;
import org.apache.sysds.runtime.instructions.cp.BuiltinNaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ListAppendRemoveCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.instructions.cp.MMChainCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MMTSJCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MultiReturnBuiltinCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MultiReturnParameterizedBuiltinCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.cp.SqlCPInstruction;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.privacy.DMLPrivacyException;
import org.apache.sysds.runtime.privacy.PrivacyConstraint;
import org.apache.sysds.runtime.privacy.PrivacyUtils;
import org.apache.sysds.runtime.privacy.propagation.AppendPropagator;
import org.apache.sysds.runtime.privacy.propagation.CBindPropagator;
import org.apache.sysds.runtime.privacy.propagation.ListAppendPropagator;
import org.apache.sysds.runtime.privacy.propagation.ListRemovePropagator;
import org.apache.sysds.runtime.privacy.propagation.MatrixMultiplicationPropagatorPrivateFirst;
import org.apache.sysds.runtime.privacy.propagation.OperatorType;
import org.apache.sysds.runtime.privacy.propagation.RBindPropagator;
import org.apache.wink.json4j.JSONException;
import org.apache.wink.json4j.JSONObject;

public class PrivacyPropagator {
    public static Data parseAndSetPrivacyConstraint(Data cd, JSONObject mtd) throws JSONException {
        PrivacyConstraint mtdPrivConstraint = PrivacyPropagator.parseAndReturnPrivacyConstraint(mtd);
        if (mtdPrivConstraint != null) {
            cd.setPrivacyConstraints(mtdPrivConstraint);
        }
        return cd;
    }

    public static PrivacyConstraint parseAndReturnPrivacyConstraint(JSONObject mtd) throws JSONException {
        String privacyLevel;
        if (mtd.containsKey("privacy") && (privacyLevel = mtd.getString("privacy")) != null) {
            return new PrivacyConstraint(PrivacyConstraint.PrivacyLevel.valueOf(privacyLevel));
        }
        return null;
    }

    private static boolean anyInputHasLevel(PrivacyConstraint.PrivacyLevel[] inputLevels, PrivacyConstraint.PrivacyLevel targetLevel) {
        return Arrays.stream(inputLevels).anyMatch(i -> i == targetLevel);
    }

    public static PrivacyConstraint.PrivacyLevel corePropagation(PrivacyConstraint.PrivacyLevel[] inputLevels, OperatorType operatorType) {
        if (PrivacyPropagator.anyInputHasLevel(inputLevels, PrivacyConstraint.PrivacyLevel.Private)) {
            return PrivacyConstraint.PrivacyLevel.Private;
        }
        if (operatorType == OperatorType.Aggregate) {
            return PrivacyConstraint.PrivacyLevel.None;
        }
        if (operatorType == OperatorType.NonAggregate && PrivacyPropagator.anyInputHasLevel(inputLevels, PrivacyConstraint.PrivacyLevel.PrivateAggregation)) {
            return PrivacyConstraint.PrivacyLevel.PrivateAggregation;
        }
        return PrivacyConstraint.PrivacyLevel.None;
    }

    private static PrivacyConstraint mergeNary(PrivacyConstraint[] privacyConstraints, OperatorType operatorType) {
        PrivacyConstraint.PrivacyLevel[] privacyLevels = (PrivacyConstraint.PrivacyLevel[])Arrays.stream(privacyConstraints).map(constraint -> {
            if (constraint != null) {
                return constraint.getPrivacyLevel();
            }
            return PrivacyConstraint.PrivacyLevel.None;
        }).toArray(PrivacyConstraint.PrivacyLevel[]::new);
        PrivacyConstraint.PrivacyLevel outputPrivacyLevel = PrivacyPropagator.corePropagation(privacyLevels, operatorType);
        return new PrivacyConstraint(outputPrivacyLevel);
    }

    public static PrivacyConstraint mergeBinary(PrivacyConstraint privacyConstraint1, PrivacyConstraint privacyConstraint2) {
        if (privacyConstraint1 != null && privacyConstraint2 != null) {
            PrivacyConstraint.PrivacyLevel[] privacyLevels = new PrivacyConstraint.PrivacyLevel[]{privacyConstraint1.getPrivacyLevel(), privacyConstraint2.getPrivacyLevel()};
            return new PrivacyConstraint(PrivacyPropagator.corePropagation(privacyLevels, OperatorType.NonAggregate));
        }
        if (privacyConstraint1 != null) {
            return privacyConstraint1;
        }
        if (privacyConstraint2 != null) {
            return privacyConstraint2;
        }
        return null;
    }

    public static void hopPropagation(Hop hop) {
        PrivacyConstraint[] inputConstraints = (PrivacyConstraint[])hop.getInput().stream().map(Hop::getPrivacy).toArray(PrivacyConstraint[]::new);
        if (hop instanceof TernaryOp || hop instanceof BinaryOp || hop instanceof ReorgOp) {
            hop.setPrivacy(PrivacyPropagator.mergeNary(inputConstraints, OperatorType.NonAggregate));
        } else if (hop instanceof AggBinaryOp || hop instanceof AggUnaryOp || hop instanceof UnaryOp) {
            hop.setPrivacy(PrivacyPropagator.mergeNary(inputConstraints, OperatorType.Aggregate));
        }
    }

    public static void postProcessInstruction(Instruction inst, ExecutionContext ec) {
        List<CPOperand> instOutputs = PrivacyPropagator.getOutputOperands(inst);
        if (!instOutputs.isEmpty()) {
            for (CPOperand output : instOutputs) {
                PrivacyConstraint outputPrivacyConstraint = output.getPrivacyConstraint();
                if (!PrivacyUtils.someConstraintSetUnary(outputPrivacyConstraint)) continue;
                PrivacyPropagator.setOutputPrivacyConstraint(ec, outputPrivacyConstraint, output.getName());
            }
        }
    }

    public static Instruction preprocessInstruction(Instruction inst, ExecutionContext ec) {
        switch (inst.getType()) {
            case CONTROL_PROGRAM: {
                return PrivacyPropagator.preprocessCPInstruction((CPInstruction)inst, ec);
            }
            case BREAKPOINT: 
            case SPARK: 
            case GPU: 
            case FEDERATED: {
                return inst;
            }
        }
        return PrivacyPropagator.throwExceptionIfInputOrInstPrivacy(inst, ec);
    }

    private static Instruction preprocessCPInstruction(CPInstruction inst, ExecutionContext ec) {
        switch (inst.getCPInstructionType()) {
            case Binary: 
            case Builtin: 
            case BuiltinNary: 
            case FCall: 
            case ParameterizedBuiltin: 
            case Quaternary: 
            case Reorg: 
            case Ternary: 
            case Unary: 
            case MultiReturnBuiltin: 
            case MultiReturnParameterizedBuiltin: 
            case MatrixIndexing: {
                return PrivacyPropagator.mergePrivacyConstraintsFromInput(inst, ec, OperatorType.NonAggregate);
            }
            case AggregateTernary: 
            case AggregateUnary: {
                return PrivacyPropagator.mergePrivacyConstraintsFromInput(inst, ec, OperatorType.Aggregate);
            }
            case Append: {
                return PrivacyPropagator.preprocessAppendCPInstruction((AppendCPInstruction)inst, ec);
            }
            case AggregateBinary: {
                if (inst instanceof AggregateBinaryCPInstruction) {
                    return PrivacyPropagator.preprocessAggregateBinaryCPInstruction((AggregateBinaryCPInstruction)inst, ec);
                }
                return PrivacyPropagator.throwExceptionIfInputOrInstPrivacy(inst, ec);
            }
            case MMTSJ: {
                OperatorType mmtsjOpType = OperatorType.getAggregationType((MMTSJCPInstruction)inst, ec);
                return PrivacyPropagator.mergePrivacyConstraintsFromInput(inst, ec, mmtsjOpType);
            }
            case MMChain: {
                OperatorType mmChainOpType = OperatorType.getAggregationType((MMChainCPInstruction)inst, ec);
                return PrivacyPropagator.mergePrivacyConstraintsFromInput(inst, ec, mmChainOpType);
            }
            case Variable: {
                return PrivacyPropagator.preprocessVariableCPInstruction((VariableCPInstruction)inst, ec);
            }
        }
        return PrivacyPropagator.throwExceptionIfInputOrInstPrivacy(inst, ec);
    }

    private static Instruction preprocessVariableCPInstruction(VariableCPInstruction inst, ExecutionContext ec) {
        switch (inst.getVariableOpcode()) {
            case CopyVariable: 
            case MoveVariable: 
            case RemoveVariableAndFile: 
            case CastAsMatrixVariable: 
            case CastAsFrameVariable: 
            case Write: 
            case SetFileName: 
            case CastAsScalarVariable: 
            case CastAsDoubleVariable: 
            case CastAsIntegerVariable: 
            case CastAsBooleanVariable: {
                return PrivacyPropagator.propagateFirstInputPrivacy(inst, ec);
            }
            case CreateVariable: {
                return PrivacyPropagator.propagateSecondInputPrivacy(inst, ec);
            }
            case AssignVariable: 
            case RemoveVariable: {
                return PrivacyPropagator.mergePrivacyConstraintsFromInput(inst, ec, OperatorType.NonAggregate);
            }
            case Read: {
                return inst;
            }
        }
        return PrivacyPropagator.throwExceptionIfInputOrInstPrivacy(inst, ec);
    }

    private static Instruction preprocessAggregateBinaryCPInstruction(AggregateBinaryCPInstruction inst, ExecutionContext ec) {
        PrivacyConstraint[] privacyConstraints = PrivacyPropagator.getInputPrivacyConstraints(ec, inst.getInputs());
        if (PrivacyUtils.someConstraintSetBinary(privacyConstraints)) {
            PrivacyConstraint mergedPrivacyConstraint;
            if (privacyConstraints[0] != null && privacyConstraints[0].hasFineGrainedConstraints() || privacyConstraints[1] != null && privacyConstraints[1].hasFineGrainedConstraints()) {
                MatrixBlock input1 = ec.getMatrixInput(inst.input1.getName());
                MatrixBlock input2 = ec.getMatrixInput(inst.input2.getName());
                MatrixMultiplicationPropagatorPrivateFirst propagator = new MatrixMultiplicationPropagatorPrivateFirst(input1, privacyConstraints[0], input2, privacyConstraints[1]);
                mergedPrivacyConstraint = propagator.propagate();
                ec.releaseMatrixInput(inst.input1.getName(), inst.input2.getName());
            } else {
                mergedPrivacyConstraint = PrivacyPropagator.mergeNary(privacyConstraints, OperatorType.getAggregationType(inst, ec));
                inst.setPrivacyConstraint(mergedPrivacyConstraint);
            }
            inst.output.setPrivacyConstraint(mergedPrivacyConstraint);
        }
        return inst;
    }

    private static Instruction preprocessAppendCPInstruction(AppendCPInstruction inst, ExecutionContext ec) {
        PrivacyConstraint[] privacyConstraints = PrivacyPropagator.getInputPrivacyConstraints(ec, inst.getInputs());
        if (PrivacyUtils.someConstraintSetBinary(privacyConstraints)) {
            if (inst.getAppendType() == AppendCPInstruction.AppendType.STRING) {
                PrivacyConstraint.PrivacyLevel[] privacyLevels = new PrivacyConstraint.PrivacyLevel[]{PrivacyUtils.getGeneralPrivacyLevel(privacyConstraints[0]), PrivacyUtils.getGeneralPrivacyLevel(privacyConstraints[1])};
                PrivacyConstraint outputConstraint = new PrivacyConstraint(PrivacyPropagator.corePropagation(privacyLevels, OperatorType.NonAggregate));
                inst.output.setPrivacyConstraint(outputConstraint);
            } else if (inst.getAppendType() == AppendCPInstruction.AppendType.LIST) {
                ListObject input1 = (ListObject)ec.getVariable(inst.input1);
                if (inst.getOpcode().equals("remove")) {
                    ScalarObject removePosition = ec.getScalarInput(inst.input2);
                    ListRemovePropagator propagator = new ListRemovePropagator(input1, privacyConstraints[0], removePosition, removePosition.getPrivacyConstraint());
                    PrivacyConstraint[] outputConstraints = propagator.propagate();
                    inst.output.setPrivacyConstraint(outputConstraints[0]);
                    ((ListAppendRemoveCPInstruction)inst).getOutput2().setPrivacyConstraint(outputConstraints[1]);
                } else {
                    ListObject input2 = (ListObject)ec.getVariable(inst.input2);
                    ListAppendPropagator propagator = new ListAppendPropagator(input1, privacyConstraints[0], input2, privacyConstraints[1]);
                    inst.output.setPrivacyConstraint(propagator.propagate());
                }
            } else {
                AppendPropagator propagator;
                MatrixBlock input1 = ec.getMatrixInput(inst.input1.getName());
                MatrixBlock input2 = ec.getMatrixInput(inst.input2.getName());
                if (inst.getAppendType() == AppendCPInstruction.AppendType.RBIND) {
                    propagator = new RBindPropagator(input1, privacyConstraints[0], input2, privacyConstraints[1]);
                } else if (inst.getAppendType() == AppendCPInstruction.AppendType.CBIND) {
                    propagator = new CBindPropagator(input1, privacyConstraints[0], input2, privacyConstraints[1]);
                } else {
                    throw new DMLPrivacyException("Instruction " + (Object)((Object)inst.getCPInstructionType()) + " with append type " + (Object)((Object)inst.getAppendType()) + " is not supported by the privacy propagator");
                }
                inst.output.setPrivacyConstraint(propagator.propagate());
                ec.releaseMatrixInput(inst.input1.getName(), inst.input2.getName());
            }
        }
        return inst;
    }

    private static Instruction mergePrivacyConstraintsFromInput(Instruction inst, ExecutionContext ec, OperatorType operatorType) {
        return PrivacyPropagator.mergePrivacyConstraintsFromInput(inst, ec, PrivacyPropagator.getInputOperands(inst), PrivacyPropagator.getOutputOperands(inst), operatorType);
    }

    private static Instruction mergePrivacyConstraintsFromInput(Instruction inst, ExecutionContext ec, CPOperand[] inputs, List<CPOperand> outputs, OperatorType operatorType) {
        PrivacyConstraint[] privacyConstraints;
        if (inputs != null && inputs.length > 0 && (privacyConstraints = PrivacyPropagator.getInputPrivacyConstraints(ec, inputs)) != null) {
            PrivacyConstraint mergedPrivacyConstraint = PrivacyPropagator.mergeNary(privacyConstraints, operatorType);
            inst.setPrivacyConstraint(mergedPrivacyConstraint);
            for (CPOperand output : outputs) {
                if (output == null) continue;
                output.setPrivacyConstraint(mergedPrivacyConstraint);
            }
        }
        return inst;
    }

    private static Instruction throwExceptionIfInputOrInstPrivacy(Instruction inst, ExecutionContext ec) {
        PrivacyPropagator.throwExceptionIfPrivacyActivated(inst);
        CPOperand[] inputOperands = PrivacyPropagator.getInputOperands(inst);
        if (inputOperands != null) {
            for (CPOperand input : inputOperands) {
                PrivacyConstraint privacyConstraint = PrivacyPropagator.getInputPrivacyConstraint(ec, input);
                if (privacyConstraint == null) continue;
                throw new DMLPrivacyException("Input of instruction " + inst + " has privacy constraints activated, but the constraints are not propagated during preprocessing of instruction.");
            }
        }
        return inst;
    }

    private static void throwExceptionIfPrivacyActivated(Instruction inst) {
        if (inst.getPrivacyConstraint() != null && inst.getPrivacyConstraint().hasConstraints()) {
            throw new DMLPrivacyException("Instruction " + inst + " has privacy constraints activated, but the constraints are not propagated during preprocessing of instruction.");
        }
    }

    private static Instruction propagateFirstInputPrivacy(VariableCPInstruction inst, ExecutionContext ec) {
        return PrivacyPropagator.propagateInputPrivacy(inst, ec, inst.getInput1(), inst.getOutput());
    }

    private static Instruction propagateSecondInputPrivacy(VariableCPInstruction inst, ExecutionContext ec) {
        return PrivacyPropagator.propagateInputPrivacy(inst, ec, inst.getInput2(), inst.getOutput());
    }

    private static Instruction propagateInputPrivacy(Instruction inst, ExecutionContext ec, CPOperand inputOperand, CPOperand outputOperand) {
        PrivacyConstraint privacyConstraint = PrivacyPropagator.getInputPrivacyConstraint(ec, inputOperand);
        if (privacyConstraint != null) {
            inst.setPrivacyConstraint(privacyConstraint);
            if (outputOperand != null) {
                outputOperand.setPrivacyConstraint(privacyConstraint);
            }
        }
        return inst;
    }

    private static PrivacyConstraint getInputPrivacyConstraint(ExecutionContext ec, CPOperand input) {
        Data dd;
        if (input != null && input.getName() != null && (dd = ec.getVariable(input.getName())) != null) {
            return dd.getPrivacyConstraint();
        }
        return null;
    }

    private static PrivacyConstraint[] getInputPrivacyConstraints(ExecutionContext ec, CPOperand[] inputs) {
        if (inputs != null && inputs.length > 0) {
            boolean privacyFound = false;
            PrivacyConstraint[] privacyConstraints = new PrivacyConstraint[inputs.length];
            for (int i = 0; i < inputs.length; ++i) {
                privacyConstraints[i] = PrivacyPropagator.getInputPrivacyConstraint(ec, inputs[i]);
                if (privacyConstraints[i] == null) continue;
                privacyFound = true;
            }
            if (privacyFound) {
                return privacyConstraints;
            }
        }
        return null;
    }

    private static void setOutputPrivacyConstraint(ExecutionContext ec, PrivacyConstraint privacyConstraint, String outputName) {
        Data dd;
        if (privacyConstraint != null && (dd = ec.getVariable(outputName)) != null) {
            dd.setPrivacyConstraints(privacyConstraint);
            ec.setVariable(outputName, dd);
        }
    }

    private static CPOperand[] getInputOperands(Instruction inst) {
        if (inst instanceof ComputationCPInstruction) {
            return ((ComputationCPInstruction)inst).getInputs();
        }
        if (inst instanceof BuiltinNaryCPInstruction) {
            return ((BuiltinNaryCPInstruction)inst).getInputs();
        }
        if (inst instanceof FunctionCallCPInstruction) {
            return ((FunctionCallCPInstruction)inst).getInputs();
        }
        if (inst instanceof SqlCPInstruction) {
            return ((SqlCPInstruction)inst).getInputs();
        }
        return null;
    }

    private static List<CPOperand> getOutputOperands(Instruction inst) {
        if (inst instanceof MultiReturnParameterizedBuiltinCPInstruction) {
            return ((MultiReturnParameterizedBuiltinCPInstruction)inst).getOutputs();
        }
        if (inst instanceof MultiReturnBuiltinCPInstruction) {
            return ((MultiReturnBuiltinCPInstruction)inst).getOutputs();
        }
        if (inst instanceof ComputationCPInstruction) {
            return PrivacyPropagator.getSingletonList(((ComputationCPInstruction)inst).getOutput());
        }
        if (inst instanceof VariableCPInstruction) {
            return PrivacyPropagator.getSingletonList(((VariableCPInstruction)inst).getOutput());
        }
        if (inst instanceof SqlCPInstruction) {
            return PrivacyPropagator.getSingletonList(((SqlCPInstruction)inst).getOutput());
        }
        if (inst instanceof BuiltinNaryCPInstruction) {
            return PrivacyPropagator.getSingletonList(((BuiltinNaryCPInstruction)inst).getOutput());
        }
        return new ArrayList<CPOperand>();
    }

    private static List<CPOperand> getSingletonList(CPOperand operand) {
        if (operand != null) {
            return new ArrayList<CPOperand>(Collections.singletonList(operand));
        }
        return new ArrayList<CPOperand>();
    }
}

