import pytest
import torch
import numpy as np

from symo.factor import O, B, S, I, Eq, factor_from_param


seed = 2
Ag = ("A", 15)
Bg = ("B", 12)
Cg = ("C", 10)
Dg = ("D", 20)
Eg = ("E", 18)
Fg = ("F", 21)
Gg = ("G", 19)
Jg = ("J", 23)
Kg = ("K", 5)
Lg = ("L", 13)


@pytest.mark.parametrize(
    "groups_spec, transpose",
    [
        ((B[Ag], B[Ag]), False),
        ((O[Ag], O[Ag]), False),
        ((S[Ag], S[Ag]), False),
        ((S[Ag], S[Bg]), False),
        ((S[Ag], I[Cg]), False),
        ((I[Ag], S[Cg]), False),
        (((O[Ag], O[Bg]), (O[Ag], O[Bg])), False),
        (((O[Ag], O[Ag]), (O[Ag], O[Ag])), False),
        (((B[Ag], B[Ag]), (B[Ag], B[Ag])), False),
        (((S[Ag], O[Ag]), (S[Ag], O[Ag])), False),
        (((B[Ag], S[Ag]), (B[Ag], S[Ag])), False),
        (((S[Dg], S[Ag]), (S[Dg], S[Ag])), False),
        (((S[Dg], S[Ag]), (S[Ag], S[Dg])), False),
        (((S[Dg], S[Ag]), (S[Ag], S[Dg])), True),
        (((S[Dg], S[Ag]), (S[Eg], S[Ag])), False),
        (((S[Dg], S[Ag]), (S[Ag], S[Eg])), False),
        (((S[Dg], S[Dg]), (S[Dg], S[Dg])), False),
        (((I[Ag], O[Dg]), (I[Ag], O[Dg])), False),
        (((I[Ag], O[Dg]), (O[Dg], I[Ag])), False),
        (((O[Ag], I[Dg]), (I[Dg], O[Ag])), False),
        (((O[Ag], I[Dg]), (O[Ag], I[Dg])), False),
        (((I[Ag], S[Dg]), (I[Ag], S[Dg])), False),
        (((I[Ag], S[Dg]), (S[Dg], I[Ag])), False),
        (((I[Ag], S[Dg]), (S[Dg], I[Ag])), True),
        (((I[Ag], S[Dg]), (S[Fg], I[Gg])), False),
        (((I[Ag], S[Dg]), (S[Fg], I[Kg])), True),
        (((S[Dg], I[Ag]), (I[Gg], S[Fg])), False),
        (((S[Ag], I[Dg]), (S[Ag], I[Dg])), False),
        (((S[Ag], I[Dg]), (S[Lg], I[Dg])), False),
        (((O[Ag], S[Dg]), (O[Ag], I[Dg])), False),
        (((O[Ag], S[Dg]), (O[Ag], I[Dg])), True),
        (((O[Ag], S[Dg]), (I[Eg], O[Ag])), False),
        (((O[Ag], S[Dg]), (I[Eg], O[Ag])), True),
        (((S[Dg], O[Ag]), (O[Ag], I[Eg])), False),
        (((S[Dg], O[Ag]), (O[Ag], I[Eg])), True),
        (((S[Dg], O[Ag]), (I[Eg], O[Ag])), False),
        (((S[Dg], O[Ag]), (I[Eg], O[Ag])), True),
        (((S[Ag], S[Dg]), (S[Ag], I[Dg])), False),
        (((S[Ag], S[Dg]), (S[Ag], I[Dg])), True),
        (((S[Ag], S[Dg]), (I[Dg], S[Ag])), False),
        (((S[Ag], S[Dg]), (I[Dg], S[Ag])), True),
        (((S[Ag], S[Dg]), (S[Dg], I[Dg])), False),
        (((S[Ag], S[Dg]), (S[Dg], I[Dg])), True),
        (((S[Ag], S[Dg]), (I[Dg], S[Dg])), False),
        (((S[Ag], S[Dg]), (I[Dg], S[Dg])), True),
        (((S[Dg], S[Dg]), (S[Dg], I[Ag])), False),
        (((S[Dg], S[Dg]), (S[Dg], I[Ag])), True),
        (((S[Dg], S[Dg]), (I[Ag], S[Dg])), False),
        (((S[Dg], S[Dg]), (I[Ag], S[Dg])), True),
        (((S[Dg], S[Fg]), (S[Jg], I[Ag])), False),
        (((S[Dg], S[Fg]), (S[Jg], I[Ag])), True),
        (((S[Dg], S[Fg]), (I[Ag], S[Jg])), False),
        (((S[Dg], S[Fg]), (I[Ag], S[Jg])), True),
        (((S[Dg], S[Fg]), (S[Ag], S[Jg])), False),
        ((I[Ag], (I[Bg], S[Dg])), False),
        ((I[Ag], (I[Bg], S[Dg])), True),
        ((I[Ag], (S[Dg], I[Bg])), False),
        ((I[Ag], (S[Dg], I[Bg])), True),
        ((S[Ag], (I[Dg], I[Kg])), False),
        ((S[Ag], (I[Dg], I[Kg])), True),
        ((S[Ag], (I[Dg], S[Bg])), False),
        ((S[Ag], (I[Dg], S[Bg])), True),
        ((S[Ag], (S[Dg], I[Bg])), False),
        ((S[Ag], (S[Dg], I[Bg])), True),
        ((I[Ag], (S[Dg], S[Bg])), False),
        ((I[Ag], (S[Dg], S[Bg])), True),
        ((S[Ag], (S[Dg], S[Bg])), False),
        ((S[Ag], (S[Dg], S[Bg])), True),
        ((O[Ag], (I[Dg], O[Ag])), False),
        ((O[Ag], (I[Dg], O[Ag])), True),
        ((O[Ag], (O[Ag], I[Dg])), False),
        ((O[Ag], (O[Ag], I[Dg])), True),
        ((S[Ag], (I[Dg], S[Ag])), False),
        ((S[Ag], (I[Dg], S[Ag])), True),
        ((S[Ag], (S[Ag], I[Dg])), False),
        ((S[Ag], (S[Ag], I[Dg])), True),
        ((S[Dg], (S[Ag], S[Dg])), False),
        ((S[Dg], (S[Ag], S[Dg])), True),
        ((S[Dg], (S[Dg], S[Ag])), False),
        ((S[Dg], (S[Dg], S[Ag])), True),
        ((S[Dg], (S[Dg], S[Dg])), False),
        ((S[Dg], (S[Dg], S[Dg])), True),
    ],
)
def test_mvn_and_vmp(groups_spec, transpose):
    torch.set_default_dtype(torch.float64)
    torch.set_default_device("cpu")

    generator = torch.Generator().manual_seed(seed)

    groups = Eq[groups_spec]()
    shape_left, shape_right = groups.shape
    if not isinstance(shape_left, tuple):
        shape_left = (shape_left,)
    if not isinstance(shape_right, tuple):
        shape_right = (shape_right,)

    vector_left = torch.randn(shape_left, generator=generator)
    vector_right = torch.randn(shape_right, generator=generator)

    factor = factor_from_param(groups, (vector_left, vector_right))
    cov = factor.cov()

    if not transpose:
        cov_vec = (cov @ vector_right.flatten()).reshape(*shape_left)
        sur_vec = factor.matvec(vector_right, transpose=transpose)
    else:
        cov_vec = (vector_left.flatten() @ cov).reshape(*shape_right)
        sur_vec = factor.matvec(vector_left, transpose=transpose)

    np.testing.assert_array_almost_equal(cov_vec, sur_vec)


if __name__ == "__main__":
    pytest.main([__file__, "-v"])
