import pytest
import jax.numpy as jnp

from symo.factor import (
    O,
    B,
    S,
    I,
    Eq,
    normalize_eq,
    nested_compare,
    NormCache,
    factor_init,
    factor_from_cov,
    factor_from_param,
    flatten_sequences,
    cov_shape,
)


class TestNormalizeEq:
    """Test the normalize_eq function"""

    @pytest.mark.parametrize(
        "eq_input,expected",
        [
            (O["N"], (O, 1, 1)),
            (B["M"], (B, 1, 1)),
            (S["K"], (S, 1, 1)),
            (I["L"], (I, 1, 1)),
            (O, (O, 1, 1)),
            (B, (B, 1, 1)),
            (S, (S, 1, 1)),
            (I, (I, 1, 1)),
            ((O["N"], B["M"]), ((O, 1, 1), (B, 2, 1))),
            ((S["K"], I["L"]), ((S, 1, 1), (I, 2, 1))),
            (((O["N"], B["M"]),), (((O, 1, 1), (B, 2, 1)),)),
        ],
    )
    def test_normalize_eq_basic(self, eq_input, expected):
        result = normalize_eq(eq_input)
        assert nested_compare(result, expected)

    def test_normalize_eq_cache_reuse(self):
        cache = NormCache()
        result1 = normalize_eq(O["N"], cache)
        result2 = normalize_eq(O["N"], cache)
        assert result1 == result2
        assert result1 is result2

    def test_normalize_eq_dimension_ordering(self):
        cache = NormCache()
        result1 = normalize_eq(O["A"], cache)
        result2 = normalize_eq(B["B"], cache)
        result3 = normalize_eq(S["A"], cache)

        assert result1[1] == 1
        assert result2[1] == 2
        assert result3[1] == 1


class TestNestedCompare:
    """Test the nested_compare function"""

    @pytest.mark.parametrize(
        "base,instance,expected",
        [
            (None, O["N"], True),
            (None, (O["N"], B["M"]), True),
            (O, O, True),
            (B, B, True),
            (S, S, True),
            (I, I, True),
            (O, B, False),
            (S, I, False),
            ((O, B), (O, B), True),
            ((O, B), (B, O), False),
            ((O, B), (O, B, S), False),
            (((O, B),), ((O, B),), True),
            (((O, B["N"]),), ((O, B["N"]),), True),
            (((O, B),), ((O, B["N"]),), False),
            (((O, B),), ((O["N"], B["N"]),), False),
            (((O, B),), ((B["N"], O["M"]),), False),
            (((O, I),), ((O, I),), True),
            (((O["N"], I["M"]),), ((O["N"], I["M"]),), True),
            (((O, I),), ((O["N"], I["N"]),), False),
            (((O, I),), ((O["N"], I["M"]),), False),
            (((O["N"], I),), ((O["N"], I["N"]),), False),
            (((I["N"], O["M"]),), ((O["N"], I["M"]),), False),
            ((B, (O, I)), ((O, I), B), False),
            (((O, I), B), ((O, I, B),), False),
            (((O, I, B), S), ((O, I, B),), False),
        ],
    )
    def test_nested_compare(self, base, instance, expected):
        result = nested_compare(base, instance)
        assert result == expected


class TestEqTypeParameter:
    """Test Eq type parameter handling"""

    def test_eq_single_parameter(self):
        eq = Eq[O(5)]()
        assert eq.shape == (5,)

    def test_eq_two_separate_groups(self):
        eq = Eq[O(3), B(4)]()
        assert eq.shape == (3, 4)

    def test_eq_two_identical_groups(self):
        On, Bm = O(3), O(4)
        eq = Eq[(On, Bm), (On, Bm)]()
        assert eq.shape == ((3, 4), (3, 4))

    def test_eq_none_shape(self):
        In, Sm = I(2), S
        eq = Eq[In, (Sm, In)]()
        assert eq.shape == (2, (None, 2))

    def test_eq_asymmetric_groups(self):
        In, Sm = I(2), S(3)
        eq = Eq[In, (Sm, In)]()
        assert eq.shape == (2, (3, 2))

    def test_eq_non_instantiated(self):
        with pytest.raises(AssertionError):
            On, Bm = O["N"](3), B["M"]
            shape = Eq[(On, Bm), (On, Bm)]
            assert shape == (3, None)


class TestCovShape:
    """Test cov_shape function"""

    @pytest.mark.parametrize(
        "groups_spec, expected_shape",
        [
            (Eq[O(4)], (4,)),
            (Eq[I(3)], (3,)),
            (Eq[O(3), B(4)], (3, 4)),
            (Eq[(O(3), B(4)), (O(3), B(4))], (12, 12)),
            (Eq[(S(2), I(5)), (S(2), I(5))], (10, 10)),
            (Eq[I(2), (S(6), I(2))], (2, 12)),
        ],
    )
    def test_cov_shape(self, groups_spec, expected_shape):
        eq = groups_spec()
        shape = eq.cov_shape
        assert shape == expected_shape


class TestTwoIdenticalGroupFactors:
    """Test factors with two identical groups (auto-covariance)"""

    def test_on_or_bn_om_or_bm_factor(self):
        On, Om = O["N"](3), O(4)
        groups = Eq[(On, Om), (On, Om)]()
        factor_obj = factor_init(groups)

        assert hasattr(factor_obj, "groups")
        assert factor_obj.groups.shape == ((3, 4), (3, 4))
        assert hasattr(factor_obj, "weights")

        cov = factor_obj.cov()
        assert cov.shape == (12, 12)

    def test_on1_or_bn1_on2_or_bn2_factor(self):
        On1, On2 = O["N", 1](3), O["N", 2](3)
        groups = Eq[(On1, On2), (On1, On2)]()
        factor_obj = factor_init(groups)

        assert hasattr(factor_obj, "groups")
        assert factor_obj.groups.shape == ((3, 3), (3, 3))
        assert hasattr(factor_obj, "weights")

        cov = factor_obj.cov()
        assert cov.shape == (9, 9)

    def test_on_or_bn_on_or_bn_factor(self):
        On = O(3)
        groups = Eq[(On, On), (On, On)]()
        factor_obj = factor_init(groups)

        assert factor_obj.groups.shape == ((3, 3), (3, 3))
        assert hasattr(factor_obj, "weights")
        assert factor_obj.weights.shape == (3,)

        cov = factor_obj.cov()
        assert cov.shape == (9, 9)

    def test_single_on_or_bn_on_or_bn_factor(self):
        On = O(3)
        groups = Eq[(On), (On)]()
        factor_obj = factor_init(groups)

        assert factor_obj.groups.shape == (3, 3)
        assert hasattr(factor_obj, "weights")
        assert factor_obj.weights.shape == ()

        cov = factor_obj.cov()
        assert cov.shape == (3, 3)

    def test_o_or_b_s_factor(self):
        On, Sm = O["N"](3), S(4)
        groups = Eq[(On, Sm), (On, Sm)]()
        factor_obj = factor_init(groups)

        assert factor_obj.groups.shape == ((3, 4), (3, 4))
        assert hasattr(factor_obj, "weights")
        assert factor_obj.weights.shape == (2,)

        cov = factor_obj.cov()
        assert cov.shape == (12, 12)

    def test_s_o_or_b_factor(self):
        Sn, Om = S["N"](3), O(4)
        groups = Eq[(Sn, Om), (Sn, Om)]()
        factor_obj = factor_init(groups)

        assert factor_obj.groups.shape == ((3, 4), (3, 4))
        assert hasattr(factor_obj, "weights")
        assert factor_obj.weights.shape == (2,)

        cov = factor_obj.cov()
        assert cov.shape == (12, 12)

    def test_sn_sm_factor(self):
        Sn, Sm = S["N"](3), S(4)
        groups = Eq[(Sn, Sm), (Sn, Sm)]()
        factor_obj = factor_init(groups)

        assert factor_obj.groups.shape == ((3, 4), (3, 4))
        assert hasattr(factor_obj, "weights")
        assert factor_obj.weights.shape == (4,)

        cov = factor_obj.cov()
        assert cov.shape == (12, 12)

    def test_sn1_sn2_factor(self):
        Sn1, Sn2 = S["N", 1](3), S["N", 2](3)
        groups = Eq[(Sn1, Sn2), (Sn1, Sn2)]()
        factor_obj = factor_init(groups)

        assert factor_obj.groups.shape == ((3, 3), (3, 3))
        assert hasattr(factor_obj, "weights")
        assert factor_obj.weights.shape == (4,)

        cov = factor_obj.cov()
        assert cov.shape == (9, 9)

    def test_sn_sn_factor(self):
        Sn = S(3)
        groups = Eq[(Sn, Sn), (Sn, Sn)]()
        factor_obj = factor_init(groups)

        assert factor_obj.groups.shape == ((3, 3), (3, 3))
        assert hasattr(factor_obj, "weights")
        assert factor_obj.weights.shape == (15,)

        cov = factor_obj.cov()
        assert cov.shape == (9, 9)

    def test_i_o_or_b_factor(self):
        In, Om = I["N"](3), O(4)
        groups = Eq[(In, Om), (In, Om)]()
        factor_obj = factor_init(groups)

        assert factor_obj.groups.shape == ((3, 4), (3, 4))
        assert hasattr(factor_obj, "weights")
        assert factor_obj.weights.shape == (3, 3)

        cov = factor_obj.cov()
        assert cov.shape == (12, 12)

    def test_o_or_b_i_factor(self):
        In, Om = I["N"](4), O(3)
        groups = Eq[(Om, In), (Om, In)]()
        factor_obj = factor_init(groups)

        assert factor_obj.groups.shape == ((3, 4), (3, 4))
        assert hasattr(factor_obj, "weights")
        assert factor_obj.weights.shape == (4, 4)

        cov = factor_obj.cov()
        assert cov.shape == (12, 12)

    def test_i_s_factor(self):
        In, Sm = I["N"](3), S(4)
        groups = Eq[(In, Sm), (In, Sm)]()
        factor_obj = factor_init(groups)

        assert factor_obj.groups.shape == ((3, 4), (3, 4))
        assert hasattr(factor_obj, "weights")
        assert factor_obj.weights.shape == (2, 3, 3)

        cov = factor_obj.cov()
        assert cov.shape == (12, 12)

    def test_s_i_factor(self):
        In, Sm = I["N"](4), S(3)
        groups = Eq[(Sm, In), (Sm, In)]()
        factor_obj = factor_init(groups)

        assert factor_obj.groups.shape == ((3, 4), (3, 4))
        assert hasattr(factor_obj, "weights")
        assert factor_obj.weights.shape == (2, 4, 4)

        cov = factor_obj.cov()
        assert cov.shape == (12, 12)


class TestCrossGroupFactors:
    """Test factors with different groups (cross-covariance)"""

    def test_i_o_o_i_factor(self):
        In, Om = I["N"](3), O(4)
        groups = Eq[(In, Om), (Om, In)]()
        factor_obj = factor_init(groups)

        assert factor_obj.groups.shape == ((3, 4), (4, 3))
        assert hasattr(factor_obj, "weights")
        assert factor_obj.weights.shape == (3, 3)

        cov = factor_obj.cov()
        assert cov.shape == (12, 12)

    def test_i_s_s_i_factor(self):
        In, Sm = I["N"](3), S(4)
        groups = Eq[(In, Sm), (Sm, In)]()
        factor_obj = factor_init(groups)

        assert factor_obj.groups.shape == ((3, 4), (4, 3))
        assert hasattr(factor_obj, "weights")
        assert factor_obj.weights.shape == (2, 3, 3)

        cov = factor_obj.cov()
        assert cov.shape == (12, 12)

    def test_i_sM_sM_iL_factor(self):
        In, Sm, IL = I(3), S["M"](4), I["L"](3)
        groups = Eq[(In, Sm), (Sm, IL)]()
        factor_obj = factor_init(groups)

        assert factor_obj.groups.shape == ((3, 4), (4, 3))
        assert hasattr(factor_obj, "weights")
        assert factor_obj.weights.shape == (2, 3, 3)

        cov = factor_obj.cov()
        assert cov.shape == (12, 12)


class TestAsymmetricFactors:
    """Test asymmetric factor combinations"""

    @pytest.mark.parametrize(
        "groups_spec,dims",
        [
            (Eq[I(2), (I(2), S(3))], (2, (2, 3))),
            (Eq[I(2), (S(3), I(2))], (2, (3, 2))),
            (Eq[S(2), (I(3), I(3))], (2, (3, 3))),
            (Eq[S["N"](2), (I(3), S["L"](4))], (2, (3, 4))),
            (Eq[S["N"](2), (S["M"](3), I(4))], (2, (3, 4))),
            (Eq[I(2), (S(3), S(3))], (2, (3, 3))),
            (Eq[S["N"](2), (S["M"](3), S["L"](4))], (2, (3, 4))),
            (Eq[S["M"](2), (S["N"](3), S["M"](2))], (2, (3, 2))),
            (Eq[O(2), (I(3), O(2))], (2, (3, 2))),
            (Eq[O(2), (O(2), I(3))], (2, (2, 3))),
            (Eq[S(2), (I(3), S(2))], (2, (3, 2))),
            (Eq[S(2), (S(2), I(3))], (2, (2, 3))),
        ],
    )
    def test_asymmetric_factors(self, groups_spec, dims):
        groups = groups_spec()
        factor_obj = factor_init(groups)

        assert factor_obj.groups.shape == dims

        cov = factor_obj.cov()
        assert cov.ndim == 2

    def test_i_s_i_factor_example(self):
        groups = Eq[I(2), (S(2), I["L"](3))]()
        factor_obj = factor_init(groups)

        assert factor_obj.groups.shape == (2, (2, 3))
        assert factor_obj.__class__.__name__ == "I_S_I_Factor"

        cov = factor_obj.cov()
        assert cov.shape == (2, 6)


class TestFactorFromCov:
    """Test factor creation from covariance matrices"""

    def test_factor_from_cov_two_groups(self):
        On, Om = O["N"](2), O["M"](3)
        groups = Eq[(On, Om), (On, Om)]()
        cov_matrix = jnp.eye(6) * 2.0

        factor_obj = factor_from_cov(groups, cov_matrix)
        assert factor_obj.groups.shape == ((2, 3), (2, 3))

        reconstructed_cov = factor_obj.cov()
        assert reconstructed_cov.shape == (6, 6)

    def test_factor_from_cov_asymmetric(self):
        In, Sm = I(2), S(3)
        groups = Eq[In, (Sm, In)]()
        cov_shape = (2, 6)
        cov_matrix = jnp.ones(cov_shape) * 1.5

        factor_obj = factor_from_cov(groups, cov_matrix)
        assert factor_obj.groups.shape == (2, (3, 2))

        reconstructed_cov = factor_obj.cov()
        assert reconstructed_cov.shape == cov_shape

    def test_factor_from_cov_wrong_shape(self):
        On, Om = O["N"](2), O(3)
        groups = Eq[(On, Om), (On, Om)]()
        wrong_cov = jnp.eye(5)

        with pytest.raises(AssertionError):
            factor_from_cov(groups, wrong_cov)


class TestFactorFromParam:
    """Test factor creation from parameters"""

    def test_factor_from_param(self):
        pytest.skip("Not implemented due to `factor_from_param`")


class TestEdgeCases:
    """Test edge cases and error conditions"""

    def test_unsupported_factor_type(self):
        On, Bm, Sl = O["N"](2), B["M"](3), S["L"](4)
        groups = Eq[On, Bm, Sl]()
        factor_obj = factor_init(groups)

        assert factor_obj.__class__.__name__ == "ZeroFactor"

        cov = factor_obj.cov()
        assert jnp.allclose(cov, 0.0)

    def test_factor_with_custom_init_fn(self):
        On, Om = O["N"](2), O(3)
        groups = Eq[(On, Om), (On, Om)]()

        def custom_init(shape):
            return jnp.ones(shape) * 3.0

        factor_obj = factor_init(groups, custom_init)
        assert factor_obj.weights == 3.0

    def test_flatten_sequences_function(self):
        result = list(flatten_sequences((1, 2, 3)))
        assert result == [1, 2, 3]

        result = list(flatten_sequences(((1, 2), (3, 4))))
        assert result == [1, 2, 3, 4]

        result = list(flatten_sequences((1, (2, 3), 4)))
        assert result == [1, 2, 3, 4]


class TestSpecialCases:
    """Test special cases and edge behaviors"""

    def test_eq_meta_instancecheck(self):
        eq1 = Eq[O["N"](3), B["M"](4)]()
        eq2 = Eq[O["N"](3), S["M"](4)]()
        eq3 = Eq[S["N"](3), I["M"](4)]()

        assert isinstance(eq1, Eq[O["N"], B["M"]])
        assert isinstance(eq2, Eq[O["N"], S["M"]])
        assert not isinstance(eq3, Eq[O["N"], B["M"]])

    def test_group_dims_assertion_errors(self):
        On1, On2 = O(3), O(4)
        groups = Eq[(On1, On2), (On1, On2)]()
        assert groups.shape == ((3, 4), (3, 4))

        with pytest.raises(AssertionError):
            factor_init(groups)

    def test_cov_reconstruction_quality(self):
        Sn = S(3)
        groups = Eq[(Sn, Sn), (Sn, Sn)]()
        true_cov = jnp.eye(9) + 0.1 * jnp.ones((9, 9))

        factor_obj = factor_from_cov(groups, true_cov)
        reconstructed_cov = factor_obj.cov()

        assert reconstructed_cov.shape == true_cov.shape
        assert abs(jnp.trace(reconstructed_cov) - jnp.trace(true_cov)) < 5.0


if __name__ == "__main__":
    pytest.main([__file__, "-v"])
