/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.resource;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.spark.SparkConf;
import org.apache.sysds.api.DMLOptions;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.parser.AssignmentStatement;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DMLTranslator;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.parser.Expression;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.FunctionDictionary;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.OutputStatement;
import org.apache.sysds.parser.ParserFactory;
import org.apache.sysds.parser.ParserWrapper;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.StringIdentifier;
import org.apache.sysds.parser.WhileStatementBlock;
import org.apache.sysds.resource.CloudUtils;
import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
import org.apache.sysds.runtime.controlprogram.ForProgramBlock;
import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysds.runtime.controlprogram.IfProgramBlock;
import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
import org.apache.sysds.runtime.controlprogram.Program;
import org.apache.sysds.runtime.controlprogram.ProgramBlock;
import org.apache.sysds.runtime.controlprogram.WhileProgramBlock;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.utils.stats.InfrastructureAnalyzer;

public class ResourceCompiler {
    public static final long DEFAULT_DRIVER_MEMORY = 0x20000000L;
    public static final int DEFAULT_DRIVER_THREADS = 1;
    public static final long DEFAULT_EXECUTOR_MEMORY = 0x20000000L;
    public static final int DEFAULT_EXECUTOR_THREADS = 2;
    public static final int DEFAULT_NUMBER_EXECUTORS = 2;

    public static Program compile(String filePath, Map<String, String> args) throws IOException {
        return ResourceCompiler.compile(filePath, args, null);
    }

    public static Program compile(String filePath, Map<String, String> args, HashMap<String, String> replaceVars) throws IOException {
        DMLOptions dmlOptions = DMLOptions.defaultOptions;
        dmlOptions.argVals = args;
        String dmlScriptStr = DMLScript.readDMLScript(true, filePath);
        Map<String, String> argVals = dmlOptions.argVals;
        ParserWrapper parser = ParserFactory.createParser();
        DMLProgram dmlProgram = parser.parse(null, dmlScriptStr, argVals);
        DMLTranslator dmlTranslator = new DMLTranslator(dmlProgram);
        dmlTranslator.liveVariableAnalysis(dmlProgram);
        dmlTranslator.validateParseTree(dmlProgram);
        if (replaceVars != null && !replaceVars.isEmpty()) {
            ResourceCompiler.replaceFilename(dmlProgram, replaceVars);
        }
        dmlTranslator.constructHops(dmlProgram);
        dmlTranslator.rewriteHopsDAG(dmlProgram);
        dmlTranslator.constructLops(dmlProgram);
        dmlTranslator.rewriteLopDAG(dmlProgram);
        return dmlTranslator.getRuntimeProgram(dmlProgram, ConfigurationManager.getDMLConfig());
    }

    public static void replaceFilename(DMLProgram dmlp, HashMap<String, String> replaceVars) {
        for (int i = 0; i < dmlp.getNumStatementBlocks(); ++i) {
            StatementBlock sb = dmlp.getStatementBlock(i);
            for (Statement statement : sb.getStatements()) {
                StringIdentifier stringIdentifier;
                if (!(statement instanceof AssignmentStatement) && !(statement instanceof OutputStatement)) continue;
                if (statement instanceof AssignmentStatement) {
                    Expression assignExpression = ((AssignmentStatement)statement).getSource();
                    if (!(assignExpression instanceof StringIdentifier) && !(assignExpression instanceof DataExpression)) continue;
                    if (assignExpression instanceof DataExpression) {
                        Expression filenameExpression = ((DataExpression)assignExpression).getVarParam("iofilename");
                        if (!(filenameExpression instanceof StringIdentifier)) continue;
                        stringIdentifier = (StringIdentifier)filenameExpression;
                    } else {
                        stringIdentifier = (StringIdentifier)assignExpression;
                    }
                } else {
                    Expression filenameExpression = ((OutputStatement)statement).getExprParam("iofilename");
                    if (!(filenameExpression instanceof StringIdentifier)) continue;
                    stringIdentifier = (StringIdentifier)filenameExpression;
                }
                if (!replaceVars.containsKey(stringIdentifier.getValue())) continue;
                String valToReplace = replaceVars.get(stringIdentifier.getValue());
                stringIdentifier.setValue(valToReplace);
            }
        }
    }

    public static Program doFullRecompilation(Program program) {
        OptimizerUtils.resetDefaultSize();
        Program newProgram = new Program(program.getDMLProg());
        ArrayList B = Stream.concat(program.getProgramBlocks().stream(), program.getFunctionProgramBlocks().values().stream()).collect(Collectors.toCollection(ArrayList::new));
        ResourceCompiler.doRecompilation(B, newProgram);
        return newProgram;
    }

    private static void doRecompilation(ArrayList<ProgramBlock> origin, Program target) {
        for (ProgramBlock originBlock : origin) {
            ResourceCompiler.doRecompilation(originBlock, target);
        }
    }

    private static void doRecompilation(ProgramBlock originBlock, Program target) {
        if (originBlock instanceof FunctionProgramBlock) {
            FunctionProgramBlock fpb = (FunctionProgramBlock)originBlock;
            Recompiler.recompileProgramBlockHierarchy(fpb.getChildBlocks(), new LocalVariableMap(), 0L, true, Recompiler.ResetType.NO_RESET);
            String functionName = ((FunctionStatement)fpb.getStatementBlock().getStatement(0)).getName();
            String namespace = null;
            for (Map.Entry<String, FunctionDictionary<FunctionStatementBlock>> pairNS : target.getDMLProg().getNamespaces().entrySet()) {
                if (!pairNS.getValue().containsFunction(functionName)) continue;
                namespace = pairNS.getKey();
            }
            target.addFunctionProgramBlock(namespace, functionName, fpb);
        } else if (originBlock instanceof IfProgramBlock) {
            IfProgramBlock ipb = (IfProgramBlock)originBlock;
            IfStatementBlock sb = (IfStatementBlock)ipb.getStatementBlock();
            if (sb != null && sb.getPredicateHops() != null) {
                ArrayList<Hop> hopAsList = new ArrayList<Hop>(Collections.singletonList(sb.getPredicateHops()));
                ArrayList<Instruction> inst = Recompiler.recompile(null, hopAsList, null, null, true, false, true, false, false, null, 0L);
                ipb.setPredicate(inst);
                target.addProgramBlock(ipb);
            }
            ResourceCompiler.doRecompilation(ipb.getChildBlocksIfBody(), target);
            ResourceCompiler.doRecompilation(ipb.getChildBlocksElseBody(), target);
        } else if (originBlock instanceof WhileProgramBlock) {
            WhileProgramBlock wpb = (WhileProgramBlock)originBlock;
            WhileStatementBlock sb = (WhileStatementBlock)originBlock.getStatementBlock();
            if (sb != null && sb.getPredicateHops() != null) {
                ArrayList<Hop> hopAsList = new ArrayList<Hop>(Collections.singletonList(sb.getPredicateHops()));
                ArrayList<Instruction> inst = Recompiler.recompile(null, hopAsList, null, null, true, false, true, false, false, null, 0L);
                wpb.setPredicate(inst);
                target.addProgramBlock(wpb);
            }
            ResourceCompiler.doRecompilation(wpb.getChildBlocks(), target);
        } else if (originBlock instanceof ForProgramBlock) {
            ForProgramBlock fpb = (ForProgramBlock)originBlock;
            ForStatementBlock sb = (ForStatementBlock)fpb.getStatementBlock();
            if (sb != null) {
                ArrayList<Instruction> inst;
                ArrayList<Hop> hopAsList;
                if (sb.getFromHops() != null) {
                    hopAsList = new ArrayList<Hop>(Collections.singletonList(sb.getFromHops()));
                    inst = Recompiler.recompile(null, hopAsList, null, null, true, false, true, false, false, null, 0L);
                    fpb.setFromInstructions(inst);
                }
                if (sb.getToHops() != null) {
                    hopAsList = new ArrayList<Hop>(Collections.singletonList(sb.getToHops()));
                    inst = Recompiler.recompile(null, hopAsList, null, null, true, false, true, false, false, null, 0L);
                    fpb.setToInstructions(inst);
                }
                if (sb.getIncrementHops() != null) {
                    hopAsList = new ArrayList<Hop>(Collections.singletonList(sb.getIncrementHops()));
                    inst = Recompiler.recompile(null, hopAsList, null, null, true, false, true, false, false, null, 0L);
                    fpb.setIncrementInstructions(inst);
                }
                target.addProgramBlock(fpb);
            }
            ResourceCompiler.doRecompilation(fpb.getChildBlocks(), target);
        } else {
            BasicProgramBlock bpb = (BasicProgramBlock)originBlock;
            StatementBlock sb = bpb.getStatementBlock();
            ArrayList<Instruction> inst = Recompiler.recompile(sb, sb.getHops(), ExecutionContextFactory.createContext(target), null, true, false, true, false, false, null, 0L);
            bpb.setInstructions(inst);
            target.addProgramBlock(bpb);
        }
    }

    public static void setSingleNodeResourceConfigs(long nodeMemory, int nodeCores) {
        DMLScript.setGlobalExecMode(Types.ExecMode.SINGLE_NODE);
        long effectiveSingleNodeMemory = (long)((double)nodeMemory * 0.9);
        InfrastructureAnalyzer.setLocalMaxMemory(effectiveSingleNodeMemory);
        InfrastructureAnalyzer.setLocalPar(nodeCores);
    }

    public static void setSparkClusterResourceConfigs(long driverMemory, int driverCores, int numExecutors, long executorMemory, int executorCores) {
        if (numExecutors <= 0) {
            throw new RuntimeException("The given number of executors was non-positive");
        }
        long effectiveDriverMemory = CloudUtils.calculateEffectiveDriverMemoryBudget(driverMemory, numExecutors * executorCores);
        if (effectiveDriverMemory <= CloudUtils.GBtoBytes(1.0) || driverMemory > 2L * effectiveDriverMemory) {
            throw new IllegalArgumentException("Driver resources are not sufficient to handle the cluster");
        }
        InfrastructureAnalyzer.setLocalMaxMemory(effectiveDriverMemory);
        InfrastructureAnalyzer.setLocalPar(driverCores);
        DMLScript.setGlobalExecMode(Types.ExecMode.HYBRID);
        SparkConf sparkConf = SparkExecutionContext.createSystemDSSparkConf();
        sparkConf.set("spark.master", "local[*]");
        sparkConf.set("spark.app.name", "SystemDS");
        sparkConf.set("spark.memory.useLegacyMode", "false");
        int[] effectiveValues = CloudUtils.getEffectiveExecutorResources(executorMemory, executorCores, numExecutors);
        int effectiveExecutorMemory = effectiveValues[0];
        int effectiveExecutorCores = effectiveValues[1];
        int effectiveNumExecutor = effectiveValues[2];
        sparkConf.set("spark.executor.memory", effectiveExecutorMemory + "m");
        sparkConf.set("spark.executor.instances", Integer.toString(effectiveNumExecutor));
        sparkConf.set("spark.executor.cores", Integer.toString(effectiveExecutorCores));
        SparkExecutionContext.initLocalSparkContext(sparkConf);
    }
}

