import re
from types import SimpleNamespace
from pathlib import Path

import torch
import pytest
from omegaconf import OmegaConf

from spastra.configs import (
    get_sparsity_groups,
    get_sparsity_specs,
    hash_config,
    get_optimizer,
    get_lr_scheduler,
)
from spastra.models.resnets import get_resnet
from torch.optim import SGD, AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, MultiStepLR


class Namespace(SimpleNamespace):
    def __getitem__(self, k):
        return getattr(self, k)

    def items(self):
        return [
            (k, v) for k, v in self.__dict__.items() if not k.startswith("_")
        ]


@pytest.fixture(scope="session")
def unstructured_cfg_dict():
    cfg_path = (
        Path(__file__).resolve().parent.parent
        / "config"
        / "sparsifiers"
        / "unstructured.yaml"
    )
    cfg = OmegaConf.load(str(cfg_path))
    return OmegaConf.to_container(cfg, resolve=True)


@pytest.fixture(scope="function")
def resnet32():
    return get_resnet(32, num_classes=10)


@pytest.fixture(scope="function")
def model_params(resnet32):
    return list(resnet32.parameters())


# --------------------
# Tests
# --------------------


class TestHashConfig:
    def test_changes_when_cfg_changes(self):
        cfg1 = OmegaConf.create({"alpha": 0.1, "foo": "bar"})
        cfg2 = OmegaConf.create({"alpha": 0.2, "foo": "bar"})
        h1 = hash_config(cfg1)
        h2 = hash_config(cfg2)
        assert isinstance(h1, str) and isinstance(h2, str)
        assert h1 != h2


class TestOptimizers:
    def test_builds_sgd_and_adamw(self, model_params):
        sgd_cfg = Namespace(
            name="sgd", lr=0.05, momentum=0.9, weight_decay=1e-4
        )
        opt_sgd = get_optimizer(sgd_cfg, model_params)
        assert isinstance(opt_sgd, SGD)
        assert opt_sgd.param_groups[0]["lr"] == pytest.approx(0.05)
        assert opt_sgd.param_groups[0]["weight_decay"] == pytest.approx(1e-4)

        adamw_cfg = Namespace(
            name="adamw", lr=1e-3, weight_decay=1e-2, betas=(0.8, 0.88)
        )
        opt_adamw = get_optimizer(adamw_cfg, model_params)
        assert isinstance(opt_adamw, AdamW)
        assert opt_adamw.param_groups[0]["lr"] == pytest.approx(1e-3)
        assert opt_adamw.param_groups[0]["weight_decay"] == pytest.approx(1e-2)


class TestSchedulers:
    def test_cosine(self, model_params):
        opt = SGD(model_params, lr=0.1, momentum=0.9)
        cosine_cfg = Namespace(name="cosine", eta_min=0.0, T_max=100)
        sch_cos = get_lr_scheduler(cosine_cfg, opt)
        assert isinstance(sch_cos, CosineAnnealingLR)
        assert sch_cos.T_max == 100

    def test_multistep_ratio_list(self, model_params):
        opt = SGD(model_params, lr=0.1, momentum=0.9)
        ms_cfg = Namespace(
            name="multistep", step_ratio=[0.5, 0.75], gamma=0.1, num_epochs=160
        )
        sch_ms = get_lr_scheduler(ms_cfg, opt)
        assert isinstance(sch_ms, MultiStepLR)
        assert set(sch_ms.milestones.keys()) == {80, 120}

    def test_multistep_single_ratio_with_default_offset(self, model_params):
        opt = SGD(model_params, lr=0.1, momentum=0.9)
        ms_cfg = Namespace(
            name="multistep",
            step_ratio=0.2,
            gamma=0.1,
            num_epochs=100,
            offset_ratio=-1,
        )
        sch_ms = get_lr_scheduler(ms_cfg, opt)
        assert isinstance(sch_ms, MultiStepLR)
        assert set(sch_ms.milestones.keys()) == {20, 40, 60, 80}


# class TestSparsityGroups:
#     def test_from_unstructured_yaml_excludes_bn(
#         self, unstructured_cfg_dict, resnet32
#     ):
#         groups = get_sparsity_groups(resnet32, unstructured_cfg_dict)
#         assert isinstance(groups, list) and len(groups) >= 1

#         # Gather spec names from first GroupCoupler
#         names = [s.name for s in groups[0].specs]
#         names = [n for n in names if n is not None]

#         # Excludes anything matching .*bn.*
#         assert all("bn" not in n for n in names)

#         # Should include conv weights and likely linear weights
#         assert any(re.match(r".*conv.*weight", n) for n in names)
#         assert any("linear.weight" in n for n in names)

#         # Every spec should have an alpha
#         assert all(s in alphas for s in groups[0].specs)
#         # Alphas are tensors on the same device as their params
#         for s in groups[0].specs:
#             a = alphas[s]
#             assert torch.is_tensor(a)
#             assert a.device == s.param.device
