import pytest
import torch
from torch.nn import Parameter
from torch.optim import SGD

from spastra.algebra import BlockGroupSpec
from spastra.algebra import GroupCoupler
from spastra.controllers import EMAController
from spastra.controllers import LambdaController
from spastra.controllers import AlphaController
from spastra.astra import SASTRA
from spastra.astra import IHTSparsifier


# Fixtures for Parameters
@pytest.fixture
def param1():
    return Parameter(torch.randn(10, 20))


@pytest.fixture
def param2():
    return Parameter(torch.randn(20, 30))


@pytest.fixture
def sp1(param1):
    return BlockGroupSpec(param1, block_size=(), group_size=())


@pytest.fixture
def sp2(param2):
    return BlockGroupSpec(param2, block_size=(), group_size=())


# Tests for Controller
class TestController:
    def test_update_single(self, param1, sp1):
        controller = EMAController(rho=0.1)
        direction = torch.ones_like(param1)
        controller.update_single(sp1, direction)
        assert sp1 in controller._ema
        assert torch.allclose(controller.get(sp1), direction * 0.1)

        controller.update_single(sp1, direction)
        # ema = 0.1 * 0.9 + 0.1 = 0.19
        assert torch.allclose(
            controller.get(sp1), torch.full_like(param1, 0.19)
        )

    def test_update_all(self, param1, param2, sp1, sp2):
        controller = EMAController(rho=0.1)
        directions = {
            sp1: torch.ones_like(param1),
            sp2: torch.ones_like(param2),
        }
        controller.update_all(directions)
        assert sp1 in controller._ema
        assert sp2 in controller._ema
        assert torch.allclose(
            controller._ema[sp1], torch.ones_like(param1) * 0.1
        )
        assert torch.allclose(
            controller._ema[sp2], torch.ones_like(param2) * 0.1
        )


# Tests for LambdaController
class TestLambdaController:
    @pytest.fixture
    def lambda_controller(self):
        return LambdaController(device=torch.device("cpu"), beta=0.5, t0=1)

    def test_update_single(self, lambda_controller: LambdaController, param1):
        psi = torch.ones_like(param1) * 0.5
        lambda_controller.update_single(param1, psi)
        assert param1 in lambda_controller._momentums
        # beta = 0.5, t0=1, t starts at 1.
        # lambda = 0 + 0.5 * 0.5 = 0.25
        assert torch.allclose(
            lambda_controller._momentums[param1],
            torch.full_like(param1, 0.25),
        )
        assert lambda_controller._t[param1] == 2

        lambda_controller.update_single(param1, psi)
        # lambda = 0.25 * 0.5 + 0.5 * 0.5 = 0.125 + 0.25 = 0.375
        assert torch.allclose(
            lambda_controller._momentums[param1],
            torch.full_like(param1, 0.375),
        )
        assert lambda_controller._t[param1] == 3

    def test_reset_time(self, lambda_controller, param1):
        lambda_controller.update_single(param1, torch.ones_like(param1))
        assert lambda_controller._t[param1] > 1
        lambda_controller.reset_time()
        assert lambda_controller._t[param1] == 0


# Fixtures for Sparsifiers
@pytest.fixture
def spec1():
    param = Parameter(torch.randn(4, 8))
    return BlockGroupSpec(param, block_size=(2, 2), group_size=(1, 2))


@pytest.fixture
def spec2():
    param = Parameter(torch.randn(8, 4))
    return BlockGroupSpec(param, block_size=(2, 2), group_size=(2, 1))


@pytest.fixture
def coupler(spec1, spec2):
    return GroupCoupler([spec1, spec2], orders=[(0, 1), (1, 0)])


# Tests for IHTSparsifier
class TestIHTSparsifier:
    def test_step(self, coupler):
        sparsifier = IHTSparsifier(groups=[coupler], kappa=1)
        initial_norms = [p.data.norm() for p in coupler.params]
        sparsifier.step()
        final_norms = [p.data.norm() for p in coupler.params]
        # Hard thresholding should shrink the parameters
        assert any(
            final_norms[i] < initial_norms[i]
            for i in range(len(coupler.params))
        )


# Tests for ASTRASparsifier
class TestASTRASparsifier:
    @pytest.fixture
    def astra_elements(self, coupler: GroupCoupler):
        params = [s.param for s in coupler.specs]
        opt = SGD(params, lr=0.1, momentum=0.9)
        num_groups = coupler.num_blocks
        sp = SASTRA(
            groups=[coupler],
            lambdas=LambdaController(device=torch.device("cpu")),
            ema_grad=EMAController(rho=0.05),
            alphas=AlphaController(default=0.01),
            sparsity={coupler: (num_groups - 1) / num_groups},
            device=torch.device("cpu"),
        )
        sp.attach_optimizer(opt)
        return sp, opt

    def test_step(self, astra_elements, coupler):
        params = [s.param for s in coupler.specs]
        astra_sparsifier, optimizer = astra_elements
        # simulate a grad update then optimizer step to populate momentum buffers
        for p in params:
            p.grad = torch.randn_like(p)
        optimizer.step(lambda: torch.ones(1))
        pre_prox_norms = [p.data.norm() for p in params]
        astra_sparsifier.step()

        final_norms = [p.data.norm() for p in params]

        # Soft thresholding should not increase norms and should shrink at least one
        for i in range(len(params)):
            assert final_norms[i] <= pre_prox_norms[i]
        assert any(
            final_norms[i] < pre_prox_norms[i] for i in range(len(params))
        )

    def test_checkpoint_gradients(self, astra_elements, coupler):
        astra_sparsifier, optimizer = astra_elements
        params = [s.param for s in coupler.specs]
        for p in params:
            p.grad = torch.ones_like(p)
        optimizer = optimizer
        grad_dict = astra_sparsifier.gather_gradients(optimizer)
        assert len(grad_dict) == len(params)
        for p in params:
            assert p in grad_dict
            assert torch.all(grad_dict[p] == 1.0)

    def test_refresh_information(self, astra_elements, coupler):
        astra_sparsifier, optimizer = astra_elements
        params = [s.param for s in coupler.specs]
        # optimizer = astra_sparsifier.optimizer
        optimizer.step(lambda: torch.ones(1))  # populate momentum state

        new_dirs, lrs = astra_sparsifier.gather_info({}, optimizer)
        assert len(lrs) == len(params)
        for p in params:
            assert p in lrs
            assert torch.all(
                lrs[p] == torch.tensor(0.1, dtype=p.dtype, device=p.device)
            )
            if "momentum_buffer" in optimizer.state[p]:
                assert torch.all(
                    new_dirs[p] == optimizer.state[p]["momentum_buffer"]
                )

    def test_lambda_controller_constant_mode_with_momentum_sgd(self, coupler):
        # Create SASTRA with momentum SGD and LambdaController in constant mode
        # params = [s.param for s in coupler.specs]

        # Make param magnitudes large so param_z doesn't cap psi (min won't change psi)
        # for p in params:
        for s in coupler.specs:
            s.param.data.fill_(1e9)

        optimizer = SGD(coupler.params, lr=0.1, momentum=0.9)

        lambdas = LambdaController(
            device=torch.device("cpu"), beta=0.5, mode="constant"
        )
        ema_grad = EMAController(
            rho=1.0
        )  # so ema equals the provided direction
        # alpha = {
        #     s: torch.tensor(0.0, device=s.param.device, dtype=s.param.dtype)
        #     for s in coupler.specs
        # }
        num_groups = coupler.num_blocks

        sp = SASTRA(
            groups=[coupler],
            lambdas=lambdas,
            ema_grad=ema_grad,
            alphas=AlphaController(default=0.0),
            sparsity={coupler: (num_groups - 1) / num_groups},
            device=torch.device("cpu"),
        )

        sp.attach_optimizer(optimizer)

        # Set grads to ones everywhere and run a step to populate momentum and cache info
        for p in coupler.params:
            p.grad = torch.ones_like(p)
        optimizer.step(lambda: torch.ones(1))

        # Now run the sparsifier step which will update lambda_momentums
        sp.step()

        # With constant beta=1, lambda_momentums should equal psi used this step.
        # psi is kth_largest over ones (then +1e-6 before lambda update).
        expected_psi = (
            coupler.kth_largest(
                {s: torch.ones_like(s.param) * (0.5) for s in coupler.specs},
                kappa=1,
            )
            + 1e-6
        )

        assert coupler in sp.lambdas._momentums
        lam = sp.lambdas._momentums[coupler]
        assert torch.allclose(lam, expected_psi)
