/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.lops.rewrite;

import java.util.ArrayList;
import java.util.List;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.lops.BinaryScalar;
import org.apache.sysds.lops.Data;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.OperatorOrderingUtils;
import org.apache.sysds.lops.RightIndex;
import org.apache.sysds.lops.UnaryCP;
import org.apache.sysds.lops.rewrite.LopRewriteRule;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.VariableSet;
import org.apache.sysds.runtime.lineage.LineageCacheConfig;

public class RewriteAddGPUEvictLop
extends LopRewriteRule {
    @Override
    public List<StatementBlock> rewriteLOPinStatementBlock(StatementBlock sb) {
        if (!ConfigurationManager.isAutoEvictionEnabled()) {
            return List.of(sb);
        }
        if (sb == null || !(sb instanceof ForStatementBlock) || !DMLScript.USE_ACCELERATOR || LineageCacheConfig.ReuseCacheType.isNone()) {
            return List.of(sb);
        }
        StatementBlock csb = ((ForStatement)sb.getStatement(0)).getBody().get(0);
        ArrayList<Lop> lops = OperatorOrderingUtils.getLopList(csb);
        boolean isMiniBatch = this.findMiniBatchSlicing(lops);
        ArrayList<StatementBlock> ret = new ArrayList<StatementBlock>();
        if (isMiniBatch) {
            int evictFrac = 100;
            StatementBlock sb0 = new StatementBlock();
            sb0.setDMLProg(sb.getDMLProg());
            sb0.setParseInfo(sb);
            sb0.setLiveIn(new VariableSet());
            sb0.setLiveOut(new VariableSet());
            ArrayList<Lop> newlops = new ArrayList<Lop>();
            ArrayList<Hop> newhops = new ArrayList<Hop>();
            Data fr = Data.createLiteralLop(Types.ValueType.INT64, Integer.toString(evictFrac));
            fr.getOutputParameters().setDimensions(0L, 0L, 0L, -1L);
            UnaryCP evict = new UnaryCP((Lop)fr, Types.OpOp1._EVICT, fr.getDataType(), fr.getValueType(), Types.ExecType.CP);
            LiteralOp in = new LiteralOp(evictFrac);
            UnaryOp evictHop = new UnaryOp("tmp", Types.DataType.SCALAR, Types.ValueType.INT64, Types.OpOp1._EVICT, in);
            newlops.add(evict);
            newhops.add(evictHop);
            sb0.setLops(newlops);
            sb0.setHops(newhops);
            ret.add(sb0);
        }
        ret.add(sb);
        return ret;
    }

    @Override
    public List<StatementBlock> rewriteLOPinStatementBlocks(List<StatementBlock> sbs) {
        return sbs;
    }

    private boolean findMiniBatchSlicing(ArrayList<Lop> lops) {
        for (Lop l : lops) {
            ArrayList<Lop> inputs;
            if (!(l instanceof RightIndex) || !((inputs = l.getInputs()).get(0) instanceof Data) || !((Data)inputs.get(0)).isTransientRead() || inputs.get(0).getInputs().size() != 0 || !(inputs.get(1) instanceof BinaryScalar) || ((BinaryScalar)inputs.get(1)).getOperationType() != Types.OpOp2.PLUS || !(inputs.get(2) instanceof BinaryScalar) || ((BinaryScalar)inputs.get(2)).getOperationType() != Types.OpOp2.MIN) continue;
            return true;
        }
        return false;
    }
}

