/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.runtime.graal.phases.sketcher;

import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import jdk.vm.ci.meta.ResolvedJavaMethod;
import org.graalvm.compiler.graph.Node;
import org.graalvm.compiler.nodes.ConstantNode;
import org.graalvm.compiler.nodes.FrameState;
import org.graalvm.compiler.nodes.GraphState;
import org.graalvm.compiler.nodes.StructuredGraph;
import org.graalvm.compiler.nodes.ValueNode;
import org.graalvm.compiler.nodes.ValuePhiNode;
import org.graalvm.compiler.nodes.calc.IntegerLessThanNode;
import org.graalvm.compiler.nodes.loop.InductionVariable;
import org.graalvm.compiler.nodes.loop.LoopEx;
import org.graalvm.compiler.phases.BasePhase;
import uk.ac.manchester.tornado.api.common.TornadoDevice;
import uk.ac.manchester.tornado.api.enums.TornadoDeviceType;
import uk.ac.manchester.tornado.api.exceptions.TornadoBailoutRuntimeException;
import uk.ac.manchester.tornado.api.exceptions.TornadoCompilationException;
import uk.ac.manchester.tornado.runtime.ASMClassVisitorProvider;
import uk.ac.manchester.tornado.runtime.common.ParallelAnnotationProvider;
import uk.ac.manchester.tornado.runtime.common.TornadoOptions;
import uk.ac.manchester.tornado.runtime.graal.nodes.ParallelOffsetNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.ParallelRangeNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.ParallelStrideNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.TornadoLoopsData;
import uk.ac.manchester.tornado.runtime.graal.phases.TornadoSketchTierContext;

public class TornadoApiReplacement
extends BasePhase<TornadoSketchTierContext> {
    private static ASMClassVisitorProvider asmClassVisitorProvider;

    public Optional<BasePhase.NotApplicable> notApplicableTo(GraphState graphState) {
        return ALWAYS_APPLICABLE;
    }

    protected void run(StructuredGraph graph, TornadoSketchTierContext context) {
        this.replaceLocalAnnotations(graph, context);
    }

    private void replaceLocalAnnotations(StructuredGraph graph, TornadoSketchTierContext context) throws TornadoCompilationException {
        Map<Node, ParallelAnnotationProvider> parallelNodes = this.getAnnotatedNodes(graph, context);
        this.addParallelProcessingNodes(graph, parallelNodes, context.getDevice());
    }

    private Map<Node, ParallelAnnotationProvider> getAnnotatedNodes(StructuredGraph graph, TornadoSketchTierContext context) {
        HashMap<ResolvedJavaMethod, ParallelAnnotationProvider[]> methodToAnnotations = new HashMap<ResolvedJavaMethod, ParallelAnnotationProvider[]>();
        methodToAnnotations.put(context.getMethod(), asmClassVisitorProvider.getParallelAnnotations(context.getMethod()));
        for (ResolvedJavaMethod resolvedJavaMethod : graph.getMethods()) {
            ParallelAnnotationProvider[] inlineParallelAnnotations = asmClassVisitorProvider.getParallelAnnotations(resolvedJavaMethod);
            if (inlineParallelAnnotations.length <= 0) continue;
            methodToAnnotations.put(resolvedJavaMethod, inlineParallelAnnotations);
        }
        HashMap<Node, ParallelAnnotationProvider> parallelNodes = new HashMap<Node, ParallelAnnotationProvider>();
        graph.getNodes().filter(FrameState.class).forEach(frameState -> {
            if (methodToAnnotations.containsKey(frameState.getMethod())) {
                for (ParallelAnnotationProvider annotation : (ParallelAnnotationProvider[])methodToAnnotations.get(frameState.getMethod())) {
                    ValueNode localNode;
                    if (frameState.bci < annotation.getStart() || frameState.bci >= annotation.getStart() + annotation.getLength() || parallelNodes.containsKey(localNode = frameState.localAt(annotation.getIndex()))) continue;
                    parallelNodes.put((Node)localNode, annotation);
                }
            }
        });
        return parallelNodes;
    }

    private void addParallelProcessingNodes(StructuredGraph graph, Map<Node, ParallelAnnotationProvider> parallelNodes, TornadoDevice device) {
        if (graph.hasLoops()) {
            TornadoLoopsData data = new TornadoLoopsData(graph);
            data.detectCountedLoops();
            int loopIndex = 0;
            List loops = data.outerFirst();
            if (device.getDeviceType() != TornadoDeviceType.CPU && TornadoOptions.TORNADO_LOOP_INTERCHANGE) {
                Collections.reverse(loops);
            }
            for (LoopEx loop : loops) {
                for (InductionVariable iv : loop.getInductionVariables().getValues()) {
                    if (!parallelNodes.containsKey(iv.valueNode())) continue;
                    List conditions = iv.valueNode().usages().filter(IntegerLessThanNode.class).snapshot();
                    IntegerLessThanNode lessThan = (IntegerLessThanNode)conditions.getFirst();
                    ValueNode maxIterations = lessThan.getY();
                    this.parallelizationReplacement(graph, iv, loopIndex, maxIterations, conditions);
                    ++loopIndex;
                }
            }
        }
    }

    private void parallelizationReplacement(StructuredGraph graph, InductionVariable inductionVar, int loopIndex, ValueNode maxIterations, List<IntegerLessThanNode> conditions) throws TornadoCompilationException {
        ValueNode oldStride;
        ValuePhiNode phi;
        ParallelRangeNode range;
        ParallelStrideNode stride;
        ParallelOffsetNode offset;
        if (inductionVar.isConstantInit() && inductionVar.isConstantStride()) {
            ConstantNode newInit = (ConstantNode)graph.addWithoutUnique((Node)ConstantNode.forInt((int)((int)inductionVar.constantInit())));
            ConstantNode newStride = (ConstantNode)graph.addWithoutUnique((Node)ConstantNode.forInt((int)((int)inductionVar.constantStride())));
            offset = (ParallelOffsetNode)graph.addWithoutUnique((Node)new ParallelOffsetNode(loopIndex, (ValueNode)newInit));
            stride = (ParallelStrideNode)graph.addWithoutUnique((Node)new ParallelStrideNode(loopIndex, (ValueNode)newStride));
            range = (ParallelRangeNode)graph.addWithoutUnique((Node)new ParallelRangeNode(loopIndex, maxIterations, offset, stride));
            phi = (ValuePhiNode)inductionVar.valueNode();
            oldStride = phi.singleBackValueOrThis();
            if (oldStride.usages().count() > 1) {
                ValueNode duplicateStride = (ValueNode)oldStride.copyWithInputs(true);
                oldStride.replaceAtMatchingUsages((Node)duplicateStride, usage -> !usage.equals(phi));
            }
        } else {
            throw new TornadoBailoutRuntimeException("Failed to parallelize because of non-constant loop strides. \nSequential code will run on the device!");
        }
        inductionVar.initNode().replaceAtMatchingUsages((Node)offset, node -> node.equals(phi));
        inductionVar.strideNode().replaceAtMatchingUsages((Node)stride, node -> node.equals(oldStride));
        maxIterations.replaceAtMatchingUsages((Node)range, node -> node.equals(conditions.getFirst()));
    }

    static {
        try {
            String tornadoAnnotationImplementation = System.getProperty("tornado.load.annotation.implementation");
            if (tornadoAnnotationImplementation == null) {
                throw new RuntimeException("[ERROR] Tornado Annotation Implementation class not specified. Did you remember to add @tornado-argfile?");
            }
            Class<?> klass = Class.forName(tornadoAnnotationImplementation);
            Constructor<?> constructor = klass.getConstructor(new Class[0]);
            asmClassVisitorProvider = (ASMClassVisitorProvider)constructor.newInstance(new Object[0]);
        }
        catch (ClassNotFoundException | IllegalAccessException | IllegalArgumentException | InstantiationException | NoSuchMethodException | SecurityException | InvocationTargetException e) {
            throw new RuntimeException("[ERROR] Tornado Annotation Implementation class not found", e);
        }
    }
}

