/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.unittests.foundation;

import org.junit.Assert;
import org.junit.Test;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.exceptions.TornadoExecutionPlanException;
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
import uk.ac.manchester.tornado.unittests.common.TornadoTestBase;
import uk.ac.manchester.tornado.unittests.foundation.TestKernels;

public class TestLinearAlgebra
extends TornadoTestBase {
    @Test
    public void vectorAdd() throws TornadoExecutionPlanException {
        int numElements = 256;
        IntArray a = new IntArray(256);
        IntArray b = new IntArray(256);
        IntArray c = new IntArray(256);
        b.init(100);
        c.init(200);
        IntArray expectedResult = new IntArray(256);
        expectedResult.init(300);
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{b, c}).task("t0", TestKernels::vectorAddCompute, (Object)a, (Object)b, (Object)c).transferToHost(1, new Object[]{a});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.execute();
        }
        for (int i = 0; i < 256; ++i) {
            Assert.assertEquals((long)expectedResult.get(i), (long)a.get(i));
        }
    }

    @Test
    public void vectorMul() throws TornadoExecutionPlanException {
        int numElements = 256;
        IntArray a = new IntArray(256);
        IntArray b = new IntArray(256);
        IntArray c = new IntArray(256);
        b.init(100);
        c.init(5);
        IntArray expectedResult = new IntArray(256);
        expectedResult.init(500);
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{b, c}).task("t0", TestKernels::vectorMul, (Object)a, (Object)b, (Object)c).transferToHost(1, new Object[]{a});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.execute();
        }
        for (int i = 0; i < 256; ++i) {
            Assert.assertEquals((long)expectedResult.get(i), (long)a.get(i));
        }
    }

    @Test
    public void vectorSub() throws TornadoExecutionPlanException {
        int numElements = 256;
        IntArray a = new IntArray(256);
        IntArray b = new IntArray(256);
        IntArray c = new IntArray(256);
        b.init(100);
        c.init(75);
        IntArray expectedResult = new IntArray(256);
        expectedResult.init(25);
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{b, c}).task("t0", TestKernels::vectorSub, (Object)a, (Object)b, (Object)c).transferToHost(1, new Object[]{a});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.execute();
        }
        for (int i = 0; i < 256; ++i) {
            Assert.assertEquals((long)expectedResult.get(i), (long)a.get(i));
        }
    }

    @Test
    public void vectorDiv() throws TornadoExecutionPlanException {
        int numElements = 256;
        IntArray a = new IntArray(256);
        IntArray b = new IntArray(256);
        IntArray c = new IntArray(256);
        b.init(512);
        c.init(2);
        IntArray expectedResult = new IntArray(256);
        expectedResult.init(256);
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{b, c}).task("t0", TestKernels::vectorDiv, (Object)a, (Object)b, (Object)c).transferToHost(1, new Object[]{a});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.execute();
        }
        for (int i = 0; i < 256; ++i) {
            Assert.assertEquals((long)expectedResult.get(i), (long)a.get(i));
        }
    }

    @Test
    public void square() throws TornadoExecutionPlanException {
        int numElements = 32;
        IntArray a = new IntArray(32);
        IntArray b = new IntArray(32);
        IntArray expectedResult = new IntArray(32);
        for (int i = 0; i < a.getSize(); ++i) {
            b.set(i, i);
            expectedResult.set(i, i * i);
        }
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{b}).task("t0", TestKernels::vectorSquare, (Object)a, (Object)b).transferToHost(1, new Object[]{a});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.execute();
        }
        for (int i = 0; i < 32; ++i) {
            Assert.assertEquals((long)expectedResult.get(i), (long)a.get(i));
        }
    }

    @Test
    public void saxpy() throws TornadoExecutionPlanException {
        int numElements = 512;
        IntArray a = new IntArray(512);
        IntArray b = new IntArray(512);
        IntArray c = new IntArray(512);
        IntArray expectedResult = new IntArray(512);
        for (int i = 0; i < a.getSize(); ++i) {
            b.set(i, i);
            c.set(i, i);
            expectedResult.set(i, 2 * i + i);
        }
        int alpha = 2;
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{b, c}).task("t0", TestKernels::saxpy, (Object)a, (Object)b, (Object)c, (Object)2).transferToHost(1, new Object[]{a});
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{taskGraph.snapshot()});){
            executionPlan.execute();
        }
        for (int i = 0; i < 512; ++i) {
            Assert.assertEquals((long)expectedResult.get(i), (long)a.get(i));
        }
    }
}

