/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.drivers.opencl.graal.phases;

import java.util.Optional;
import jdk.vm.ci.meta.JavaKind;
import org.graalvm.compiler.graph.Node;
import org.graalvm.compiler.nodes.ConstantNode;
import org.graalvm.compiler.nodes.GraphState;
import org.graalvm.compiler.nodes.StructuredGraph;
import org.graalvm.compiler.nodes.ValueNode;
import org.graalvm.compiler.nodes.calc.AddNode;
import org.graalvm.compiler.nodes.calc.MulNode;
import org.graalvm.compiler.nodes.calc.SubNode;
import org.graalvm.compiler.phases.BasePhase;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.GlobalThreadIdNode;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.GlobalThreadSizeNode;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.OCLIntBinaryIntrinsicNode;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.calc.DivNode;
import uk.ac.manchester.tornado.runtime.TornadoCoreRuntime;
import uk.ac.manchester.tornado.runtime.common.TornadoSchedulingStrategy;
import uk.ac.manchester.tornado.runtime.common.TornadoXPUDevice;
import uk.ac.manchester.tornado.runtime.graal.nodes.AbstractParallelNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.ParallelOffsetNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.ParallelRangeNode;
import uk.ac.manchester.tornado.runtime.graal.nodes.ParallelStrideNode;
import uk.ac.manchester.tornado.runtime.graal.phases.TornadoHighTierContext;

public class TornadoParallelScheduler
extends BasePhase<TornadoHighTierContext> {
    public Optional<BasePhase.NotApplicable> notApplicableTo(GraphState graphState) {
        return ALWAYS_APPLICABLE;
    }

    private void replaceOffsetNode(TornadoSchedulingStrategy schedule, StructuredGraph graph, ParallelOffsetNode offset, ParallelRangeNode range, ValueNode blockSize) {
        switch (schedule) {
            case PER_CPU_BLOCK: {
                this.replaceOffsetPerBlock(graph, offset, blockSize);
                break;
            }
            case PER_ACCELERATOR_ITERATION: {
                this.replaceOffsetPerIteration(graph, offset, range);
            }
        }
    }

    private void replaceOffsetPerIteration(StructuredGraph graph, ParallelOffsetNode offset, ParallelRangeNode range) {
        ConstantNode index = (ConstantNode)graph.addOrUnique((Node)ConstantNode.forInt((int)offset.index()));
        GlobalThreadIdNode threadId = (GlobalThreadIdNode)graph.addOrUnique((Node)new GlobalThreadIdNode(index));
        AddNode addNode = (AddNode)graph.addOrUnique((Node)new AddNode((ValueNode)threadId, offset.value()));
        MulNode mulNode = (MulNode)graph.addOrUnique((Node)new MulNode((ValueNode)addNode, range.stride().value()));
        offset.replaceAtUsages((Node)mulNode);
        offset.safeDelete();
    }

    private void replaceOffsetPerBlock(StructuredGraph graph, ParallelOffsetNode offset, ValueNode blockSize) {
        GlobalThreadIdNode threadId = (GlobalThreadIdNode)graph.addOrUnique((Node)new GlobalThreadIdNode(ConstantNode.forInt((int)offset.index(), (StructuredGraph)graph)));
        MulNode newOffset = (MulNode)graph.addOrUnique((Node)new MulNode((ValueNode)threadId, blockSize));
        offset.replaceAtUsages((Node)newOffset);
        offset.safeDelete();
    }

    private void replaceStrideNode(TornadoSchedulingStrategy schedule, StructuredGraph graph, ParallelStrideNode stride) {
        switch (schedule) {
            case PER_CPU_BLOCK: {
                this.replaceStridePerBlock(stride);
                break;
            }
            case PER_ACCELERATOR_ITERATION: {
                this.replaceStridePerIteration(graph, stride);
            }
        }
    }

    private void replaceStridePerIteration(StructuredGraph graph, ParallelStrideNode stride) {
        ConstantNode index = (ConstantNode)graph.addOrUnique((Node)ConstantNode.forInt((int)stride.index()));
        GlobalThreadSizeNode threadCount = (GlobalThreadSizeNode)graph.addOrUnique((Node)new GlobalThreadSizeNode(index));
        stride.replaceAtUsages((Node)threadCount);
        stride.safeDelete();
    }

    private void replaceStridePerBlock(ParallelStrideNode stride) {
        stride.replaceAtUsages((Node)stride.value());
        stride.safeDelete();
    }

    private ValueNode replaceRangeNode(TornadoSchedulingStrategy schedule, StructuredGraph graph, ParallelRangeNode range) {
        switch (schedule) {
            case PER_CPU_BLOCK: {
                return this.replaceRangePerBlock(graph, range);
            }
            case PER_ACCELERATOR_ITERATION: {
                this.replaceRangePerIteration(range);
                return null;
            }
        }
        return null;
    }

    private void replaceRangePerIteration(ParallelRangeNode range) {
        range.replaceAtUsages((Node)range.value());
    }

    private ValueNode buildBlockSize(StructuredGraph graph, ParallelRangeNode range) {
        ValueNode rangeByStride = (ValueNode)graph.addOrUnique((Node)DivNode.create(range.value(), range.stride().value()));
        SubNode trueRange = (SubNode)graph.addOrUnique((Node)new SubNode(rangeByStride, range.offset().value()));
        ConstantNode index = ConstantNode.forInt((int)range.index(), (StructuredGraph)graph);
        GlobalThreadSizeNode threadCount = (GlobalThreadSizeNode)graph.addOrUnique((Node)new GlobalThreadSizeNode(index));
        SubNode threadCountM1 = (SubNode)graph.addOrUnique((Node)new SubNode((ValueNode)threadCount, (ValueNode)ConstantNode.forInt((int)1, (StructuredGraph)graph)));
        AddNode adjustedTrueRange = (AddNode)graph.addOrUnique((Node)new AddNode((ValueNode)trueRange, (ValueNode)threadCountM1));
        ValueNode div = (ValueNode)graph.addOrUnique((Node)DivNode.create((ValueNode)adjustedTrueRange, (ValueNode)threadCount));
        return (ValueNode)graph.addOrUnique((Node)new MulNode(div, range.stride().value()));
    }

    private ValueNode replaceRangePerBlock(StructuredGraph graph, ParallelRangeNode range) {
        ValueNode blockSize = this.buildBlockSize(graph, range);
        GlobalThreadIdNode threadId = (GlobalThreadIdNode)graph.addOrUnique((Node)new GlobalThreadIdNode(ConstantNode.forInt((int)range.index(), (StructuredGraph)graph)));
        MulNode newOffset = (MulNode)graph.addOrUnique((Node)new MulNode((ValueNode)threadId, blockSize));
        AddNode newRange = (AddNode)graph.addOrUnique((Node)new AddNode((ValueNode)newOffset, blockSize));
        MulNode stride = (MulNode)graph.addOrUnique((Node)new MulNode((ValueNode)newRange, range.stride().value()));
        ValueNode adjustedRange = (ValueNode)graph.addOrUnique((Node)OCLIntBinaryIntrinsicNode.create((ValueNode)stride, range.value(), OCLIntBinaryIntrinsicNode.Operation.MIN, JavaKind.Int));
        range.replaceAtUsages((Node)adjustedRange);
        range.safeDelete();
        return blockSize;
    }

    protected void run(StructuredGraph graph, TornadoHighTierContext context) {
        if (context.getMeta() == null) {
            return;
        }
        TornadoXPUDevice device = context.getDeviceMapping();
        TornadoSchedulingStrategy strategy = device.getPreferredSchedule();
        long[] maxWorkItemSizes = device.getPhysicalDevice().getDeviceMaxWorkItemSizes();
        graph.getNodes().filter(ParallelRangeNode.class).forEach(node -> {
            if (maxWorkItemSizes[node.index()] > 1L) {
                ParallelOffsetNode offset = node.offset();
                ParallelStrideNode stride = node.stride();
                ValueNode blockSize = this.replaceRangeNode(strategy, graph, (ParallelRangeNode)node);
                this.replaceOffsetNode(strategy, graph, offset, (ParallelRangeNode)node, blockSize);
                this.replaceStrideNode(strategy, graph, stride);
            } else {
                this.serialiseLoop((ParallelRangeNode)node);
            }
            TornadoCoreRuntime.getDebugContext().dump(1, (Object)graph, "after scheduling loop index=" + node.index());
        });
        graph.clearLastSchedule();
    }

    private void serialiseLoop(ParallelRangeNode range) {
        ParallelOffsetNode offset = range.offset();
        ParallelStrideNode stride = range.stride();
        range.replaceAtUsages((Node)range.value());
        this.killNode((AbstractParallelNode)range);
        offset.replaceAtUsages((Node)offset.value());
        stride.replaceAtUsages((Node)stride.value());
        this.killNode((AbstractParallelNode)offset);
        this.killNode((AbstractParallelNode)stride);
    }

    private void killNode(AbstractParallelNode node) {
        if (node.inputs().isNotEmpty()) {
            node.clearInputs();
        }
        if (!node.isDeleted()) {
            node.safeDelete();
        }
    }
}

