from __future__ import annotations
import math
import re
from dataclasses import dataclass, asdict
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F


# ---------- Config ----------

TargetSpec = Union[
    Iterable[str],  # e.g. ("o_proj", "up_proj", "down_proj")
    str,  # single name
    re.Pattern,  # compiled regex on *full module name*
    Callable[[str, nn.Module], bool],  # predicate(full_name, module) -> bool
]


@dataclass
class LoraConfigLite:
    r: int = 8
    alpha: float = 16.0
    dropout: float = 0.0
    # DEFAULTS: your requested targets
    target_modules: TargetSpec = ("o_proj", "up_proj", "down_proj")
    init: str = "zeros"  # "zeros" | "kaiming"
    fan_in_fan_out: bool = False  # rare: old GPT/Conv1D style
    merge_adapter: bool = (
        False  # in-place merge into base (export); not used during training
    )

    @classmethod
    def from_dict(cls, d: Dict) -> "LoraConfigLite":
        return cls(**d)

    def to_dict(self) -> Dict:
        return asdict(self)


# ---------- Core Layer ----------


class LoRALinear(nn.Module):
    """
    Drop-in replacement for nn.Linear that supports:
      - multiple named LoRA adapters
      - per-adapter trainability
      - summing multiple active adapters (adapter arithmetic)
      - apply_delta_on_base: set adapter = base + delta (useful for inference/export)
      - runtime deltas (dA, dB) attached per batch to keep autograd graph
    Base weight/bias are cloned and kept as parameters here (freeze them at model level).
    """

    def __init__(self, base: nn.Linear, fan_in_fan_out: bool = False):
        super().__init__()
        self.in_features = base.in_features
        self.out_features = base.out_features
        self.fan_in_fan_out = fan_in_fan_out

        # clone base params into this module; we operate on these
        self.weight = nn.Parameter(base.weight.detach().clone())
        self.bias = (
            nn.Parameter(base.bias.detach().clone()) if base.bias is not None else None
        )

        # neuter original module (avoid training base twice / keep HF happy)
        base.weight = nn.Parameter(base.weight.detach())
        base.weight.requires_grad = False
        if base.bias is not None:
            base.bias = nn.Parameter(base.bias.detach())
            base.bias.requires_grad = False

        # adapters: name -> dict(A,B,alpha,dropout,trainable,runtime_delta)
        self.adapters: Dict[
            str,
            Dict[
                str,
                Union[
                    nn.Parameter,
                    float,
                    nn.Module,
                    bool,
                    Optional[Tuple[torch.Tensor, torch.Tensor]],
                ],
            ],
        ] = {}
        self.active_adapters: List[str] = []

    # ---- Adapter management ----

    def add_adapter(
        self,
        name: str,
        r: int,
        alpha: float = 1.0,
        dropout: float = 0.0,
        init: str = "zeros",
        trainable: bool = True,
        device: Optional[torch.device] = None,
        dtype: Optional[torch.dtype] = None,
    ):
        if name in self.adapters:
            raise ValueError(f"Adapter '{name}' already exists.")

        A = nn.Parameter(torch.empty(self.in_features, r, device=device, dtype=dtype))
        B = nn.Parameter(torch.empty(r, self.out_features, device=device, dtype=dtype))

        if init == "zeros":
            nn.init.zeros_(A)
            nn.init.zeros_(B)
        elif init == "kaiming":
            nn.init.kaiming_uniform_(A, a=math.sqrt(5))
            nn.init.zeros_(B)
        else:
            raise ValueError(f"Unknown init '{init}'")

        A.requires_grad = trainable
        B.requires_grad = trainable
        drop = nn.Dropout(p=dropout) if dropout and dropout > 0 else nn.Identity()

        self.adapters[name] = dict(
            A=A,
            B=B,
            alpha=float(alpha),
            dropout=drop,
            trainable=bool(trainable),
            runtime_delta=None,  # holds (dA, dB) tensors for this adapter; do not detach
        )

    def remove_adapter(self, name: str):
        if name not in self.adapters:
            return
        self.active_adapters = [n for n in self.active_adapters if n != name]
        del self.adapters[name]

    def set_adapter(self, names: Optional[Union[str, Iterable[str]]]):
        if names is None:
            self.active_adapters = []
            return
        if isinstance(names, str):
            if names not in self.adapters:
                raise KeyError(f"Adapter '{names}' not found.")
            self.active_adapters = [names]
        else:
            names = list(names)
            for n in names:
                if n not in self.adapters:
                    raise KeyError(f"Adapter '{n}' not found.")
            self.active_adapters = names

    def list_adapters(self) -> List[str]:
        return list(self.adapters.keys())

    def set_trainable(self, name: str, trainable: bool):
        A = self.adapters[name]["A"]
        B = self.adapters[name]["B"]
        assert isinstance(A, nn.Parameter) and isinstance(B, nn.Parameter)
        A.requires_grad = trainable
        B.requires_grad = trainable
        self.adapters[name]["trainable"] = bool(trainable)

    def get_adapter_weights(self, name: str) -> Tuple[torch.Tensor, torch.Tensor]:
        A = self.adapters[name]["A"]
        B = self.adapters[name]["B"]
        return A.detach().clone(), B.detach().clone()

    def set_adapter_weights(self, name: str, A: torch.Tensor, B: torch.Tensor):
        self.adapters[name]["A"].data.copy_(A)
        self.adapters[name]["B"].data.copy_(B)

    def apply_delta_on_base(
        self,
        name: str,
        base_A: torch.Tensor,
        base_B: torch.Tensor,
        dA: torch.Tensor,
        dB: torch.Tensor,
        trainable: bool = False,
    ):
        # Hard copy for inference/export; does not preserve graph.
        A = self.adapters[name]["A"]
        B = self.adapters[name]["B"]
        A.data.copy_(base_A + dA)
        B.data.copy_(base_B + dB)
        A.requires_grad = trainable
        B.requires_grad = trainable
        self.adapters[name]["trainable"] = bool(trainable)

    # ---- Runtime deltas (keep autograd graph) ----

    def set_runtime_delta(self, name: str, dA: torch.Tensor, dB: torch.Tensor):
        if name not in self.adapters:
            raise KeyError(f"Adapter '{name}' not found.")
        self.adapters[name]["runtime_delta"] = (dA, dB)

    def clear_runtime_delta(self, name: str):
        if name in self.adapters:
            self.adapters[name]["runtime_delta"] = None

    # ---- Forward ----

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = F.linear(x, self.weight, self.bias)
        if not self.active_adapters:
            return out

        for name in self.active_adapters:
            entry = self.adapters[name]
            A: torch.Tensor = entry["A"]  # (in, r)
            B: torch.Tensor = entry["B"]  # (r, out)
            rd = entry.get("runtime_delta", None)

            if rd is not None:
                dA, dB = rd
                A_eff = A + dA  # keep graph
                B_eff = B + dB
            else:
                A_eff, B_eff = A, B

            alpha: float = float(entry["alpha"])
            drop = entry["dropout"]
            r = A_eff.shape[1]
            scale = alpha / r
            out = out + (drop(x) @ A_eff @ B_eff) * scale

        return out


# ---------- Model injection & utils ----------


def _match_target(full_name: str, module: nn.Module, target: TargetSpec) -> bool:
    if callable(target):
        return bool(target(full_name, module))
    if isinstance(target, re.Pattern):
        return bool(target.search(full_name))
    if isinstance(target, str):
        leaf = full_name.split(".")[-1]
        return (leaf == target) or (target in full_name)
    targets = set(target)
    leaf = full_name.split(".")[-1]
    return (leaf in targets) or any(t in full_name for t in targets)


def _iter_target_linears(
    model: nn.Module, target: TargetSpec
) -> List[Tuple[str, nn.Linear]]:
    matches: List[Tuple[str, nn.Linear]] = []
    for full_name, module in model.named_modules():
        if isinstance(module, nn.Linear) and _match_target(full_name, module, target):
            matches.append((full_name, module))
    return matches


def _get_parent_and_leaf(model: nn.Module, full_name: str) -> Tuple[nn.Module, str]:
    parent = model
    parts = full_name.split(".")
    for p in parts[:-1]:
        parent = getattr(parent, p)
    return parent, parts[-1]


def inject_lora(
    model: nn.Module,
    config: LoraConfigLite,
    adapter_name: str,
    trainable: bool = True,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
) -> Dict[str, LoRALinear]:
    """
    Replace selected nn.Linear with LoRALinear, register 'adapter_name' in each.
    Returns {full_name: LoRALinear}.
    """
    mapping: Dict[str, LoRALinear] = {}
    for full_name, lin in _iter_target_linears(model, config.target_modules):
        parent, leaf = _get_parent_and_leaf(model, full_name)

        # Already wrapped?
        if isinstance(getattr(parent, leaf), LoRALinear):
            lora_lin: LoRALinear = getattr(parent, leaf)
        else:
            lora_lin = LoRALinear(lin, fan_in_fan_out=config.fan_in_fan_out)
            setattr(parent, leaf, lora_lin)

        lora_lin.add_adapter(
            name=adapter_name,
            r=config.r,
            alpha=config.alpha,
            dropout=config.dropout,
            init=config.init,
            trainable=trainable,
            device=device,
            dtype=dtype,
        )
        mapping[full_name] = lora_lin
    return mapping


def set_active_adapters(model: nn.Module, names: Optional[Union[str, Iterable[str]]]):
    """
    Enable a single adapter, a list (sums them), or disable all with None.
    Robust to layers that don't have the requested adapter(s): those layers get no active adapter.
    """
    for m in model.modules():
        if isinstance(m, LoRALinear):
            if names is None:
                m.set_adapter(None)
            elif isinstance(names, str):
                if names in m.adapters:
                    m.set_adapter(names)
                else:
                    # adapter not present on this layer → disable
                    m.set_adapter(None)
            else:
                # filter to adapters that exist on this layer
                requested = list(names)
                filtered = [n for n in requested if n in m.adapters]
                if filtered:
                    m.set_adapter(filtered)
                else:
                    m.set_adapter(None)


def list_all_adapters(model: nn.Module) -> Dict[str, List[str]]:
    out: Dict[str, List[str]] = {}
    for full_name, m in model.named_modules():
        if isinstance(m, LoRALinear):
            out[full_name] = m.list_adapters()
    return out


def freeze_base_model(model: nn.Module):
    for p in model.parameters():
        p.requires_grad = False


def extract_adapter_specs(
    model: nn.Module,
    adapter_name: str,
) -> Dict[str, Tuple[LoRALinear, Tuple[int, int], Tuple[int, int]]]:
    """
    Return {full_name: (module, shape_A, shape_B)} for 'adapter_name'.
    Shapes are (in_features, r) and (r, out_features).
    """
    specs = {}
    for full_name, module in model.named_modules():
        if isinstance(module, LoRALinear) and adapter_name in module.adapters:
            A: nn.Parameter = module.adapters[adapter_name]["A"]
            B: nn.Parameter = module.adapters[adapter_name]["B"]
            specs[full_name] = (module, tuple(A.shape), tuple(B.shape))
    return specs


def apply_adapter_deltas(
    layer_specs: Dict[str, Tuple[LoRALinear, Tuple[int, int], Tuple[int, int]]],
    deltas: Dict[str, Tuple[torch.Tensor, torch.Tensor]],
    adapter_name: str,
    base_adapter_weights: Dict[str, Tuple[torch.Tensor, torch.Tensor]],
    trainable: bool = False,
):
    """Hard set adapter weights = base + delta for each layer (use for inference/export)."""
    for name, (module, _, _) in layer_specs.items():
        dA, dB = deltas[name]
        baseA, baseB = base_adapter_weights[name]
        module.apply_delta_on_base(
            adapter_name, baseA, baseB, dA, dB, trainable=trainable
        )


def set_runtime_adapter_deltas(
    layer_specs: Dict[str, Tuple[LoRALinear, Tuple[int, int], Tuple[int, int]]],
    deltas: Dict[str, Tuple[torch.Tensor, torch.Tensor]],
    adapter_name: str,
):
    """Attach per-batch deltas (keeps autograd graph) across layers."""
    for name, (module, _, _) in layer_specs.items():
        dA, dB = deltas[name]  # generator outputs; do NOT detach
        module.set_runtime_delta(adapter_name, dA, dB)


def clear_runtime_adapter_deltas(
    layer_specs: Dict[str, Tuple[LoRALinear, Tuple[int, int], Tuple[int, int]]],
    adapter_name: str,
):
    for name, (module, _, _) in layer_specs.items():
        module.clear_runtime_delta(adapter_name)


def get_adapter_state_dict(
    model: nn.Module, adapter_name: str
) -> Dict[str, torch.Tensor]:
    """Collect { '<full_name>.A': tensor, '<full_name>.B': tensor } for a given adapter."""
    state: Dict[str, torch.Tensor] = {}
    for full_name, module in model.named_modules():
        if isinstance(module, LoRALinear) and adapter_name in module.adapters:
            A = module.adapters[adapter_name]["A"]
            B = module.adapters[adapter_name]["B"]
            state[f"{full_name}.A"] = A.detach().cpu()
            state[f"{full_name}.B"] = B.detach().cpu()
    return state


def load_adapter_state_dict(
    model: nn.Module,
    adapter_name: str,
    state: Dict[str, torch.Tensor],
    strict: bool = True,
):
    """Load A/B tensors into an existing adapter slot."""
    missing = []
    for full_name, module in model.named_modules():
        if isinstance(module, LoRALinear) and adapter_name in module.adapters:
            try:
                module.adapters[adapter_name]["A"].data.copy_(state[f"{full_name}.A"])
                module.adapters[adapter_name]["B"].data.copy_(state[f"{full_name}.B"])
            except KeyError:
                missing.append(full_name)
    if strict and missing:
        raise KeyError(f"Missing keys for layers: {missing}")


def merge_adapter_into_base(model: nn.Module, adapter_name: str):
    """
    (Optional) Irreversibly merge adapter weights into base weights:
    W <- W + (A @ B) * (alpha/r). Bias unchanged. Active adapters are not used here.
    """
    for module in model.modules():
        if isinstance(module, LoRALinear) and adapter_name in module.adapters:
            A = module.adapters[adapter_name]["A"]
            B = module.adapters[adapter_name]["B"]
            alpha = float(module.adapters[adapter_name]["alpha"])
            r = A.shape[1]
            delta = (A @ B) * (alpha / r)
            module.weight.data += delta.T if module.fan_in_fan_out else delta
