/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.runtime.analyzer;

import java.lang.annotation.Annotation;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Objects;
import org.graalvm.compiler.graph.Graph;
import org.graalvm.compiler.graph.Node;
import org.graalvm.compiler.graph.iterators.NodeIterable;
import org.graalvm.compiler.nodes.ConstantNode;
import org.graalvm.compiler.nodes.EndNode;
import org.graalvm.compiler.nodes.FixedNode;
import org.graalvm.compiler.nodes.IfNode;
import org.graalvm.compiler.nodes.InvokeNode;
import org.graalvm.compiler.nodes.LogicNode;
import org.graalvm.compiler.nodes.LoopBeginNode;
import org.graalvm.compiler.nodes.MergeNode;
import org.graalvm.compiler.nodes.ParameterNode;
import org.graalvm.compiler.nodes.PhiNode;
import org.graalvm.compiler.nodes.StartNode;
import org.graalvm.compiler.nodes.StructuredGraph;
import org.graalvm.compiler.nodes.ValueNode;
import org.graalvm.compiler.nodes.calc.AddNode;
import org.graalvm.compiler.nodes.calc.BinaryArithmeticNode;
import org.graalvm.compiler.nodes.calc.BinaryNode;
import org.graalvm.compiler.nodes.calc.IntegerLessThanNode;
import org.graalvm.compiler.nodes.calc.MulNode;
import org.graalvm.compiler.nodes.java.ArrayLengthNode;
import org.graalvm.compiler.nodes.java.MethodCallTargetNode;
import org.graalvm.compiler.nodes.java.StoreIndexedNode;
import uk.ac.manchester.tornado.api.annotations.Reduce;
import uk.ac.manchester.tornado.api.common.PrebuiltTaskPackage;
import uk.ac.manchester.tornado.api.common.TaskPackage;
import uk.ac.manchester.tornado.api.exceptions.TornadoRuntimeException;
import uk.ac.manchester.tornado.api.types.arrays.TornadoNativeArray;
import uk.ac.manchester.tornado.runtime.TornadoCoreRuntime;
import uk.ac.manchester.tornado.runtime.analyzer.CodeAnalysis;
import uk.ac.manchester.tornado.runtime.analyzer.MetaReduceCodeAnalysis;
import uk.ac.manchester.tornado.runtime.analyzer.MetaReduceTasks;
import uk.ac.manchester.tornado.runtime.graal.nodes.StoreAtomicIndexedNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.TornadoReduceAddNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.WriteAtomicNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.interfaces.MarkFloatingPointIntrinsicsNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.interfaces.MarkIntIntrinsicNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.interfaces.MarkIntrinsicsNode;

public class ReduceCodeAnalysis {
    private static boolean checkIfVarIsInLoop(Node store) {
        Node node = store.predecessor();
        boolean hasPred = true;
        while (hasPred) {
            if (node instanceof LoopBeginNode) {
                return true;
            }
            if (node instanceof StartNode) {
                hasPred = false;
                continue;
            }
            if (node instanceof MergeNode) {
                MergeNode merge = (MergeNode)node;
                EndNode endNode = merge.forwardEndAt(0);
                node = endNode.predecessor();
                continue;
            }
            node = node.predecessor();
        }
        return false;
    }

    public static List<REDUCE_OPERATION> getReduceOperation(List<ValueNode> reduceOperation) {
        ArrayList<REDUCE_OPERATION> operations = new ArrayList<REDUCE_OPERATION>();
        for (ValueNode operation : reduceOperation) {
            String op;
            MarkIntrinsicsNode mark;
            if (operation instanceof TornadoReduceAddNode) {
                operations.add(REDUCE_OPERATION.SUM);
                continue;
            }
            if (operation instanceof AddNode) {
                operations.add(REDUCE_OPERATION.SUM);
                continue;
            }
            if (operation instanceof MulNode) {
                operations.add(REDUCE_OPERATION.MUL);
                continue;
            }
            if (operation instanceof InvokeNode) {
                InvokeNode invoke = (InvokeNode)operation;
                if (invoke.callTarget().targetName().equals("Math.max")) {
                    operations.add(REDUCE_OPERATION.MAX);
                    continue;
                }
                if (invoke.callTarget().targetName().equals("Math.min")) {
                    operations.add(REDUCE_OPERATION.MIN);
                    continue;
                }
                throw new TornadoRuntimeException("[ERROR] Automatic reduce operation not supported yet: " + String.valueOf(operation));
            }
            if (operation instanceof BinaryNode && operation instanceof MarkFloatingPointIntrinsicsNode) {
                mark = (MarkFloatingPointIntrinsicsNode)operation;
                op = mark.getOperation();
                if (op.equals("FMAX")) {
                    operations.add(REDUCE_OPERATION.MAX);
                    continue;
                }
                if (op.equals("FMIN")) {
                    operations.add(REDUCE_OPERATION.MIN);
                    continue;
                }
                throw new TornadoRuntimeException("[ERROR] Automatic reduce operation not supported yet: " + String.valueOf(operation));
            }
            if (operation instanceof BinaryNode && operation instanceof MarkIntIntrinsicNode) {
                mark = (MarkIntIntrinsicNode)operation;
                op = mark.getOperation();
                if (op.equals("MAX")) {
                    operations.add(REDUCE_OPERATION.MAX);
                    continue;
                }
                if (op.equals("MIN")) {
                    operations.add(REDUCE_OPERATION.MIN);
                    continue;
                }
                throw new TornadoRuntimeException("[ERROR] Automatic reduce operation not supported yet: " + String.valueOf(operation));
            }
            throw new TornadoRuntimeException("[ERROR] Automatic reduce operation not supported yet: " + String.valueOf(operation));
        }
        return operations;
    }

    private static boolean shouldSkip(int index, StructuredGraph graph) {
        return graph.method().isStatic() && index >= ReduceCodeAnalysis.getNumberOfParameterNodes(graph);
    }

    public static List<REDUCE_OPERATION> getReduceOperation(StructuredGraph graph, List<Integer> reduceIndices) {
        ArrayList<ValueNode> reduceOperation = new ArrayList<ValueNode>();
        for (Integer paramIndex : reduceIndices) {
            if (!graph.method().isStatic()) {
                Integer n = paramIndex;
                paramIndex = paramIndex + 1;
            }
            if (ReduceCodeAnalysis.shouldSkip(paramIndex, graph)) continue;
            ParameterNode parameterNode = graph.getParameter(paramIndex.intValue());
            NodeIterable usages = parameterNode.usages();
            for (Node node : usages) {
                MethodCallTargetNode method;
                if (node instanceof StoreIndexedNode) {
                    InvokeNode invoke;
                    StoreIndexedNode store = (StoreIndexedNode)node;
                    if (!ReduceCodeAnalysis.checkIfVarIsInLoop((Node)store)) continue;
                    if (store.value() instanceof BinaryNode) {
                        ValueNode value = store.value();
                        reduceOperation.add(value);
                        continue;
                    }
                    ValueNode valueNode = store.value();
                    if (!(valueNode instanceof InvokeNode) || !(invoke = (InvokeNode)valueNode).callTarget().targetName().startsWith("Math")) continue;
                    reduceOperation.add((ValueNode)invoke);
                    continue;
                }
                if (!(node instanceof MethodCallTargetNode) || !(method = (MethodCallTargetNode)node).inputs().filter(BinaryNode.class).isNotEmpty()) continue;
                ValueNode value = (ValueNode)method.inputs().filter(BinaryNode.class).first();
                reduceOperation.add(value);
            }
        }
        return ReduceCodeAnalysis.getReduceOperation(reduceOperation);
    }

    public static List<REDUCE_OPERATION> getReduceOperatorFromSketch(Graph graph, List<Integer> reduceIndices) {
        ArrayList<ValueNode> reduceOperation = new ArrayList<ValueNode>();
        StructuredGraph sg = (StructuredGraph)graph.copy(TornadoCoreRuntime.getDebugContext());
        for (Integer paramIndex : reduceIndices) {
            if (!sg.method().isStatic()) {
                Integer n = paramIndex;
                paramIndex = paramIndex + 1;
            }
            ParameterNode parameterNode = sg.getParameter(paramIndex.intValue());
            NodeIterable usages = parameterNode.usages();
            for (Node node : usages) {
                InvokeNode invoke;
                if (node instanceof StoreAtomicIndexedNode) {
                    StoreAtomicIndexedNode store = (StoreAtomicIndexedNode)node;
                    if (store.value() instanceof BinaryNode || store.value() instanceof BinaryArithmeticNode) {
                        ValueNode value = store.value();
                        reduceOperation.add(value);
                        continue;
                    }
                    if (!(store.value() instanceof InvokeNode) || !(invoke = (InvokeNode)store.value()).callTarget().targetName().startsWith("Math")) continue;
                    reduceOperation.add((ValueNode)invoke);
                    continue;
                }
                if (!(node instanceof WriteAtomicNode)) continue;
                WriteAtomicNode write = (WriteAtomicNode)node;
                if (write.value() instanceof BinaryNode || write.value() instanceof BinaryArithmeticNode) {
                    ValueNode value = write.value();
                    reduceOperation.add(value);
                    continue;
                }
                ValueNode valueNode = write.value();
                if (!(valueNode instanceof InvokeNode) || !(invoke = (InvokeNode)valueNode).callTarget().targetName().startsWith("Math")) continue;
                reduceOperation.add((ValueNode)invoke);
            }
        }
        return ReduceCodeAnalysis.getReduceOperation(reduceOperation);
    }

    private static ArrayLengthNode inspectArrayLengthNode(Node aux) {
        IfNode ifNode;
        LogicNode condition;
        ArrayLengthNode arrayLengthNode = null;
        if ((aux = aux.successors().first()) instanceof IfNode && (condition = (ifNode = (IfNode)aux).condition()) instanceof IntegerLessThanNode) {
            IntegerLessThanNode integerLessThanNode = (IntegerLessThanNode)condition;
            if (integerLessThanNode.getX() instanceof ArrayLengthNode) {
                arrayLengthNode = (ArrayLengthNode)integerLessThanNode.getX();
            } else if (integerLessThanNode.getY() instanceof ArrayLengthNode) {
                arrayLengthNode = (ArrayLengthNode)integerLessThanNode.getY();
            }
        }
        return arrayLengthNode;
    }

    private static ValueNode inspectConstantNode(Node aux) {
        IfNode ifNode;
        LogicNode condition;
        ConstantNode constantNode = null;
        if ((aux = aux.successors().first()) instanceof IfNode && (condition = (ifNode = (IfNode)aux).condition()) instanceof IntegerLessThanNode) {
            IntegerLessThanNode integerLessThanNode = (IntegerLessThanNode)condition;
            if (integerLessThanNode.getX() instanceof ConstantNode) {
                constantNode = (ConstantNode)integerLessThanNode.getX();
            } else if (integerLessThanNode.getY() instanceof ConstantNode) {
                constantNode = (ConstantNode)integerLessThanNode.getY();
            }
        }
        return constantNode;
    }

    private static int getNumberOfParameterNodes(StructuredGraph graph) {
        return graph.getNodes().filter(ParameterNode.class).count();
    }

    private static void obtainLoopBoundForPanamaRegions(Node aux, ArrayList<ValueNode> loopBound) {
        LoopBeginNode loopBegin = null;
        ValueNode loopBoundNode = null;
        while (!(aux instanceof LoopBeginNode)) {
            ValueNode valueNode;
            InvokeNode invokeNode;
            if (aux instanceof MergeNode) {
                MergeNode mergeNode = (MergeNode)aux;
                aux = mergeNode.forwardEndAt(0);
            } else {
                aux = aux.predecessor();
            }
            if (aux instanceof StartNode) break;
            if (aux instanceof LoopBeginNode) {
                loopBegin = (LoopBeginNode)aux;
                continue;
            }
            if (!(aux instanceof InvokeNode) || !(invokeNode = (InvokeNode)aux).getTargetMethod().getName().equals("getSize")) continue;
            loopBoundNode = valueNode = (ValueNode)invokeNode.callTarget().arguments().first();
            loopBound.add(valueNode);
        }
        if (loopBoundNode == null) {
            loopBoundNode = ReduceCodeAnalysis.inspectConstantNode(aux);
        }
        if (loopBegin != null) {
            loopBound.add(Objects.requireNonNull(loopBoundNode));
        }
    }

    private static void obtainLoopBoundForOnHeapArrays(Node aux, ArrayList<ValueNode> loopBound) {
        LoopBeginNode loopBegin = null;
        ArrayLengthNode loopBoundNode = null;
        while (!(aux instanceof LoopBeginNode)) {
            if (aux instanceof MergeNode) {
                MergeNode mergeNode = (MergeNode)aux;
                aux = mergeNode.forwardEndAt(0);
            } else {
                aux = aux.predecessor();
            }
            if (aux instanceof StartNode) break;
            if (aux instanceof LoopBeginNode) {
                loopBegin = (LoopBeginNode)aux;
                continue;
            }
            if (!(aux instanceof ArrayLengthNode)) continue;
            loopBoundNode = (ArrayLengthNode)aux;
        }
        if (loopBoundNode == null) {
            loopBoundNode = ReduceCodeAnalysis.inspectArrayLengthNode(aux);
        }
        if (loopBoundNode == null) {
            loopBoundNode = ReduceCodeAnalysis.inspectConstantNode(aux);
        }
        if (loopBegin != null) {
            if (loopBoundNode instanceof ArrayLengthNode) {
                loopBound.add(Objects.requireNonNull(loopBoundNode).array());
            } else {
                loopBound.add((ValueNode)Objects.requireNonNull(loopBoundNode));
            }
        }
    }

    private static void getInputRageForReductionNode(ParameterNode parameterNode, ArrayList<ValueNode> loopBound) {
        for (Node node : parameterNode.usages()) {
            if (node instanceof MethodCallTargetNode) {
                InvokeNode panamaStoreNode;
                MethodCallTargetNode methodCallTargetNode = (MethodCallTargetNode)node;
                Node aux = methodCallTargetNode.usages().first();
                if (!(aux instanceof InvokeNode) || !(panamaStoreNode = (InvokeNode)aux).getTargetMethod().getName().equals("set")) continue;
                ReduceCodeAnalysis.obtainLoopBoundForPanamaRegions(aux, loopBound);
                continue;
            }
            if (!(node instanceof StoreIndexedNode)) continue;
            ReduceCodeAnalysis.obtainLoopBoundForOnHeapArrays(node, loopBound);
        }
    }

    private static ArrayList<ValueNode> findLoopUpperBoundNode(StructuredGraph graph, ArrayList<Integer> reduceIndexes) {
        ArrayList<ValueNode> loopBoundNodes = new ArrayList<ValueNode>();
        for (Integer paramIndex : reduceIndexes) {
            if (!graph.method().isStatic()) {
                Integer n = paramIndex;
                paramIndex = paramIndex + 1;
            }
            if (ReduceCodeAnalysis.shouldSkip(paramIndex, graph)) continue;
            ParameterNode parameterNode = graph.getParameter(paramIndex.intValue());
            ReduceCodeAnalysis.getInputRageForReductionNode(parameterNode, loopBoundNodes);
        }
        return loopBoundNodes;
    }

    public static MetaReduceCodeAnalysis analyzeTaskGraph(List<TaskPackage> taskPackages) {
        int taskIndex = 0;
        int inputSize = 0;
        HashMap<Integer, MetaReduceTasks> tableMetaDataReduce = new HashMap<Integer, MetaReduceTasks>();
        for (TaskPackage taskMetadata : taskPackages) {
            if (taskMetadata instanceof PrebuiltTaskPackage) continue;
            Object taskCode = taskMetadata.getTaskParameters()[0];
            StructuredGraph graph = CodeAnalysis.buildHighLevelGraalGraph(taskCode);
            assert (graph != null);
            Annotation[][] annotations = graph.method().getParameterAnnotations();
            ArrayList<Integer> reduceIndices = new ArrayList<Integer>();
            for (int paramIndex = 0; paramIndex < annotations.length; ++paramIndex) {
                for (Annotation annotation : annotations[paramIndex]) {
                    if (!(annotation instanceof Reduce)) continue;
                    reduceIndices.add(paramIndex);
                }
            }
            if (reduceIndices.isEmpty()) {
                ++taskIndex;
                continue;
            }
            ArrayList<ValueNode> loopBound = ReduceCodeAnalysis.findLoopUpperBoundNode(graph, reduceIndices);
            for (int i = 0; i < graph.method().getParameters().length; ++i) {
                for (ValueNode valueNode : loopBound) {
                    int position;
                    int n = position = !graph.method().isStatic() ? i + 1 : i;
                    if (valueNode.equals(graph.getParameter(position))) {
                        Object object = taskPackages.get(taskIndex).getTaskParameters()[i + 1];
                        if (object instanceof TornadoNativeArray) {
                            TornadoNativeArray tornadoNativeArray = (TornadoNativeArray)object;
                            inputSize = tornadoNativeArray.getSize();
                            continue;
                        }
                        if (object.getClass().isArray()) {
                            inputSize = Array.getLength(object);
                            continue;
                        }
                        throw new TornadoRuntimeException("[ERROR] Unsupported type for reductions: " + String.valueOf(object.getClass()));
                    }
                    if (!(valueNode instanceof ConstantNode)) continue;
                    ConstantNode constant = (ConstantNode)valueNode;
                    inputSize = Integer.parseInt(constant.getValue().toValueString());
                }
            }
            MetaReduceTasks reduceTasks = new MetaReduceTasks(taskIndex, graph, reduceIndices, inputSize);
            tableMetaDataReduce.put(taskIndex, reduceTasks);
            ++taskIndex;
        }
        return tableMetaDataReduce.isEmpty() ? null : new MetaReduceCodeAnalysis(tableMetaDataReduce);
    }

    public static void performLoopBoundNodeSubstitution(StructuredGraph graph, int lowValue) {
        for (LoopBeginNode beginNode : graph.getNodes().filter(LoopBeginNode.class)) {
            PhiNode phi;
            FixedNode node = beginNode.next();
            while (true) {
                if (node instanceof IfNode) break;
                node = (FixedNode)node.successors().first();
            }
            IfNode ifNode = (IfNode)node;
            LogicNode condition = ifNode.condition();
            if (!(condition instanceof IntegerLessThanNode)) continue;
            IntegerLessThanNode integer = (IntegerLessThanNode)condition;
            ValueNode x = integer.getX();
            ConstantNode lowBound = (ConstantNode)graph.addOrUnique((Node)ConstantNode.forInt((int)lowValue));
            if (!(x instanceof PhiNode) || !((phi = (PhiNode)x).valueAt(0) instanceof ConstantNode)) continue;
            phi.setValueAt(0, (ValueNode)lowBound);
        }
    }

    public static enum REDUCE_OPERATION {
        SUM,
        MUL,
        MIN,
        MAX;

    }
}

