/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.hops.codegen.template;

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import org.apache.sysml.hops.codegen.cplan.CNode;
import org.apache.sysml.hops.codegen.cplan.CNodeData;
import org.apache.sysml.hops.codegen.cplan.CNodeMultiAgg;
import org.apache.sysml.hops.codegen.cplan.CNodeTpl;

public class CPlanCSERewriter {
    public CNodeTpl eliminateCommonSubexpressions(CNodeTpl tpl) {
        List<CNode> outputs = tpl instanceof CNodeMultiAgg ? ((CNodeMultiAgg)tpl).getOutputs() : Collections.singletonList(tpl.getOutput());
        tpl.resetVisitStatusOutputs();
        for (CNode out : outputs) {
            this.rSetStrictDataNodeComparision(out, true);
        }
        HashMap<CNode, CNode> cseSet = new HashMap<CNode, CNode>();
        tpl.resetVisitStatusOutputs();
        for (CNode out : outputs) {
            this.rEliminateCommonSubexpression(out, cseSet);
        }
        tpl.resetVisitStatusOutputs();
        for (CNode out : outputs) {
            this.rSetStrictDataNodeComparision(out, false);
        }
        tpl.resetVisitStatusOutputs();
        return tpl;
    }

    private void rEliminateCommonSubexpression(CNode current, HashMap<CNode, CNode> cseSet) {
        if (current.isVisited()) {
            return;
        }
        for (int i = 0; i < current.getInput().size(); ++i) {
            CNode input = current.getInput().get(i);
            if (!cseSet.containsKey(input)) continue;
            current.getInput().set(i, cseSet.get(input));
        }
        for (CNode input : current.getInput()) {
            this.rEliminateCommonSubexpression(input, cseSet);
        }
        cseSet.put(current, current);
        current.setVisited();
    }

    private void rSetStrictDataNodeComparision(CNode current, boolean flag) {
        if (current.isVisited()) {
            return;
        }
        for (CNode input : current.getInput()) {
            this.rSetStrictDataNodeComparision(input, flag);
            input.resetHash();
        }
        if (current instanceof CNodeData) {
            ((CNodeData)current).setStrictEquals(flag);
        }
        current.setVisited();
    }
}

