import types
import torch
import torch.nn as nn
import copy, os, random

from typing import List

class Chicken(nn.Module):
    """
    A Incremental learning class Module.
    """
    # optional cache so we don't recreate the same subclass over and over
    _cls_cache: dict[type[nn.Module], type[nn.Module]] = {}

    # -------- object construction --------
    def __new__(cls, model: nn.Module, *args, **kwargs):
        if cls is Chicken:  # only when user calls Chicken(…)
            base = type(model)

            # reuse cached subclass if it exists
            Wrapped = cls._cls_cache.get(base)
            if Wrapped is None:
                Wrapped = types.new_class(
                    f"Conti{base.__name__}",        # e.g. ContiVisionTransformer
                    (Chicken, base),               # MRO: Chicken → base model
                    {}
                )
                cls._cls_cache[base] = Wrapped

            # allocate instance of the *new* subclass
            inst = super().__new__(Wrapped)
            # copy every weight / buffer / attribute
            inst.__dict__.update(model.__dict__)
            return inst

        # if somebody subclasses Chicken explicitly, honour normal behaviour
        return super().__new__(cls)

    def __init__(
        self,
        model,
        device: str="cpu",
        init_val: float=0.1,
        max_mult: float=1.0,
        matching_texts: List[str]=("layernorm", "bias", "embeddings", "layrnorm", "layer_norm"),
        rank=None,  # optional truncation
    ):
        """
        Parameters
        ----------
        model: torch.nn.Module, required
        device: string, optional
            Initial Value (default cpu).
        init_val: float, optional
            Maximum initial value mask ~ U[0,init_val] (default 0.1).
        max_mult: float, optional
            Maximum possible value the mask can take [0,max_mult] (default 1.0).
        matching_texts: List[str], optional
            A list of matching layer names that should not perform the decomposition and reconstruction (default ("layernorm", "bias", "embeddings", "layrnorm", "layer_norm")).
        
        Examples
        --------
        >>> from transformers import ViTModel
        >>> model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
        >>> model = Chicken(model, device="cuda", init_val=0.05, max_mult=1.0)
        """
        # super().__init__()  # DON'T: dynamic subclass already has base attrs
        if not isinstance(model, nn.Module):
            raise TypeError("model is not an torch.nn.Module")

        self.init_val = float(init_val)
        self.matching_texts = tuple(matching_texts)
        self.device = device
        self.max_mult = float(max_mult)
        self.rank = rank  # None = full SVD

        # snapshot of base params (on current device/dtype)
        self.base_params = copy.deepcopy(model.state_dict())
        # precompute decomposition only for 2D weights we intend to adapt
        self.decomposed_params = self.decompose(self.base_params, self.matching_texts, self.rank, self.device)

        self.class_policy_map = {}

        # register mask containers so they move with .to()
        self.learnable_params = nn.ModuleDict()   # key: str(mask_idx) -> ParameterDict
        self._mask_param_lists = {}               # mask_idx (int) -> list[Parameter]

        self.num_params = 0
        self.enable_mask = []
        self.new_mask_idx = 0
        self.selected_mask = -1  # no mask selected

    # ---------- properties / helpers ----------
    @property
    def class_map(self):
        """
        Returns a string of mask index and the classes associated with it

        Returns:
            string
        
        Examples
        --------
        >>> print(model.class_map)
        CLASS MAP
        1: cat, dog, horse, cow
        2: mouse, lion
        """
        # inverse map
        inverse_map = {}
        for name in self.class_policy_map:
            mask_idx = self.class_policy_map[name]
            if mask_idx not in inverse_map:
                inverse_map[mask_idx] = []
            inverse_map[mask_idx].append(name)

        string = "CLASS MAP\n"
        string += "------------------\n"
        for mask_idx in inverse_map:
            string += f"{mask_idx}: {', '.join(inverse_map[mask_idx])}\n"
        string += "------------------\n"
        return string

    @property
    def latest_mask_idx(self):
        """
        retruns the latest mask index
        
        Returns
        -------
        int
        """
        return self.new_mask_idx - 1

    @staticmethod
    def decompose(base_params, skip_match_texts=(), rank=None, device="cpu"):
        decomposed_params = {}
        for k, v in base_params.items():
            if any(text in k for text in skip_match_texts):
                continue  # skip this param
            W = v.detach().to(device=device, dtype=torch.float32)
            # U: [m,r], S: [r], Vh: [r,n], r = min(m,n)
            U, S, Vh = torch.linalg.svd(W, full_matrices=False)
            decomposed_params[f"{k}::U"] = U
            decomposed_params[f"{k}::S"] = S
            decomposed_params[f"{k}::Vh"] = Vh
        return decomposed_params
    
    def add_mask(self):
        """
        Call this to add a new mask (creates a new mask vector per decomposed matrix)
        """
        mask_params = nn.ParameterDict()
        for k, v in self.base_params.items():
            if any(text in k for text in self.matching_texts):
                continue
            S = self.decomposed_params.get(f"{k}::S")
            if S is None:
                continue
            # init small random so sigmoid ≈ 0.5 with small variance
            m = nn.Parameter(torch.randn_like(S, dtype=torch.float32, device=self.device) * self.init_val)
            mask_params[k.replace('.', '__')] = m
            self.num_params += m.numel()

        key = str(self.new_mask_idx)
        self.learnable_params[key] = mask_params
        self._mask_param_lists[self.new_mask_idx] = list(mask_params.parameters())
        self.enable_mask.append(True)

        self.new_mask_idx += 1
        return True

    def add_class(self, class_names: List[str]):
        """
        Call this to add a new set of classes (creates a new mask vector per decomposed matrix)

        Parameters
        ----------
        class_names: List[str], required
                 A list of class names

        Returns
        -------
        bool
            True if the classes were added successfully, False otherwise.

        Examples
        --------
        >>> model.add_class(["cat", "dog"])
        True
        """
        for name in class_names:
            self.class_policy_map[name] = self.new_mask_idx

        mask_params = nn.ParameterDict()
        for k, v in self.base_params.items():
            if any(text in k for text in self.matching_texts):
                continue
            S = self.decomposed_params.get(f"{k}::S")
            if S is None:
                continue
            # init small random so sigmoid ≈ 0.5 with small variance
            m = nn.Parameter(torch.randn_like(S, dtype=torch.float32, device=self.device) * self.init_val)
            mask_params[k.replace('.', '__')] = m
            self.num_params += m.numel()

        key = str(self.new_mask_idx)
        self.learnable_params[key] = mask_params
        self._mask_param_lists[self.new_mask_idx] = list(mask_params.parameters())
        self.enable_mask.append(True)

        self.new_mask_idx += 1
        return True

    def set_mask(self, mask_idx: int = 0):
        """
        Set the selected mask

        Parameters
        ----------
        mask_idx: int, optional
            Set the selected mask to the mask_idx (default 0)

        Returns
        -------
        boolean
            True if selected mask set successfully
        """
        if mask_idx == -1:          # special: base weights
            self.selected_mask = -1
            return True
        try:
            self.enable_mask[mask_idx]
        except IndexError:
            raise IndexError("the mask number is out of range")
        self.selected_mask = mask_idx
        return True


    def get_mask(self, mask_idx: int = -1):
        """
        Returns the state dictionary of the the selected mask

        Parameters
        ----------
        mask_idx: int, required
            The mask index if not sepecified return the last mask (default -1).
        
        Returns
        -------
        dict
            state_dict: a state dict of the selected mask
        """
        try:
            self.enable_mask[mask_idx]
        except IndexError:
            raise IndexError("the mask number is out of range")

        # set to latest mask if not specified
        if mask_num == -1:
            self.selected_mask = self.latest_mask_idx

        # return ParameterDict for transparency
        return self.learnable_params[str(mask_num)]

    def get_trainable_parameters(self, mask_idx=None):
        if mask_idx is None:
            mask_idx = self.selected_mask
        if mask_idx == -1:
            return []   # nothing to train when using base weights
        return self._mask_param_lists[mask_idx]

    def save_weights(self, path: str):
        """
        Save the mask weights to the path

        Parameters
        ----------
        path: str, required
             location to where the mask should be saved should be .pt file.
        """
        payload = {
            "learnable_params": {
                idx: {n: p.detach().cpu() for n, p in self.learnable_params[idx].items()}
                for idx in self.learnable_params.keys()
            },
            "enable_mask": self.enable_mask,
            "new_mask_idx": self.new_mask_idx,
            "class_policy_map": self.class_policy_map,
            "rank": self.rank,
            "matching_texts": self.matching_texts,
            "init_val": self.init_val,
            "max_mult": self.max_mult,
        }
        torch.save(payload, path)

    def load_weights(self, path: str):
        """
        Load the mask

        Parameters
        ----------
        path: str, required
            location to where the .pt for the mask is located.
        """
        info = torch.load(path, map_location=self.device)
        self.learnable_params = nn.ModuleDict()
        self._mask_param_lists.clear()
        self.enable_mask = list(info["enable_mask"])
        self.new_mask_idx = int(info["new_mask_idx"])
        self.class_policy_map = dict(info["class_policy_map"])
        for idx, d in info["learnable_params"].items():
            pd = nn.ParameterDict({n: nn.Parameter(t.to(self.device)) for n, t in d.items()})
            self.learnable_params[idx] = pd
            self._mask_param_lists[int(idx)] = list(pd.parameters())
        # choose a mask (method unchanged per your request)
        self.set_mask()

    def activate_mask(self, p, mask_idx):
        if mask_idx == -1:
            return torch.ones_like(p, dtype=torch.float32)
        if not self.enable_mask[mask_idx]:
            return torch.ones_like(p, dtype=torch.float32)
        return torch.sigmoid(p).to(torch.float32) * self.max_mult

    def forward(self, *args, **kwargs):
        return super().forward(*args, **kwargs)

    def compose_new_params(self, param_name, mask_idx):
        U  = self.decomposed_params[f"{param_name}::U"]   # [m,r]
        S  = self.decomposed_params[f"{param_name}::S"]   # [r]
        Vh = self.decomposed_params[f"{param_name}::Vh"]  # [r,n]

        mparam = self.learnable_params[str(mask_idx)][param_name.replace('.', '__')]  # [r]
        mm = self.activate_mask(mparam, mask_idx)          # [r]
        S_scaled = S * mm

        eps = torch.finfo(S.dtype).eps
        scale = (S.sum() / (S_scaled.sum() + eps))

        Wp = torch.einsum('mr,r->mr', U, S_scaled)
        Wp = torch.einsum('mr,rn->mn', Wp, Vh) * scale
        return Wp


    def toggle_mask(self, mask_value: bool = True, mask_idx: int = None):
        """
        turn on or off the mask

        Parameters
        ----------
        mask_value: bool, optional
            A boolean checking whether the mask should be on or off (default True)
        mask_idx: int, optional
           If None selected the last mask index (default None)
        """
        if mask_idx is None:
            mask_idx = self.latest_mask_idx

        self.enable_mask[mask_idx] = mask_value
        self.apply_policy_to_model(mask_idx)

    def update_backward(self, mask_idx: int = None):
        """
        Backpropagate through the learnable mask parameters using VJP.
        Requires that loss.backward() has populated dL/dW on base weights.

        Parameters
        ----------
        mask_idx: int, optional
            If None use the selected mask from set_mask (default None)
        """
        if mask_idx is None:
            mask_idx = self.selected_mask

        keys = [k for k in self.base_params
                if all(text not in k for text in self.matching_texts)
                and self.decomposed_params.get(f"{k}::S") is not None]
        if not keys:
            return
        last_key = keys[-1]
        for k in keys:
            g = self.get_parameter(k).grad
            if g is None:
                raise RuntimeError(f"No grad for {k}; call set_train() and loss.backward() first.")
            self.compose_new_params(k, mask_idx).backward(g, retain_graph=(k is not last_key))

    def set_train(self, mask_idx: int = None):
        """
        Set the learnable parameters to training mode.

        Parameters
        ---------
        mask_idx: int, optional
            If None use the mask index from set_mask
        """
        if mask_idx is None:
            mask_idx = self.selected_mask

        # 1) freeze everything
        for _, p in self.named_parameters():
            p.requires_grad_(False)

        # 2) enable grads on base weights we compose (so dL/dW is computed)
        for k in self.base_params:
            if any(s in k for s in self.matching_texts):
                continue
            if self.decomposed_params.get(f"{k}::S") is None:
                continue
            p = self.get_parameter(k)
            p.requires_grad_(True)
            p.retain_grad()  # keep grad around for VJP

        # 3) ensure masks are trainable
        for p in self.get_trainable_parameters(mask_idx):
            p.requires_grad_(True)

    def apply_policy_to_model(self, mask_idx: int = None):
        """
        Compose & write weights into the live model (fast in-place copy).

        Parameters
        ----------
        mask_idx: int, required
            index of the mask that should be applied to the model if None will choose based on set_mask or latest mask
        
        Examples
        --------
        >>> model.apply_policy_to_model(1)
        """
        if mask_idx is None:
            mask_idx = self.selected_mask

        with torch.no_grad():
            for k, base in self.base_params.items():
                param = self.get_parameter(k)
                if any(skip in k for skip in self.matching_texts):
                    param.copy_(base.to(param.dtype).to(self.device))
                    continue
                if self.decomposed_params.get(f"{k}::S") is None:
                    param.copy_(base.to(param.dtype).to(self.device))
                    continue
                Wp = self.compose_new_params(k, mask_idx).to(param.dtype).to(self.device)
                param.copy_(Wp)
