import torch
import torch.nn as nn
import torch.nn.functional as F
from peft import LoraConfig


import math
from metrics import *

from typing import Iterable, Dict, Tuple, Mapping, Sequence, Optional, Union


from transformers import StoppingCriteria, StoppingCriteriaList
from functional_lora import FunctionalLoRAInjector, generate_layer_specs

Tensor = torch.Tensor
DeltaDict = Dict[str, Tuple[Tensor, Tensor]]  # name -> (A:[r,in], B:[out,r])


class MetaGenerator(nn.Module):
    """
    Uses a PEFT LoRA adapter ('generator') on the LM to extract support-conditioned features,
    then predicts LoRA deltas (A,B) for target layers. Intended to be used with a
    FunctionalLoRAInjector so that we never rebind Parameters (gradients flow cleanly).
    """

    def __init__(
        self,
        model: nn.Module,  # wrapper with .lm (HF CausalLM)
        layer_specs: Dict[
            str, Tuple[nn.Module, torch.Size, torch.Size]
        ],  # name -> (base_linear, shape_A, shape_B)
        gen_config: LoraConfig,
        input_dim: int,
        hidden_dim: int = 512,
        num_layers: int = 2,
        gen_adapter: str = "generator",
    ):
        super().__init__()
        self.model = model
        self.gen_adapter = gen_adapter

        # Install the generator LoRA on the LM (used only to condition features)
        model.lm.add_adapter(gen_config, adapter_name=gen_adapter)

        # Shapes for produced deltas
        self.layer_specs = layer_specs
        self.param_shapes: Dict[str, Tuple[torch.Size, torch.Size]] = {}
        total_A, total_B = 0, 0
        for name, (_, shape_A, shape_B) in layer_specs.items():
            self.param_shapes[name] = (shape_A, shape_B)
            total_A += shape_A.numel()  # A:[r,in]
            total_B += shape_B.numel()  # B:[out,r]

        # Encoder over last hidden states from LM
        self.encoder = nn.GRU(
            input_dim, hidden_dim, batch_first=True, num_layers=num_layers
        )

        # Small trunk + split heads so we can init A/B differently (classic LoRA init)
        self.trunk = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.ReLU())
        self.head_A = nn.Linear(hidden_dim, total_A)  # A ~ Kaiming
        self.head_B = nn.Linear(hidden_dim, total_B)  # B = 0
        self._init_heads()

    def _init_heads(self):
        # Classic LoRA init: A ~ Kaiming (so x A^T ≠ 0), B = 0 (LoRA output starts at 0; B learns first)
        nn.init.kaiming_uniform_(self.head_A.weight, a=0.0, nonlinearity="relu")
        nn.init.zeros_(self.head_A.bias)

        nn.init.zeros_(self.head_B.weight)
        nn.init.zeros_(self.head_B.bias)

    def forward(self, support_inputs: Tensor, reduce: str = "first") -> DeltaDict:
        """
        support_inputs: token IDs for support text, shape [B, T]
        returns: dict[name] -> (A:[r,in], B:[out,r]) to feed into the functional injector

        Note: We keep adapters + autograd ON here so grads flow into the generator LoRA
              and the GRU/heads. We purposely DO NOT use no_grad() on the LM forward.
        """
        # 1) Condition the LM with the generator LoRA
        self.model.lm.set_adapter(self.gen_adapter)
        self.model.lm.train()  # keep dropout etc. in the generator LoRA active if configured

        # 2) Get last hidden states from the LM with generator adapter active
        out = self.model.lm(
            input_ids=support_inputs,
            output_hidden_states=True,
            return_dict=True,
        )
        hs = out.hidden_states[-1]  # [B, T, D_model]

        # 3) Encode support (pool via GRU final state)
        _, h_n = self.encoder(hs)  # h_n: [num_layers, B, hidden_dim]
        h = h_n[-1]  # [B, hidden_dim]
        z = self.trunk(h)  # [B, hidden_dim]

        # 4) Predict flattened A/B
        flat_A = self.head_A(z)  # [B, total_A]
        flat_B = self.head_B(z)  # [B, total_B]

        Bsz = flat_A.size(0)
        if reduce == "mean" and Bsz > 1:
            flat_A = flat_A.mean(dim=0, keepdim=True)
            flat_B = flat_B.mean(dim=0, keepdim=True)
            Bsz = 1

        # 5) Unflatten per-layer tensors
        deltas: DeltaDict = {}
        pa = pb = 0

        # Infer a reference dtype/device from any target layer (for clean casting)
        any_name = next(iter(self.layer_specs))
        ref_module, _, _ = self.layer_specs[any_name]
        ref_dtype = next(ref_module.parameters()).dtype
        ref_device = next(ref_module.parameters()).device

        for name, (shape_A, shape_B) in self.param_shapes.items():
            size_A = shape_A.numel()
            size_B = shape_B.numel()

            A_chunk = flat_A[:, pa : pa + size_A].view(Bsz, *shape_A)
            B_chunk = flat_B[:, pb : pb + size_B].view(Bsz, *shape_B)

            # Reduce batch if needed
            A_one = A_chunk[0] if Bsz > 1 else A_chunk.squeeze(0)
            B_one = B_chunk[0] if Bsz > 1 else B_chunk.squeeze(0)

            # Match dtype/device of target linear weights (helps mixed precision)
            deltas[name] = (
                A_one.to(dtype=ref_dtype, device=ref_device),
                B_one.to(dtype=ref_dtype, device=ref_device),
            )

            pa += size_A
            pb += size_B

        return deltas

    # ---- training plumbing ----

    def set_trainable_params(self):
        """
        Make ONLY the generator LoRA + GRU/trunk/heads trainable.
        Freeze everything else on model.lm to avoid accidental updates.
        """
        # Freeze entire LM
        self.model.lm.requires_grad_(False)

        # Unfreeze the generator LoRA weights
        for module, *_ in self.layer_specs.values():
            if hasattr(module, "lora_A") and self.gen_adapter in module.lora_A:
                module.lora_A[self.gen_adapter].weight.requires_grad_(True)
            if hasattr(module, "lora_B") and self.gen_adapter in module.lora_B:
                module.lora_B[self.gen_adapter].weight.requires_grad_(True)
            if hasattr(module, "lora_embedding_A") and self.gen_adapter in getattr(
                module, "lora_embedding_A", {}
            ):
                module.lora_embedding_A[self.gen_adapter].weight.requires_grad_(True)
            if hasattr(module, "lora_embedding_B") and self.gen_adapter in getattr(
                module, "lora_embedding_B", {}
            ):
                module.lora_embedding_B[self.gen_adapter].weight.requires_grad_(True)

        # Unfreeze generator heads/encoder
        self.encoder.requires_grad_(True)
        self.trunk.requires_grad_(True)
        self.head_A.requires_grad_(True)
        self.head_B.requires_grad_(True)

    def get_opt_params(self):
        """
        Return ONLY the parameters we want the optimizer to update.
        """
        opt = []
        # generator LoRA tensors
        for module, *_ in self.layer_specs.values():
            if hasattr(module, "lora_A") and self.gen_adapter in module.lora_A:
                opt.append(module.lora_A[self.gen_adapter].weight)
            if hasattr(module, "lora_B") and self.gen_adapter in module.lora_B:
                opt.append(module.lora_B[self.gen_adapter].weight)
            if hasattr(module, "lora_embedding_A") and self.gen_adapter in getattr(
                module, "lora_embedding_A", {}
            ):
                opt.append(module.lora_embedding_A[self.gen_adapter].weight)
            if hasattr(module, "lora_embedding_B") and self.gen_adapter in getattr(
                module, "lora_embedding_B", {}
            ):
                opt.append(module.lora_embedding_B[self.gen_adapter].weight)

        # generator network params
        opt += list(self.encoder.parameters())
        opt += list(self.trunk.parameters())
        opt += list(self.head_A.parameters())
        opt += list(self.head_B.parameters())
        return opt


class AdapterGenerator(nn.Module):
    def __init__(self, layer_specs, input_dim=3072, hidden_dim=512, num_layers=2):
        super().__init__()
        self.layer_specs = layer_specs

        # Tally A/B sizes independently
        self.param_shapes: Dict[str, Tuple[torch.Size, torch.Size]] = {}
        total_A, total_B = 0, 0
        for name, (_, shape_A, shape_B) in layer_specs.items():
            self.param_shapes[name] = (shape_A, shape_B)
            total_A += shape_A[0] * shape_A[1]  # A:[r, in]
            total_B += shape_B[0] * shape_B[1]  # B:[out, r]

        self.encoder = nn.GRU(
            input_dim, hidden_dim, batch_first=True, num_layers=num_layers
        )

        # Shared trunk
        self.trunk = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
        )
        # Separate heads so we can init differently
        self.head_A = nn.Linear(hidden_dim, total_A)  # Kaiming
        self.head_B = nn.Linear(hidden_dim, total_B)  # Zero

        self._init_heads()

    def _init_heads(self):
        # A head: Kaiming (He) init -> nonzero A so (x A^T) != 0 initially
        nn.init.normal_(self.head_A.weight, mean=0.0, std=1e-5)
        # nn.init.zeros_(self.head_A.weight)
        nn.init.zeros_(self.head_A.bias)

        # B head: strict zeros → LoRA output starts at 0; B learns first
        nn.init.zeros_(self.head_B.weight)
        nn.init.zeros_(self.head_B.bias)

    def forward(self, inputs, model):
        # Get last hidden from base model (feature extractor)
        model.eval()
        with torch.no_grad():
            out = model.lm.base_model(
                inputs, output_hidden_states=True, return_dict=True
            )
            hs = out.hidden_states[-1]  # [B, T, D]

        # Encode support
        _, h_n = self.encoder(hs)  # h_n: [num_layers, B, hidden_dim]
        h = h_n[-1]  # [B, hidden_dim]
        z = self.trunk(h)

        flat_A = self.head_A(z)  # [B, total_A]
        flat_B = self.head_B(z)  # [B, total_B]  (all zeros at init)

        Bsz = flat_A.size(0)
        deltas = {}
        pa = pb = 0
        for name, (shape_A, shape_B) in self.param_shapes.items():
            size_A = shape_A[0] * shape_A[1]
            size_B = shape_B[0] * shape_B[1]

            A_chunk = flat_A[:, pa : pa + size_A].view(Bsz, *shape_A)
            B_chunk = flat_B[:, pb : pb + size_B].view(Bsz, *shape_B)

            # If batching, pick index 0 or reduce; here we just use the first
            A_one = A_chunk[0] if Bsz > 1 else A_chunk.squeeze(0)
            B_one = B_chunk[0] if Bsz > 1 else B_chunk.squeeze(0)
            deltas[name] = (A_one, B_one)

            pa += size_A
            pb += size_B

        return deltas

    def set_trainable_params(self):
        self.encoder.requires_grad_(True)
        self.trunk.requires_grad_(True)
        self.head_A.requires_grad_(True)
        self.head_B.requires_grad_(True)

    def get_opt_params(self):
        return self.parameters()


class LinearAdapterGeneratorSmall(nn.Module):
    def __init__(self, layer_specs, input_dim, d_z=64, a_std=1e-3):
        super().__init__()
        self.layer_specs = layer_specs
        self.a_std = float(a_std)

        self.param_shapes = {}
        total_A, total_B = 0, 0
        for name, (_, shape_A, shape_B) in layer_specs.items():
            self.param_shapes[name] = (shape_A, shape_B)
            total_A += shape_A.numel()
            total_B += shape_B.numel()

        # shared bottleneck
        self.encoder = nn.Linear(input_dim, d_z, bias=False)

        # small heads from latent
        self.head_A = nn.Linear(d_z, total_A, bias=False)
        self.head_B = nn.Linear(d_z, total_B, bias=False)
        self._init_heads()

    def _init_heads(self):
        nn.init.normal_(self.head_A.weight, mean=0.0, std=self.a_std)
        nn.init.zeros_(self.head_B.weight)
        # encoder can be near-identity-ish; small init works fine
        nn.init.normal_(self.encoder.weight, mean=0.0, std=1e-2)

    @staticmethod
    def _mean_pool(hs, mask=None):
        if mask is None:
            return hs.mean(dim=1)
        mask = mask.to(hs.dtype).unsqueeze(-1)
        return (hs * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1.0)

    def forward(self, support_inputs, model):
        model.eval()
        with torch.no_grad():
            out = model.lm.base_model(
                input_ids=support_inputs,
                output_hidden_states=True,
                return_dict=True,
            )
            pooled = self._mean_pool(out.hidden_states[-1], mask=None)  # [B, D]

        z = self.encoder(pooled)  # [B, d_z]
        flat_A = self.head_A(z)  # [B, total_A]
        flat_B = self.head_B(z)  # [B, total_B]

        Bsz = flat_A.size(0)
        deltas = {}
        pa = pb = 0

        any_name = next(iter(self.layer_specs))
        ref_mod, _, _ = self.layer_specs[any_name]
        ref_dtype = next(ref_mod.parameters()).dtype
        ref_device = next(ref_mod.parameters()).device

        for name, (shape_A, shape_B) in self.param_shapes.items():
            size_A, size_B = shape_A.numel(), shape_B.numel()
            A_chunk = flat_A[:, pa : pa + size_A].view(Bsz, *shape_A)
            B_chunk = flat_B[:, pb : pb + size_B].view(Bsz, *shape_B)
            A_one = A_chunk[0] if Bsz > 1 else A_chunk.squeeze(0)
            B_one = B_chunk[0] if Bsz > 1 else B_chunk.squeeze(0)
            deltas[name] = (
                A_one.to(dtype=ref_dtype, device=ref_device),
                B_one.to(dtype=ref_dtype, device=ref_device),
            )
            pa += size_A
            pb += size_B
        return deltas

    def set_trainable_params(self):
        for m in [self.encoder, self.head_A, self.head_B]:
            m.requires_grad_(True)

    def get_opt_params(self):
        return list(self.parameters())


class BasePack(nn.Module):
    """Holds U_in and U_out as Parameters."""

    def __init__(
        self,
        k_in: int,
        in_dim: int,
        out_dim: int,
        k_out: int,
        train_bases: bool,
        device=None,
        dtype=None,
    ):
        super().__init__()
        self.U_in = nn.Parameter(
            torch.empty(k_in, in_dim, device=device, dtype=dtype),
            requires_grad=train_bases,
        )
        self.U_out = nn.Parameter(
            torch.empty(out_dim, k_out, device=device, dtype=dtype),
            requires_grad=train_bases,
        )
        self._orthogonal_init_safe(self.U_in)
        self._orthogonal_init_safe(self.U_out)

    @staticmethod
    def _orthogonal_init_safe(param: torch.Tensor) -> None:
        """Orthogonal init via fp32 temp to avoid bf16 CUDA QR limitation."""
        with torch.no_grad():
            tmp = torch.empty_like(param, dtype=torch.float32, device=param.device)
            nn.init.orthogonal_(tmp)
            param.copy_(tmp.to(dtype=param.dtype))


def _sanitize(name: str) -> str:
    # ModuleDict keys cannot contain ".", so replace with a safe token
    return name.replace(".", "__")


class BasisFactorizedAdapterGenerator(nn.Module):
    def __init__(
        self,
        layer_specs: Dict[
            str, Tuple[nn.Module, torch.Size, torch.Size]
        ],  # name -> (module, A_shape[r,in], B_shape[out,r])
        input_dim: int,
        d_z: int = 256,
        k_in: int = 64,
        k_out: int = 64,
        a_std: float = 1e-3,
        train_bases: bool = False,
    ):
        super().__init__()
        self.layer_specs = layer_specs
        self.a_std = float(a_std)
        self.k_in = int(k_in)
        self.k_out = int(k_out)

        # Reference dtype/device from the first target module
        any_name = next(iter(self.layer_specs))
        ref_mod, A_shape_ref, B_shape_ref = self.layer_specs[any_name]
        ref_param = next(ref_mod.parameters())
        ref_dtype = ref_param.dtype
        ref_device = ref_param.device

        # Keep mapping original <-> sanitized names
        self.name2key: Dict[str, str] = {}
        self.key2name: Dict[str, str] = {}

        # Register bases per layer as submodules
        self.bases = nn.ModuleDict()
        self.core_shapes: Dict[str, Tuple[int, int, int, int]] = (
            {}
        )  # name -> (r, k_in, k_out, r)

        total_core_A = 0
        total_core_B = 0

        for name, (_, A_shape, B_shape) in self.layer_specs.items():
            r, in_dim = A_shape
            out_dim, rB = B_shape
            assert r == rB, f"LoRA ranks mismatch for {name}: {r} vs {rB}"

            key = _sanitize(name)
            if key in self.bases:
                raise ValueError(
                    f"Sanitized key collision for layer name '{name}' -> '{key}'"
                )

            pack = BasePack(
                self.k_in,
                in_dim,
                out_dim,
                self.k_out,
                train_bases=train_bases,
                device=ref_device,
                dtype=ref_dtype,
            )
            self.bases[key] = pack

            self.name2key[name] = key
            self.key2name[key] = name
            self.core_shapes[name] = (r, self.k_in, self.k_out, r)

            total_core_A += r * self.k_in  # C_A size
            total_core_B += self.k_out * r  # C_B size

        # Shared latent + small heads for cores (not full mats)
        self.encoder = nn.Linear(
            input_dim, d_z, bias=False, device=ref_device, dtype=ref_dtype
        )
        self.head_core_A = nn.Linear(
            d_z, total_core_A, bias=False, device=ref_device, dtype=ref_dtype
        )
        self.head_core_B = nn.Linear(
            d_z, total_core_B, bias=False, device=ref_device, dtype=ref_dtype
        )

        # inits
        nn.init.normal_(self.encoder.weight, mean=0.0, std=1e-2)
        nn.init.normal_(self.head_core_A.weight, mean=0.0, std=self.a_std)
        nn.init.zeros_(self.head_core_B.weight)  # classic: start with zero output

    @staticmethod
    def _mean_pool(hs: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        if mask is None:
            return hs.mean(dim=1)
        mask = mask.to(hs.dtype).unsqueeze(-1)
        return (hs * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1.0)

    def forward(
        self, support_inputs: torch.Tensor, model
    ) -> Dict[str, Tuple[torch.Tensor, torch.Tensor]]:
        # extract pooled support features without training the base LM
        model.eval()
        with torch.no_grad():
            out = model.lm.base_model(
                input_ids=support_inputs,
                output_hidden_states=True,
                return_dict=True,
            )
            pooled = self._mean_pool(out.hidden_states[-1], mask=None)  # [B, D]

        z = self.encoder(pooled)  # [B, d_z]
        flat_CA = self.head_core_A(z)  # [B, sum r*k_in]
        flat_CB = self.head_core_B(z)  # [B, sum k_out*r]

        Bsz = z.size(0)
        deltas: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
        pa = pb = 0

        # Reference device/dtype from any target module
        any_name = next(iter(self.layer_specs))
        ref_mod, _, _ = self.layer_specs[any_name]
        ref_param = next(ref_mod.parameters())
        ref_dtype = ref_param.dtype
        ref_device = ref_param.device

        for name, (_, A_shape, B_shape) in self.layer_specs.items():
            r, in_dim = A_shape
            out_dim, _ = B_shape
            size_CA = r * self.k_in
            size_CB = self.k_out * r

            CA = flat_CA[:, pa : pa + size_CA].view(Bsz, r, self.k_in)
            CB = flat_CB[:, pb : pb + size_CB].view(Bsz, self.k_out, r)
            pa += size_CA
            pb += size_CB

            CA = CA[0] if Bsz > 1 else CA.squeeze(0)  # [r, k_in]
            CB = CB[0] if Bsz > 1 else CB.squeeze(0)  # [k_out, r]

            key = self.name2key[name]
            pack: BasePack = self.bases[key]

            # Reconstruct A and B with bases
            # Ensure math happens on the correct device/dtype
            U_in = pack.U_in.to(device=ref_device, dtype=ref_dtype)
            U_out = pack.U_out.to(device=ref_device, dtype=ref_dtype)

            A = CA.to(device=ref_device, dtype=ref_dtype) @ U_in  # [r, in_dim]
            B = U_out @ CB.to(device=ref_device, dtype=ref_dtype)  # [out_dim, r]

            deltas[name] = (A, B)

        return deltas

    def set_trainable_params(self):
        self.encoder.requires_grad_(True)
        self.head_core_A.requires_grad_(True)
        self.head_core_B.requires_grad_(True)
        # Optional: learn bases too (if train_bases=True at init they already require grad)
        # for pack in self.bases.values():
        #     pack.U_in.requires_grad_(True)
        #     pack.U_out.requires_grad_(True)

    def get_opt_params(self):
        return list(self.parameters())
