/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.runtime.graal.phases.sketcher;

import java.lang.annotation.Annotation;
import java.util.Optional;
import org.graalvm.compiler.graph.Node;
import org.graalvm.compiler.graph.iterators.NodeIterable;
import org.graalvm.compiler.nodes.EndNode;
import org.graalvm.compiler.nodes.FixedNode;
import org.graalvm.compiler.nodes.FixedWithNextNode;
import org.graalvm.compiler.nodes.GraphState;
import org.graalvm.compiler.nodes.IfNode;
import org.graalvm.compiler.nodes.LoopBeginNode;
import org.graalvm.compiler.nodes.MergeNode;
import org.graalvm.compiler.nodes.ParameterNode;
import org.graalvm.compiler.nodes.PhiNode;
import org.graalvm.compiler.nodes.PiNode;
import org.graalvm.compiler.nodes.StartNode;
import org.graalvm.compiler.nodes.StructuredGraph;
import org.graalvm.compiler.nodes.ValueNode;
import org.graalvm.compiler.nodes.calc.AddNode;
import org.graalvm.compiler.nodes.calc.BinaryArithmeticNode;
import org.graalvm.compiler.nodes.calc.BinaryNode;
import org.graalvm.compiler.nodes.calc.CompareNode;
import org.graalvm.compiler.nodes.calc.MulNode;
import org.graalvm.compiler.nodes.calc.SubNode;
import org.graalvm.compiler.nodes.extended.JavaReadNode;
import org.graalvm.compiler.nodes.extended.JavaWriteNode;
import org.graalvm.compiler.nodes.java.LoadFieldNode;
import org.graalvm.compiler.nodes.java.LoadIndexedNode;
import org.graalvm.compiler.nodes.java.StoreFieldNode;
import org.graalvm.compiler.nodes.java.StoreIndexedNode;
import org.graalvm.compiler.nodes.memory.address.AddressNode;
import org.graalvm.compiler.nodes.memory.address.OffsetAddressNode;
import org.graalvm.compiler.phases.BasePhase;
import uk.ac.manchester.tornado.api.annotations.Reduce;
import uk.ac.manchester.tornado.api.exceptions.TornadoRuntimeException;
import uk.ac.manchester.tornado.runtime.graal.nodes.StoreAtomicIndexedNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.StoreAtomicIndexedNodeExtension;
import uk.ac.manchester.tornado.runtime.graal.nodes.TornadoReduceAddNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.TornadoReduceMulNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.TornadoReduceSubNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.WriteAtomicNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.WriteAtomicNodeExtension;
import uk.ac.manchester.tornado.runtime.graal.nodes.interfaces.MarkFloatingPointIntrinsicsNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.interfaces.MarkIntIntrinsicNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.interfaces.MarkIntrinsicsNode;
import uk.ac.manchester.tornado.runtime.graal.phases.TornadoSketchTierContext;

public class TornadoReduceReplacement
extends BasePhase<TornadoSketchTierContext> {
    private static ValueNode getArithmeticNode(ValueNode value, ValueNode accumulator) {
        ValueNode arithmeticNode = null;
        if (value instanceof TornadoReduceAddNode) {
            TornadoReduceAddNode reduce = (TornadoReduceAddNode)value;
            if (reduce.getX() instanceof BinaryArithmeticNode) {
                arithmeticNode = reduce.getX();
            } else if (reduce.getY() instanceof BinaryArithmeticNode) {
                arithmeticNode = reduce.getY();
            } else if (reduce.getX() instanceof MarkFloatingPointIntrinsicsNode) {
                arithmeticNode = reduce.getX();
            } else if (reduce.getY() instanceof MarkFloatingPointIntrinsicsNode) {
                arithmeticNode = reduce.getY();
            }
        }
        if (arithmeticNode == null && accumulator instanceof BinaryNode) {
            arithmeticNode = accumulator;
        }
        return arithmeticNode;
    }

    public Optional<BasePhase.NotApplicable> notApplicableTo(GraphState graphState) {
        return ALWAYS_APPLICABLE;
    }

    protected void run(StructuredGraph graph, TornadoSketchTierContext context) {
        this.findParametersWithReduceAnnotations(graph);
    }

    private boolean recursiveCheck(ValueNode arrayToStore, ValueNode indexToStore, ValueNode currentNode) {
        JavaReadNode readNode;
        AddressNode y2;
        boolean isReduction = false;
        if (currentNode instanceof BinaryNode) {
            BinaryNode value = (BinaryNode)currentNode;
            ValueNode x = value.getX();
            isReduction = this.recursiveCheck(arrayToStore, indexToStore, x);
            if (!isReduction) {
                ValueNode y2 = value.getY();
                return this.recursiveCheck(arrayToStore, indexToStore, y2);
            }
        } else if (currentNode instanceof LoadIndexedNode) {
            LoadIndexedNode loadNode = (LoadIndexedNode)currentNode;
            if (loadNode.array() == arrayToStore && loadNode.index() == indexToStore) {
                isReduction = true;
            }
        } else if (currentNode instanceof JavaReadNode && (y2 = (readNode = (JavaReadNode)currentNode).getAddress()) instanceof OffsetAddressNode) {
            OffsetAddressNode readAddress = (OffsetAddressNode)y2;
            if (indexToStore instanceof OffsetAddressNode) {
                OffsetAddressNode writeAddress = (OffsetAddressNode)indexToStore;
                isReduction = writeAddress.valueEquals((Node)readAddress);
            }
        }
        return isReduction;
    }

    private boolean checkIfReduction(StoreIndexedNode store) {
        ValueNode arrayToStore = store.array();
        ValueNode indexToStore = store.index();
        ValueNode valueToStore = store.value();
        return this.recursiveCheck(arrayToStore, indexToStore, valueToStore);
    }

    private boolean checkIfVarIsInLoop(Node store) {
        Node node = store.predecessor();
        boolean hasPred = true;
        while (hasPred) {
            if (node instanceof LoopBeginNode) {
                return true;
            }
            if (node instanceof StartNode) {
                hasPred = false;
                continue;
            }
            if (node instanceof MergeNode) {
                MergeNode merge = (MergeNode)node;
                EndNode endNode = merge.forwardEndAt(0);
                node = endNode.predecessor();
                continue;
            }
            node = node.predecessor();
        }
        return false;
    }

    private ValueNode obtainInputArray(ValueNode currentNode, ValueNode outputArray) {
        ValueNode array = null;
        if (currentNode instanceof StoreIndexedNode) {
            StoreIndexedNode store = (StoreIndexedNode)currentNode;
            return this.obtainInputArray(store.value(), store.array());
        }
        if (currentNode instanceof BinaryArithmeticNode) {
            BinaryArithmeticNode value = (BinaryArithmeticNode)currentNode;
            array = this.obtainInputArray(value.getX(), outputArray);
            if (array == null) {
                array = this.obtainInputArray(value.getY(), outputArray);
            }
        } else if (currentNode instanceof BinaryNode) {
            if (currentNode instanceof MarkIntrinsicsNode && (array = this.obtainInputArray(((BinaryNode)currentNode).getX(), outputArray)) == null) {
                array = this.obtainInputArray(((BinaryNode)currentNode).getY(), outputArray);
            }
        } else if (currentNode instanceof LoadIndexedNode) {
            LoadIndexedNode loadNode = (LoadIndexedNode)currentNode;
            if (loadNode.array() != outputArray) {
                array = loadNode.array();
            }
        } else {
            ParameterNode parameterNode;
            if (currentNode instanceof JavaReadNode) {
                return this.obtainInputArray((ValueNode)((JavaReadNode)currentNode).getAddress(), outputArray);
            }
            if (currentNode instanceof OffsetAddressNode) {
                return this.obtainInputArray(((OffsetAddressNode)currentNode).getBase(), outputArray);
            }
            if (currentNode instanceof LoadFieldNode) {
                PiNode piNode;
                ParameterNode parameterNode2;
                LoadFieldNode loadFieldNode = (LoadFieldNode)currentNode;
                ValueNode valueNode = loadFieldNode.getValue();
                if (valueNode instanceof PiNode && (parameterNode2 = (ParameterNode)(piNode = (PiNode)valueNode).inputs().filter(ParameterNode.class).first()) != outputArray) {
                    array = parameterNode2;
                }
            } else if (currentNode instanceof PiNode && ((PiNode)currentNode).object() instanceof ParameterNode && (parameterNode = (ParameterNode)((PiNode)currentNode).object()) != outputArray) {
                array = parameterNode;
            }
        }
        return array;
    }

    private ReductionMetadataNode createReductionNode(StructuredGraph graph, Node store, ValueNode inputArray, ValueNode startNode) throws RuntimeException {
        Object value;
        ValueNode accumulator;
        ValueNode storeValue = null;
        if (store instanceof StoreIndexedNode) {
            storeValue = ((StoreIndexedNode)store).value();
        } else if (store instanceof JavaWriteNode) {
            storeValue = ((JavaWriteNode)store).value();
        }
        if (storeValue == null) {
            throw new TornadoRuntimeException("\n\n[NODE REDUCTION NOT SUPPORTED] Node : " + String.valueOf(store) + " not supported yet.");
        }
        if (storeValue instanceof AddNode) {
            AddNode addNode = (AddNode)storeValue;
            TornadoReduceAddNode atomicAdd = (TornadoReduceAddNode)graph.addOrUnique((Node)new TornadoReduceAddNode(addNode.getX(), addNode.getY()));
            accumulator = addNode.getY();
            value = atomicAdd;
            addNode.safeDelete();
        } else if (storeValue instanceof MulNode) {
            MulNode mulNode = (MulNode)storeValue;
            TornadoReduceMulNode atomicMultiplication = (TornadoReduceMulNode)graph.addOrUnique((Node)new TornadoReduceMulNode(mulNode.getX(), mulNode.getY()));
            accumulator = mulNode.getX();
            value = atomicMultiplication;
            mulNode.safeDelete();
        } else if (storeValue instanceof SubNode) {
            SubNode subNode = (SubNode)storeValue;
            TornadoReduceSubNode atomicSub = (TornadoReduceSubNode)graph.addOrUnique((Node)new TornadoReduceSubNode(subNode.getX(), subNode.getY()));
            accumulator = subNode.getX();
            value = atomicSub;
            subNode.safeDelete();
        } else if (storeValue instanceof BinaryNode) {
            accumulator = storeValue instanceof MarkFloatingPointIntrinsicsNode || storeValue instanceof MarkIntIntrinsicNode ? ((BinaryNode)storeValue).getY() : storeValue;
            value = storeValue;
        } else {
            throw new TornadoRuntimeException("\n\n[NODE REDUCTION NOT SUPPORTED] Node : " + String.valueOf(storeValue) + " not supported yet.");
        }
        return new ReductionMetadataNode((ValueNode)value, accumulator, inputArray, startNode);
    }

    private void performNodeReplacement(StructuredGraph graph, FixedWithNextNode node, Node predecessor, ReductionMetadataNode reductionNode, ValueNode outArray) {
        ValueNode value = reductionNode.value;
        ValueNode accumulator = reductionNode.accumulator;
        ValueNode inputArray = reductionNode.inputArray;
        ValueNode startNode = reductionNode.startNode;
        FixedWithNextNode storeNode = null;
        if (node instanceof StoreIndexedNode) {
            StoreIndexedNode store = (StoreIndexedNode)node;
            StoreAtomicIndexedNodeExtension storeAtomicIndexedNodeExtension = new StoreAtomicIndexedNodeExtension(startNode);
            graph.addOrUnique((Node)storeAtomicIndexedNodeExtension);
            storeNode = (FixedWithNextNode)graph.addOrUnique((Node)new StoreAtomicIndexedNode(store.array(), store.index(), store.elementKind(), store.getBoundsCheck(), value, accumulator, inputArray, storeAtomicIndexedNodeExtension));
        } else if (node instanceof JavaWriteNode) {
            JavaWriteNode javaWriteNode = (JavaWriteNode)node;
            WriteAtomicNodeExtension writeAtomicNodeExtension = new WriteAtomicNodeExtension(startNode);
            graph.addOrUnique((Node)writeAtomicNodeExtension);
            storeNode = (FixedWithNextNode)graph.addOrUnique((Node)new WriteAtomicNode(javaWriteNode.getWriteKind(), javaWriteNode.getAddress(), value, accumulator, inputArray, outArray, writeAtomicNodeExtension));
        }
        ValueNode arithmeticNode = TornadoReduceReplacement.getArithmeticNode(value, accumulator);
        if (storeNode instanceof StoreAtomicIndexedNode) {
            StoreAtomicIndexedNode storeAtomicIndexedNode = (StoreAtomicIndexedNode)storeNode;
            storeAtomicIndexedNode.setOptionalOperation(arithmeticNode);
        } else if (storeNode instanceof WriteAtomicNode) {
            WriteAtomicNode writeAtomicNode = (WriteAtomicNode)storeNode;
            writeAtomicNode.setOptionalOperation(arithmeticNode);
        }
        FixedNode next = node.next();
        predecessor.replaceFirstSuccessor((Node)node, storeNode);
        node.replaceAndDelete(storeNode);
        storeNode.setNext(next);
    }

    private boolean shouldSkip(int index, StructuredGraph graph) {
        return graph.method().isStatic() && index >= this.getNumberOfParameterNodes(graph);
    }

    private void processReduceAnnotation(StructuredGraph graph, int index) {
        if (this.shouldSkip(index, graph)) {
            return;
        }
        ParameterNode reduceParameter = graph.getParameter(index);
        assert (reduceParameter != null);
        NodeIterable usages = reduceParameter.usages();
        for (Node node : usages) {
            if (node instanceof StoreIndexedNode) {
                boolean isInALoop;
                StoreIndexedNode store = (StoreIndexedNode)node;
                boolean isReductionValue = this.checkIfReduction(store);
                if (!isReductionValue || !(isInALoop = this.checkIfVarIsInLoop((Node)store))) continue;
                this.insertReduceStoreNode(graph, (FixedWithNextNode)store, store.value(), store.array());
                continue;
            }
            if (node instanceof StoreFieldNode) {
                throw new TornadoRuntimeException("\n[NOT SUPPORTED] Node StoreFieldNode is not supported yet.");
            }
            if (!(node instanceof PiNode)) continue;
            for (OffsetAddressNode offsetAddressNode : node.usages().filter(OffsetAddressNode.class)) {
                boolean isInALoop;
                if (!offsetAddressNode.usages().filter(JavaWriteNode.class).isNotEmpty()) continue;
                JavaWriteNode javaWriteNode = (JavaWriteNode)offsetAddressNode.usages().filter(JavaWriteNode.class).first();
                ParameterNode parameterNode = (ParameterNode)node.inputs().filter(ParameterNode.class).first();
                boolean isReductionValue = this.recursiveCheck((ValueNode)parameterNode, (ValueNode)offsetAddressNode, javaWriteNode.value());
                if (!isReductionValue || !(isInALoop = this.checkIfVarIsInLoop((Node)javaWriteNode))) continue;
                this.insertReduceStoreNode(graph, (FixedWithNextNode)javaWriteNode, javaWriteNode.value(), (ValueNode)parameterNode);
            }
        }
    }

    private void insertReduceStoreNode(StructuredGraph graph, FixedWithNextNode storeNode, ValueNode storeNodeValue, ValueNode array) {
        ValueNode inputArray = this.obtainInputArray(storeNodeValue, array);
        ValueNode startNode = this.obtainStartLoopNode((Node)storeNode);
        ReductionMetadataNode reductionNode = this.createReductionNode(graph, (Node)storeNode, inputArray, startNode);
        Node predecessor = storeNode.predecessor();
        this.performNodeReplacement(graph, storeNode, predecessor, reductionNode, array);
    }

    private ValueNode obtainStartLoopNode(Node store) {
        boolean startFound = false;
        ValueNode startNode = null;
        Node node = store.predecessor();
        while (!startFound) {
            if (node instanceof IfNode) {
                IfNode ifNode = (IfNode)node;
                while (!(node.predecessor() instanceof LoopBeginNode)) {
                    if (!((node = node.predecessor()) instanceof StartNode)) continue;
                    return null;
                }
                CompareNode condition = (CompareNode)ifNode.condition();
                ValueNode valueNode = condition.getX();
                if (valueNode instanceof PhiNode) {
                    PhiNode phi = (PhiNode)valueNode;
                    startNode = phi.valueAt(0);
                    break;
                }
            }
            if (node instanceof MergeNode) {
                MergeNode merge = (MergeNode)node;
                EndNode endNode = merge.forwardEndAt(0);
                node = endNode.predecessor();
                continue;
            }
            if (node instanceof LoopBeginNode) {
                LoopBeginNode loopBeginNode = (LoopBeginNode)node;
                NodeIterable usages = loopBeginNode.usages();
                for (Node u : usages) {
                    if (!(u instanceof PhiNode)) continue;
                    PhiNode phiNode = (PhiNode)u;
                    startNode = phiNode.valueAt(0);
                    startFound = true;
                }
                continue;
            }
            node = node.predecessor();
        }
        return startNode;
    }

    private int getNumberOfParameterNodes(StructuredGraph graph) {
        return graph.getNodes().filter(ParameterNode.class).count();
    }

    private void findParametersWithReduceAnnotations(StructuredGraph graph) {
        Annotation[][] parameterAnnotations = graph.method().getParameterAnnotations();
        for (int index = 0; index < parameterAnnotations.length; ++index) {
            for (Annotation annotation : parameterAnnotations[index]) {
                if (!(annotation instanceof Reduce)) continue;
                if (!graph.method().isStatic() || this.getNumberOfParameterNodes(graph) > parameterAnnotations.length) {
                    ++index;
                }
                this.processReduceAnnotation(graph, index);
            }
        }
    }

    private record ReductionMetadataNode(ValueNode value, ValueNode accumulator, ValueNode inputArray, ValueNode startNode) {
    }
}

