from dataclasses import dataclass, field
from typing import Dict, Optional, Tuple, List
import torch
import torch.nn as nn
import warnings

from spastra.algebra import GroupCoupler
from spastra.controllers import EMAController
from spastra.controllers import LambdaController
from spastra.controllers import AlphaController

from spastra import hooks

Tensor = torch.Tensor
Parameter = nn.Parameter
Optimizer = torch.optim.Optimizer


@dataclass
class BaseASTRASparsifier:
    groups: List[GroupCoupler]
    lambdas: LambdaController
    ema_grad: EMAController
    alphas: AlphaController
    device: Optional[torch.device] = None
    eps: float = 1e-7

    _cached_directions: Dict[Parameter, Tensor] = field(
        default_factory=dict, init=False, repr=False
    )
    _cached_learning_rates: Dict[Parameter, Tensor] = field(
        default_factory=dict, init=False, repr=False
    )
    _hook_handles: Dict[Optimizer, list] = field(
        default_factory=dict, init=False, repr=False
    )

    _updated = False

    @property
    def params(self) -> set[Parameter]:
        ps = set()
        for group in self.groups:
            for p in group.params:
                ps.add(p)
        return ps

    @property
    def specs(self):
        ps = set()
        for group in self.groups:
            for s in group.specs:
                ps.add(s)
        return ps

    def _pre_step_hook(self, optimizer, *args, **kwargs):
        self._cached_directions.update(self.gather_gradients(optimizer))
        self._updated = True

    def _post_step_hook(self, optimizer, *args, **kwargs):
        params_for_opt = {
            p for group in optimizer.param_groups for p in group["params"]
        }
        directions_for_opt = {
            p: d
            for p, d in self._cached_directions.items()
            if p in params_for_opt
        }
        new_directions, new_lrs = self.gather_info(
            directions_for_opt, optimizer
        )
        self._cached_directions.update(new_directions)
        self._cached_learning_rates.update(new_lrs)
        self._updated = True

    def attach_optimizer(self, optimizer: Optimizer):
        if optimizer in self._hook_handles:
            warnings.warn("Optimizer already attached to ASTRA.")
            return self
        pre_hook = optimizer.register_step_pre_hook(self._pre_step_hook)
        post_hook = optimizer.register_step_post_hook(self._post_step_hook)
        self._hook_handles[optimizer] = [pre_hook, post_hook]
        return self

    def detach_optimizer(self, optimizer: Optimizer):
        if optimizer not in self._hook_handles:
            warnings.warn("Optimizer was not attached to ASTRA before")
            return
        for handle in self._hook_handles.pop(optimizer):
            handle.remove()
        return self

    def detach_all_optimizers(self):
        for optimizer in list(self._hook_handles.keys()):
            self.detach_optimizer(optimizer)

    @torch.no_grad()
    def step(self, sparsify: bool = True):
        if not self._updated:
            print("Sparsifier states not updated. did you attach optimizer?")
        if not self._cached_directions:
            return

        directions = self._cached_directions
        learning_rates = self._cached_learning_rates
        self._cached_directions = {}
        self._cached_learning_rates = {}

        for sp in self.specs:
            self.ema_grad.update_single(sp, directions.get(sp.param, None))

        for group in self.groups:
            grad_bar_values = {}
            lrs = {}
            for sp in group.specs:
                lrs[sp] = learning_rates[sp.param]
                alpha = self.alphas.get(sp)
                ema = self.ema_grad.get(sp)
                v = ema - alpha * sp.param.data
                grad_bar_values[sp] = v

            kappa = group.kappa
            psi = group.kth_largest(grad_bar_values, kappa)
            self.lambdas.update_single(group, psi)
            threshold = self.lambdas.get(group).add(self.eps)
            if sparsify:
                group.soft_threshold(threshold, learning_rates=lrs)

        self._updated = False

    def gather_gradients(self, optimizer: Optimizer) -> Dict[Parameter, Tensor]:
        directions = {}
        param_set = {s.param for g in self.groups for s in g.specs}
        for group in optimizer.param_groups:
            for p in group["params"]:
                if p not in param_set or p.grad is None:
                    continue
                directions[p] = p.grad.detach()
        return directions

    def gather_info(
        self, directions: Dict[Parameter, Tensor], optimizer: Optimizer
    ) -> Tuple[Dict[Parameter, Tensor], Dict[Parameter, Tensor]]:
        raise NotImplementedError(
            "Subclasses must implement gather_update_info()"
        )


class SASTRA(BaseASTRASparsifier):
    """ASTRA sparsifier specialized for SGD (with/without momentum).

    - Direction per param: momentum_buffer if available (when use_momentum), else raw grad.
    - Learning rate per param: scalar base lr from the param group.
    """

    def gather_info(
        self, directions: Dict[Parameter, Tensor], optimizer: Optimizer
    ) -> Tuple[Dict[Parameter, Tensor], Dict[Parameter, Tensor]]:
        lrs: Dict[Parameter, Tensor] = {}
        param_set = self.params

        for param_g in optimizer.param_groups:
            base_lr = param_g.get("lr", 1.0)
            for p in param_g["params"]:
                if p not in param_set:
                    continue
                st = optimizer.state.get(p, {})
                momentum = st.get("momentum_buffer", None)
                if momentum is not None:
                    directions[p] = st["momentum_buffer"].detach()
                lrs[p] = torch.as_tensor(
                    base_lr, device=p.device, dtype=p.dtype
                )

        return directions, lrs


class AdASTRA(BaseASTRASparsifier):
    """ASTRA sparsifier specialized for Adam/AdamW.

    - Direction per param: exp_avg (first moment EMA).
    - Learning rate per param: elementwise step_size/(sqrt(exp_avg_sq)+eps) with bias correction.
    """

    def gather_info(
        self, directions: Dict[Parameter, Tensor], optimizer: Optimizer
    ) -> Tuple[Dict[Parameter, Tensor], Dict[Parameter, Tensor]]:
        lrs: Dict[Parameter, Tensor] = {}
        param_set = {s.param for g in self.groups for s in g.specs}

        for group in optimizer.param_groups:
            base_lr = group.get("lr", 1.0)
            betas = group.get("betas", (0.9, 0.999))
            eps = group.get("eps", 1e-8)
            for p in group["params"]:
                if p not in param_set:
                    continue
                st = optimizer.state.get(p, {})

                exp_avg = st.get("exp_avg", None)
                if exp_avg is not None:
                    directions[p] = exp_avg.detach()

                # Elementwise learning rate tensor
                exp_avg_sq = st.get("exp_avg_sq", None)
                step_val = st.get("step", 0)
                step_t = (
                    int(step_val.item())
                    if isinstance(step_val, torch.Tensor)
                    else int(step_val)
                )
                step_t = max(1, step_t)
                b1, b2 = betas
                bc1 = 1 - (b1**step_t)
                bc2 = 1 - (b2**step_t)
                base_lr_t = torch.as_tensor(
                    base_lr, device=p.device, dtype=p.dtype
                )
                if exp_avg_sq is not None:
                    denom = exp_avg_sq.detach().sqrt().add_(eps)
                    step_size = base_lr_t * (bc2**0.5) / bc1
                    lrs[p] = step_size / denom
                else:
                    lrs[p] = base_lr_t

        return directions, lrs


@dataclass
class IHTSparsifier:
    groups: List[GroupCoupler]
    kappa: Optional[int] = None
    sparsity: Optional[float] = None
    device: Optional[torch.device] = None

    _masks: dict = field(default_factory=dict, init=False, repr=False)

    _hook_handles: List = field(default_factory=list, init=False, repr=False)

    def step(self, sparsify=True, *args, **kwargs):
        if not sparsify:
            return
        for group in self.groups:
            group.hard_threshold()

    def freeze_support(self, optimizer: Optimizer):
        # hard threshold
        self.step()
        # get the masks
        for group in self.groups:
            for sp in group.specs:
                p = sp.param
                mask = p.data.abs() >= 1e-12
                p.data.mul_(mask.to(p))
                self._masks[sp.param] = mask
                # register hooks to zero-out gradients after each backward
                self._hook_handles.append(
                    p.register_post_accumulate_grad_hook(
                        # m=mask to avoid all params using the same mask... 🤬
                        lambda p, m=mask: hooks.mask_post_accumulate_hook(
                            p, m.to(p)
                        )
                    )
                )

        if isinstance(optimizer, torch.optim.SGD):
            # zero out the momentum outside the support
            for param_g in optimizer.param_groups:
                if param_g["momentum"] == 0.0:
                    continue
                for p in param_g["params"]:
                    if p in self._masks:
                        st = optimizer.state.get(p, {})
                        momentum = st.get("momentum_buffer", None)
                        if momentum is not None:
                            momentum.mul_(self._masks[p].to(momentum))

        else:
            raise ValueError(
                "IHT.freeze_support only supports SGD (for now), "
                f"got {optimizer.__class__}"
            )
