/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.runtime.graph;

import java.lang.reflect.Array;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import uk.ac.manchester.tornado.api.KernelContext;
import uk.ac.manchester.tornado.api.TornadoDeviceContext;
import uk.ac.manchester.tornado.api.common.Access;
import uk.ac.manchester.tornado.api.common.Event;
import uk.ac.manchester.tornado.api.common.SchedulableTask;
import uk.ac.manchester.tornado.api.common.TornadoDevice;
import uk.ac.manchester.tornado.api.exceptions.TornadoRuntimeException;
import uk.ac.manchester.tornado.api.profiler.ProfilerType;
import uk.ac.manchester.tornado.api.profiler.TornadoProfiler;
import uk.ac.manchester.tornado.api.types.arrays.TornadoNativeArray;
import uk.ac.manchester.tornado.api.types.collections.TornadoCollectionInterface;
import uk.ac.manchester.tornado.api.types.images.TornadoImagesInterface;
import uk.ac.manchester.tornado.api.types.matrix.TornadoMatrixInterface;
import uk.ac.manchester.tornado.api.types.vectors.TornadoVectorsInterface;
import uk.ac.manchester.tornado.api.types.volumes.TornadoVolumesInterface;
import uk.ac.manchester.tornado.runtime.common.KernelStackFrame;
import uk.ac.manchester.tornado.runtime.common.RuntimeUtilities;
import uk.ac.manchester.tornado.runtime.common.TornadoLogger;
import uk.ac.manchester.tornado.runtime.common.TornadoOptions;
import uk.ac.manchester.tornado.runtime.common.TornadoXPUDevice;
import uk.ac.manchester.tornado.runtime.common.XPUDeviceBufferState;
import uk.ac.manchester.tornado.runtime.common.enums.DataTypeSize;
import uk.ac.manchester.tornado.runtime.profiler.TimeProfiler;
import uk.ac.manchester.tornado.runtime.tasks.LocalObjectState;
import uk.ac.manchester.tornado.runtime.tasks.meta.ScheduleContext;

public class TornadoExecutionContext {
    public static int INIT_VALUE = -1;
    private final int MAX_TASKS = 256;
    private final int INITIAL_DEVICE_CAPACITY = 16;
    private final String name;
    private ScheduleContext meta;
    private KernelStackFrame[] kernelStackFrame;
    private List<SchedulableTask> tasks;
    private List<Object> constants;
    private Map<Integer, Integer> objectMap;
    private HashMap<Object, Access> objectsAccesses;
    private List<Object> objects;
    private List<Object> persistedObjects;
    private Map<String, List<Object>> persistedTaskToObjectsMap;
    private List<LocalObjectState> objectState;
    private List<TornadoXPUDevice> devices;
    private TornadoXPUDevice[] taskToDeviceMapTable;
    private int nextTask;
    private long batchSize;
    private long executionPlanMemoryLimit;
    private Set<TornadoXPUDevice> lastDevices;
    private boolean redeployOnDevice;
    private boolean defaultScheduler;
    private boolean isDataDependencyDetected;
    private TornadoProfiler profiler;
    private boolean isPrintKernel;
    private long executionPlanId;
    private long currentDeviceMemoryUsage;

    public TornadoExecutionContext(String id) {
        this.name = id;
        this.meta = new ScheduleContext(this.name);
        this.tasks = new ArrayList<SchedulableTask>();
        this.constants = new ArrayList<Object>();
        this.objectMap = new HashMap<Integer, Integer>();
        this.objects = new ArrayList<Object>();
        this.persistedObjects = new ArrayList<Object>();
        this.objectsAccesses = new HashMap();
        this.objectState = new ArrayList<LocalObjectState>();
        this.persistedTaskToObjectsMap = new HashMap<String, List<Object>>();
        this.devices = new ArrayList<TornadoXPUDevice>(16);
        this.kernelStackFrame = new KernelStackFrame[256];
        this.taskToDeviceMapTable = new TornadoXPUDevice[256];
        Arrays.fill(this.taskToDeviceMapTable, null);
        this.nextTask = 0;
        this.batchSize = INIT_VALUE;
        this.executionPlanMemoryLimit = INIT_VALUE;
        this.lastDevices = new HashSet<TornadoXPUDevice>();
        this.currentDeviceMemoryUsage = 0L;
        this.profiler = null;
        this.isDataDependencyDetected = this.isDataDependencyInTaskGraph();
    }

    public KernelStackFrame[] getKernelStackFrame() {
        return this.kernelStackFrame;
    }

    public int insertVariable(Object parameter, Access access) {
        int index;
        if (parameter.getClass().isPrimitive() || RuntimeUtilities.isBoxedPrimitiveClass(parameter.getClass())) {
            index = this.constants.indexOf(parameter);
            if (index == -1) {
                index = this.constants.size();
                this.constants.add(parameter);
            }
        } else if (this.objectMap.containsKey(parameter.hashCode())) {
            if (access.name().equals("READ_WRITE") && !this.objectsAccesses.get(parameter).name().equals(access.name())) {
                this.objectsAccesses.replace(parameter, access);
            }
            index = this.objectMap.get(parameter.hashCode());
        } else {
            index = this.objects.size();
            this.objects.add(parameter);
            this.objectsAccesses.put(parameter, access);
            this.objectMap.put(parameter.hashCode(), index);
            this.objectState.add(index, new LocalObjectState(parameter));
        }
        return index;
    }

    public long getBatchSize() {
        return this.batchSize;
    }

    public void setBatchSize(long size) {
        this.batchSize = size;
    }

    public long getExecutionPlanMemoryLimit() {
        return this.executionPlanMemoryLimit;
    }

    public void setExecutionPlanMemoryLimit(long memoryLimitSize) {
        this.executionPlanMemoryLimit = memoryLimitSize;
    }

    public boolean isMemoryLimited() {
        return this.getExecutionPlanMemoryLimit() != (long)INIT_VALUE;
    }

    public boolean doesExceedExecutionPlanLimit() {
        long totalSize = 0L;
        for (Object parameter : this.getObjects()) {
            if (parameter.getClass().isArray()) {
                Class<?> componentType = parameter.getClass().getComponentType();
                DataTypeSize dataTypeSize = DataTypeSize.findDataTypeSize(componentType);
                if (dataTypeSize == null) {
                    throw new TornadoRuntimeException("[UNSUPPORTED] Data type not supported for processing in batches");
                }
                long size = Array.getLength(parameter);
                totalSize += size * (long)dataTypeSize.getSize();
                continue;
            }
            if (parameter instanceof TornadoNativeArray) {
                TornadoNativeArray tornadoNativeArray = (TornadoNativeArray)parameter;
                totalSize += tornadoNativeArray.getNumBytesOfSegment();
                continue;
            }
            if (parameter instanceof TornadoVectorsInterface) {
                TornadoVectorsInterface tornadoVector = (TornadoVectorsInterface)parameter;
                totalSize += tornadoVector.getNumBytes();
                continue;
            }
            if (parameter instanceof TornadoCollectionInterface) {
                TornadoCollectionInterface collection = (TornadoCollectionInterface)parameter;
                totalSize += collection.getNumBytesWithHeader();
                continue;
            }
            if (parameter instanceof TornadoVolumesInterface) {
                TornadoVolumesInterface tornadoVolume = (TornadoVolumesInterface)parameter;
                totalSize += tornadoVolume.getNumBytesWithHeader();
                continue;
            }
            if (parameter instanceof TornadoMatrixInterface) {
                TornadoMatrixInterface tornadoMatrix = (TornadoMatrixInterface)parameter;
                totalSize += tornadoMatrix.getNumBytesWithHeader();
                continue;
            }
            if (parameter instanceof TornadoImagesInterface) {
                TornadoImagesInterface tornadoImage = (TornadoImagesInterface)parameter;
                totalSize += tornadoImage.getNumBytesWithHeader();
                continue;
            }
            if (parameter instanceof KernelContext || parameter instanceof AtomicInteger) continue;
            throw new TornadoRuntimeException("Unsupported type: " + String.valueOf(parameter.getClass()));
        }
        if (!this.constants.isEmpty()) {
            for (Object field : this.constants) {
                DataTypeSize dataTypeSize = DataTypeSize.findDataTypeSize(field.getClass());
                if (dataTypeSize == null) {
                    throw new TornadoRuntimeException("[UNSUPPORTED] Data type not supported for processing in batches");
                }
                totalSize += (long)dataTypeSize.getSize();
            }
        }
        return totalSize > this.getExecutionPlanMemoryLimit();
    }

    public int replaceVariable(Object oldObj, Object newObj) {
        int index;
        if (oldObj.getClass().isPrimitive() || RuntimeUtilities.isBoxedPrimitiveClass(oldObj.getClass())) {
            index = this.constants.indexOf(oldObj);
            this.constants.set(index, newObj);
        } else {
            int oldIndex = this.objectMap.get(oldObj.hashCode());
            LocalObjectState oldLocalObjectState = this.objectState.remove(oldIndex);
            this.objectMap.remove(oldObj.hashCode());
            this.objects.remove(oldIndex);
            LocalObjectState newLocalObjectState = new LocalObjectState(newObj);
            newLocalObjectState.setStreamIn(oldLocalObjectState.isStreamIn());
            newLocalObjectState.setForceStreamIn(oldLocalObjectState.isForcedStreamIn());
            newLocalObjectState.setStreamOut(oldLocalObjectState.isStreamOut());
            newLocalObjectState.setOnDevice(oldLocalObjectState.isOnDevice());
            index = oldIndex;
            this.objects.add(index, newObj);
            Access access = this.objectsAccesses.get(oldObj);
            this.objectsAccesses.remove(oldObj);
            this.objectsAccesses.put(newObj, access);
            this.objectMap.put(newObj.hashCode(), index);
            this.objectState.add(index, newLocalObjectState);
        }
        return index;
    }

    public int getTaskCount() {
        return this.nextTask;
    }

    public int getTaskCountAndIncrement() {
        int taskID = this.nextTask++;
        return taskID;
    }

    public int addTask(SchedulableTask task) {
        int index = this.tasks.indexOf(task);
        if (index == -1) {
            index = this.tasks.size();
            this.tasks.add(task);
        }
        return index;
    }

    public void addPersistedObject(Object object) {
        if (object != null) {
            this.persistedObjects.add(object);
        }
    }

    public List<Object> getPersistedObjects() {
        return this.persistedObjects;
    }

    public void setTask(int index, SchedulableTask task) {
        this.tasks.set(index, task);
    }

    public List<Object> getConstants() {
        return this.constants;
    }

    public List<Object> getObjects() {
        return this.objects;
    }

    public HashMap<Object, Access> getObjectsAccesses() {
        return this.objectsAccesses;
    }

    public TornadoXPUDevice getDeviceForTask(int index) {
        return this.taskToDeviceMapTable[index];
    }

    public TornadoXPUDevice getDevice(int index) {
        return this.devices.get(index);
    }

    public SchedulableTask getTask(int index) {
        return this.tasks.get(index);
    }

    public void apply(Consumer<SchedulableTask> consumer) {
        for (SchedulableTask task : this.tasks) {
            consumer.accept(task);
        }
    }

    public void mapAllTasksToSingleDevice(TornadoDevice tornadoDevice) {
        if (!(tornadoDevice instanceof TornadoXPUDevice)) {
            throw new TornadoRuntimeException("Device " + String.valueOf(tornadoDevice.getClass()) + " not supported yet");
        }
        TornadoXPUDevice tornadoAcceleratorDevice = (TornadoXPUDevice)tornadoDevice;
        this.devices.clear();
        this.devices.addFirst(tornadoAcceleratorDevice);
        this.apply(task -> task.setDevice(tornadoDevice));
        Arrays.fill(this.taskToDeviceMapTable, tornadoDevice);
    }

    public void setDevice(TornadoXPUDevice device) {
        if (!this.devices.contains(device)) {
            this.devices.add(device);
        }
    }

    private void assignTaskToDevice(int index, SchedulableTask task) {
        TornadoXPUDevice tornadoAcceleratorDevice;
        String id = task.getId();
        TornadoDevice target = task.getDevice();
        if (!(target instanceof TornadoXPUDevice)) {
            throw new TornadoRuntimeException("Device " + String.valueOf(target.getClass()) + " not supported yet");
        }
        TornadoXPUDevice accelerator = tornadoAcceleratorDevice = (TornadoXPUDevice)target;
        this.setDevice(accelerator);
        new TornadoLogger().info("assigning %s to %s", id, target.getDeviceName());
        this.taskToDeviceMapTable[index] = accelerator;
    }

    public void scheduleTaskToDevices() {
        if (!this.isDataDependencyDetected) {
            for (int i = 0; i < this.tasks.size(); ++i) {
                this.assignTaskToDevice(i, this.tasks.get(i));
            }
        } else {
            this.mapAllTasksToSingleDevice(this.getDeviceOfFirstTask());
        }
    }

    public int getValidContextSize() {
        return (int)this.getDevices().stream().filter(Objects::nonNull).count();
    }

    public TornadoDevice getDeviceOfFirstTask() {
        return this.tasks.get(0).getDevice();
    }

    public LocalObjectState getLocalStateObject(Object object, Access access) {
        return this.objectState.get(this.insertVariable(object, access));
    }

    @Deprecated
    public LocalObjectState replaceObjectState(Object oldObj, Object newObj) {
        return this.objectState.get(this.replaceVariable(oldObj, newObj));
    }

    private boolean isDataDependencyInTaskGraph() {
        for (int i = 0; i < this.tasks.size(); ++i) {
            SchedulableTask task = this.tasks.get(i);
            for (int j = i + 1; j < this.tasks.size(); ++j) {
                List<Object> commonArgs;
                SchedulableTask otherTask = this.tasks.get(j);
                if (this.doTasksHaveSameIDs(task, otherTask) || (commonArgs = this.getCommonArgumentsInTasks(task, otherTask)) == Collections.emptyList() || !this.hasWriteAccess(task, otherTask)) continue;
                return true;
            }
        }
        return false;
    }

    private boolean doTasksHaveSameIDs(SchedulableTask task1, SchedulableTask task2) {
        return task1.getTaskName().equals(task2.getTaskName()) && task1.getId().equals(task2.getId());
    }

    private List<Object> getCommonArgumentsInTasks(SchedulableTask task1, SchedulableTask task2) {
        ArrayList<Object> commonArguments = new ArrayList<Object>();
        block0: for (Object arg1 : task1.getArguments()) {
            for (Object arg2 : task2.getArguments()) {
                if (!arg1.equals(arg2)) continue;
                commonArguments.add(arg1);
                continue block0;
            }
        }
        return commonArguments;
    }

    private boolean hasWriteAccess(SchedulableTask task, SchedulableTask otherTask) {
        for (int i = 0; i < task.getArguments().length; ++i) {
            Access access;
            if (!task.getArguments()[i].equals(otherTask.getArguments()[i]) || (access = task.getArgumentsAccess()[i]) != Access.WRITE_ONLY && access != Access.READ_WRITE) continue;
            return true;
        }
        return false;
    }

    public List<LocalObjectState> getObjectStates() {
        return this.objectState;
    }

    public List<SchedulableTask> getTasks() {
        return this.tasks;
    }

    public List<TornadoXPUDevice> getDevices() {
        return this.devices;
    }

    public Deque<Integer> getActiveDeviceIndexes() {
        ArrayDeque<Integer> nonNullIndexes = new ArrayDeque<Integer>();
        for (int i = this.devices.size() - 1; i >= 0; --i) {
            TornadoXPUDevice device = this.devices.get(i);
            if (device == null) continue;
            nonNullIndexes.push(i);
        }
        return nonNullIndexes;
    }

    public List<SchedulableTask> getTasksForDevice(TornadoDeviceContext deviceContext) {
        ArrayList<SchedulableTask> tasksForDevice = new ArrayList<SchedulableTask>();
        for (SchedulableTask task : this.tasks) {
            task.getDevice().getBackendIndex();
            if (task.getDevice().getDeviceContext() != deviceContext) continue;
            tasksForDevice.add(task);
        }
        return tasksForDevice;
    }

    @Deprecated
    public TornadoXPUDevice getDefaultDevice() {
        return this.meta.getXPUDevice();
    }

    public SchedulableTask getTask(String id) {
        for (SchedulableTask task : this.tasks) {
            String canonicalId = this.canonicalizeId(id);
            if (!task.getId().equalsIgnoreCase(canonicalId)) continue;
            return task;
        }
        return null;
    }

    private String canonicalizeId(String id) {
        return id.startsWith(this.getId()) ? id : this.getId() + "." + id;
    }

    public TornadoXPUDevice getDeviceForTask(String id) {
        TornadoDevice device = this.getTask(id).getDevice();
        if (!(device instanceof TornadoXPUDevice)) {
            throw new RuntimeException("Device " + String.valueOf(device.getClass()) + " not supported yet");
        }
        TornadoXPUDevice tornadoDevice = (TornadoXPUDevice)device;
        return this.getTask(id) == null ? null : tornadoDevice;
    }

    public String getId() {
        return this.name;
    }

    public ScheduleContext meta() {
        return this.meta;
    }

    public void sync() {
        for (int i = 0; i < this.objects.size(); ++i) {
            Object object = this.objects.get(i);
            if (object == null) continue;
            LocalObjectState localState = this.objectState.get(i);
            Event event = localState.sync(this.executionPlanId, object, this.meta().getXPUDevice());
            if (!TornadoOptions.isProfilerEnabled() || event == null) continue;
            long value = this.profiler.getTimer(ProfilerType.COPY_OUT_TIME_SYNC);
            this.profiler.setTimer(ProfilerType.COPY_OUT_TIME_SYNC, value += event.getElapsedTime());
            XPUDeviceBufferState deviceObjectState = localState.getDataObjectState().getDeviceBufferState(this.meta().getXPUDevice());
            this.profiler.addValueToMetric(ProfilerType.COPY_OUT_SIZE_BYTES_SYNC, TimeProfiler.NO_TASK_NAME, deviceObjectState.getXPUBuffer().size());
        }
    }

    public void addLastDevice(TornadoXPUDevice device) {
        this.lastDevices.add(device);
    }

    public Set<TornadoXPUDevice> getLastDevices() {
        return this.lastDevices;
    }

    public void newCallWrapper(boolean newCallWrapper) {
        this.redeployOnDevice = newCallWrapper;
    }

    public boolean redeployOnDevice() {
        return this.redeployOnDevice;
    }

    public void setDefaultThreadScheduler(boolean use) {
        this.defaultScheduler = use;
    }

    public boolean useDefaultThreadScheduler() {
        return this.defaultScheduler;
    }

    public void dumpExecutionContextMeta() {
        int i;
        String ansiReset = "\u001b[0m";
        String ansiCyan = "\u001b[36m";
        String ansiYellow = "\u001b[33m";
        String ansiPurple = "\u001b[35m";
        String ansiGreen = "\u001b[32m";
        System.out.println("-----------------------------------");
        System.out.println("\u001b[36mDevice Table:\u001b[0m");
        for (i = 0; i < this.devices.size(); ++i) {
            System.out.printf("[%d]: %s\n", i, this.devices.get(i));
        }
        System.out.println("\u001b[33mConstant Table:\u001b[0m");
        for (i = 0; i < this.constants.size(); ++i) {
            System.out.printf("[%d]: %s\n", i, this.constants.get(i));
        }
        System.out.println("\u001b[35mObject Table:\u001b[0m");
        for (i = 0; i < this.objects.size(); ++i) {
            Object obj = this.objects.get(i);
            System.out.printf("[%d]: 0x%x %s\n", i, obj.hashCode(), obj);
        }
        System.out.println("\u001b[32mTask Table:\u001b[0m");
        for (i = 0; i < this.tasks.size(); ++i) {
            SchedulableTask task = this.tasks.get(i);
            System.out.printf("[%d]: %s\n", i, task.getFullName());
        }
        System.out.println("-----------------------------------");
    }

    public void withProfiler(TornadoProfiler timeProfiler) {
        this.profiler = timeProfiler;
    }

    public TornadoExecutionContext clone() {
        TornadoExecutionContext newExecutionContext = new TornadoExecutionContext(this.getId());
        newExecutionContext.tasks = new ArrayList<SchedulableTask>(this.tasks);
        newExecutionContext.kernelStackFrame = (KernelStackFrame[])this.kernelStackFrame.clone();
        newExecutionContext.constants = new ArrayList<Object>(this.constants);
        newExecutionContext.objectMap = new HashMap<Integer, Integer>(this.objectMap);
        newExecutionContext.objectsAccesses = new HashMap<Object, Access>(this.objectsAccesses);
        newExecutionContext.objects = new ArrayList<Object>(this.objects);
        newExecutionContext.persistedObjects = new ArrayList<Object>(this.persistedObjects);
        newExecutionContext.persistedTaskToObjectsMap = new HashMap<String, List<Object>>(this.persistedTaskToObjectsMap);
        ArrayList<LocalObjectState> objectStateCopy = new ArrayList<LocalObjectState>();
        for (LocalObjectState localObjectState : this.objectState) {
            objectStateCopy.add(localObjectState.clone());
        }
        newExecutionContext.objectState = objectStateCopy;
        newExecutionContext.devices = new ArrayList<TornadoXPUDevice>(this.devices);
        newExecutionContext.taskToDeviceMapTable = (TornadoXPUDevice[])this.taskToDeviceMapTable.clone();
        newExecutionContext.lastDevices = new HashSet<TornadoXPUDevice>(this.lastDevices);
        newExecutionContext.isPrintKernel = this.isPrintKernel;
        newExecutionContext.profiler = this.profiler;
        newExecutionContext.nextTask = this.nextTask;
        newExecutionContext.executionPlanMemoryLimit = this.executionPlanMemoryLimit;
        return newExecutionContext;
    }

    public long getExecutionPlanId() {
        return this.executionPlanId;
    }

    public void setExecutionPlanId(long executionPlanId) {
        this.executionPlanId = executionPlanId;
    }

    public long getCurrentDeviceMemoryUsage() {
        return this.currentDeviceMemoryUsage;
    }

    public void setCurrentDeviceMemoryUsage(long currentDeviceMemoryUsage) {
        this.currentDeviceMemoryUsage = currentDeviceMemoryUsage;
    }

    public void addPersistedObject(String taskgraphUniqueName, Object value) {
        this.persistedTaskToObjectsMap.computeIfAbsent(taskgraphUniqueName, k -> new ArrayList()).add(value);
    }

    public Map<String, List<Object>> getPersistedTaskToObjectsMap() {
        return this.persistedTaskToObjectsMap;
    }
}

