/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.drivers.common.compiler.phases.analysis;

import java.util.LinkedHashMap;
import java.util.Optional;
import jdk.vm.ci.meta.JavaKind;
import org.graalvm.compiler.graph.Node;
import org.graalvm.compiler.nodes.ConstantNode;
import org.graalvm.compiler.nodes.GraphState;
import org.graalvm.compiler.nodes.IfNode;
import org.graalvm.compiler.nodes.LoopBeginNode;
import org.graalvm.compiler.nodes.ParameterNode;
import org.graalvm.compiler.nodes.StructuredGraph;
import org.graalvm.compiler.nodes.ValueNode;
import org.graalvm.compiler.nodes.calc.AddNode;
import org.graalvm.compiler.nodes.calc.AndNode;
import org.graalvm.compiler.nodes.calc.FloatEqualsNode;
import org.graalvm.compiler.nodes.calc.FloatLessThanNode;
import org.graalvm.compiler.nodes.calc.IntegerDivRemNode;
import org.graalvm.compiler.nodes.calc.IntegerEqualsNode;
import org.graalvm.compiler.nodes.calc.IntegerLessThanNode;
import org.graalvm.compiler.nodes.calc.LeftShiftNode;
import org.graalvm.compiler.nodes.calc.MulNode;
import org.graalvm.compiler.nodes.calc.OrNode;
import org.graalvm.compiler.nodes.calc.RemNode;
import org.graalvm.compiler.nodes.calc.RightShiftNode;
import org.graalvm.compiler.nodes.calc.ShiftNode;
import org.graalvm.compiler.nodes.calc.SignExtendNode;
import org.graalvm.compiler.nodes.calc.SignedDivNode;
import org.graalvm.compiler.nodes.calc.SignedRemNode;
import org.graalvm.compiler.nodes.calc.SubNode;
import org.graalvm.compiler.nodes.calc.UnaryArithmeticNode;
import org.graalvm.compiler.nodes.calc.UnsignedRightShiftNode;
import org.graalvm.compiler.nodes.calc.XorNode;
import org.graalvm.compiler.nodes.extended.IntegerSwitchNode;
import org.graalvm.compiler.nodes.memory.FloatingReadNode;
import org.graalvm.compiler.nodes.memory.ReadNode;
import org.graalvm.compiler.nodes.memory.WriteNode;
import org.graalvm.compiler.nodes.memory.address.AddressNode;
import org.graalvm.compiler.phases.BasePhase;
import org.graalvm.compiler.phases.Phase;
import uk.ac.manchester.tornado.api.TornadoDeviceContext;
import uk.ac.manchester.tornado.runtime.graal.nodes.interfaces.MarkCastNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.interfaces.MarkFloatingPointIntrinsicsNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.interfaces.MarkGlobalThreadID;
import uk.ac.manchester.tornado.runtime.graal.nodes.interfaces.MarkIntIntrinsicNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.interfaces.MarkLocalArray;
import uk.ac.manchester.tornado.runtime.graal.nodes.interfaces.MarkOCLWriteNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.interfaces.MarkVectorLoad;
import uk.ac.manchester.tornado.runtime.graal.nodes.interfaces.MarkVectorValueNode;
import uk.ac.manchester.tornado.runtime.profiler.FeatureExtractionUtilities;
import uk.ac.manchester.tornado.runtime.profiler.ProfilerCodeFeatures;

public class TornadoFeatureExtraction
extends Phase {
    private TornadoDeviceContext tornadoDeviceContext;

    public TornadoFeatureExtraction(TornadoDeviceContext tornadoDeviceContext) {
        this.tornadoDeviceContext = tornadoDeviceContext;
    }

    public Optional<BasePhase.NotApplicable> notApplicableTo(GraphState graphState) {
        return ALWAYS_APPLICABLE;
    }

    protected void run(StructuredGraph graph) {
        LinkedHashMap<ProfilerCodeFeatures, Integer> irfeatures = this.extractFeatures(graph, FeatureExtractionUtilities.initializeFeatureMap());
        FeatureExtractionUtilities.emitFeatureProfileJsonFile(irfeatures, (StructuredGraph)graph, (TornadoDeviceContext)this.tornadoDeviceContext);
    }

    private LinkedHashMap<ProfilerCodeFeatures, Integer> extractFeatures(StructuredGraph graph, LinkedHashMap<ProfilerCodeFeatures, Integer> initMap) {
        LinkedHashMap<ProfilerCodeFeatures, Integer> irFeatures = initMap;
        for (Node node : graph.getNodes().snapshot()) {
            if (node instanceof MulNode || node instanceof AddNode || node instanceof SubNode || node instanceof SignedDivNode || node instanceof AddNode || node instanceof IntegerDivRemNode || node instanceof RemNode || node instanceof SignedRemNode || node instanceof FloatEqualsNode || node instanceof IntegerEqualsNode) {
                this.updateWithType(irFeatures, node);
                continue;
            }
            if (node instanceof MarkOCLWriteNode || node instanceof WriteNode) {
                this.updateMemoryAccesses(irFeatures, node, false);
                continue;
            }
            if (node instanceof FloatingReadNode || node instanceof ReadNode) {
                this.updateMemoryAccesses(irFeatures, node, true);
                continue;
            }
            if (node instanceof LoopBeginNode) {
                this.updateCounter(irFeatures, ProfilerCodeFeatures.LOOPS);
                continue;
            }
            if (node instanceof IfNode) {
                this.updateCounter(irFeatures, ProfilerCodeFeatures.IFS);
                continue;
            }
            if (node instanceof IntegerSwitchNode) {
                this.updateCounter(irFeatures, ProfilerCodeFeatures.SWITCH);
                int countCases = irFeatures.get(ProfilerCodeFeatures.CASE);
                irFeatures.put(ProfilerCodeFeatures.CASE, countCases + ((IntegerSwitchNode)node).getSuccessorCount());
                continue;
            }
            if (node instanceof MarkVectorLoad || node instanceof MarkVectorValueNode) {
                this.updateCounter(irFeatures, ProfilerCodeFeatures.VECTORS);
                continue;
            }
            if (node instanceof IntegerLessThanNode) {
                this.updateCounter(irFeatures, ProfilerCodeFeatures.I_CMP);
                continue;
            }
            if (node instanceof OrNode || node instanceof AndNode || node instanceof LeftShiftNode || node instanceof RightShiftNode || node instanceof UnsignedRightShiftNode || node instanceof ShiftNode || node instanceof XorNode) {
                this.updateWithType(irFeatures, node);
                continue;
            }
            if (node instanceof MarkGlobalThreadID) {
                this.updateCounter(irFeatures, ProfilerCodeFeatures.PARALLEL_LOOPS);
                continue;
            }
            if (node instanceof ConstantNode || node instanceof ParameterNode || node instanceof SignExtendNode) {
                this.updateCounter(irFeatures, ProfilerCodeFeatures.PRIVATE_LOADS);
                this.updateCounter(irFeatures, ProfilerCodeFeatures.PRIVATE_STORES);
                continue;
            }
            if (node instanceof MarkCastNode) {
                this.updateCounter(irFeatures, ProfilerCodeFeatures.CAST);
                continue;
            }
            if (node instanceof FloatLessThanNode) {
                this.updateCounter(irFeatures, ProfilerCodeFeatures.F_CMP);
                continue;
            }
            if (node instanceof MarkFloatingPointIntrinsicsNode || node instanceof UnaryArithmeticNode) {
                this.updateCounter(irFeatures, ProfilerCodeFeatures.F_MATH);
                continue;
            }
            if (!(node instanceof MarkIntIntrinsicNode)) continue;
            this.updateCounter(irFeatures, ProfilerCodeFeatures.I_MATH);
        }
        return irFeatures;
    }

    private JavaKind getPrimitiveType(Node inputNode) {
        return ((ValueNode)inputNode).getStackKind();
    }

    private void updateCounter(LinkedHashMap<ProfilerCodeFeatures, Integer> irFeatures, ProfilerCodeFeatures feature) {
        irFeatures.put(feature, irFeatures.get(feature) + 1);
    }

    private void updateWithType(LinkedHashMap<ProfilerCodeFeatures, Integer> irFeatures, Node node) {
        JavaKind opType = this.getPrimitiveType(node);
        if (opType == JavaKind.Boolean || opType == JavaKind.Char || opType == JavaKind.Int || opType == JavaKind.Short || opType == JavaKind.Long) {
            this.updateCounter(irFeatures, ProfilerCodeFeatures.INTEGER_OPS);
        } else if (opType == JavaKind.Double) {
            this.updateCounter(irFeatures, ProfilerCodeFeatures.FLOAT_OPS);
            this.updateCounter(irFeatures, ProfilerCodeFeatures.DOUBLES);
        } else if (opType == JavaKind.Float) {
            this.updateCounter(irFeatures, ProfilerCodeFeatures.FLOAT_OPS);
            this.updateCounter(irFeatures, ProfilerCodeFeatures.FP32);
        }
    }

    private void updateMemoryAccesses(LinkedHashMap<ProfilerCodeFeatures, Integer> irFeatures, Node node, boolean isLoad) {
        for (Node memOpNode : node.inputs().filter(AddressNode.class)) {
            for (Node addressInput : memOpNode.inputs()) {
                if (addressInput instanceof MarkLocalArray) {
                    if (isLoad) {
                        this.updateCounter(irFeatures, ProfilerCodeFeatures.LOCAL_LOADS);
                        continue;
                    }
                    this.updateCounter(irFeatures, ProfilerCodeFeatures.LOCAL_STORES);
                    continue;
                }
                if (addressInput instanceof ParameterNode) {
                    if (isLoad) {
                        this.updateCounter(irFeatures, ProfilerCodeFeatures.GLOBAL_LOADS);
                        continue;
                    }
                    this.updateCounter(irFeatures, ProfilerCodeFeatures.GLOBAL_STORES);
                    continue;
                }
                if (!(addressInput instanceof FloatingReadNode) || isLoad) continue;
                this.updateCounter(irFeatures, ProfilerCodeFeatures.GLOBAL_STORES);
            }
        }
    }
}

