/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.drivers.opencl;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.EnumMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import jdk.vm.ci.hotspot.HotSpotJVMCIRuntime;
import org.graalvm.compiler.options.OptionValues;
import org.graalvm.compiler.phases.util.Providers;
import uk.ac.manchester.tornado.api.common.TornadoDevice;
import uk.ac.manchester.tornado.api.enums.TornadoDeviceType;
import uk.ac.manchester.tornado.api.enums.TornadoVMBackendType;
import uk.ac.manchester.tornado.api.exceptions.TornadoBailoutRuntimeException;
import uk.ac.manchester.tornado.api.exceptions.TornadoDeviceNotFound;
import uk.ac.manchester.tornado.drivers.opencl.OCLContextInterface;
import uk.ac.manchester.tornado.drivers.opencl.OCLTargetDevice;
import uk.ac.manchester.tornado.drivers.opencl.OpenCL;
import uk.ac.manchester.tornado.drivers.opencl.TornadoPlatformInterface;
import uk.ac.manchester.tornado.drivers.opencl.enums.OCLDeviceType;
import uk.ac.manchester.tornado.drivers.opencl.graal.OCLHotSpotBackendFactory;
import uk.ac.manchester.tornado.drivers.opencl.graal.OCLSuitesProvider;
import uk.ac.manchester.tornado.drivers.opencl.graal.backend.OCLBackend;
import uk.ac.manchester.tornado.runtime.TornadoAcceleratorBackend;
import uk.ac.manchester.tornado.runtime.TornadoVMConfigAccess;
import uk.ac.manchester.tornado.runtime.common.TornadoLogger;
import uk.ac.manchester.tornado.runtime.common.TornadoXPUDevice;

public final class OCLBackendImpl
implements TornadoAcceleratorBackend {
    private static final List<OCLDeviceType> DEVICE_TYPE_LIST = Arrays.asList(OCLDeviceType.CL_DEVICE_TYPE_GPU, OCLDeviceType.CL_DEVICE_TYPE_CPU, OCLDeviceType.CL_DEVICE_TYPE_ACCELERATOR, OCLDeviceType.CL_DEVICE_TYPE_CUSTOM);
    private final OCLBackend[][] backends;
    private final List<OCLContextInterface> contexts;
    private OCLBackend[] flatBackends;
    private volatile List<TornadoDevice> devices;
    private final TornadoLogger logger;

    public OCLBackendImpl(OptionValues options, HotSpotJVMCIRuntime vmRuntime, TornadoVMConfigAccess vmConfig) {
        int numPlatforms = OpenCL.getNumPlatforms();
        if (numPlatforms < 1) {
            throw new TornadoBailoutRuntimeException("[WARNING] No OpenCL platforms found. Deoptimizing to sequential execution.");
        }
        this.backends = new OCLBackend[numPlatforms][];
        this.contexts = new ArrayList<OCLContextInterface>();
        this.logger = new TornadoLogger(this.getClass());
        this.discoverDevices(options, vmRuntime, vmConfig);
        this.flatBackends = this.flattenBackends(this.backends);
        this.flatBackends = this.orderFlattenBackends();
    }

    private OCLBackend[] flattenBackends(OCLBackend[][] backends) {
        OCLBackend[] flatBackendList = new OCLBackend[this.getNumDevices()];
        int index = 0;
        for (int i = 0; i < this.getNumPlatforms(); ++i) {
            int j = 0;
            while (j < this.getNumDevices(i)) {
                flatBackendList[index] = backends[i][j];
                ++j;
                ++index;
            }
        }
        return flatBackendList;
    }

    private OCLBackend[] orderFlattenBackends() {
        ArrayList backendList = new ArrayList();
        EnumMap<OCLDeviceType, List> deviceTypeMap = new EnumMap<OCLDeviceType, List>(OCLDeviceType.class);
        for (OCLBackend backend2 : this.flatBackends) {
            OCLDeviceType deviceType2 = backend2.getDeviceContext().getDevice().getDeviceType();
            List backendListForDeviceType = deviceTypeMap.computeIfAbsent(deviceType2, k -> new ArrayList());
            backendListForDeviceType.add(backend2);
        }
        for (OCLDeviceType deviceType3 : DEVICE_TYPE_LIST) {
            List backendListForDeviceType = (List)deviceTypeMap.get((Object)deviceType3);
            if (backendListForDeviceType == null) continue;
            backendList.addAll(backendListForDeviceType);
        }
        Map<OCLDeviceType, List<OCLBackend>> groupedByDeviceType = backendList.stream().collect(Collectors.groupingBy(backend -> backend.getDeviceContext().getDevice().getDeviceType()));
        groupedByDeviceType.forEach((deviceType, sublist) -> Collections.sort(sublist, (backend1, backend2) -> {
            long size1 = backend1.getDeviceContext().getDevice().getDeviceContext().getDevice().getMaxThreadsPerBlock();
            long size2 = backend2.getDeviceContext().getDevice().getDeviceContext().getDevice().getMaxThreadsPerBlock();
            return Long.compare(size2, size1);
        }));
        ArrayList<OCLBackend> sortedBackends = new ArrayList<OCLBackend>();
        for (OCLDeviceType deviceType4 : DEVICE_TYPE_LIST) {
            List<OCLBackend> backendsOfType = groupedByDeviceType.get((Object)deviceType4);
            if (backendsOfType == null) continue;
            sortedBackends.addAll(backendsOfType);
        }
        backendList = sortedBackends;
        return backendList.toArray(new OCLBackend[0]);
    }

    public TornadoXPUDevice getDefaultDevice() {
        return this.flatBackends[0].getDeviceContext().toDevice();
    }

    public void setDefaultDevice(int index) {
        this.swapDefaultDevice(index);
    }

    public TornadoXPUDevice getDevice(int index) {
        if (index < this.flatBackends.length) {
            return this.flatBackends[index].getDeviceContext().toDevice();
        }
        throw new TornadoDeviceNotFound("[ERROR] device required not found: " + index + " - Max: " + this.flatBackends.length);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public List<TornadoDevice> getAllDevices() {
        if (this.devices == null) {
            OCLBackendImpl oCLBackendImpl = this;
            synchronized (oCLBackendImpl) {
                if (this.devices == null) {
                    this.devices = new ArrayList<TornadoDevice>();
                    for (int deviceIndex = 0; deviceIndex < this.getNumDevices(); ++deviceIndex) {
                        this.devices.add((TornadoDevice)this.getDevice(deviceIndex));
                    }
                }
            }
        }
        return this.devices;
    }

    public int getNumDevices() {
        return IntStream.range(0, this.getNumPlatforms()).map(this::getNumDevices).sum();
    }

    private OCLBackend checkAndInitBackend(int platform, int device) {
        OCLBackend backend = this.backends[platform][device];
        if (!backend.isInitialised()) {
            backend.init();
        }
        return backend;
    }

    private void swapDefaultDevice(int device) {
        OCLBackend tmp = this.flatBackends[0];
        this.flatBackends[0] = this.flatBackends[device];
        this.flatBackends[device] = tmp;
        OCLBackend backend = this.flatBackends[0];
        if (!backend.isInitialised()) {
            backend.init();
        }
    }

    private OCLBackend createOCLJITCompiler(OptionValues options, HotSpotJVMCIRuntime jvmciRuntime, TornadoVMConfigAccess vmConfig, OCLContextInterface context, int deviceIndex) {
        OCLTargetDevice device = context.devices().get(deviceIndex);
        this.logger.info("Creating backend for %s", new Object[]{device.getDeviceName()});
        return OCLHotSpotBackendFactory.createJITCompiler(options, jvmciRuntime, vmConfig, context, device);
    }

    private void installDevices(int platformIndex, TornadoPlatformInterface platform, OptionValues options, HotSpotJVMCIRuntime vmRuntime, TornadoVMConfigAccess vmConfig) {
        this.logger.info("OpenCL[%d]: Platform %s", new Object[]{platformIndex, platform.getName()});
        OCLContextInterface context = platform.createContext();
        assert (context != null) : "OpenCL context is null";
        this.contexts.add(context);
        int numDevices = context.getNumDevices();
        this.logger.info("OpenCL[%d]: Has %d devices...", new Object[]{platformIndex, numDevices});
        this.backends[platformIndex] = new OCLBackend[numDevices];
        for (int deviceIndex = 0; deviceIndex < numDevices; ++deviceIndex) {
            OCLTargetDevice device = context.devices().get(deviceIndex);
            this.logger.info("OpenCL[%d]: device=%s", new Object[]{platformIndex, device.getDeviceName()});
            this.backends[platformIndex][deviceIndex] = this.createOCLJITCompiler(options, vmRuntime, vmConfig, context, deviceIndex);
        }
    }

    private void discoverDevices(OptionValues options, HotSpotJVMCIRuntime vmRuntime, TornadoVMConfigAccess vmConfig) {
        IntStream.range(0, OpenCL.getNumPlatforms()).forEach(i -> {
            TornadoPlatformInterface platform = OpenCL.getPlatform(i);
            this.installDevices(i, platform, options, vmRuntime, vmConfig);
        });
    }

    public OCLBackend getBackend(int platform, int device) {
        return this.checkAndInitBackend(platform, device);
    }

    public OCLBackend getDefaultBackend() {
        return this.checkAndInitBackend(0, 0);
    }

    public int getNumDevices(int platform) {
        try {
            return this.backends[platform].length;
        }
        catch (NullPointerException e) {
            return 0;
        }
    }

    public int getNumPlatforms() {
        return this.backends.length;
    }

    public OCLContextInterface getPlatformContext(int index) {
        return index < this.contexts.size() ? this.contexts.get(index) : this.contexts.get(0);
    }

    public Providers getProviders() {
        return this.getDefaultBackend().getProviders();
    }

    public OCLSuitesProvider getSuitesProvider() {
        return this.getDefaultBackend().getTornadoSuites();
    }

    public String getName() {
        return "OpenCL";
    }

    public TornadoVMBackendType getBackendType() {
        return TornadoVMBackendType.OPENCL;
    }

    public TornadoDeviceType getTypeDefaultDevice() {
        return this.getDefaultDevice().getDeviceType();
    }
}

