/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.drivers.opencl.graal.phases;

import java.util.Optional;
import jdk.vm.ci.meta.JavaKind;
import jdk.vm.ci.meta.MetaAccessProvider;
import jdk.vm.ci.meta.PrimitiveConstant;
import jdk.vm.ci.meta.ResolvedJavaType;
import org.graalvm.compiler.graph.Node;
import org.graalvm.compiler.graph.NodeInputList;
import org.graalvm.compiler.graph.iterators.NodeIterable;
import org.graalvm.compiler.nodes.CallTargetNode;
import org.graalvm.compiler.nodes.ConstantNode;
import org.graalvm.compiler.nodes.FixedWithNextNode;
import org.graalvm.compiler.nodes.GraphState;
import org.graalvm.compiler.nodes.InvokeNode;
import org.graalvm.compiler.nodes.StructuredGraph;
import org.graalvm.compiler.nodes.ValueNode;
import org.graalvm.compiler.nodes.util.GraphUtil;
import org.graalvm.compiler.phases.BasePhase;
import uk.ac.manchester.tornado.api.exceptions.TornadoInternalError;
import uk.ac.manchester.tornado.drivers.opencl.graal.OCLArchitecture;
import uk.ac.manchester.tornado.drivers.opencl.graal.OCLLoweringProvider;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.FixedArrayNode;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.GlobalThreadIdNode;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.GlobalThreadSizeNode;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.GroupIdNode;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.LocalArrayNode;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.LocalGroupSizeNode;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.LocalThreadIDFixedNode;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.OCLBarrierNode;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.OpenCLPrintf;
import uk.ac.manchester.tornado.runtime.graal.phases.TornadoHighTierContext;

public class TornadoOpenCLIntrinsicsReplacements
extends BasePhase<TornadoHighTierContext> {
    private MetaAccessProvider metaAccess;

    public TornadoOpenCLIntrinsicsReplacements(MetaAccessProvider metaAccess) {
        this.metaAccess = metaAccess;
    }

    private ConstantNode getConstantNodeFromArguments(InvokeNode invoke, int index) {
        NodeInputList arguments = invoke.callTarget().arguments();
        return (ConstantNode)arguments.get(index);
    }

    public Optional<BasePhase.NotApplicable> notApplicableTo(GraphState graphState) {
        return ALWAYS_APPLICABLE;
    }

    protected void run(StructuredGraph graph, TornadoHighTierContext context) {
        NodeIterable invokeNodes = graph.getNodes().filter(InvokeNode.class);
        for (InvokeNode invoke : invokeNodes) {
            String methodName;
            switch (methodName = invoke.callTarget().targetName()) {
                case "Direct#NewArrayNode.newArray": {
                    this.lowerInvokeNode(invoke);
                    break;
                }
                case "Direct#OpenCLIntrinsics.localBarrier": {
                    OCLBarrierNode barrier = (OCLBarrierNode)graph.addOrUnique((Node)new OCLBarrierNode(OCLBarrierNode.OCLMemFenceFlags.LOCAL));
                    graph.replaceFixed((FixedWithNextNode)invoke, (Node)barrier);
                    break;
                }
                case "Direct#OpenCLIntrinsics.globalBarrier": {
                    OCLBarrierNode barrier = (OCLBarrierNode)graph.addOrUnique((Node)new OCLBarrierNode(OCLBarrierNode.OCLMemFenceFlags.GLOBAL));
                    graph.replaceFixed((FixedWithNextNode)invoke, (Node)barrier);
                    break;
                }
                case "Direct#OpenCLIntrinsics.get_local_id": {
                    ConstantNode dimension = this.getConstantNodeFromArguments(invoke, 0);
                    LocalThreadIDFixedNode localIDNode = (LocalThreadIDFixedNode)graph.addOrUnique((Node)new LocalThreadIDFixedNode(dimension));
                    graph.replaceFixed((FixedWithNextNode)invoke, (Node)localIDNode);
                    break;
                }
                case "Direct#OpenCLIntrinsics.get_local_size": {
                    ConstantNode dimension = this.getConstantNodeFromArguments(invoke, 0);
                    LocalGroupSizeNode groupSizeNode = (LocalGroupSizeNode)graph.addOrUnique((Node)new LocalGroupSizeNode(dimension));
                    graph.replaceFixed((FixedWithNextNode)invoke, (Node)groupSizeNode);
                    break;
                }
                case "Direct#OpenCLIntrinsics.get_global_id": {
                    ConstantNode dimension = this.getConstantNodeFromArguments(invoke, 0);
                    GlobalThreadIdNode globalThreadId = (GlobalThreadIdNode)graph.addOrUnique((Node)new GlobalThreadIdNode(dimension));
                    graph.replaceFixed((FixedWithNextNode)invoke, (Node)globalThreadId);
                    break;
                }
                case "Direct#OpenCLIntrinsics.get_global_size": {
                    ConstantNode dimension = this.getConstantNodeFromArguments(invoke, 0);
                    GlobalThreadSizeNode globalSize = (GlobalThreadSizeNode)graph.addOrUnique((Node)new GlobalThreadSizeNode(dimension));
                    graph.replaceFixed((FixedWithNextNode)invoke, (Node)globalSize);
                    break;
                }
                case "Direct#OpenCLIntrinsics.get_group_id": {
                    ConstantNode dimension = this.getConstantNodeFromArguments(invoke, 0);
                    GroupIdNode groupIdNode = (GroupIdNode)graph.addOrUnique((Node)new GroupIdNode(dimension));
                    graph.replaceFixed((FixedWithNextNode)invoke, (Node)groupIdNode);
                    break;
                }
                case "Direct#OpenCLIntrinsics.printEmpty": {
                    OpenCLPrintf printfNode = (OpenCLPrintf)graph.addOrUnique((Node)new OpenCLPrintf("\"\""));
                    graph.replaceFixed((FixedWithNextNode)invoke, (Node)printfNode);
                }
            }
        }
    }

    private void lowerLocalInvokeNodeNewArray(StructuredGraph graph, int length, JavaKind elementKind, InvokeNode newArray) {
        ConstantNode newLengthNode = ConstantNode.forInt((int)length, (StructuredGraph)graph);
        ResolvedJavaType elementType = this.metaAccess.lookupJavaType(elementKind.toJavaClass());
        LocalArrayNode localArrayNode = (LocalArrayNode)graph.addWithoutUnique((Node)new LocalArrayNode(OCLArchitecture.localSpace, elementType, (ValueNode)newLengthNode));
        newArray.replaceAtUsages((Node)localArrayNode);
    }

    private void lowerPrivateInvokeNodeNewArray(StructuredGraph graph, int size, JavaKind elementKind, InvokeNode newArray) {
        ConstantNode newLengthNode = ConstantNode.forInt((int)size, (StructuredGraph)graph);
        ResolvedJavaType elementType = this.metaAccess.lookupJavaType(elementKind.toJavaClass());
        FixedArrayNode fixedArrayNode = (FixedArrayNode)graph.addWithoutUnique((Node)new FixedArrayNode(OCLArchitecture.privateSpace, elementType, newLengthNode));
        newArray.replaceAtUsages((Node)fixedArrayNode);
    }

    private void lowerInvokeNode(InvokeNode newArray) {
        CallTargetNode callTarget = newArray.callTarget();
        StructuredGraph graph = newArray.graph();
        ValueNode secondInput = (ValueNode)callTarget.arguments().get(1);
        if (secondInput instanceof ConstantNode) {
            ConstantNode lengthNode = (ConstantNode)secondInput;
            if (lengthNode.getValue() instanceof PrimitiveConstant) {
                int length = ((PrimitiveConstant)lengthNode.getValue()).asInt();
                JavaKind elementKind = this.getJavaKindFromConstantNode((ConstantNode)callTarget.arguments().get(0));
                int offset = this.metaAccess.getArrayBaseOffset(elementKind);
                int size = offset + elementKind.getByteCount() * length;
                if (OCLLoweringProvider.isGPUSnippet()) {
                    this.lowerLocalInvokeNodeNewArray(graph, length, elementKind, newArray);
                } else {
                    this.lowerPrivateInvokeNodeNewArray(graph, size, elementKind, newArray);
                }
                newArray.clearInputs();
                GraphUtil.unlinkFixedNode((FixedWithNextNode)newArray);
            } else {
                TornadoInternalError.shouldNotReachHere();
            }
        } else {
            TornadoInternalError.unimplemented((String)"dynamically sized array declarations are not supported");
        }
    }

    private JavaKind getJavaKindFromConstantNode(ConstantNode signatureNode) {
        switch (signatureNode.getValue().toValueString()) {
            case "Class:int": 
            case "Class:uk.ac.manchester.tornado.api.types.arrays.IntArray": {
                return JavaKind.Int;
            }
            case "Class:long": 
            case "Class:uk.ac.manchester.tornado.api.types.arrays.LongArray": {
                return JavaKind.Long;
            }
            case "Class:float": 
            case "Class:uk.ac.manchester.tornado.api.types.arrays.FloatArray": {
                return JavaKind.Float;
            }
            case "Class:double": 
            case "Class:uk.ac.manchester.tornado.api.types.arrays.DoubleArray": {
                return JavaKind.Double;
            }
        }
        TornadoInternalError.unimplemented((String)("Other types not supported yet: " + signatureNode.getValue().toValueString()));
        return null;
    }
}

