import itertools
from dataclasses import dataclass, asdict
import pytest
import torch
import numpy as np
from symo.experiments.mlp_groups import group_config
from symo.experiments.models import MLP
from symo.factory import CovFactory, MeanFactory

from symo.factor import O, B, S, I


@dataclass(frozen=True)
class ModelConfig:
    input_dim: int = 11
    hidden_dims: tuple[int, ...] = (7,) * 1
    output_dim: int = 5

    skip_every: int | None = None
    use_bias: bool = True
    use_bias_last: bool = True
    use_bias: bool = False
    use_bias_last: bool = False
    activation: str = "tanh"


def gen_parametrized_variables():
    groups = [S, I, O, B]
    same = [True, False]
    variants = list(itertools.product(groups, groups, same))
    # variants = filter_out_same_s(variants)
    return variants


def filter_out_same_s(values):
    return [
        [hid, inout, same] for hid, inout, same in values if not (hid == S and same)
    ]


@pytest.mark.parametrize(
    "hid_group, inout_group, same",
    gen_parametrized_variables(),
)
@torch.no_grad
def test_mean_construction(hid_group, inout_group, same):
    config = ModelConfig()
    model = MLP(**asdict(config))

    with torch.inference_mode():
        groups = dict(
            group_config(model, hid_group=hid_group, inout_group=inout_group, same=same)
        )
        params = list(model.parameters())
        named_params = list(model.named_parameters())
        group_list = [groups[k] for k, _ in named_params]
        factory = MeanFactory(group_list)

        avg_params = factory.avg(params)
        avg_weights1 = factory.weights(clone=True)
        mu_factors_toy = factory.avg(avg_params)
        avg_weights2 = factory.weights(clone=True)

        for w1, w2 in zip(avg_weights1, avg_weights2):
            np.testing.assert_array_almost_equal(w1, w2)


@pytest.mark.parametrize(
    "hid_group, inout_group, same",
    gen_parametrized_variables(),
)
@torch.no_grad
def test_cov_construction(hid_group, inout_group, same):
    config = ModelConfig()
    model = MLP(**asdict(config))

    with torch.inference_mode():
        groups = dict(
            group_config(model, hid_group=hid_group, inout_group=inout_group, same=same)
        )
        params = list(model.parameters())
        named_params = list(model.named_parameters())
        group_list = [groups[k] for k, _ in named_params]
        f_surr = CovFactory(group_list, surrogate=True)
        f_full = CovFactory(group_list, surrogate=False)

        f_surr2 = CovFactory(group_list, surrogate=True)
        f_full2 = CovFactory(group_list, surrogate=False)

        # cov factors

        f_surr.outer_update(params)
        f_full.outer_update(params)

        cov_surr = f_surr.cov()
        cov_full = f_full.cov()

        f_surr2.cov_update(cov_surr)
        f_full2.cov_update(cov_full)

        cov_surr2 = f_surr2.cov()
        cov_full2 = f_full2.cov()

        surr_ws = list(f_surr.weights())
        full_ws = list(f_full.weights())

        surr2_ws = list(f_surr2.weights())
        full2_ws = list(f_full2.weights())

        for surr_w, full_w in zip(surr_ws, full_ws):
            np.testing.assert_array_almost_equal(surr_w, full_w)

        for surr_w, surr2_w in zip(surr_ws, surr2_ws):
            np.testing.assert_array_almost_equal(surr_w, surr2_w)

        for surr_w, full_w in zip(surr2_ws, full2_ws):
            np.testing.assert_array_almost_equal(surr_w, full_w)

        np.testing.assert_array_almost_equal(cov_surr, cov_surr2)
        np.testing.assert_array_almost_equal(cov_full, cov_full2)


@pytest.mark.parametrize(
    "hid_group, inout_group, same",
    gen_parametrized_variables(),
)
@torch.no_grad
def test_cov_matvec(hid_group, inout_group, same):
    generator = torch.Generator().manual_seed(2)
    config = ModelConfig()
    model = MLP(**asdict(config))

    with torch.inference_mode():
        groups = dict(
            group_config(model, hid_group=hid_group, inout_group=inout_group, same=same)
        )
        params = list(model.parameters())
        named_params = list(model.named_parameters())
        group_list = [groups[k] for k, _ in named_params]

        block_only = False
        full_f = CovFactory(group_list, surrogate=False, block_diag_only=block_only)
        surr_f = CovFactory(group_list, surrogate=True, block_diag_only=block_only)

        full_f.outer_update(params)
        surr_f.outer_update(params)

        # rv = [
        #     torch.normal(mean=0, std=10, size=p.shape, generator=generator)
        #     for p in params
        # ]

        rv = [
            torch.ones_like(p)
            for p in params

        ]

        rv_flat = torch.concat([v.flatten() for v in rv])
        full_cov = full_f.cov()
        full_cov_rv_flat = full_cov @ rv_flat
        surr_cov_rv = surr_f.matvec(rv)
        surr_cov_rv_flat = torch.concat([v.flatten() for v in surr_cov_rv])
        np.testing.assert_array_almost_equal(full_cov_rv_flat, surr_cov_rv_flat)


# TODO: fails for S and same=true
# @pytest.mark.parametrize(
#     "hid_group, inout_group, same",
#     gen_parametrized_variables(),
# )
# def test_cov_cov_transpose(hid_group, inout_group, same):
#     # generator = torch.Generator().manual_seed(2)
#     config = ModelConfig()
#     model = MLP(**asdict(config))

#     with torch.inference_mode():
#         args = dict(
#             hid_group=hid_group,
#             inout_group=inout_group,
#             same=same,
#         )
#         groups = dict(group_config(model, **args))
#         groups_tr = dict(group_config(model, transpose=True, **args))

#         params = list(model.parameters())
#         params_tr = [p.T for p in params]

#         named_params = list(model.named_parameters())
#         group_list = [groups[k] for k, _ in named_params]
#         group_tr_list = [groups_tr[k] for k, _ in named_params]

#         factory = FactorGrid(group_list)
#         factory_tr = FactorGrid(group_tr_list)

#         cov_factors = factory.cov_factors_from_vectors(params)
#         cov_factors_tr = factory_tr.cov_factors_from_vectors(params_tr)

#         for f, f_tr in zip(cov_factors, cov_factors_tr):
#             if f.weights.ndim > 2:
#                 np.testing.assert_array_almost_equal(f.weights, f_tr.weights.mT)
#             else:
#                 np.testing.assert_array_almost_equal(f.weights, f_tr.weights)


if __name__ == "__main__":
    pytest.main([__file__, "-v"])
