import itertools
from dataclasses import dataclass, asdict
import pytest
import torch
import numpy as np
from symo.experiments.mlp_groups import group_config, group_config_v2
from symo.experiments.models import MLP
from symo.factory import CovFactory, MeanFactory
from symo.factory2 import CovFactory as CovFactory2, MeanFactory as MeanFactory2

from symo.factor import O, B, S, I


@dataclass(frozen=True)
class ModelConfig:
    input_dim: int = 11
    hidden_dims: tuple[int, ...] = (7,) * 2
    output_dim: int = 5

    skip_every: int | None = None
    use_bias: bool = True
    use_bias_last: bool = True
    use_bias: bool = False
    use_bias_last: bool = False
    activation: str = "tanh"


def gen_parametrized_variables():
    groups = [S, I, O, B]
    same = [True, False]
    variants = list(itertools.product(groups, groups, same))
    # variants = filter_out_same_s(variants)
    return variants


def filter_out_same_s(values):
    return [
        [hid, inout, same] for hid, inout, same in values if not (hid == S and same)
    ]


@pytest.mark.parametrize(
    "hid_group, inout_group, same",
    gen_parametrized_variables(),
)
@torch.no_grad
def test_mean_construction(hid_group, inout_group, same):
    config = ModelConfig()
    model = MLP(**asdict(config))

    with torch.inference_mode():
        group_list = group_config(
            model, hid_group=hid_group, inout_group=inout_group, same=same
        )
        group_list = [g for _, g in group_list]
        groups_spec = group_config_v2(
            model,
            hid_group=hid_group.__qualname__,
            inout_group=inout_group.__qualname__,
            same=same,
        )
        params = list(model.parameters())
        factory = MeanFactory(group_list)
        factory2 = MeanFactory2(groups_spec)

        avg_params = factory.avg(params)
        avg_params2 = factory2.avg(params)

        for w1, w2 in zip(avg_params, avg_params2):
            w1_np = w1.detach().cpu().numpy()
            w2_np = w2.detach().cpu().numpy()
            np.testing.assert_array_almost_equal(w1_np, w2_np)


@pytest.mark.parametrize(
    "hid_group, inout_group, same",
    gen_parametrized_variables(),
)
@torch.no_grad
def test_cov_construction(hid_group, inout_group, same):
    config = ModelConfig()
    model = MLP(**asdict(config))

    with torch.inference_mode():
        group_list = group_config(
            model, hid_group=hid_group, inout_group=inout_group, same=same
        )
        group_list = [g for _, g in group_list]
        params = list(model.parameters())

        groups_spec = group_config_v2(
            model,
            hid_group=hid_group.__qualname__,
            inout_group=inout_group.__qualname__,
            same=same,
        )

        f_surr = CovFactory(group_list, surrogate=True)
        f_full = CovFactory(group_list, surrogate=False)

        f_surr2 = CovFactory2(groups_spec)
        f_full2 = CovFactory2(groups_spec)

        # cov factors

        f_surr.outer_update(params)
        f_full.outer_update(params)

        cov_surr = f_surr.cov()
        cov_full = f_full.cov()

        f_surr2.cov_update(cov_surr, surrogate=True)
        f_full2.cov_update(cov_full, surrogate=False)

        cov_surr2 = f_surr2.cov(surrogate=True)
        cov_full2 = f_full2.cov(surrogate=False)
        surr2_ws = list(f_surr2.weights())
        full2_ws = list(f_full2.weights())

        for surr_w, full_w in zip(surr2_ws, full2_ws):
            np.testing.assert_array_almost_equal(surr_w, full_w)

        # TODO(bla): how to test them properly?
        # for surr_w, full_w in zip(surr_ws, full_ws):
        #    np.testing.assert_array_almost_equal(surr_w, full_w)
        # for surr_w, surr2_w in zip(surr_ws, surr2_ws):
        #    np.testing.assert_array_almost_equal(surr_w, surr2_w)

        np.testing.assert_array_almost_equal(cov_surr, cov_surr2)
        np.testing.assert_array_almost_equal(cov_full, cov_full2)


# @pytest.mark.parametrize(
#     "hid_group, inout_group, same",
#     gen_parametrized_variables(),
# )
# @torch.no_grad
# def test_cov_matvec(hid_group, inout_group, same):
#     generator = torch.Generator().manual_seed(2)
#     config = ModelConfig()
#     model = MLP(**asdict(config))

#     with torch.inference_mode():
#         # groups = dict(
#         #    group_config(model, hid_group=hid_group, inout_group=inout_group, same=same)
#         # )
#         params = list(model.parameters())
#         named_params = list(model.named_parameters())
#         # group_list = [groups[k] for k, _ in named_params]

#         groups2, dims = group_config_v2(
#             model,
#             hid_group=hid_group.__qualname__,
#             inout_group=inout_group.__qualname__,
#             same=same,
#         )
#         groups2 = dict(groups2)
#         group2_list = [groups2[k] for k, _ in named_params]

#         block_only = False
#         full_f = CovFactory2(
#             group2_list, dims, surrogate=False, block_diag_only=block_only
#         )
#         surr_f = CovFactory2(
#             group2_list, dims, surrogate=True, block_diag_only=block_only
#         )

#         full_f.outer_update(params)
#         surr_f.outer_update(params)

#         # rv = [
#         #     torch.normal(mean=0, std=10, size=p.shape, generator=generator)
#         #     for p in params
#         # ]

#         rv = [torch.ones_like(p) for p in params]

#         rv_flat = torch.concat([v.flatten() for v in rv])
#         full_cov = full_f.cov()
#         full_cov_rv_flat = full_cov @ rv_flat
#         surr_cov_rv = surr_f.matvec(rv)
#         surr_cov_rv_flat = torch.concat([v.flatten() for v in surr_cov_rv])
#         np.testing.assert_array_almost_equal(full_cov_rv_flat, surr_cov_rv_flat)


# TODO: fails for S and same=true
# @pytest.mark.parametrize(
#     "hid_group, inout_group, same",
#     gen_parametrized_variables(),
# )
# def test_cov_cov_transpose(hid_group, inout_group, same):
#     # generator = torch.Generator().manual_seed(2)
#     config = ModelConfig()
#     model = MLP(**asdict(config))

#     with torch.inference_mode():
#         args = dict(
#             hid_group=hid_group,
#             inout_group=inout_group,
#             same=same,
#         )
#         groups = dict(group_config(model, **args))
#         groups_tr = dict(group_config(model, transpose=True, **args))

#         params = list(model.parameters())
#         params_tr = [p.T for p in params]

#         named_params = list(model.named_parameters())
#         group_list = [groups[k] for k, _ in named_params]
#         group_tr_list = [groups_tr[k] for k, _ in named_params]

#         factory = FactorGrid(group_list)
#         factory_tr = FactorGrid(group_tr_list)

#         cov_factors = factory.cov_factors_from_vectors(params)
#         cov_factors_tr = factory_tr.cov_factors_from_vectors(params_tr)

#         for f, f_tr in zip(cov_factors, cov_factors_tr):
#             if f.weights.ndim > 2:
#                 np.testing.assert_array_almost_equal(f.weights, f_tr.weights.mT)
#             else:
#                 np.testing.assert_array_almost_equal(f.weights, f_tr.weights)


if __name__ == "__main__":
    pytest.main([__file__, "-q", "-r", "f"])
