import pytest
import torch
from torch.nn import Parameter

from spastra.algebra import BlockGroupSpec, GroupCoupler
from spastra.linalg import kth_largest


@pytest.fixture
def param_4d():
    """A 4D parameter tensor for testing."""
    return Parameter(torch.arange(4 * 8 * 2 * 2).view(4, 8, 2, 2).float())


@pytest.fixture
def spec_u(param_4d):
    """A BlockGroupSpec instance for a 4D tensor."""
    return BlockGroupSpec(
        param_4d, block_size=(2, 4, 1, 1), group_size=(2, 1, 1, 2), name="U"
    )


class TestBlockGroupSpec:
    def test_init_valid(self, param_4d):
        """Test valid BlockGroupSpec initialization."""
        spec = BlockGroupSpec(param_4d, block_size=(2, 4, 1, 1), group_size=())
        assert spec.shape == (4, 8, 2, 2)
        assert spec.param.ndim == 4
        assert spec.block_size == (2, 4, 1, 1)
        assert spec.block_grid_size == (2, 2, 2, 2)
        assert spec.group_size == (2, 2, 2, 2)
        assert spec.group_grid_size == (1,)

    def test_init_invalid_block(self, param_4d):
        """Test initialization with invalid block size."""
        with pytest.raises(ValueError, match=r"axis 0: size 4 not divisible by block_size\[0\]=3"):
            BlockGroupSpec(param_4d, block_size=(3, 8, 2, 2), group_size=())

    def test_init_invalid_group(self, param_4d):
        """Test initialization with invalid group size."""
        with pytest.raises(
            ValueError,
            match=r"axis 0: block_grid\[0\]=2 not divisible by group_size\[0\]=3",
        ):
            BlockGroupSpec(param_4d, block_size=(2, 4, 2, 2), group_size=(3, 1))

    def test_init_group_neg_one(self, param_4d):
        """Test initialization with -1 in group, making it full size."""
        spec = BlockGroupSpec(
            param_4d, block_size=(2, 4, 1, 1), group_size=(-1, -1, -1, -1)
        )
        assert spec.group_size == spec.block_grid_size

    def test_view_conversions(self, spec_u):
        """Test the tensor view conversion methods."""
        blocked = spec_u.element_to_block(spec_u.param)
        assert blocked.shape == (2, 2, 2, 4, 2, 2)

        blocked = spec_u.element_to_block(spec_u.param, squeeze=False)
        assert blocked.shape == (2, 2, 2, 4, 2, 1, 2, 1)

        block_tensor = torch.randn(spec_u.block_grid_size)
        grouped = spec_u.block_to_group(block_tensor, squeeze=False)
        assert grouped.shape == (1, 2, 2, 1, 2, 1, 1, 2)

        group_tensor = torch.randn(spec_u.group_grid_size)
        ungrouped = spec_u.group_to_block(group_tensor)
        assert ungrouped.shape == spec_u.block_grid_size

        unblocked = spec_u.block_to_element(ungrouped)
        assert unblocked.shape == spec_u.shape

    def test_block_norms(self, spec_u: BlockGroupSpec):
        """Test calculation of block norms."""
        norms = spec_u.block_norms(spec_u.param)
        assert norms.shape == spec_u.block_grid_size

    def test_hard_threshold(self, param_4d):
        """Test hard thresholding."""
        spec = BlockGroupSpec(
            param_4d.clone(), block_size=(2, 4, 1, 1), group_size=(2, 1, 1, 2)
        )
        initial_param = spec.param.clone()
        thresholds = torch.zeros(spec.group_grid_size)
        thresholds[0, 0] = 1e6  # A very high threshold for the first group
        spec.hard_threshold(group_thresholds=thresholds)

        block_norms = spec.block_norms(spec.param.data)
        group_norms = spec.block_to_group(block_norms)
        assert torch.all(group_norms[:, 0, 0, :] == 0)

        assert not torch.all(initial_param == 0)
        assert torch.any(spec.param.data != 0)

    def test_soft_threshold(self, param_4d):
        """Test soft thresholding."""
        spec = BlockGroupSpec(
            param_4d.clone(), block_size=(2, 4, 1, 1), group_size=()
        )
        initial_norm = torch.linalg.vector_norm(spec.param.data)
        lambdas = torch.ones(spec.group_grid_size)
        spec.soft_threshold(lambdas, eta_t=0.1)
        final_norm = torch.linalg.vector_norm(spec.param.data)
        assert final_norm < initial_norm

    def test_soft_threshold_uniform_blocks_exact_factor(self):
        """Uniform 2x2 blocks of ones shrink by a predictable factor."""
        W = torch.ones(4, 4)
        spec = BlockGroupSpec(
            Parameter(W.clone()), block_size=(2, 2), group_size=()
        )
        # Each 2x2 block has L2 norm = 2; scaled by sqrt(4)=2 -> score=1.
        c = 0.3
        lambdas = torch.full(spec.group_grid_size, 1.0) * spec.block_numel
        spec.soft_threshold(lambdas, eta_t=c)
        # Expect every element scaled by (1 - c)
        assert torch.allclose(spec.param.data, torch.full_like(W, 1 - c))

    def test_soft_threshold_zero_and_nonzero_blocks(self):
        """Zero blocks stay zero; non-zero blocks shrink appropriately."""
        W = torch.zeros(4, 4)
        # Make bottom-right 2x2 block be all 2s
        W[2:, 2:] = 2.0
        # W[2:, :2] = 2.0
        spec = BlockGroupSpec(
            Parameter(W.clone()), block_size=(2, 2), group_size=()
        )
        # Scores: zero for three blocks, 2x2 block: L2=4, scaled by 2 => 2.
        # Use lambda that zeros the zero blocks (no-op) and shrinks the non-zero block by 50%
        lambdas = torch.tensor([2.0]) * spec.block_numel
        # last block: score=2. We want to shrink by 0.5, so lambda*eta_t/score = 0.5 -> lambda*eta_t = 1
        spec.soft_threshold(lambdas, eta_t=0.5)
        # Zero blocks remain zero
        bview = spec.element_to_block(spec.param.data)
        assert torch.all(bview[0, :, 0, :] == 0)
        assert torch.all(bview[0, :, 1, :] == 0)
        assert torch.all(bview[1, :, 0, :] == 0)
        # Non-zero block halved
        assert torch.allclose(spec.param.data[2:, 2:], torch.ones((2, 2)))

    def test_soft_threshold_strong_lambda_zeroes(self):
        """Lambda > score should zero the entire block."""
        W = torch.full((4, 4), 3.0)
        spec = BlockGroupSpec(
            Parameter(W.clone()), block_size=(2, 2), group_size=()
        )
        # Score per block: L2= sqrt(4*9)=6; scaled by 2 => 3.
        lambdas = torch.full(spec.group_grid_size, 3.1) * (
            spec.block_numel
        )  # > score
        spec.soft_threshold(lambdas, eta_t=1.0)
        assert torch.count_nonzero(spec.param.data) == 0

    def test_block_norms_ord1_and_scaling(self):
        """Check ord=1 vs ord=2 and the scaling behavior."""
        W = torch.ones(4, 4)
        spec = BlockGroupSpec(
            Parameter(W.clone()), block_size=(2, 2), group_size=()
        )
        # ord=2 scaled: score=1 everywhere
        n2 = spec.block_norms(spec.param.data, ord=2, scale=True)
        assert torch.allclose(n2, torch.ones_like(n2))
        # ord=1 without scale: sum over 2x2 block => 4
        n1_raw = spec.block_norms(spec.param.data, ord=1, scale=False)
        assert torch.allclose(n1_raw, torch.full_like(n1_raw, 4.0))
        # ord=1 with scale: divides by sqrt(4)=2 => 2
        n1_scaled = spec.block_norms(spec.param.data, ord=1, scale=True)
        assert torch.allclose(n1_scaled, torch.full_like(n1_scaled, 2.0))


# Fixtures for GroupCoupler
@pytest.fixture
def param_u():
    return Parameter(torch.randn(4, 8, 2, 2))


@pytest.fixture
def param_v():
    return Parameter(torch.randn(8, 16, 2, 2))


@pytest.fixture
def coupler_spec_u(param_u):
    return BlockGroupSpec(
        param_u, block_size=(2, 2, 1, 1), group_size=(1, 1, -1, -1), name="U"
    )


@pytest.fixture
def coupler_spec_v(param_v):
    return BlockGroupSpec(
        param_v, block_size=(2, 8, 1, 1), group_size=(1, 1, -1, -1), name="V"
    )


@pytest.fixture
def coupler(coupler_spec_u, coupler_spec_v):
    """A GroupCoupler instance for testing."""
    return GroupCoupler(
        [coupler_spec_u, coupler_spec_v], orders=[(0, 1), (1, 0)]
    )


# Tests for GroupCoupler
class TestGroupCoupler:
    def test_init_valid(self, coupler_spec_u, coupler_spec_v):
        """Test valid GroupCoupler initialization."""
        coupler = GroupCoupler(
            [coupler_spec_u, coupler_spec_v],
            orders=[(0, 1), (1, 0)],
        )
        assert len(coupler.specs) == 2
        assert len(coupler.orders) == 2

    def test_init_invalid_order(self, coupler_spec_u, coupler_spec_v):
        """Test initialization with incompatible orders."""
        with pytest.raises(ValueError, match="Incompatible grouped shapes"):
            GroupCoupler(
                [coupler_spec_u, coupler_spec_v],
                orders=[(0, 1), (0, 1)],
            )

    def test_kth_largest(self, coupler):
        """Test the kth_largest method."""
        k = 2
        thresholds = coupler.kth_largest(
            {s: s.param.data for s in coupler.specs}, kappa=k
        )
        # Shape of thresholds should be the permuted grouped shape of the reference spec
        ref_spec = coupler.specs[0]
        ref_order = coupler.orders[0]
        expect_grid_size = tuple(ref_spec.group_grid_size[i] for i in ref_order)
        assert thresholds.shape == expect_grid_size

    def test_hard_threshold(self, coupler):
        """Test hard thresholding with the coupler."""
        initial_norms = [
            torch.linalg.vector_norm(s.param.data) for s in coupler.specs
        ]
        coupler.hard_threshold(kappa=1)
        final_norms = [
            torch.linalg.vector_norm(s.param.data) for s in coupler.specs
        ]

        # At least one of the norms should have decreased
        assert any(
            final_norms[i] < initial_norms[i] for i in range(len(coupler.specs))
        )

    def test_soft_threshold(self, coupler):
        """Test soft thresholding with the coupler."""
        initial_norms = [
            torch.linalg.vector_norm(s.param.data) for s in coupler.specs
        ]
        group_lambdas = coupler.kth_largest(
            {s: s.param.data for s in coupler.specs}, kappa=1
        )
        learning_rates = {s: 0.1 for s in coupler.specs}
        coupler.soft_threshold(group_lambdas, learning_rates)
        final_norms = [
            torch.linalg.vector_norm(s.param.data) for s in coupler.specs
        ]

        for i in range(len(coupler.specs)):
            assert final_norms[i] < initial_norms[i]

    def test_soft_threshold_does_not_increase_any_block(self, coupler):
        """Per-block L2 norms should not increase after soft-thresholding."""
        before = [s.block_norms(s.param.data).clone() for s in coupler.specs]
        group_lambdas = coupler.kth_largest(
            {s: s.param.data for s in coupler.specs}, kappa=1
        )
        learning_rates = {s: 0.1 for s in coupler.specs}
        coupler.soft_threshold(group_lambdas, learning_rates)
        after = [s.block_norms(s.param.data).clone() for s in coupler.specs]
        for b, a in zip(before, after):
            assert torch.all(a <= b + 1e-6)


class TestKthLargest:
    def test_kth_largest_global(self):
        t = torch.tensor([3.0, 1.0, 4.0, 2.0])
        # 1st largest = 4, 2nd = 3, 4th = 1
        assert kth_largest(t, 1).item() == 4.0
        assert kth_largest(t, 2).item() == 3.0
        assert kth_largest(t, 4).item() == 1.0

    def test_kth_largest_dimensional(self):
        t = torch.tensor([[1.0, 3.0, 2.0], [9.0, 7.0, 8.0]])
        v = kth_largest(t, 2, dim=1)
        # For each row: 2nd largest -> [2, 8]
        assert torch.allclose(v, torch.tensor([2.0, 8.0]))

    def test_kth_largest_errors(self):
        t = torch.tensor([])
        with pytest.raises(ValueError):
            kth_largest(t, 1)
        with pytest.raises(ValueError):
            kth_largest(torch.tensor([1.0, 2.0]), 3)
