/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.drivers.opencl.graal.phases;

import java.util.Optional;
import org.graalvm.compiler.nodes.GraphState;
import org.graalvm.compiler.nodes.NodeView;
import org.graalvm.compiler.nodes.StructuredGraph;
import org.graalvm.compiler.nodes.memory.ReadNode;
import org.graalvm.compiler.phases.BasePhase;
import org.graalvm.compiler.phases.Phase;
import uk.ac.manchester.tornado.api.TornadoDeviceContext;
import uk.ac.manchester.tornado.api.TornadoTargetDevice;
import uk.ac.manchester.tornado.api.exceptions.TornadoDeviceFP16NotSupported;
import uk.ac.manchester.tornado.drivers.opencl.OCLDevice;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.OCLDecompressedReadFieldNode;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.OCLFPBinaryIntrinsicNode;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.ReadHalfFloatNode;
import uk.ac.manchester.tornado.drivers.opencl.virtual.VirtualOCLDevice;

public class OCLFP16SupportPhase
extends Phase {
    private TornadoDeviceContext deviceContext;

    public OCLFP16SupportPhase(TornadoDeviceContext deviceContext) {
        this.deviceContext = deviceContext;
    }

    public Optional<BasePhase.NotApplicable> notApplicableTo(GraphState graphState) {
        return ALWAYS_APPLICABLE;
    }

    private boolean isMinMaxOperation(String operation) {
        return operation.equals("FMIN") || operation.equals("FMAX");
    }

    private boolean operatesOnHalfFloat(OCLFPBinaryIntrinsicNode node) {
        return node.getX() instanceof ReadHalfFloatNode || node.getY() instanceof ReadHalfFloatNode;
    }

    protected void run(StructuredGraph graph) {
        boolean fp16Support = false;
        String extensions = null;
        TornadoTargetDevice tornadoTargetDevice = this.deviceContext.getDevice();
        if (tornadoTargetDevice instanceof OCLDevice) {
            OCLDevice oclDevice = (OCLDevice)tornadoTargetDevice;
            extensions = oclDevice.getDeviceExtensions();
        } else {
            tornadoTargetDevice = this.deviceContext.getDevice();
            if (tornadoTargetDevice instanceof VirtualOCLDevice) {
                VirtualOCLDevice virtualOCLDevice = (VirtualOCLDevice)tornadoTargetDevice;
                extensions = virtualOCLDevice.getDeviceExtensions();
            }
        }
        if (extensions != null && extensions.contains("cl_khr_fp16")) {
            fp16Support = true;
        }
        for (OCLDecompressedReadFieldNode decompressedField : graph.getNodes().filter(OCLDecompressedReadFieldNode.class)) {
            if (!decompressedField.getObject().stamp(NodeView.DEFAULT).toString().contains("VectorHalf") || fp16Support) continue;
            throw new TornadoDeviceFP16NotSupported("The current OpenCL device (" + this.deviceContext.getDeviceName() + ") does not support FP16");
        }
        for (OCLFPBinaryIntrinsicNode binaryIntrinsicNode : graph.getNodes().filter(OCLFPBinaryIntrinsicNode.class)) {
            String operation = binaryIntrinsicNode.getOperation();
            if (!this.isMinMaxOperation(operation) || !this.operatesOnHalfFloat(binaryIntrinsicNode) || fp16Support) continue;
            throw new TornadoDeviceFP16NotSupported("The current OpenCL device (" + this.deviceContext.getDeviceName() + ") does not support the " + binaryIntrinsicNode.getOperation() + " operation for FP16 types");
        }
        for (ReadNode readNode : graph.getNodes().filter(ReadNode.class)) {
            if (!readNode.getLocationIdentity().toString().contains("VectorHalf") || fp16Support) continue;
            throw new TornadoDeviceFP16NotSupported("The current OpenCL device (" + this.deviceContext.getDeviceName() + ") does not support FP16");
        }
    }
}

