/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.drivers.opencl.graal.phases;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.Optional;
import jdk.vm.ci.meta.Constant;
import jdk.vm.ci.meta.JavaKind;
import jdk.vm.ci.meta.RawConstant;
import org.graalvm.compiler.core.common.type.StampFactory;
import org.graalvm.compiler.graph.Node;
import org.graalvm.compiler.nodes.ConstantNode;
import org.graalvm.compiler.nodes.FrameState;
import org.graalvm.compiler.nodes.GraphState;
import org.graalvm.compiler.nodes.StructuredGraph;
import org.graalvm.compiler.nodes.ValueNode;
import org.graalvm.compiler.nodes.ValuePhiNode;
import org.graalvm.compiler.nodes.calc.AddNode;
import org.graalvm.compiler.nodes.extended.JavaReadNode;
import org.graalvm.compiler.nodes.extended.JavaWriteNode;
import org.graalvm.compiler.nodes.java.LoadIndexedNode;
import org.graalvm.compiler.nodes.memory.address.OffsetAddressNode;
import org.graalvm.compiler.phases.BasePhase;
import uk.ac.manchester.tornado.runtime.common.BatchCompilationConfig;
import uk.ac.manchester.tornado.runtime.graal.phases.TornadoHighTierContext;

public class TornadoBatchGlobalIndexOffset
extends BasePhase<TornadoHighTierContext> {
    private long batchSize;
    private int batchNumber;

    public Optional<BasePhase.NotApplicable> notApplicableTo(GraphState graphState) {
        return ALWAYS_APPLICABLE;
    }

    protected void run(StructuredGraph graph, TornadoHighTierContext context) {
        BatchCompilationConfig batchCompilationConfig = context.getBatchCompilationConfig();
        this.batchSize = batchCompilationConfig.getBatchSize();
        if (this.batchSize == 0L) {
            return;
        }
        this.batchNumber = batchCompilationConfig.getBatchNumber();
        for (ValuePhiNode phiNode : graph.getNodes().filter(ValuePhiNode.class)) {
            ArrayList<ValueNode> indexUsages = new ArrayList<ValueNode>();
            for (Node phiNodeUsage : phiNode.usages()) {
                if (!TornadoBatchGlobalIndexOffset.isIndexUsedInJavaWrite(phiNodeUsage)) continue;
                indexUsages.add((ValueNode)phiNodeUsage);
            }
            for (ValueNode phiIndexUsage : indexUsages) {
                RawConstant batchNumberConstant = new RawConstant((long)this.batchNumber * this.batchSize);
                ConstantNode batchNumberNode = new ConstantNode((Constant)batchNumberConstant, StampFactory.forKind((JavaKind)JavaKind.Int));
                graph.addWithoutUnique((Node)batchNumberNode);
                AddNode addOffsets = new AddNode((ValueNode)batchNumberNode, (ValueNode)phiNode);
                graph.addWithoutUnique((Node)addOffsets);
                phiIndexUsage.replaceFirstInput((Node)phiNode, (Node)addOffsets);
            }
        }
    }

    private static boolean isIndexUsedInJavaWrite(Node indexUsage) {
        if (indexUsage instanceof OffsetAddressNode || indexUsage instanceof FrameState || indexUsage instanceof LoadIndexedNode || indexUsage instanceof JavaReadNode) {
            return false;
        }
        if (indexUsage instanceof JavaWriteNode) {
            return true;
        }
        Iterator iterator = indexUsage.usages().iterator();
        if (iterator.hasNext()) {
            Node usage = (Node)iterator.next();
            return TornadoBatchGlobalIndexOffset.isIndexUsedInJavaWrite(usage);
        }
        return false;
    }
}

