/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.runtime.graph;

import java.nio.ByteBuffer;
import java.util.BitSet;
import java.util.List;
import java.util.Objects;
import jdk.vm.ci.meta.ResolvedJavaMethod;
import uk.ac.manchester.tornado.api.common.Access;
import uk.ac.manchester.tornado.api.common.SchedulableTask;
import uk.ac.manchester.tornado.api.exceptions.TornadoRuntimeException;
import uk.ac.manchester.tornado.runtime.TornadoCoreRuntime;
import uk.ac.manchester.tornado.runtime.common.TornadoXPUDevice;
import uk.ac.manchester.tornado.runtime.graph.TornadoExecutionContext;
import uk.ac.manchester.tornado.runtime.graph.TornadoGraph;
import uk.ac.manchester.tornado.runtime.graph.nodes.AbstractNode;
import uk.ac.manchester.tornado.runtime.graph.nodes.AllocateMultipleBuffersNode;
import uk.ac.manchester.tornado.runtime.graph.nodes.AllocateNode;
import uk.ac.manchester.tornado.runtime.graph.nodes.ConstantNode;
import uk.ac.manchester.tornado.runtime.graph.nodes.ContextNode;
import uk.ac.manchester.tornado.runtime.graph.nodes.ContextOpNode;
import uk.ac.manchester.tornado.runtime.graph.nodes.CopyInNode;
import uk.ac.manchester.tornado.runtime.graph.nodes.CopyOutNode;
import uk.ac.manchester.tornado.runtime.graph.nodes.DeallocateNode;
import uk.ac.manchester.tornado.runtime.graph.nodes.DependentReadNode;
import uk.ac.manchester.tornado.runtime.graph.nodes.ObjectNode;
import uk.ac.manchester.tornado.runtime.graph.nodes.OnDeviceObjectNode;
import uk.ac.manchester.tornado.runtime.graph.nodes.PersistedObjectNode;
import uk.ac.manchester.tornado.runtime.graph.nodes.StreamInNode;
import uk.ac.manchester.tornado.runtime.graph.nodes.TaskNode;
import uk.ac.manchester.tornado.runtime.sketcher.Sketch;
import uk.ac.manchester.tornado.runtime.sketcher.TornadoSketcher;
import uk.ac.manchester.tornado.runtime.tasks.CompilableTask;
import uk.ac.manchester.tornado.runtime.tasks.LocalObjectState;
import uk.ac.manchester.tornado.runtime.tasks.TornadoGraphBitcodes;

public class TornadoGraphBuilder {
    private static void createStreamInNode(ContextNode context, TornadoGraph graph, ObjectNode arg, AbstractNode[] args, int argIndex, AllocateMultipleBuffersNode persistNode) {
        StreamInNode streamInNode = new StreamInNode(context);
        streamInNode.setValue(arg);
        graph.add(streamInNode);
        context.addUse(streamInNode);
        args[argIndex] = streamInNode;
        persistNode.addValue(arg);
    }

    private static void createAllocateNode(ContextNode context, TornadoGraph graph, AbstractNode arg, AbstractNode[] args, int argIndex, AllocateMultipleBuffersNode persistNode) {
        AllocateNode allocateNode = new AllocateNode(context);
        allocateNode.setValue((ObjectNode)arg);
        graph.add(allocateNode);
        context.addUse(allocateNode);
        args[argIndex] = allocateNode;
        persistNode.addValue((ObjectNode)arg);
    }

    private static void createOnDeviceNode(ContextNode context, TornadoGraph graph, AbstractNode arg, AbstractNode[] args, int argIndex, AllocateMultipleBuffersNode persistNode) {
        OnDeviceObjectNode onDeviceObjectNode = new OnDeviceObjectNode(context);
        onDeviceObjectNode.setValue((ObjectNode)arg);
        graph.add(onDeviceObjectNode);
        context.addUse(onDeviceObjectNode);
        args[argIndex] = onDeviceObjectNode;
        persistNode.addValue((ObjectNode)arg);
    }

    private static void createCopyInNode(ContextNode context, TornadoGraph graph, AbstractNode arg, AbstractNode[] args, int argIndex, AllocateMultipleBuffersNode persistNode) {
        CopyInNode copyInNode = new CopyInNode(context);
        copyInNode.setValue((ObjectNode)arg);
        graph.add(copyInNode);
        context.addUse(copyInNode);
        args[argIndex] = copyInNode;
        persistNode.addValue((ObjectNode)arg);
    }

    private static boolean shouldPerformSharedObjectCopy(AbstractNode arg, ContextNode contextNode) {
        return ((ContextOpNode)arg).getContext().getUses().size() != 1 && contextNode.getDeviceIndex() != ((ContextOpNode)arg).getContext().getDeviceIndex();
    }

    public static TornadoGraph buildGraph(TornadoExecutionContext executionContext, ByteBuffer buffer) {
        Object compilableTask;
        TornadoGraph graph = new TornadoGraph();
        Access[] accesses = null;
        AbstractNode[] args = null;
        AbstractNode context = null;
        AllocateMultipleBuffersNode persist = null;
        TaskNode taskNode = null;
        int argIndex = 0;
        int taskIndex = 0;
        List<Object> constants = executionContext.getConstants();
        List<Object> objects = executionContext.getObjects();
        ConstantNode[] constantNodes = new ConstantNode[constants.size()];
        for (int i = 0; i < constants.size(); ++i) {
            constantNodes[i] = new ConstantNode(i);
            graph.add(constantNodes[i]);
        }
        AbstractNode[] objectNodes = new AbstractNode[objects.size()];
        for (int i = 0; i < objects.size(); ++i) {
            objectNodes[i] = new ObjectNode(i);
            graph.add(objectNodes[i]);
        }
        List<LocalObjectState> states = executionContext.getObjectStates();
        boolean shouldExit = false;
        while (!shouldExit && buffer.hasRemaining()) {
            int variableIndex;
            byte op = buffer.get();
            if (op == TornadoGraphBitcodes.ARG_LIST.index()) {
                int size = buffer.getInt();
                args = new AbstractNode[size];
                argIndex = 0;
                taskNode = new TaskNode((ContextNode)context, taskIndex, args);
                continue;
            }
            if (op == TornadoGraphBitcodes.LOAD_REF.index()) {
                AbstractNode nextAccessNode;
                variableIndex = buffer.getInt();
                AbstractNode arg = objectNodes[variableIndex];
                if (!(arg instanceof ContextOpNode)) {
                    ObjectNode objectNode = (ObjectNode)arg;
                    LocalObjectState state = states.get(objectNode.getIndex());
                    if (((Access[])Objects.requireNonNull(accesses))[argIndex] == Access.WRITE_ONLY) {
                        TornadoGraphBuilder.createAllocateNode((ContextNode)context, graph, arg, args, argIndex, persist);
                    } else if (state.isOnDevice()) {
                        TornadoGraphBuilder.createOnDeviceNode((ContextNode)context, graph, arg, args, argIndex, persist);
                    } else if (state.isStreamIn()) {
                        TornadoGraphBuilder.createStreamInNode((ContextNode)context, graph, objectNode, args, argIndex, persist);
                    } else if (!state.isUnderDemand()) {
                        TornadoGraphBuilder.createCopyInNode((ContextNode)context, graph, arg, args, argIndex, persist);
                    } else {
                        TornadoGraphBuilder.createAllocateNode((ContextNode)context, graph, arg, args, argIndex, persist);
                    }
                } else {
                    if (TornadoGraphBuilder.shouldPerformSharedObjectCopy(arg, (ContextNode)context)) {
                        TornadoGraphBuilder.createCopyInNode((ContextNode)context, graph, arg.getInputs().get(0), args, argIndex, persist);
                    }
                    args[argIndex] = arg;
                }
                assert (accesses != null);
                if (accesses[argIndex] == Access.WRITE_ONLY || accesses[argIndex] == Access.READ_WRITE) {
                    ObjectNode value;
                    DependentReadNode depRead = new DependentReadNode((ContextNode)context);
                    AbstractNode abstractNode = objectNodes[variableIndex];
                    if (abstractNode instanceof ObjectNode) {
                        ObjectNode objectNode;
                        value = objectNode = (ObjectNode)abstractNode;
                    } else {
                        abstractNode = objectNodes[variableIndex];
                        if (abstractNode instanceof DependentReadNode) {
                            DependentReadNode dependentRead = (DependentReadNode)abstractNode;
                            value = dependentRead.getValue();
                            if (states.get(variableIndex).isForcedStreamIn()) {
                                TornadoGraphBuilder.createStreamInNode((ContextNode)context, graph, value, args, argIndex, persist);
                            }
                        } else {
                            abstractNode = objectNodes[variableIndex];
                            if (abstractNode instanceof CopyInNode) {
                                CopyInNode copyInNode = (CopyInNode)abstractNode;
                                value = copyInNode.getValue();
                            } else {
                                abstractNode = objectNodes[variableIndex];
                                if (abstractNode instanceof AllocateNode) {
                                    AllocateNode allocateNode = (AllocateNode)abstractNode;
                                    value = allocateNode.getValue();
                                } else {
                                    abstractNode = objectNodes[variableIndex];
                                    if (abstractNode instanceof OnDeviceObjectNode) {
                                        OnDeviceObjectNode onDeviceObjectNode = (OnDeviceObjectNode)abstractNode;
                                        value = onDeviceObjectNode.getValue();
                                    } else {
                                        abstractNode = objectNodes[variableIndex];
                                        if (abstractNode instanceof StreamInNode) {
                                            StreamInNode streamInNode = (StreamInNode)abstractNode;
                                            value = streamInNode.getValue();
                                        } else {
                                            throw new TornadoRuntimeException("Invalid graph node in TornadoGraph builder for node: " + objectNodes[variableIndex].getClass().getName());
                                        }
                                    }
                                }
                            }
                        }
                    }
                    depRead.setValue(value);
                    depRead.setDependent(taskNode);
                    graph.add(depRead);
                    nextAccessNode = depRead;
                } else {
                    nextAccessNode = args[argIndex];
                }
                objectNodes[variableIndex] = nextAccessNode;
                ++argIndex;
                continue;
            }
            if (op == TornadoGraphBitcodes.LOAD_PRIM.index()) {
                variableIndex = buffer.getInt();
                args[argIndex] = constantNodes[variableIndex];
                ++argIndex;
                continue;
            }
            if (op == TornadoGraphBitcodes.LAUNCH.index()) {
                context.addUse(taskNode);
                graph.add(taskNode);
                continue;
            }
            if (op == TornadoGraphBitcodes.CONTEXT.index()) {
                int globalTaskId = buffer.getInt();
                taskIndex = buffer.getInt();
                SchedulableTask task = executionContext.getTask(taskIndex);
                TornadoXPUDevice deviceForTask = executionContext.getDeviceForTask(taskIndex);
                context = graph.addUnique(new ContextNode(executionContext.getDevices().indexOf(deviceForTask), deviceForTask));
                persist = graph.addUnique(new AllocateMultipleBuffersNode((ContextNode)context));
                context.addUse(persist);
                if (task instanceof CompilableTask) {
                    compilableTask = (CompilableTask)task;
                    ResolvedJavaMethod resolvedMethod = TornadoCoreRuntime.getTornadoRuntime().resolveMethod(((CompilableTask)compilableTask).getMethod());
                    Sketch sketch = TornadoSketcher.lookup(resolvedMethod, task.meta().getBackendIndex(), task.meta().getDeviceIndex());
                    accesses = sketch.getArgumentsAccess();
                    continue;
                }
                accesses = task.getArgumentsAccess();
                continue;
            }
            shouldExit = true;
        }
        for (int i = 0; i < states.size(); ++i) {
            AbstractNode dependentRead;
            if (states.get(i).isStreamOut()) {
                compilableTask = objectNodes[i];
                if (!(compilableTask instanceof DependentReadNode)) continue;
                AbstractNode readNode = dependentRead = (DependentReadNode)compilableTask;
                context = ((ContextOpNode)readNode).getContext();
                CopyOutNode copyOutNode = new CopyOutNode((ContextNode)context);
                copyOutNode.setValue((DependentReadNode)readNode);
                graph.add(copyOutNode);
                context.addUse(copyOutNode);
                continue;
            }
            if (!states.get(i).isStreamIn() || !((dependentRead = objectNodes[i]) instanceof ObjectNode)) continue;
            ObjectNode objectNode = (ObjectNode)dependentRead;
            StreamInNode streamInNode = new StreamInNode((ContextNode)context);
            streamInNode.setValue(objectNode);
            graph.add(streamInNode);
            assert (context != null);
            context.addUse(streamInNode);
            assert (persist != null);
            persist.addValue((ObjectNode)objectNodes[i]);
        }
        BitSet asyncNodes = graph.filter(ContextOpNode.class::isInstance);
        int dependencyIndex = asyncNodes.previousSetBit(asyncNodes.length() - 1);
        ContextOpNode dependencyNode = (ContextOpNode)graph.getNode(dependencyIndex);
        int i = asyncNodes.nextSetBit(0);
        while (i != -1 && i < asyncNodes.length()) {
            ContextOpNode node = (ContextOpNode)graph.getNode(i);
            if (node instanceof CopyInNode || node instanceof AllocateNode || node instanceof StreamInNode || node instanceof OnDeviceObjectNode) {
                ObjectNode objectNode = TornadoGraphBuilder.getObjectNodeFromNode(node);
                ContextNode contextNode = node.getContext();
                Object targetObject = objects.get(objectNode.getIndex());
                boolean isPersistedObject = executionContext.getPersistedObjects().contains(targetObject);
                if (isPersistedObject) {
                    PersistedObjectNode persistNode = new PersistedObjectNode(contextNode);
                    persistNode.setValue(objectNode);
                    graph.add(persistNode);
                    contextNode.addUse(persistNode);
                } else {
                    DeallocateNode deallocateNode = new DeallocateNode(contextNode);
                    deallocateNode.setValue(objectNode);
                    deallocateNode.setDependent(dependencyNode);
                    graph.add(deallocateNode);
                    contextNode.addUse(deallocateNode);
                }
            }
            i = asyncNodes.nextSetBit(i + 1);
        }
        return graph;
    }

    private static ObjectNode getObjectNodeFromNode(ContextOpNode node) {
        if (node instanceof CopyInNode) {
            CopyInNode copyInNode = (CopyInNode)node;
            return copyInNode.getValue();
        }
        if (node instanceof AllocateNode) {
            AllocateNode allocateNode = (AllocateNode)node;
            return allocateNode.getValue();
        }
        if (node instanceof StreamInNode) {
            StreamInNode streamInNode = (StreamInNode)node;
            return streamInNode.getValue();
        }
        if (node instanceof OnDeviceObjectNode) {
            OnDeviceObjectNode onDeviceObjectNode = (OnDeviceObjectNode)node;
            return onDeviceObjectNode.getValue();
        }
        throw new IllegalArgumentException("Unknown node type: " + String.valueOf(node.getClass()));
    }
}

