/**
 * @license
 * Copyright 2020 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import * as tf from '../../index';
import { ALL_ENVS, describeWithFlags } from '../../jasmine_util';
import { expectArraysClose } from '../../test_util';
describeWithFlags('fused matmul', ALL_ENVS, () => {
    it('fused A x B', async () => {
        const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
        const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]);
        const c = tf.fused.matMul({ a, b });
        expect(c.shape).toEqual([2, 2]);
        expectArraysClose(await c.data(), [0, 8, -3, 20]);
    });
    it('fused A x B with relu', async () => {
        const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
        const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]);
        const transposeA = false;
        const transposeB = false;
        const c = tf.fused.matMul({ a, b, transposeA, transposeB, bias: null, activation: 'relu' });
        expect(c.shape).toEqual([2, 2]);
        expectArraysClose(await c.data(), [0, 8, 0, 20]);
    });
    it('fused A x B with elu', async () => {
        const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
        const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]);
        const transposeA = false;
        const transposeB = false;
        const c = tf.fused.matMul({ a, b, transposeA, transposeB, bias: null, activation: 'elu' });
        expect(c.shape).toEqual([2, 2]);
        expectArraysClose(await c.data(), [0, 8, -0.9502, 20]);
    });
    it('fused A x B with relu6', async () => {
        const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
        const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]);
        const transposeA = false;
        const transposeB = false;
        const c = tf.fused.matMul({ a, b, transposeA, transposeB, bias: null, activation: 'relu6' });
        expect(c.shape).toEqual([2, 2]);
        expectArraysClose(await c.data(), [0, 6, 0, 6]);
    });
    it('fused A x B with prelu', async () => {
        const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
        const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]);
        const alpha = tf.tensor2d([0.5, 0.5], [1, 2]);
        const transposeA = false;
        const transposeB = false;
        const c = tf.fused.matMul({
            a,
            b,
            transposeA,
            transposeB,
            bias: null,
            activation: 'prelu',
            preluActivationWeights: alpha
        });
        expect(c.shape).toEqual([2, 2]);
        expectArraysClose(await c.data(), [0, 8, -1.5, 20]);
    });
    it('fused A x B with leakyrelu', async () => {
        const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
        const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]);
        const alpha = 0.3;
        const transposeA = false;
        const transposeB = false;
        const c = tf.fused.matMul({
            a,
            b,
            transposeA,
            transposeB,
            bias: null,
            activation: 'leakyrelu',
            leakyreluAlpha: alpha
        });
        expect(c.shape).toEqual([2, 2]);
        expectArraysClose(await c.data(), [0, 8, -0.9000000357627869, 20]);
    });
    it('fused A x B with relu transpose', async () => {
        const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
        const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [2, 3]);
        const transposeA = false;
        const transposeB = true;
        const c = tf.fused.matMul({ a, b, transposeA, transposeB, bias: null, activation: 'relu' });
        expect(c.shape).toEqual([2, 2]);
        expectArraysClose(await c.data(), [0, 9, 0, 24]);
    });
    it('fused A x B with 2d bias and relu', async () => {
        const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
        const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]);
        const c = tf.tensor2d([1, 1, 1, 1], [2, 2]);
        const transposeA = false;
        const transposeB = false;
        const d = tf.fused.matMul({ a, b, transposeA, transposeB, bias: c, activation: 'relu' });
        expect(d.shape).toEqual([2, 2]);
        expectArraysClose(await d.data(), [1, 9, 0, 21]);
    });
    it('fused A x B with relu and broadcasted bias', async () => {
        const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
        const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]);
        const c = tf.tensor1d([1, 1]);
        const act = 'relu';
        const transposeA = false;
        const transposeB = false;
        const d = tf.fused.matMul({ a, b, transposeA, transposeB, bias: c, activation: act });
        expect(d.shape).toEqual([2, 2]);
        expectArraysClose(await d.data(), [1, 9, 0, 21]);
    });
    it('fused A x B with elu and broadcasted bias', async () => {
        const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
        const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]);
        const c = tf.tensor1d([1, 1]);
        const act = 'elu';
        const transposeA = false;
        const transposeB = false;
        const d = tf.fused.matMul({ a, b, transposeA, transposeB, bias: c, activation: act });
        expect(d.shape).toEqual([2, 2]);
        expectArraysClose(await d.data(), [1, 9, -0.8647, 21]);
    });
    it('fused A x B with relu and broadcasted bias different rank', async () => {
        const a = tf.tensor3d([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [2, 2, 3]);
        const b = tf.tensor3d([0, 1, -3, 2, 2, 1, 0, 1, -3, 2, 2, 1], [2, 3, 2]);
        const c = tf.tensor2d([1, 2], [1, 2]);
        const act = 'relu';
        const transposeA = false;
        const transposeB = false;
        const d = tf.fused.matMul({ a, b, transposeA, transposeB, bias: c, activation: act });
        expect(d.shape).toEqual([2, 2, 2]);
        expectArraysClose(await d.data(), [2, 6, 0, 18, 0, 30, 0, 42]);
    });
    it('fused A x B with 2d bias only', async () => {
        const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
        const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]);
        const c = tf.tensor2d([1, 1, 1, 1], [2, 2]);
        const transposeA = false;
        const transposeB = false;
        const d = tf.fused.matMul({ a, b, transposeA, transposeB, bias: c, activation: 'linear' });
        expect(d.shape).toEqual([2, 2]);
        expectArraysClose(await d.data(), [1, 9, -2, 21]);
    });
    it('fused A x B with relu gradient', async () => {
        const a = tf.tensor2d([1, 2, 3, 10, 20, -30], [2, 3]);
        const b = tf.tensor2d([2, 3, 4, -1, 2, 3], [3, 2]);
        const dy = tf.tensor2d([1, 10, 20, 30], [2, 2]);
        const transposeA = false;
        const transposeB = false;
        const grads = tf.grads((a, b) => {
            const prod = tf.matMul(a, b, transposeA, transposeB);
            return tf.relu(prod);
        });
        const fusedGrads = tf.grads((a, b) => {
            return tf.fused.matMul({ a, b, transposeA, transposeB, bias: null, activation: 'relu' });
        });
        const [da, db] = grads([a, b], dy);
        const [fusedDa, fusedDb] = fusedGrads([a, b], dy);
        expectArraysClose(await da.array(), await fusedDa.array());
        expectArraysClose(await db.data(), await fusedDb.array());
    });
    it('gradient with clones A x B with relu', () => {
        const a = tf.tensor2d([1, 2, 3, 10, 20, -30], [2, 3]);
        const b = tf.tensor2d([2, 3, 4, -1, 2, 3], [3, 2]);
        const dy = tf.tensor2d([1, 10, 20, 30], [2, 2]);
        const transposeA = false;
        const transposeB = false;
        const fusedGrads = tf.grads((a, b) => {
            return tf.fused
                .matMul({
                a: a.clone(),
                b: b.clone(),
                transposeA,
                transposeB,
                bias: null,
                activation: 'relu'
            })
                .clone();
        });
        const [fusedDa, fusedDb] = fusedGrads([a, b], dy);
        expect(fusedDa.shape).toEqual(a.shape);
        expect(fusedDb.shape).toEqual(b.shape);
    });
    it('fused A x B with relu bias gradient', async () => {
        const a = tf.tensor2d([1, 2, 3, 10, 20, -30], [2, 3]);
        const b = tf.tensor2d([2, 3, 4, -1, 2, 3], [3, 2]);
        const c = tf.tensor2d([1, 1, 1, 1], [2, 2]);
        const transposeA = false;
        const transposeB = false;
        const dy = tf.tensor2d([1, 10, 20, 30], [2, 2]);
        const grads = tf.grads((a, b, c) => {
            const prod = tf.matMul(a, b, transposeA, transposeB);
            const sum = tf.add(prod, c);
            return tf.relu(sum);
        });
        const fusedGrads = tf.grads((a, b, c) => {
            return tf.fused.matMul({ a, b, transposeA, transposeB, bias: c, activation: 'relu' });
        });
        const [da, db, dc] = grads([a, b, c], dy);
        const [fusedDa, fusedDb, fusedDc] = fusedGrads([a, b, c], dy);
        expectArraysClose(await da.array(), await fusedDa.array());
        expectArraysClose(await db.array(), await fusedDb.array());
        expectArraysClose(await dc.array(), await fusedDc.array());
    });
    it('fused A x B with relu bias gradient transpose', async () => {
        const a = tf.tensor2d([1, 2, 3, 10, 20, -30], [3, 2]);
        const b = tf.tensor2d([2, 3, 4, -1, 2, 3], [3, 2]);
        const c = tf.tensor2d([1, 1, 1, 1], [2, 2]);
        const transposeA = true;
        const transposeB = false;
        const dy = tf.tensor2d([1, 10, 20, 30], [2, 2]);
        const grads = tf.grads((a, b, c) => {
            const prod = tf.matMul(a, b, transposeA, transposeB);
            const sum = tf.add(prod, c);
            return tf.relu(sum);
        });
        const fusedGrads = tf.grads((a, b, c) => {
            return tf.fused.matMul({ a, b, transposeA, transposeB, bias: c, activation: 'relu' });
        });
        const [da, db, dc] = grads([a, b, c], dy);
        const [fusedDa, fusedDb, fusedDc] = fusedGrads([a, b, c], dy);
        expectArraysClose(await da.array(), await fusedDa.array());
        expectArraysClose(await db.array(), await fusedDb.array());
        expectArraysClose(await dc.array(), await fusedDc.array());
    });
    it('fused A x B with relu and broadcasted bias gradient', async () => {
        const a = tf.tensor2d([1, 2, 3, 10, 20, -30], [2, 3]);
        const b = tf.tensor2d([2, 3, 4, -1, 2, 3], [3, 2]);
        const c = tf.tensor2d([[1]]);
        const transposeA = false;
        const transposeB = false;
        const dy = tf.tensor2d([1, 10, 20, 30], [2, 2]);
        const grads = tf.grads((a, b, c) => {
            const prod = tf.matMul(a, b, transposeA, transposeB);
            const sum = tf.add(prod, c);
            return tf.relu(sum);
        });
        const fusedGrads = tf.grads((a, b, c) => {
            return tf.fused.matMul({ a, b, transposeA, transposeB, bias: c, activation: 'relu' });
        });
        const [da, db, dc] = grads([a, b, c], dy);
        const [fusedDa, fusedDb, fusedDc] = fusedGrads([a, b, c], dy);
        expectArraysClose(await da.array(), await fusedDa.array());
        expectArraysClose(await db.array(), await fusedDb.array());
        expectArraysClose(await dc.array(), await fusedDc.array());
    });
    it('fused matmul with relu6 and gradients', async () => {
        const a = tf.tensor2d([1, 2, 3, 10, 20, -30], [2, 3]);
        const b = tf.tensor2d([2, 3, 4, -1, 2, 3], [3, 2]);
        const dy = tf.tensor2d([1, 10, 20, 30], [2, 2]);
        const transposeA = false;
        const transposeB = false;
        const fusedGrads = tf.grads((a, b) => {
            return tf.fused.matMul({ a, b, transposeA, transposeB, bias: null, activation: 'relu6' });
        });
        const [fusedDa, fusedDb] = fusedGrads([a, b], dy);
        const grads = tf.grads((a, b) => {
            const prod = tf.matMul(a, b, transposeA, transposeB);
            return tf.relu6(prod);
        });
        const [da, db] = grads([a, b], dy);
        expectArraysClose(await da.array(), await fusedDa.array());
        expectArraysClose(await db.data(), await fusedDb.array());
    });
});
//# sourceMappingURL=fused_mat_mul_test.js.map