import numpy as np

from ... import backend as T
from .. import khatri_rao
from ...testing import assert_array_equal, assert_raises


# Author: Jean Kossaifi


def test_khatri_rao():
    """Test for khatri_rao"""
    columns = 4
    rows = [3, 4, 2]
    matrices = [T.tensor(np.arange(k * columns).reshape((k, columns))) for k in rows]
    res = khatri_rao(matrices)
    # resulting matrix must be of shape (prod(n_rows), n_columns)
    n_rows = 3 * 4 * 2
    n_columns = 4
    assert res.shape[0] == n_rows
    assert res.shape[1] == n_columns

    # fail case: all matrices must have same number of columns
    shapes = [[3, 4], [3, 4], [3, 2]]
    matrices = [T.tensor(np.arange(i * j).reshape((i, j))) for (i, j) in shapes]
    with assert_raises(ValueError):
        khatri_rao(matrices)

    # all matrices should be of dim 2...
    matrices = [T.tensor(np.eye(3)), T.tensor(np.arange(3 * 2 * 2).reshape((3, 2, 2)))]
    with assert_raises(ValueError):
        khatri_rao(matrices)

    # Classic example/test
    t1 = T.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
    t2 = T.tensor([[1, 4, 7], [2, 5, 8], [3, 6, 9]])
    true_res = T.tensor(
        [
            [1.0, 8.0, 21.0],
            [2.0, 10.0, 24.0],
            [3.0, 12.0, 27.0],
            [4.0, 20.0, 42.0],
            [8.0, 25.0, 48.0],
            [12.0, 30.0, 54.0],
            [7.0, 32.0, 63.0],
            [14.0, 40.0, 72.0],
            [21.0, 48.0, 81.0],
        ]
    )
    res = khatri_rao([t1, t2])
    assert_array_equal(res, true_res)

    # A = np.hstack((np.eye(3), np.arange(3)[:, None]))
    A = T.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 2.0]])
    B = T.tensor([[1.0, 0.0, 0.0, 3.0], [0.0, 1.0, 0.0, 4.0], [0.0, 0.0, 1.0, 5.0]])
    true_res = T.tensor(
        [
            [1.0, 0.0, 0.0, 0.0],
            [0.0, 0.0, 0.0, 0.0],
            [0.0, 0.0, 0.0, 0.0],
            [0.0, 0.0, 0.0, 3.0],
            [0.0, 1.0, 0.0, 4.0],
            [0.0, 0.0, 0.0, 5.0],
            [0.0, 0.0, 0.0, 6.0],
            [0.0, 0.0, 0.0, 8.0],
            [0.0, 0.0, 1.0, 10.0],
        ]
    )
    assert_array_equal(khatri_rao([A, B]), true_res)

    U1 = np.reshape(np.arange(1, 10), (3, 3))
    U2 = np.reshape(np.arange(10, 22), (4, 3))
    U3 = np.reshape(np.arange(22, 28), (2, 3))
    U4 = np.reshape(np.arange(28, 34), (2, 3))
    U = [T.tensor(matrix) for matrix in [U1, U2, U3, U4]]
    true_res = T.tensor(
        [
            [6160, 14674, 25920],
            [6820, 16192, 28512],
            [7000, 16588, 29160],
            [7750, 18304, 32076],
            [8008, 18676, 32400],
            [8866, 20608, 35640],
            [9100, 21112, 36450],
            [10075, 23296, 40095],
            [9856, 22678, 38880],
            [10912, 25024, 42768],
            [11200, 25636, 43740],
            [12400, 28288, 48114],
            [11704, 26680, 45360],
            [12958, 29440, 49896],
            [13300, 30160, 51030],
            [14725, 33280, 56133],
            [24640, 36685, 51840],
            [27280, 40480, 57024],
            [28000, 41470, 58320],
            [31000, 45760, 64152],
            [32032, 46690, 64800],
            [35464, 51520, 71280],
            [36400, 52780, 72900],
            [40300, 58240, 80190],
            [39424, 56695, 77760],
            [43648, 62560, 85536],
            [44800, 64090, 87480],
            [49600, 70720, 96228],
            [46816, 66700, 90720],
            [51832, 73600, 99792],
            [53200, 75400, 102060],
            [58900, 83200, 112266],
            [43120, 58696, 77760],
            [47740, 64768, 85536],
            [49000, 66352, 87480],
            [54250, 73216, 96228],
            [56056, 74704, 97200],
            [62062, 82432, 106920],
            [63700, 84448, 109350],
            [70525, 93184, 120285],
            [68992, 90712, 116640],
            [76384, 100096, 128304],
            [78400, 102544, 131220],
            [86800, 113152, 144342],
            [81928, 106720, 136080],
            [90706, 117760, 149688],
            [93100, 120640, 153090],
            [103075, 133120, 168399],
        ]
    )
    res = khatri_rao(U)
    assert_array_equal(res, true_res)

    res_1 = khatri_rao(U, skip_matrix=1)
    res_2 = khatri_rao([U[0]] + U[2:])
    assert_array_equal(res_1, res_2)

    res_1 = khatri_rao(U, skip_matrix=0)
    res_2 = khatri_rao(U[1:])
    assert_array_equal(res_1, res_2)

    # Test with one matrix only: khatri-rao of one matrix = that matrix
    assert_array_equal(khatri_rao([U[0]]), U[0])
