/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.unittests.common;

import org.junit.Before;
import uk.ac.manchester.tornado.api.TornadoBackend;
import uk.ac.manchester.tornado.api.TornadoRuntime;
import uk.ac.manchester.tornado.api.common.TornadoDevice;
import uk.ac.manchester.tornado.api.enums.TornadoVMBackendType;
import uk.ac.manchester.tornado.api.runtime.TornadoRuntimeProvider;
import uk.ac.manchester.tornado.unittests.common.SPIRVOptNotSupported;
import uk.ac.manchester.tornado.unittests.common.TornadoVMOpenCLNotSupported;
import uk.ac.manchester.tornado.unittests.common.TornadoVMPTXNotSupported;
import uk.ac.manchester.tornado.unittests.common.TornadoVMSPIRVNotSupported;
import uk.ac.manchester.tornado.unittests.tools.TornadoHelper;

public abstract class TornadoTestBase {
    public static final float DELTA = 0.001f;
    public static final float DELTA_05 = 0.5f;
    protected static boolean wasDeviceInspected = false;

    public static TornadoRuntime getTornadoRuntime() {
        return TornadoRuntimeProvider.getTornadoRuntime();
    }

    @Before
    public void before() {
        int deviceIndex;
        for (int backendIndex = 0; backendIndex < TornadoRuntimeProvider.getTornadoRuntime().getNumBackends(); ++backendIndex) {
            TornadoBackend driver = TornadoRuntimeProvider.getTornadoRuntime().getBackend(backendIndex);
            for (deviceIndex = 0; deviceIndex < driver.getNumDevices(); ++deviceIndex) {
                driver.getDevice(deviceIndex).clean();
            }
        }
        if (!wasDeviceInspected && !this.getVirtualDeviceEnabled()) {
            Tuple2<Integer, Integer> pairDriverDevice = this.getDriverAndDeviceIndex();
            int driverIndex = pairDriverDevice.f0();
            if (driverIndex != 0) {
                TornadoRuntimeProvider.getTornadoRuntime().setDefaultBackend(driverIndex);
            }
            if ((deviceIndex = pairDriverDevice.f1().intValue()) != 0) {
                TornadoBackend driver = TornadoRuntimeProvider.getTornadoRuntime().getBackend(0);
                driver.setDefaultDevice(deviceIndex);
            }
            wasDeviceInspected = true;
        }
    }

    private boolean getVirtualDeviceEnabled() {
        return Boolean.parseBoolean(System.getProperty("tornado.virtual.device", "False"));
    }

    protected Tuple2<Integer, Integer> getDriverAndDeviceIndex() {
        String driverAndDevice = System.getProperty("tornado.unittests.device", "0:0");
        String[] propertyValues = driverAndDevice.split(":");
        return new Tuple2<Integer, Integer>(Integer.parseInt(propertyValues[0]), Integer.parseInt(propertyValues[1]));
    }

    public void assertNotBackend(TornadoVMBackendType backend) {
        this.assertNotBackend(backend, null);
    }

    public void assertNotBackend(TornadoVMBackendType backend, String customBackendAssertionMessage) {
        int driverIndex = TornadoTestBase.getTornadoRuntime().getDefaultDevice().getBackendIndex();
        if (TornadoTestBase.getTornadoRuntime().getBackendType(driverIndex) == backend) {
            switch (backend) {
                case PTX: {
                    throw new TornadoVMPTXNotSupported(customBackendAssertionMessage != null ? customBackendAssertionMessage : "Test not supported for the PTX backend");
                }
                case OPENCL: {
                    throw new TornadoVMOpenCLNotSupported(customBackendAssertionMessage != null ? customBackendAssertionMessage : "Test not supported for the OpenCL backend");
                }
                case SPIRV: {
                    throw new TornadoVMSPIRVNotSupported(customBackendAssertionMessage != null ? customBackendAssertionMessage : "Test not supported for the SPIR-V backend");
                }
            }
            throw new IllegalStateException("Unexpected value for backend: " + String.valueOf(backend));
        }
    }

    public void assertNotBackendOptimization(TornadoVMBackendType backend) {
        if (!TornadoHelper.OPTIMIZE_LOAD_STORE_SPIRV) {
            return;
        }
        int driverIndex = TornadoTestBase.getTornadoRuntime().getDefaultDevice().getBackendIndex();
        if (TornadoTestBase.getTornadoRuntime().getBackendType(driverIndex) == backend && backend == TornadoVMBackendType.SPIRV) {
            throw new SPIRVOptNotSupported("Test not supported for the optimized SPIR-V BACKEND");
        }
    }

    private void assertIfNeeded(TornadoDevice device, int driverIndex) {
        TornadoVMBackendType backendType = TornadoRuntimeProvider.getTornadoRuntime().getBackend(driverIndex).getBackendType();
        if (backendType != TornadoVMBackendType.OPENCL || !device.isSPIRVSupported()) {
            this.assertNotBackend(TornadoVMBackendType.OPENCL);
        }
    }

    protected TornadoDevice getSPIRVSupportedDevice() {
        Tuple2<Integer, Integer> driverAndDeviceIndex = this.getDriverAndDeviceIndex();
        if (driverAndDeviceIndex.f0() != 0) {
            TornadoBackend driver = TornadoTestBase.getTornadoRuntime().getBackend(0);
            TornadoDevice device = driver.getDevice(0);
            this.assertIfNeeded(device, 0);
            return device;
        }
        int numDrivers = TornadoTestBase.getTornadoRuntime().getNumBackends();
        for (int driverIndex = 0; driverIndex < numDrivers; ++driverIndex) {
            TornadoBackend driver = TornadoTestBase.getTornadoRuntime().getBackend(driverIndex);
            if (driver.getBackendType() == TornadoVMBackendType.PTX) continue;
            int maxDevices = driver.getNumDevices();
            for (int i = 0; i < maxDevices; ++i) {
                TornadoDevice device = driver.getDevice(i);
                if (!device.isSPIRVSupported()) continue;
                return device;
            }
        }
        return null;
    }

    protected static class Tuple2<T0, T1> {
        T0 t0;
        T1 t1;

        public Tuple2(T0 first, T1 second) {
            this.t0 = first;
            this.t1 = second;
        }

        public T0 f0() {
            return this.t0;
        }

        public T1 f1() {
            return this.t1;
        }
    }
}

