/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.drivers.opencl.graal.phases;

import java.util.HashMap;
import java.util.HashSet;
import org.graalvm.compiler.graph.Node;
import org.graalvm.compiler.nodes.AbstractBeginNode;
import org.graalvm.compiler.nodes.AbstractEndNode;
import org.graalvm.compiler.nodes.AbstractMergeNode;
import org.graalvm.compiler.nodes.BeginNode;
import org.graalvm.compiler.nodes.EndNode;
import org.graalvm.compiler.nodes.FixedNode;
import org.graalvm.compiler.nodes.IfNode;
import org.graalvm.compiler.nodes.LogicNode;
import org.graalvm.compiler.nodes.LoopBeginNode;
import org.graalvm.compiler.nodes.LoopEndNode;
import org.graalvm.compiler.nodes.PhiNode;
import org.graalvm.compiler.nodes.StructuredGraph;
import org.graalvm.compiler.phases.BasePhase;
import org.graalvm.compiler.phases.common.DeadCodeEliminationPhase;
import uk.ac.manchester.tornado.api.exceptions.TornadoInternalError;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.logic.LogicalNotNode;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.logic.LogicalOrNode;
import uk.ac.manchester.tornado.runtime.graal.phases.TornadoMidTierContext;

public class TornadoIfCanonicalization
extends BasePhase<TornadoMidTierContext> {
    protected void run(StructuredGraph graph, TornadoMidTierContext context) {
        graph.getNodes(IfNode.TYPE).forEach(ifNode -> this.canonicalize(graph, (IfNode)ifNode));
    }

    private boolean isMerge(AbstractBeginNode begin) {
        return begin.next() instanceof AbstractEndNode && ((AbstractEndNode)begin.next()).merge() instanceof AbstractMergeNode;
    }

    private boolean isIf(EndNode end) {
        return end.predecessor() instanceof BeginNode && end.predecessor().predecessor() instanceof IfNode;
    }

    private boolean isIf(LoopEndNode end) {
        return end.predecessor() instanceof BeginNode && end.predecessor().predecessor() instanceof IfNode;
    }

    private IfNode getIf(EndNode end) {
        return (IfNode)end.predecessor().predecessor();
    }

    private IfNode getIf(LoopEndNode end) {
        return (IfNode)end.predecessor().predecessor();
    }

    private boolean getBranchTaken(IfNode ifNode, EndNode end) {
        return ifNode.trueSuccessor().next().equals(end);
    }

    private boolean getBranchTaken(IfNode ifNode, LoopEndNode end) {
        return ifNode.trueSuccessor().next().equals(end);
    }

    private AbstractMergeNode getMerge(AbstractBeginNode begin) {
        return ((AbstractEndNode)begin.next()).merge();
    }

    private void canonicalize(StructuredGraph graph, IfNode ifNode) {
        System.out.printf("if-canonicalize: ifNode=%s\n", ifNode);
        if (ifNode.predecessor() instanceof LoopBeginNode) {
            return;
        }
        AbstractBeginNode trueBranch = ifNode.trueSuccessor();
        AbstractBeginNode falseBranch = ifNode.falseSuccessor();
        if (this.isMerge(trueBranch)) {
            this.tryMergeClauses(graph, ifNode, trueBranch);
        } else if (this.isMerge(falseBranch)) {
            this.tryMergeClauses(graph, ifNode, falseBranch);
        }
    }

    private void tryMergeClauses(StructuredGraph graph, IfNode ifNode, AbstractBeginNode branch) {
        System.out.printf("if-canonicalize: trying merge for ifNode=%s\n", ifNode);
        AbstractMergeNode merge = this.getMerge(branch);
        System.out.printf("if-canonicalize: merge=%s\n", merge);
        if (merge instanceof LoopBeginNode) {
            LoopBeginNode loopBegin = (LoopBeginNode)merge;
            int endCount = loopBegin.loopEnds().count();
            IfNode[] clauses = new IfNode[endCount];
            boolean[] branchTaken = new boolean[endCount];
            int i = 0;
            for (LoopEndNode end : loopBegin.orderedLoopEnds()) {
                System.out.printf("if-canonicalize: search end=%s\n", end);
                if (!this.isIf(end)) continue;
                clauses[i] = this.getIf(end);
                if (i != 0 && clauses[i].equals(clauses[i - 1])) continue;
                branchTaken[i] = this.getBranchTaken(clauses[i], end);
                System.out.printf("if-canonicalize: found clause %s on branch %s\n", clauses[i], branchTaken[i]);
                ++i;
            }
            boolean clausesValid = this.checkClauses(ifNode, clauses, branchTaken);
            if (clausesValid) {
                System.out.printf("check-clauses: passed\n", new Object[0]);
            }
        } else {
            IfNode[] clauses = new IfNode[merge.forwardEndCount()];
            boolean[] branchTaken = new boolean[merge.forwardEndCount()];
            int i = 0;
            for (EndNode end : merge.forwardEnds()) {
                System.out.printf("if-canonicalize: search end=%s\n", end);
                if (!this.isIf(end)) continue;
                clauses[i] = this.getIf(end);
                branchTaken[i] = this.getBranchTaken(clauses[i], end);
                System.out.printf("if-canonicalize: found clause %s on branch %s\n", clauses[i], branchTaken[i]);
                ++i;
            }
            boolean clausesValid = this.checkClauses(ifNode, clauses, branchTaken);
            if (clausesValid) {
                System.out.printf("check-clauses: passed\n", new Object[0]);
                int lastIndex = clauses.length - 1;
                LogicNode newCondition = this.mergeClauses(graph, clauses, branchTaken);
                clauses[lastIndex].setCondition(newCondition);
                this.cleanupClauses(graph, clauses, branchTaken, merge);
                new DeadCodeEliminationPhase().apply(graph);
            }
        }
    }

    private AbstractBeginNode getNode(IfNode ifNode, boolean branch) {
        return branch ? ifNode.trueSuccessor() : ifNode.falseSuccessor();
    }

    private boolean checkClauses(IfNode root, IfNode[] clauses, boolean[] branchTaken) {
        boolean result = true;
        HashSet<IfNode> ifNodes = new HashSet<IfNode>();
        HashMap<IfNode, Boolean> branches = new HashMap<IfNode, Boolean>();
        for (int i = 0; i < clauses.length; ++i) {
            ifNodes.add(clauses[i]);
            branches.put(clauses[i], branchTaken[i]);
        }
        IfNode current = root;
        System.out.printf("check-clauses: start=%s\n", current);
        for (int i = 0; i < clauses.length && result; ++i) {
            if (ifNodes.remove(current)) {
                clauses[i] = current;
                branchTaken[i] = (Boolean)branches.get(current);
                if (current.predecessor() instanceof LoopBeginNode) {
                    result = false;
                }
                if (ifNodes.isEmpty()) continue;
                AbstractBeginNode begin = this.getNode(current, !branchTaken[i]);
                System.out.printf("check-clauses: current=%s, branch=%s -> begin=%s\n", current, !branchTaken[i], begin);
                if (begin.next() instanceof IfNode) {
                    current = (IfNode)begin.next();
                    continue;
                }
                System.out.printf("check-clauses: next != ifNode (%s)\n", begin.next());
                continue;
            }
            System.out.printf("check-clauses: ifNode=%s not in set\n", current);
            result = false;
        }
        return result;
    }

    private void cleanupClauses(StructuredGraph graph, IfNode[] clauses, boolean[] branchTaken, AbstractMergeNode merge) {
        for (int i = 0; i < clauses.length - 1; ++i) {
            this.cleanupBranch(clauses[i], branchTaken[i]);
            merge.forwardEnds().remove((Object)clauses[i]);
            clauses[i].replaceAndDelete((Node)clauses[i + 1]);
        }
        EndNode validEnd = null;
        for (EndNode e : merge.forwardEnds()) {
            System.out.printf("merge-cleanup: forward end=%s\n", e);
            if (!e.isAlive()) continue;
            validEnd = e;
        }
        for (PhiNode phi : merge.phis()) {
            System.out.printf("merge-cleanup: phi=%s\n", phi);
        }
        TornadoInternalError.guarantee((merge.phis().count() == 0 ? 1 : 0) != 0, (String)"phi values exist on merge node that is to be removed", (Object[])new Object[0]);
        FixedNode current = merge.next();
        validEnd.replaceAtPredecessor((Node)current);
    }

    private void cleanupBranch(IfNode ifNode, boolean b) {
        AbstractBeginNode begin = this.getNode(ifNode, b);
        begin.next().markDeleted();
        begin.markDeleted();
    }

    private LogicNode createClause(StructuredGraph graph, LogicNode left, boolean negateLeft, LogicNode right, boolean negateRight) {
        LogicNode lhs = negateLeft ? (LogicNode)graph.addOrUnique((Node)new LogicalNotNode(left)) : left;
        LogicNode rhs = negateRight ? (LogicNode)graph.addOrUnique((Node)new LogicalNotNode(right)) : right;
        return (LogicNode)graph.addOrUnique((Node)new LogicalOrNode(lhs, rhs));
    }

    private LogicNode mergeClauses(StructuredGraph graph, IfNode[] clauses, boolean[] branchTaken) {
        LogicNode leftCondition = clauses[0].condition();
        for (int i = 1; i < clauses.length; ++i) {
            System.out.printf("i=%d\n", i);
            LogicNode rightCondition = clauses[i].condition();
            System.out.printf("merge-clauses: left=%s, right=%s\n", leftCondition, rightCondition);
            leftCondition = this.createClause(graph, leftCondition, false, rightCondition, !branchTaken[i]);
        }
        return leftCondition;
    }
}

