/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.unittests.api;

import org.junit.Assert;
import org.junit.Test;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.TornadoExecutionResult;
import uk.ac.manchester.tornado.api.exceptions.TornadoExecutionPlanException;
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
import uk.ac.manchester.tornado.unittests.TestHello;
import uk.ac.manchester.tornado.unittests.common.TornadoTestBase;

public class TestSharedBuffers
extends TornadoTestBase {
    private static final int numElements = 16;

    public static void empty(IntArray a, IntArray b) {
        if (a.getSize() != b.getSize()) {
            // empty if block
        }
    }

    public static void initializeContext(IntArray context, int size) {
        for (int i = 0; i < size; ++i) {
            context.set(i, 1);
        }
    }

    public static void forcePropagate(IntArray output) {
        output.set(0, output.get(0));
    }

    public static void updateContext(IntArray context, int size) {
        for (int i = 0; i < size; ++i) {
            context.set(i, context.get(i) + 1);
        }
    }

    public static void prepareOutput(IntArray context, int size) {
        for (int i = 0; i < size; ++i) {
            context.set(i, context.get(i) * 2);
        }
    }

    public static void finalizeContext(IntArray context, int size) {
        for (int i = 0; i < size; ++i) {
            context.set(i, context.get(i) + 10);
        }
    }

    public static void processBuffer(IntArray input, IntArray output, IntArray context, int size) {
        for (int i = 0; i < size; ++i) {
            output.set(i, input.get(i) + context.get(i));
        }
    }

    @Test
    public void testSingleReadWriteSharedObject() throws TornadoExecutionPlanException {
        IntArray a = new IntArray(16);
        IntArray b = new IntArray(16);
        IntArray c = new IntArray(16);
        a.init(10);
        b.init(20);
        TaskGraph tg1 = new TaskGraph("s0").transferToDevice(0, new Object[]{a, b}).task("t0", TestHello::add, (Object)a, (Object)b, (Object)c).persistOnDevice(new Object[]{c});
        TaskGraph tg2 = new TaskGraph("s1").consumeFromDevice(tg1.getTaskGraphName(), new Object[]{c}).task("t1", TestHello::add, (Object)c, (Object)c, (Object)c).transferToHost(1, new Object[]{c});
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{tg1.snapshot(), tg2.snapshot()});){
            executionPlan.withGraph(0).execute();
            executionPlan.withGraph(1).execute();
            for (int i = 0; i < a.getSize(); ++i) {
                Assert.assertEquals((long)60L, (long)c.get(i));
            }
        }
    }

    @Test
    public void testMixInputConsumeAndCopy() throws TornadoExecutionPlanException {
        IntArray a = new IntArray(16);
        IntArray b = new IntArray(16);
        IntArray c = new IntArray(16);
        IntArray d = new IntArray(16);
        a.init(10);
        b.init(20);
        d.init(50);
        TaskGraph tg1 = new TaskGraph("s0").transferToDevice(0, new Object[]{a, b}).task("t0", TestHello::add, (Object)a, (Object)b, (Object)c).persistOnDevice(new Object[]{c});
        TaskGraph tg2 = new TaskGraph("s1").consumeFromDevice(tg1.getTaskGraphName(), new Object[]{c}).transferToDevice(0, new Object[]{d}).task("t1", TestHello::add, (Object)c, (Object)d, (Object)c).transferToHost(1, new Object[]{c});
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{tg1.snapshot(), tg2.snapshot()});){
            executionPlan.withGraph(0).execute();
            executionPlan.withGraph(1).execute();
            for (int i = 0; i < a.getSize(); ++i) {
                Assert.assertEquals((long)80L, (long)c.get(i));
            }
        }
    }

    @Test
    public void testForcedCopyInData() throws TornadoExecutionPlanException {
        IntArray input = new IntArray(16);
        IntArray intermediateValues = new IntArray(16);
        IntArray output = new IntArray(16);
        input.init(25);
        intermediateValues.init(5);
        TaskGraph forceCopyGraph = new TaskGraph("forceCopyGraph").transferToDevice(1, new Object[]{input, intermediateValues}).task("emptyTask", TestSharedBuffers::empty, (Object)input, (Object)intermediateValues).persistOnDevice(new Object[]{input});
        TaskGraph computeGraph = new TaskGraph("computeGraph").consumeFromDevice(forceCopyGraph.getTaskGraphName(), new Object[]{input}).transferToDevice(1, new Object[]{intermediateValues}).task("addTask", TestHello::add, (Object)input, (Object)intermediateValues, (Object)output).transferToHost(1, new Object[]{output});
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{forceCopyGraph.snapshot(), computeGraph.snapshot()});){
            executionPlan.withGraph(0).execute();
            executionPlan.withGraph(1).execute();
            for (int i = 0; i < input.getSize(); ++i) {
                Assert.assertEquals((long)30L, (long)output.get(i));
            }
        }
    }

    @Test
    public void testMultipleSharedObjects() throws TornadoExecutionPlanException {
        IntArray a = new IntArray(16);
        IntArray b = new IntArray(16);
        IntArray c = new IntArray(16);
        IntArray d = new IntArray(16);
        a.init(10);
        b.init(20);
        TaskGraph tg1 = new TaskGraph("s0").transferToDevice(0, new Object[]{a, b}).task("t0", TestHello::add, (Object)a, (Object)b, (Object)c).persistOnDevice(new Object[]{a, b});
        TaskGraph tg2 = new TaskGraph("s1").consumeFromDevice(tg1.getTaskGraphName(), new Object[]{a, b}).task("t1", TestHello::add, (Object)a, (Object)b, (Object)d).transferToHost(1, new Object[]{d});
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{tg1.snapshot(), tg2.snapshot()});){
            executionPlan.withGraph(0).execute();
            executionPlan.withGraph(1).execute();
            for (int i = 0; i < a.getSize(); ++i) {
                Assert.assertEquals((long)30L, (long)d.get(i));
            }
        }
    }

    @Test
    public void testMultipleSharedObjectsEmptyConsume() throws TornadoExecutionPlanException {
        IntArray a = new IntArray(16);
        IntArray b = new IntArray(16);
        IntArray c = new IntArray(16);
        IntArray d = new IntArray(16);
        a.init(10);
        b.init(20);
        TaskGraph tg1 = new TaskGraph("s0").transferToDevice(0, new Object[]{a, b}).task("t0", TestHello::add, (Object)a, (Object)b, (Object)c).persistOnDevice(new Object[]{a, b});
        TaskGraph tg2 = new TaskGraph("s1").consumeFromDevice(new Object[]{a, b}).task("t1", TestHello::add, (Object)a, (Object)b, (Object)d).transferToHost(1, new Object[]{d});
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{tg1.snapshot(), tg2.snapshot()});){
            executionPlan.withGraph(0).execute();
            executionPlan.withGraph(1).execute();
            for (int i = 0; i < a.getSize(); ++i) {
                Assert.assertEquals((long)30L, (long)d.get(i));
            }
        }
    }

    @Test
    public void testThreeTaskGraphsWithSharedBuffers() throws TornadoExecutionPlanException {
        IntArray a = new IntArray(16);
        IntArray b = new IntArray(16);
        IntArray c = new IntArray(16);
        IntArray d = new IntArray(16);
        IntArray r = new IntArray(16);
        a.init(10);
        b.init(20);
        d.init(5);
        TaskGraph tg1 = new TaskGraph("s0").transferToDevice(0, new Object[]{a, b}).task("t0", TestHello::add, (Object)a, (Object)b, (Object)c).persistOnDevice(new Object[]{c});
        TaskGraph tg2 = new TaskGraph("s1").consumeFromDevice("s0", new Object[]{c}).transferToDevice(0, new Object[]{d}).task("t1", TestHello::add, (Object)c, (Object)d, (Object)r).persistOnDevice(new Object[]{r});
        TaskGraph tg3 = new TaskGraph("s2").consumeFromDevice("s1", new Object[]{r}).task("t1", TestHello::add, (Object)r, (Object)r, (Object)r).transferToHost(1, new Object[]{r});
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{tg1.snapshot(), tg2.snapshot(), tg3.snapshot()});){
            executionPlan.withGraph(0).execute();
            executionPlan.withGraph(1).execute();
            executionPlan.withGraph(2).execute();
            for (int i = 0; i < a.getSize(); ++i) {
                Assert.assertEquals((long)70L, (long)r.get(i));
            }
        }
    }

    @Test
    public void testThreeTaskGraphsWithSharedBuffersEmptyConsume() throws TornadoExecutionPlanException {
        IntArray a = new IntArray(16);
        IntArray b = new IntArray(16);
        IntArray c = new IntArray(16);
        IntArray d = new IntArray(16);
        IntArray r = new IntArray(16);
        a.init(10);
        b.init(20);
        d.init(5);
        TaskGraph tg1 = new TaskGraph("s0").transferToDevice(0, new Object[]{a, b}).task("t0", TestHello::add, (Object)a, (Object)b, (Object)c).persistOnDevice(new Object[]{c});
        TaskGraph tg2 = new TaskGraph("s1").consumeFromDevice(new Object[]{c}).transferToDevice(0, new Object[]{d}).task("t1", TestHello::add, (Object)c, (Object)d, (Object)r).persistOnDevice(new Object[]{r});
        TaskGraph tg3 = new TaskGraph("s2").consumeFromDevice(new Object[]{r}).task("t1", TestHello::add, (Object)r, (Object)r, (Object)r).transferToHost(1, new Object[]{r});
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{tg1.snapshot(), tg2.snapshot(), tg3.snapshot()});){
            executionPlan.withGraph(0).execute();
            executionPlan.withGraph(1).execute();
            TornadoExecutionResult executionResult = executionPlan.withGraph(2).execute();
            for (int i = 0; i < a.getSize(); ++i) {
                Assert.assertEquals((long)70L, (long)r.get(i));
            }
        }
    }

    @Test
    public void testFourTaskGraphsWithPersistentBuffers() throws TornadoExecutionPlanException {
        int numElements = 10;
        IntArray inputBuffer = new IntArray(numElements);
        IntArray intermediateBuffer1 = new IntArray(numElements);
        IntArray intermediateBuffer2 = new IntArray(numElements);
        IntArray outputBuffer = new IntArray(numElements);
        IntArray contextBuffer = new IntArray(numElements);
        inputBuffer.init(10);
        contextBuffer.init(2);
        TaskGraph firstGraph = new TaskGraph("firstProcessing").transferToDevice(0, new Object[]{inputBuffer, contextBuffer}).task("initializeContext", TestSharedBuffers::initializeContext, (Object)contextBuffer, (Object)numElements).task("processInitial", TestSharedBuffers::processBuffer, (Object)inputBuffer, (Object)intermediateBuffer1, (Object)contextBuffer, (Object)numElements).persistOnDevice(new Object[]{intermediateBuffer1, contextBuffer});
        TaskGraph secondGraph = new TaskGraph("intermediateProcessing").consumeFromDevice(firstGraph.getTaskGraphName(), new Object[]{intermediateBuffer1, contextBuffer}).task("updateContext", TestSharedBuffers::updateContext, (Object)contextBuffer, (Object)numElements).task("processIntermediate", TestSharedBuffers::processBuffer, (Object)intermediateBuffer1, (Object)intermediateBuffer2, (Object)contextBuffer, (Object)numElements).persistOnDevice(new Object[]{intermediateBuffer2, contextBuffer});
        TaskGraph thirdGraph = new TaskGraph("preFinalProcessing").consumeFromDevice(secondGraph.getTaskGraphName(), new Object[]{intermediateBuffer2, contextBuffer}).task("prepareOutput", TestSharedBuffers::prepareOutput, (Object)contextBuffer, (Object)numElements).task("processPreFinal", TestSharedBuffers::processBuffer, (Object)intermediateBuffer2, (Object)outputBuffer, (Object)contextBuffer, (Object)numElements).persistOnDevice(new Object[]{outputBuffer, contextBuffer});
        TaskGraph fourthGraph = new TaskGraph("finalTransfer").consumeFromDevice(thirdGraph.getTaskGraphName(), new Object[]{outputBuffer, contextBuffer}).task("finalizeContext", TestSharedBuffers::finalizeContext, (Object)contextBuffer, (Object)numElements).task("empty", TestSharedBuffers::forcePropagate, (Object)outputBuffer).transferToHost(1, new Object[]{outputBuffer, contextBuffer});
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{firstGraph.snapshot(), secondGraph.snapshot(), thirdGraph.snapshot(), fourthGraph.snapshot()});){
            for (int graphIndex = 0; graphIndex < 4; ++graphIndex) {
                executionPlan.withGraph(graphIndex).execute();
            }
            boolean hasNonZeroOutput = false;
            for (int i = 0; i < numElements; ++i) {
                if (outputBuffer.get(i) == 0) continue;
                hasNonZeroOutput = true;
            }
            Assert.assertTrue((String)"Output array should have non-zero values", (boolean)hasNonZeroOutput);
        }
    }

    @Test
    public void testThreeTaskGraphsWithSharedContextBuffer() throws TornadoExecutionPlanException {
        IntArray a = new IntArray(16);
        IntArray b = new IntArray(16);
        IntArray c = new IntArray(16);
        IntArray sharedContext = new IntArray(16);
        a.init(10);
        b.init(20);
        sharedContext.init(1);
        TaskGraph tg1 = new TaskGraph("s0").transferToDevice(0, new Object[]{a, b}).task("t0", TestHello::add, (Object)a, (Object)b, (Object)c).persistOnDevice(new Object[]{c});
        TaskGraph tg2 = new TaskGraph("s1").consumeFromDevice("s0", new Object[]{c}).task("updateContext", TestSharedBuffers::updateContext, (Object)sharedContext, (Object)16).task("addWithContext", TestSharedBuffers::processBuffer, (Object)c, (Object)c, (Object)sharedContext, (Object)16).persistOnDevice(new Object[]{c, sharedContext});
        TaskGraph tg3 = new TaskGraph("s2").consumeFromDevice("s1", new Object[]{c, sharedContext}).task("finalizeContext", TestSharedBuffers::finalizeContext, (Object)sharedContext, (Object)16).task("processWithFinalContext", TestSharedBuffers::processBuffer, (Object)c, (Object)c, (Object)sharedContext, (Object)16).task("forcePropagate", TestSharedBuffers::forcePropagate, (Object)c).transferToHost(1, new Object[]{c, sharedContext});
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{tg1.snapshot(), tg2.snapshot(), tg3.snapshot()});){
            for (int i = 0; i < 5; ++i) {
                executionPlan.withGraph(0).execute();
                executionPlan.withGraph(1).execute();
                executionPlan.withGraph(2).execute();
            }
            executionPlan.getTraceExecutionPlan();
            boolean hasNonZeroOutput = false;
            for (int i = 0; i < 16; ++i) {
                if (c.get(i) == 0) continue;
                hasNonZeroOutput = true;
                break;
            }
            Assert.assertTrue((String)"Output array should have non-zero values", (boolean)hasNonZeroOutput);
        }
    }
}

