/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.drivers.opencl.graal.phases;

import java.lang.reflect.Array;
import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.lang.runtime.SwitchBootstraps;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;
import jdk.vm.ci.meta.JavaKind;
import jdk.vm.ci.meta.ResolvedJavaField;
import org.graalvm.compiler.core.common.type.ObjectStamp;
import org.graalvm.compiler.graph.Graph;
import org.graalvm.compiler.graph.Node;
import org.graalvm.compiler.graph.iterators.NodeIterable;
import org.graalvm.compiler.nodes.ConstantNode;
import org.graalvm.compiler.nodes.FixedGuardNode;
import org.graalvm.compiler.nodes.FixedWithNextNode;
import org.graalvm.compiler.nodes.GraphState;
import org.graalvm.compiler.nodes.LogicConstantNode;
import org.graalvm.compiler.nodes.LogicNode;
import org.graalvm.compiler.nodes.NodeView;
import org.graalvm.compiler.nodes.ParameterNode;
import org.graalvm.compiler.nodes.PhiNode;
import org.graalvm.compiler.nodes.PiNode;
import org.graalvm.compiler.nodes.StructuredGraph;
import org.graalvm.compiler.nodes.ValueNode;
import org.graalvm.compiler.nodes.calc.IntegerLessThanNode;
import org.graalvm.compiler.nodes.calc.IsNullNode;
import org.graalvm.compiler.nodes.extended.UnboxNode;
import org.graalvm.compiler.nodes.java.ArrayLengthNode;
import org.graalvm.compiler.nodes.java.LoadFieldNode;
import org.graalvm.compiler.nodes.spi.CoreProviders;
import org.graalvm.compiler.nodes.util.GraphUtil;
import org.graalvm.compiler.phases.BasePhase;
import org.graalvm.compiler.phases.common.CanonicalizerPhase;
import org.graalvm.compiler.phases.common.DeadCodeEliminationPhase;
import uk.ac.manchester.tornado.api.exceptions.TornadoBailoutRuntimeException;
import uk.ac.manchester.tornado.api.exceptions.TornadoInternalError;
import uk.ac.manchester.tornado.api.types.HalfFloat;
import uk.ac.manchester.tornado.drivers.common.compiler.phases.analysis.TornadoValueTypeReplacement;
import uk.ac.manchester.tornado.drivers.common.compiler.phases.loops.TornadoLoopUnroller;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.OCLKernelContextAccessNode;
import uk.ac.manchester.tornado.runtime.TornadoCoreRuntime;
import uk.ac.manchester.tornado.runtime.common.RuntimeUtilities;
import uk.ac.manchester.tornado.runtime.common.TornadoLogger;
import uk.ac.manchester.tornado.runtime.graal.nodes.ParallelRangeNode;
import uk.ac.manchester.tornado.runtime.graal.phases.TornadoHighTierContext;

public class TornadoTaskSpecialisation
extends BasePhase<TornadoHighTierContext> {
    private static final int MAX_ITERATIONS = 15;
    private static final String WARNING_GRID_SCHEDULER_DYNAMIC_LOOP_BOUNDS = "[TornadoVM] Warning: The loop bounds will be configured by the GridScheduler. Check the grid by using the flag --threadInfo.";
    private final CanonicalizerPhase canonicalizer;
    private final TornadoValueTypeReplacement valueTypeReplacement;
    private final DeadCodeEliminationPhase deadCodeElimination;
    private final TornadoLoopUnroller loopUnroll;
    private long batchThreads;
    private boolean gridScheduling;
    private int index;
    private boolean printOnce = true;

    public TornadoTaskSpecialisation(CanonicalizerPhase canonicalizer) {
        this.canonicalizer = canonicalizer;
        this.valueTypeReplacement = new TornadoValueTypeReplacement();
        this.deadCodeElimination = new DeadCodeEliminationPhase();
        this.loopUnroll = new TornadoLoopUnroller(canonicalizer);
    }

    private static boolean hasPanamaArraySizeNode(StructuredGraph graph) {
        for (LoadFieldNode loadField : graph.getNodes().filter(LoadFieldNode.class)) {
            ResolvedJavaField field = loadField.field();
            if (!field.getType().getJavaKind().isPrimitive() || !loadField.toString().contains("numberOfElements")) continue;
            return true;
        }
        return false;
    }

    private Field lookupField(Class<?> type, String field) {
        Field f = null;
        try {
            f = type.getDeclaredField(field);
            if (!f.isAccessible()) {
                f.setAccessible(true);
            }
        }
        catch (NoSuchFieldException | SecurityException e) {
            if (type.getSuperclass() != null) {
                f = this.lookupField(type.getSuperclass(), field);
            }
            e.printStackTrace();
        }
        return f;
    }

    public Optional<BasePhase.NotApplicable> notApplicableTo(GraphState graphState) {
        return ALWAYS_APPLICABLE;
    }

    private <T> T lookup(Object object, FunctionThatThrows<Object, T> function) throws IllegalArgumentException, IllegalAccessException {
        return function.apply(object);
    }

    private Object lookupRefField(StructuredGraph graph, Node node, Object obj, String field) {
        Object result;
        Class<?> type = obj.getClass();
        Field f = this.lookupField(type, field);
        try {
            result = f.get(obj);
        }
        catch (IllegalAccessException | IllegalArgumentException e) {
            throw new RuntimeException(e);
        }
        return result;
    }

    private ConstantNode lookupPrimField(StructuredGraph graph, Node node, Object obj, String field, JavaKind kind) {
        Class<?> type = obj.getClass();
        Field f = this.lookupField(type, field);
        ConstantNode constant = null;
        try {
            switch (kind) {
                case Boolean: {
                    constant = ConstantNode.forBoolean((boolean)this.lookup(obj, f::getBoolean));
                    break;
                }
                case Byte: {
                    constant = ConstantNode.forByte((byte)this.lookup(obj, f::getByte), (StructuredGraph)graph);
                    break;
                }
                case Char: {
                    constant = ConstantNode.forChar((char)this.lookup(obj, f::getChar).charValue(), (StructuredGraph)graph);
                    break;
                }
                case Double: {
                    constant = ConstantNode.forDouble((double)this.lookup(obj, f::getDouble));
                    break;
                }
                case Float: {
                    constant = ConstantNode.forFloat((float)this.lookup(obj, f::getFloat).floatValue());
                    break;
                }
                case Int: {
                    constant = ConstantNode.forInt((int)this.lookup(obj, f::getInt));
                    break;
                }
                case Long: {
                    constant = ConstantNode.forLong((long)this.lookup(obj, f::getLong));
                    break;
                }
                case Short: {
                    constant = ConstantNode.forShort((short)this.lookup(obj, f::getShort), (StructuredGraph)graph);
                    break;
                }
                case Object: {
                    if (!Modifier.isFinal(f.getModifiers())) break;
                    Object value = this.lookup(obj, f::get);
                    node.usages().filter(LoadFieldNode.class).forEach(load -> this.evaluate(graph, (Node)load, value));
                    node.usages().filter(ArrayLengthNode.class).forEach(arrayLength -> this.evaluate(graph, (Node)arrayLength, value));
                    break;
                }
            }
        }
        catch (IllegalAccessException | IllegalArgumentException e) {
            e.printStackTrace();
        }
        return constant;
    }

    private void printWarningMessageForDynamicLoopBounds() {
        if (this.printOnce) {
            System.out.println(WARNING_GRID_SCHEDULER_DYNAMIC_LOOP_BOUNDS);
            this.printOnce = false;
        }
    }

    private void evaluate(StructuredGraph graph, Node node, Object value) {
        if (node instanceof ArrayLengthNode) {
            ArrayLengthNode arrayLength = (ArrayLengthNode)node;
            int length = Array.getLength(value);
            if (this.gridScheduling && this.isParameterInvolvedInParallelLoopBound(node)) {
                this.printWarningMessageForDynamicLoopBounds();
                ConstantNode constantValue = (ConstantNode)graph.addOrUnique((Node)ConstantNode.forInt((int)this.index));
                OCLKernelContextAccessNode kernelContextAccessNode = (OCLKernelContextAccessNode)graph.addOrUnique((Node)new OCLKernelContextAccessNode(constantValue));
                node.replaceAtUsages((Node)kernelContextAccessNode);
                ++this.index;
            } else {
                ConstantNode constant = this.batchThreads <= 0L ? ConstantNode.forInt((int)length) : ConstantNode.forInt((int)((int)this.batchThreads));
                node.replaceAtUsages(graph.addOrUnique((Node)constant));
            }
            arrayLength.clearInputs();
            GraphUtil.removeFixedWithUnusedInputs((FixedWithNextNode)arrayLength);
        } else if (node instanceof LoadFieldNode) {
            LoadFieldNode loadField = (LoadFieldNode)node;
            ResolvedJavaField field = loadField.field();
            if (field.getType().getJavaKind().isPrimitive()) {
                ConstantNode constant = node.toString().contains("numberOfElements") ? (this.batchThreads <= 0L ? this.lookupPrimField(graph, node, value, field.getName(), field.getJavaKind()) : ConstantNode.forInt((int)((int)this.batchThreads))) : this.lookupPrimField(graph, node, value, field.getName(), field.getJavaKind());
                constant = (ConstantNode)graph.addOrUnique((Node)constant);
                node.replaceAtUsages((Node)constant);
                loadField.clearInputs();
                graph.removeFixed((FixedWithNextNode)loadField);
            } else if (field.isFinal()) {
                Object object = this.lookupRefField(graph, node, value, field.getName());
                node.usages().forEach(n -> this.evaluate(graph, (Node)n, object));
            } else if (!field.isFinal()) {
                throw new TornadoBailoutRuntimeException("Non-final objects introduced via scope are not supported");
            }
        } else if (node instanceof IsNullNode) {
            boolean isNull;
            IsNullNode isNullNode = (IsNullNode)node;
            boolean bl = isNull = value == null;
            if (isNull) {
                isNullNode.replaceAtUsages((Node)LogicConstantNode.tautology((Graph)graph));
            } else {
                isNullNode.replaceAtUsages((Node)LogicConstantNode.contradiction((Graph)graph));
            }
            isNullNode.safeDelete();
        } else if (node instanceof PiNode) {
            PiNode piNode = (PiNode)node;
            piNode.replaceAtUsages((Node)piNode.getOriginalNode());
            piNode.safeDelete();
        }
    }

    private ConstantNode createPrimitiveConstantFromObjectParameter(Object obj, StructuredGraph graph) {
        Object object = obj;
        Objects.requireNonNull(object);
        Object object2 = object;
        int n = 0;
        return switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{Boolean.class, Byte.class, Character.class, Short.class, HalfFloat.class, Integer.class, Float.class, Double.class, Long.class, Object.class}, (Object)object2, n)) {
            case 0 -> {
                Boolean objBoolean = (Boolean)object2;
                yield ConstantNode.forBoolean((boolean)objBoolean, (StructuredGraph)graph);
            }
            case 1 -> {
                Byte objByte = (Byte)object2;
                yield ConstantNode.forByte((byte)objByte, (StructuredGraph)graph);
            }
            case 2 -> {
                Character objChar = (Character)object2;
                yield ConstantNode.forChar((char)objChar.charValue(), (StructuredGraph)graph);
            }
            case 3 -> {
                Short objShort = (Short)object2;
                yield ConstantNode.forShort((short)objShort, (StructuredGraph)graph);
            }
            case 4 -> {
                HalfFloat objHalfFloat = (HalfFloat)object2;
                yield ConstantNode.forFloat((float)objHalfFloat.getFloat32(), (StructuredGraph)graph);
            }
            case 5 -> {
                Integer objInteger = (Integer)object2;
                yield ConstantNode.forInt((int)objInteger, (StructuredGraph)graph);
            }
            case 6 -> {
                Float objFloat = (Float)object2;
                yield ConstantNode.forFloat((float)objFloat.floatValue(), (StructuredGraph)graph);
            }
            case 7 -> {
                Double objDouble = (Double)object2;
                yield ConstantNode.forDouble((double)objDouble, (StructuredGraph)graph);
            }
            case 8 -> {
                Long objLong = (Long)object2;
                yield ConstantNode.forLong((long)objLong, (StructuredGraph)graph);
            }
            default -> {
                Object object = object2;
                TornadoInternalError.unimplemented((String)"createPrimitiveConstantFromObjectParameter: %s", (Object[])new Object[]{obj});
                yield null;
            }
        };
    }

    private boolean isParameterInvolvedInParallelLoopBound(Node parameterNode) {
        AtomicBoolean parameterInLoopBound = new AtomicBoolean(false);
        parameterNode.usages().snapshot().forEach(node -> {
            if (node instanceof ParallelRangeNode) {
                parameterInLoopBound.set(true);
            }
        });
        return parameterInLoopBound.get();
    }

    private void propagateParameters(StructuredGraph graph, ParameterNode parameterNode, Object[] args) {
        if (args[parameterNode.index()] != null && RuntimeUtilities.isBoxedPrimitiveClass(args[parameterNode.index()].getClass())) {
            if (this.gridScheduling && this.isParameterInvolvedInParallelLoopBound((Node)parameterNode)) {
                this.printWarningMessageForDynamicLoopBounds();
                ConstantNode constantValue = (ConstantNode)graph.addOrUnique((Node)ConstantNode.forInt((int)this.index));
                OCLKernelContextAccessNode kernelContextAccessNode = (OCLKernelContextAccessNode)graph.addOrUnique((Node)new OCLKernelContextAccessNode(constantValue));
                parameterNode.replaceAtUsages((Node)kernelContextAccessNode);
                ++this.index;
            } else {
                Object value = args[parameterNode.index()];
                ConstantNode primitiveConstant = this.createPrimitiveConstantFromObjectParameter(value, graph);
                parameterNode.replaceAtAllUsages((Node)primitiveConstant, true);
                parameterNode.safeDelete();
                graph.getNodes().filter(n -> {
                    PiNode piNode;
                    return n instanceof PiNode && (piNode = (PiNode)n).object() == primitiveConstant;
                }).snapshot().forEach(node -> {
                    List usagesSnapshot = node.usages().snapshot();
                    node.replaceAtAllUsages((Node)primitiveConstant, true);
                    node.safeDelete();
                    usagesSnapshot.forEach(n -> {
                        if (n instanceof UnboxNode) {
                            IsNullNode isNullNode;
                            FixedGuardNode fixedGuardNode;
                            LogicNode patt0$temp;
                            UnboxNode unboxNode = (UnboxNode)n;
                            Node prev = n.predecessor();
                            unboxNode.replaceAtAllUsages((Node)primitiveConstant, true);
                            graph.removeFixed((FixedWithNextNode)unboxNode);
                            if (prev instanceof FixedGuardNode && (patt0$temp = (fixedGuardNode = (FixedGuardNode)prev).condition()) instanceof IsNullNode && (isNullNode = (IsNullNode)patt0$temp).getValue() == primitiveConstant) {
                                fixedGuardNode.clearInputs();
                                graph.removeFixed((FixedWithNextNode)fixedGuardNode);
                            }
                        }
                    });
                });
            }
        } else {
            parameterNode.usages().snapshot().forEach(n -> this.evaluate(graph, (Node)n, args[parameterNode.index()]));
        }
    }

    protected void run(StructuredGraph graph, TornadoHighTierContext context) {
        int iterations = 0;
        int lastNodeCount = graph.getNodeCount();
        boolean hasWork = true;
        this.batchThreads = context.getBatchCompilationConfig().getBatchThreads();
        this.gridScheduling = context.isGridSchedulerEnabled();
        while (hasWork) {
            Graph.Mark mark = graph.getMark();
            if (context.hasArgs()) {
                TornadoCoreRuntime.getDebugContext().dump(2, (Object)graph, "Before Phase Propagate Parameters");
                for (ParameterNode param : graph.getNodes(ParameterNode.TYPE).snapshot()) {
                    this.propagateParameters(graph, param, context.getArgs());
                }
                TornadoCoreRuntime.getDebugContext().dump(2, (Object)graph, "After Phase Propagate Parameters");
            }
            this.canonicalizer.apply(graph, (Object)context);
            graph.getNewNodes(mark).filter(PiNode.class).forEach(pi -> {
                if (pi.stamp(NodeView.DEFAULT) instanceof ObjectStamp && pi.object().stamp(NodeView.DEFAULT) instanceof ObjectStamp) {
                    pi.replaceAtUsages((Node)pi.object());
                    pi.clearInputs();
                    pi.safeDelete();
                }
            });
            TornadoCoreRuntime.getDebugContext().dump(2, (Object)graph, "After Phase Pi Node Removal");
            this.loopUnroll.execute(graph, (CoreProviders)context);
            this.valueTypeReplacement.execute(graph, context);
            this.canonicalizer.apply(graph, (Object)context);
            this.deadCodeElimination.run(graph);
            TornadoCoreRuntime.getDebugContext().dump(2, (Object)graph, "After TaskSpecialisation iteration = " + iterations);
            hasWork = (lastNodeCount != graph.getNodeCount() || graph.getNewNodes(mark).isNotEmpty() || TornadoTaskSpecialisation.hasPanamaArraySizeNode(graph)) && iterations < 15;
            lastNodeCount = graph.getNodeCount();
            ++iterations;
        }
        graph.getNodes().filter(ParallelRangeNode.class).forEach(range -> {
            ValueNode patt0$temp = range.value();
            if (patt0$temp instanceof PhiNode) {
                PhiNode phiNode = (PhiNode)patt0$temp;
                NodeIterable usages = range.usages();
                for (Node usage : usages) {
                    if (!(usage instanceof IntegerLessThanNode)) continue;
                    IntegerLessThanNode less = (IntegerLessThanNode)usage;
                    ConstantNode constant = null;
                    if (less.getX() instanceof ConstantNode) {
                        constant = (ConstantNode)less.getX();
                    } else if (less.getY() instanceof ConstantNode) {
                        constant = (ConstantNode)less.getY();
                    }
                    if (constant == null) continue;
                    TornadoCoreRuntime.getDebugContext().dump(2, (Object)graph, "Before Swapping Constant-Phi");
                    ParallelRangeNode pr = new ParallelRangeNode(range.index(), (ValueNode)constant, range.offset(), range.stride());
                    graph.addOrUnique((Node)pr);
                    range.safeDelete();
                    IntegerLessThanNode intLess = new IntegerLessThanNode((ValueNode)phiNode, (ValueNode)pr);
                    graph.addOrUnique((Node)intLess);
                    less.usages().first().replaceAllInputs((Node)less, (Node)intLess);
                    less.safeDelete();
                    TornadoCoreRuntime.getDebugContext().dump(2, (Object)graph, "After Swapping Constant-Phi");
                }
            }
        });
        TornadoLogger logger = new TornadoLogger(((Object)((Object)this)).getClass());
        if (iterations == 15) {
            logger.warn("TaskSpecialisation unable to complete after %d iterations", new Object[]{iterations});
        }
        logger.debug("TaskSpecialisation ran %d iterations", new Object[]{iterations});
        logger.debug("valid graph? %s", new Object[]{graph.verify()});
        this.index = 0;
    }

    @FunctionalInterface
    private static interface FunctionThatThrows<T, R> {
        public R apply(T var1) throws IllegalArgumentException, IllegalAccessException;
    }
}

