import numpy as np
import torch
import pytest
import itertools
from symo.group import I, O, S, B, Eq, stable_dims
from symo.factor import factor_from_param, ZeroFactor
from symo.invariance import invariance_from_spec
from symo.compiler import Compiler

dim_sizes = dict(N=5, M=7, K=7, P=11)
dim_surr = dict(O=1, S=4)
indices = ["N", "M", "K", "P"]
groups = [S, I]
device = "cpu"


def construct_v1(pair):
    group_cls, index = pair
    d = dim_sizes[index]
    return group_cls[index, d]


def construct_v2(pair):
    """Compiler version."""
    group_cls, index = pair
    return f"{group_cls.__name__}_{index}"


def group_comb(group_set):
    atomic_pairs = list(itertools.product(group_set, group_set))

    triple_pairs = []
    for v in group_set:
        triple_pairs += list(itertools.product([v], atomic_pairs))

    quad_pairs = list(itertools.product(atomic_pairs, atomic_pairs))

    all_pairs = atomic_pairs + triple_pairs + quad_pairs
    return all_pairs


seed = 111
group_pairs = list(itertools.product(groups, indices))
set_v1 = list(map(construct_v1, group_pairs))
set_v2 = list(map(construct_v2, group_pairs))
group_set = set_v2
group_set_v1 = group_comb(set_v1)
group_set_v2 = group_comb(set_v2)
group_set_vars = list(zip(group_set_v1, group_set_v2))


def find_surr_dims(group, out: dict):
    if isinstance(group, (list, tuple)):
        for g in group:
            find_surr_dims(g, out)
    else:
        index, _ = group._type_parameter
        dim = stable_dims(group)
        if group.name() == "I":
            if "I" not in out:
                out["I"] = {}
            out["I"][index] = dim
        elif group.name() == "S":
            if "S" not in out:
                out["S"] = {}
            out["S"][index] = dim


@pytest.mark.parametrize("pair_v1, pair_v2", group_set_vars)
def test_outer_estimation(pair_v1, pair_v2):
    surr_dims = {}
    find_surr_dims(pair_v1, surr_dims)
    eq = Eq[pair_v1]()
    lhs_shape, rhs_shape = eq.shape
    generator = torch.Generator(device=device).manual_seed(seed)
    lhs = torch.randn(lhs_shape, generator=generator, device=device)
    rhs = torch.randn(rhs_shape, generator=generator, device=device)
    factor = factor_from_param(eq, (lhs, rhs))

    if isinstance(factor, ZeroFactor):
        pytest.skip(f"Factor-v1 doesn't have implementation for {pair_v2}")
    weights_v1 = factor.weights
    weights_v1 = weights_v1.detach().cpu().numpy()
    weights_v1 = np.sort(weights_v1.reshape(-1))
    lhs_v2, rhs_v2 = pair_v2

    lhs_v2 = list([lhs_v2] if not isinstance(lhs_v2, (tuple, list)) else lhs_v2)
    rhs_v2 = list([rhs_v2] if not isinstance(rhs_v2, (tuple, list)) else rhs_v2)

    dims = [
        [dim_sizes[axis.split("_")[1]] for axis in lhs_v2],
        [dim_sizes[axis.split("_")[1]] for axis in rhs_v2],
    ]
    surr_dims_ = [
        [surr_dims[axis.split("_")[0]][axis.split("_")[1]] for axis in lhs_v2],
        [surr_dims[axis.split("_")[0]][axis.split("_")[1]] for axis in rhs_v2],
    ]
    invariance = invariance_from_spec(lhs_v2, rhs_v2)
    compiled = Compiler(invariance, dims, surr_dims_, device)

    weights_v2 = compiled.outer_estimate([lhs, rhs]).squeeze()
    weights_v2 = weights_v2.detach().cpu().numpy()
    weights_v2 = np.sort(weights_v2.reshape(-1))

    if weights_v1.shape[0] != weights_v2.shape[0]:
        pytest.fail(
            f"shape not match factor: {weights_v1.shape[0]}, compiler: {weights_v2.shape[0]}, {pair_v2}"
        )
    np.testing.assert_array_almost_equal(weights_v1, weights_v2, decimal=5)


@pytest.mark.parametrize("pair_v1, pair_v2", group_set_vars)
def test_cov(pair_v1, pair_v2):
    surr_dims = {}
    find_surr_dims(pair_v1, surr_dims)
    eq = Eq[pair_v1]()
    lhs_shape, rhs_shape = eq.shape
    generator = torch.Generator(device=device).manual_seed(seed)
    lhs = torch.randn(lhs_shape, generator=generator, device=device)
    rhs = torch.randn(rhs_shape, generator=generator, device=device)
    factor = factor_from_param(eq, (lhs, rhs))

    if isinstance(factor, ZeroFactor):
        pytest.skip(f"Factor-v1 doesn't have implementation for {pair_v2}")
    cov_v1 = factor.cov(surrogate=True)
    cov_v1 = cov_v1.detach().cpu().numpy()

    lhs_v2, rhs_v2 = pair_v2

    lhs_v2 = list([lhs_v2] if not isinstance(lhs_v2, (tuple, list)) else lhs_v2)
    rhs_v2 = list([rhs_v2] if not isinstance(rhs_v2, (tuple, list)) else rhs_v2)
    dims = [
        [dim_sizes[axis.split("_")[1]] for axis in lhs_v2],
        [dim_sizes[axis.split("_")[1]] for axis in rhs_v2],
    ]
    surr_dims_ = [
        [surr_dims[axis.split("_")[0]][axis.split("_")[1]] for axis in lhs_v2],
        [surr_dims[axis.split("_")[0]][axis.split("_")[1]] for axis in rhs_v2],
    ]
    invariance = invariance_from_spec(lhs_v2, rhs_v2)
    compiled = Compiler(invariance, dims, surr_dims_, device)

    compiled.outer_estimate_(lhs, rhs)
    cov_v2 = compiled.cov(surrogate=True)
    cov_v2 = cov_v2.detach().cpu().numpy()

    if np.any(cov_v1.shape != cov_v2.shape):
        pytest.fail(
            f"shape not match factor: {cov_v1.shape}, compiler: {cov_v2.shape}, {pair_v2}, {surr_dims}"
        )
    np.testing.assert_array_almost_equal(cov_v1, cov_v2, decimal=5)


@pytest.mark.parametrize("pair_v1, pair_v2", group_set_vars)
def test_matvec(pair_v1, pair_v2):
    surr_dims = {}
    find_surr_dims(pair_v1, surr_dims)
    eq = Eq[pair_v1]()
    lhs_shape, rhs_shape = eq.shape
    generator = torch.Generator(device=device).manual_seed(seed)
    lhs = torch.randn(lhs_shape, generator=generator, device=device)
    rhs = torch.randn(rhs_shape, generator=generator, device=device)
    v_rhs = torch.randn(rhs_shape, generator=generator, device=device)
    v_lhs = torch.randn(lhs_shape, generator=generator, device=device)

    factor = factor_from_param(eq, (lhs, rhs))

    if isinstance(factor, ZeroFactor):
        pytest.skip(f"Factor-v1 doesn't have implementation for {pair_v2}")
    out_r_v1 = factor.matvec(v_rhs)
    out_r_v1 = out_r_v1.detach().cpu().numpy()
    out_l_v1 = factor.matvec(v_lhs, transpose=True)
    out_l_v1 = out_l_v1.detach().cpu().numpy()

    lhs_v2, rhs_v2 = pair_v2

    lhs_v2 = list([lhs_v2] if not isinstance(lhs_v2, (tuple, list)) else lhs_v2)
    rhs_v2 = list([rhs_v2] if not isinstance(rhs_v2, (tuple, list)) else rhs_v2)

    dims = [
        [dim_sizes[axis.split("_")[1]] for axis in lhs_v2],
        [dim_sizes[axis.split("_")[1]] for axis in rhs_v2],
    ]
    surr_dims_ = [
        [surr_dims[axis.split("_")[0]][axis.split("_")[1]] for axis in lhs_v2],
        [surr_dims[axis.split("_")[0]][axis.split("_")[1]] for axis in rhs_v2],
    ]

    invariance = invariance_from_spec(lhs_v2, rhs_v2)
    compiled = Compiler(invariance, dims, surr_dims_, device)

    compiled.outer_estimate_(lhs, rhs)
    out_r_v2 = compiled.matvec(v_rhs)
    out_r_v2 = out_r_v2.detach().cpu().numpy()
    out_l_v2 = compiled.matvec(v_lhs, transpose=True)
    out_l_v2 = out_l_v2.detach().cpu().numpy()

    if np.any(out_r_v1.shape != out_r_v2.shape):
        pytest.fail(
            f"shape not match factor: {out_r_v1.shape}, compiler: {out_r_v2.shape}, {pair_v2}, {surr_dims}"
        )
    try:
        np.testing.assert_array_almost_equal(
            out_r_v1, out_r_v2, decimal=5, err_msg=f"{pair_v2}"
        )
        np.testing.assert_array_almost_equal(
            out_l_v1, out_l_v2, decimal=5, err_msg=f"{pair_v2}"
        )
    except:
        raise AssertionError(
            f"Outer estimation mismatch for pair_v1={pair_v1}, pair_v2={pair_v2}."
        )


if __name__ == "__main__":
    pytest.main([__file__, "-q", "-r", "f"])
