import torch
import torch.nn as nn
import torch.nn.functional as F
from contextlib import contextmanager
from typing import Dict, Tuple, Any

Tensor = torch.Tensor
DeltaDict = Dict[str, Tuple[Tensor, Tensor]]  # name -> (A:[r,in], B:[out,r])


class FunctionalLoRAInjector:
    """
    Functionally injects LoRA at forward-time for specified Linear layers.
    - Never rebinds .weight (keeps them leaf Parameters).
    - Deltas (A,B) are regular tensors (from your generator) so grads flow back.
    """

    def __init__(self, scale: float = 1.0):
        self.scale = float(scale)
        self._hooks = []
        self._orig_forwards: Dict[nn.Module, Any] = {}
        self._name_to_module: Dict[str, nn.Module] = {}
        self._deltas: DeltaDict = {}
        self._enabled: bool = True

    def attach(self, layer_specs: Dict[str, Tuple[nn.Module, torch.Size, torch.Size]]):
        """
        layer_specs: name -> (module, shape_A, shape_B)
            module must be an nn.Linear with W:[out,in]
            shape_A == (r, in), shape_B == (out, r)
        """
        for name, (mod, shape_A, shape_B) in layer_specs.items():
            if not isinstance(mod, nn.Linear):
                raise TypeError(
                    f"FunctionalLoRA expects nn.Linear, got {type(mod)} for '{name}'"
                )
            self._name_to_module[name] = mod

            # Wrap each module's forward once
            if mod not in self._orig_forwards:
                self._orig_forwards[mod] = mod.forward

                def make_forward(original_forward, lname):
                    def wrapped_forward(x: Tensor, *args, **kwargs):
                        y = original_forward(x, *args, **kwargs)
                        if self._enabled and lname in self._deltas:
                            A, B = self._deltas[lname]  # A:[r,in], B:[out,r]
                            # Ensure correct device/dtype (no grad-blocking ops!)
                            A = A.to(device=y.device, dtype=y.dtype)
                            B = B.to(device=y.device, dtype=y.dtype)
                            # LoRA term: (x @ A^T) @ B^T
                            # F.linear(x, A^T) expects A as [r,in] -> weight [r,in] means x@[r,in]^T => x@A^T -> [*, r]
                            lora_down = F.linear(x, A)  # shape [..., r]
                            lora_up = F.linear(lora_down, B)  # shape [..., out]
                            y = y + self.scale * lora_up
                        return y

                    return wrapped_forward

                mod.forward = make_forward(self._orig_forwards[mod], name)

    def detach(self):
        # Restore original forwards
        for mod, orig in self._orig_forwards.items():
            mod.forward = orig
        self._orig_forwards.clear()
        self._name_to_module.clear()
        self._hooks.clear()
        self._deltas = {}
        self._enabled = True

    def set_deltas(self, deltas: DeltaDict):
        """
        Provide current LoRA deltas per layer for the next forwards.
        Each A:[r,in], B:[out,r]. These can require_grad.
        """
        self._deltas = deltas

    def clear_deltas(self):
        self._deltas = {}

    def enable(self):
        self._enabled = True

    def disable(self):
        self._enabled = False

    @contextmanager
    def enabled(self, flag: bool):
        prev = self._enabled
        self._enabled = flag
        try:
            yield
        finally:
            self._enabled = prev


# utils_layer_specs.py
import re
import torch
import torch.nn as nn
from typing import Dict, Iterable, Tuple

LayerSpecs = Dict[str, Tuple[nn.Module, torch.Size, torch.Size]]


def _matches(name: str, targets: Iterable[str]) -> bool:
    """
    Match if `name` ends with any target token, or equals it.
    This mirrors common PEFT target_modules conventions.
    """
    for t in targets:
        if name == t or name.endswith(f".{t}") or name.endswith(t):
            return True
    return False


def generate_layer_specs(
    model: nn.Module,
    target_modules: Iterable[str],
    default_rank: int,
    ranks=None,
    name_prefix="",
) -> LayerSpecs:
    """
    Build layer specs from a base HF model by finding target nn.Linear modules.
    Shapes follow LoRA: A:[r, in], B:[out, r].

    Args:
        model: HF transformer (e.g., model.lm or model.lm.base_model)
        target_modules: list of component names to match in module names
                        e.g. ["q_proj","k_proj","v_proj","o_proj"]
        default_rank: rank to use if `ranks` is not provided
        ranks: optional list of ranks, one per target_modules entry
        name_prefix: optional prefix for spec keys

    Returns:
        dict[name] -> (module, shape_A, shape_B)
    """
    targets = list(target_modules)
    if ranks is not None:
        if len(ranks) != len(targets):
            raise ValueError(
                f"'ranks' must have the same length as 'target_modules' "
                f"(got {len(ranks)} vs {len(targets)})"
            )
        comp_to_rank = {t: r for t, r in zip(targets, ranks)}
    else:
        comp_to_rank = {t: default_rank for t in targets}

    specs: LayerSpecs = {}
    for name, mod in model.named_modules():
        for t in targets:
            if _matches(name, [t]):
                r = comp_to_rank[t]
                if not isinstance(r, int) or r <= 0:
                    raise ValueError(f"Invalid rank {r} for component '{t}'")
                if isinstance(mod, nn.Linear):
                    out_f, in_f = mod.weight.shape  # [out, in]
                    shape_A = torch.Size([r, in_f])
                    shape_B = torch.Size([out_f, r])
                    specs[f"{name_prefix}{name}"] = (mod, shape_A, shape_B)
                break  # stop checking other targets for this module

    if not specs:
        raise ValueError(
            f"No target layers found for {list(targets)}. "
            f"Check names via: [n for n,_ in model.named_modules()]."
        )

    return specs
