import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Tuple

from adapter_generator import apply_adapter_deltas

MEM_TOKEN = "<|MEM|>"
RECON_TOKEN = "<|RECON|>"


# --------------------------------------------------------------------------
# Special-tokens setup
# --------------------------------------------------------------------------
def ensure_token(tokenizer, model_lm, tok: str) -> int:
    """Add `tok` to tokenizer if missing and resize model embeddings. Returns token_id."""
    if tok not in tokenizer.get_vocab():
        tokenizer.add_special_tokens({"additional_special_tokens": [tok]})
        model_lm.resize_token_embeddings(len(tokenizer))
    return tokenizer.convert_tokens_to_ids(tok)


# --------------------------------------------------------------------------
# Linear memory generator: z = F(h), Δ = W_mem z   (bias-free for exact additivity)
# --------------------------------------------------------------------------
class MemoryAdapterGenerator(nn.Module):
    def __init__(self, d_hidden: int, k_latent: int, D_flat: int):
        """
        d_hidden: pooled LM hidden size (e.g., model.lm.config.hidden_size)
        k_latent: latent dimension (z)
        D_flat:   total flattened LoRA delta dimension over all target modules
        """
        super().__init__()
        self.F = nn.Sequential(
            nn.Linear(d_hidden, 512),
            nn.ReLU(),
            nn.Linear(512, k_latent),
            nn.ReLU(),
        )
        # Bias-free linear head (exact additivity in adapter space)
        self.W_mem = nn.Linear(k_latent, D_flat, bias=False)

    def forward(self, pooled_hidden: torch.Tensor) -> torch.Tensor:
        """
        pooled_hidden: [B, d_hidden]  (mean-pooled last hidden)
        returns Δ_flat: [B, D_flat]
        """
        z = self.F(pooled_hidden)
        delta_flat = self.W_mem(z)  # bias=False ⇒ linear/additive
        return delta_flat



def kd_loss(student_logits, teacher_logits, mask: torch.Tensor, T: float = 2.0):
    """
    KL( teacher || student ) averaged over masked positions.
    logits: [B, L, V], mask: [B, L] (1=keep)
    """
    s = torch.log_softmax(student_logits / T, dim=-1)
    t = torch.softmax(teacher_logits / T, dim=-1)
    kl = torch.sum(t * (torch.log(t + 1e-8) - s), dim=-1)  # [B, L]
    kl = (kl * mask).sum() / (mask.sum() + 1e-8)
    return (T**2) * kl


def make_recon_batch(
    tokenizer, target_text_ids: torch.Tensor, device, max_len: int = None
):
    """
    Teacher-forced reconstruction: inputs=[RECON]+target[:-1], labels=target
    target_text_ids: [B, T]
    """
    rid = tokenizer.convert_tokens_to_ids(RECON_TOKEN)
    tgt = target_text_ids if max_len is None else target_text_ids[:, :max_len]
    bos = torch.full((tgt.size(0), 1), rid, dtype=tgt.dtype, device=device)
    inp = torch.cat([bos, tgt[:, :-1]], dim=1) if tgt.size(1) > 0 else bos
    labels = tgt.contiguous()
    return inp, labels


def build_teacher_ids(
    tokenizer,
    mem_texts: List[str],
    q_ids: torch.Tensor,
    device,
    mem_token_id: int,
    max_len: int = None,
):
    """
    Prepend <|MEM|> memory text to each query to form teacher inputs.
    q_ids: [B, Lq]
    returns: LongTensor [B, L_teacher]
    """
    B = len(mem_texts)
    mem_bos = torch.full((B, 1), mem_token_id, dtype=q_ids.dtype, device=device)
    mem_ids_list = tokenizer(
        mem_texts, return_tensors="pt", padding=True, truncation=True
    ).input_ids.to(device)
    # Align B in case tokenizer batched returns slightly different shape
    if mem_ids_list.size(0) != B:
        raise ValueError("Tokenizer returned unexpected batch size for memory texts.")
    teacher = torch.cat([mem_bos, mem_ids_list, q_ids], dim=1)
    if max_len is not None:
        teacher = teacher[:, :max_len]
    return teacher, mem_ids_list


def flatten_deltas(
    deltas: Dict[str, Tuple[torch.Tensor, torch.Tensor]], layer_specs
) -> torch.Tensor:
    """
    Flatten dict[name] -> (ΔA, ΔB) to a single [D_flat] vector, order == layer_specs iteration.
    """
    parts = []
    for name, (_, shape_A, shape_B) in layer_specs.items():
        A, B = deltas[name]
        parts.append(A.reshape(-1))
        parts.append(B.reshape(-1))
    return torch.cat(parts, dim=0)  # [D_flat]


def unflatten_delta_flat(
    delta_flat: torch.Tensor, layer_specs
) -> Dict[str, Tuple[torch.Tensor, torch.Tensor]]:
    """
    Inverse of flatten: takes [D_flat] or [B, D_flat] → dict[name] -> (ΔA, ΔB)
    Supports B=1 or B>1. Returns batched tensors if B>1.
    """
    batched = delta_flat.dim() == 2
    if not batched:
        delta_flat = delta_flat.unsqueeze(0)  # [1, D]
    B, D = delta_flat.size()
    ptr = 0
    out = {}
    for name, (_, shape_A, shape_B) in layer_specs.items():
        sA = shape_A[0] * shape_A[1]
        sB = shape_B[0] * shape_B[1]
        chunk = delta_flat[:, ptr : ptr + sA + sB]
        A = chunk[:, :sA].view(B, *shape_A)
        Bm = chunk[:, sA:].view(B, *shape_B)
        out[name] = (A if batched else A.squeeze(0), Bm if batched else Bm.squeeze(0))
        ptr += sA + sB
    return out


# --------------------------------------------------------------------------
# Training step for one batch (memory distill + (optional) CE + reconstruction)
# --------------------------------------------------------------------------
def train_memory_batch(
    batch: Dict,
    model,
    tokenizer,
    generator: MemoryAdapterGenerator,
    layer_specs,
    base_adapter_weights,
    student_adapter_name: str,
    kd_T: float = 2.0,
    lambda_ce: float = 0.0,
    lambda_recon: float = 1.0,
    max_teacher_len: int = None,
    max_recon_len: int = None,
):
    """
    batch: {
      "memory_text": List[str] (len=B),
      "query_ids":   LongTensor [B, Lq],
      "labels":      LongTensor [B, Lq] or None  (use -100 for ignore)
    }
    """
    device = next(model.parameters()).device
    B = len(batch["memory_text"])
    q_ids = batch["query_ids"].to(device)  # [B, Lq]
    labels = batch.get("labels", None)
    if labels is not None:
        labels = labels.to(device)


    mem_tok_id = ensure_token(tokenizer, model.lm, MEM_TOKEN)
    _recon_tok_id = ensure_token(tokenizer, model.lm, RECON_TOKEN)

    teacher_ids, mem_ids = build_teacher_ids(
        tokenizer,
        batch["memory_text"],
        q_ids,
        device,
        mem_tok_id,
        max_len=max_teacher_len,
    )  # teacher on [<|MEM|> m] + q
    with torch.no_grad():
        # Run the base LM (no adapters) as the teacher
        t_out = model.lm.base_model(teacher_ids, return_dict=True)
        # Align teacher logits to the last |q| positions (behavior on the query tokens)
        Lq = q_ids.size(1)
        logits_T = t_out.logits[:, -Lq:, :]  # [B, Lq, V]

    # --- STUDENT: build Δ(m) and apply once PER-EXAMPLE (or grouped by identical memories) ---
    # Encode memory => pooled hidden
    with torch.no_grad():
        mem_out = model.lm.base_model(
            mem_ids, output_hidden_states=True, return_dict=True
        )
        H = mem_out.hidden_states[-1]  # [B, Tm, d]
        h = H.mean(dim=1)  # [B, d_hidden]
    delta_flat_batch = generator(h)  # [B, D_flat], linear & bias-free
    # Apply per example (LoRA weights are not batched; apply sequentially)
    # We do student forward per example to avoid weight clashes; small B is recommended.
    loss_kd, loss_ce, loss_recon = 0.0, 0.0, 0.0
    ce_fn = nn.CrossEntropyLoss(ignore_index=-100)

    for i in range(B):
        delta_i = unflatten_delta_flat(delta_flat_batch[i], layer_specs)
        # apply Δ(m_i)
        apply_adapter_deltas(
            layer_specs, delta_i, student_adapter_name, base_adapter_weights
        )
        model.set_adapter(student_adapter_name)

        # STUDENT forward on q_i
        s_out = model(q_ids[i : i + 1])  # [1, Lq, V]
        logits_S = s_out.logits  # [1, Lq, V]

        # KD on query positions
        mask = (q_ids[i : i + 1] != tokenizer.pad_token_id).float()
        loss_kd += kd_loss(logits_S, logits_T[i : i + 1], mask, T=kd_T)

        # Optional CE on gold labels
        if labels is not None:
            loss_ce += ce_fn(
                logits_S.view(-1, logits_S.size(-1)), labels[i : i + 1].view(-1)
            )

        # Reconstruction: memory m_i from <|RECON|>
        rinp, rlab = make_recon_batch(
            tokenizer, mem_ids[i : i + 1], device, max_len=max_recon_len
        )
        r_out = model(rinp)
        loss_recon += ce_fn(r_out.logits.view(-1, r_out.logits.size(-1)), rlab.view(-1))

    loss_kd = loss_kd / B
    loss_ce = (loss_ce / B) if labels is not None else 0.0
    loss_recon = loss_recon / B

    total = loss_kd + lambda_ce * loss_ce + lambda_recon * loss_recon
    return total, {
        "kd": float(loss_kd),
        "ce": float(loss_ce),
        "recon": float(loss_recon),
    }


def train_memory_epoch(
    dataloader,
    model,
    tokenizer,
    generator: MemoryAdapterGenerator,
    layer_specs,
    base_adapter_weights,
    student_adapter_name: str = "student_mem",
    optimizer: torch.optim.Optimizer = None,
    kd_T: float = 2.0,
    lambda_ce: float = 0.0,
    lambda_recon: float = 1.0,
    max_teacher_len: int = None,
    max_recon_len: int = None,
    log_every: int = 50,
):
    model.train()
    generator.train()

    running = {"kd": 0.0, "ce": 0.0, "recon": 0.0}
    for step, batch in enumerate(dataloader, 1):
        loss, parts = train_memory_batch(
            batch=batch,
            model=model,
            tokenizer=tokenizer,
            generator=generator,
            layer_specs=layer_specs,
            base_adapter_weights=base_adapter_weights,
            student_adapter_name=student_adapter_name,
            kd_T=kd_T,
            lambda_ce=lambda_ce,
            lambda_recon=lambda_recon,
            max_teacher_len=max_teacher_len,
            max_recon_len=max_recon_len,
        )
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        for k in running:
            running[k] += parts[k]
        if step % log_every == 0:
            print(
                f"[step {step}] loss={loss.item():.4f} | "
                f"kd={running['kd']/log_every:.4f} "
                f"ce={running['ce']/log_every:.4f} "
                f"recon={running['recon']/log_every:.4f}"
            )
            running = {k: 0.0 for k in running}



@torch.no_grad()
def reconstruct_memory_text(
    model,
    tokenizer,
    layer_specs,
    base_adapter_weights,
    delta_flat: torch.Tensor,  # [D_flat] for one memory
    adapter_name: str = "student_mem",
    max_new_tokens: int = 128,
    do_sample: bool = False,
    temperature: float = 1.0,
):
    # apply Δ(m)
    deltas = unflatten_delta_flat(delta_flat, layer_specs)
    apply_adapter_deltas(layer_specs, deltas, adapter_name, base_adapter_weights)
    model.set_adapter(adapter_name)

    # generate from <|RECON|>
    recon_id = ensure_token(tokenizer, model.lm, RECON_TOKEN)
    inp = torch.tensor([[recon_id]], device=next(model.parameters()).device)
    out = model.generate(
        input_ids=inp,
        max_new_tokens=max_new_tokens,
        do_sample=do_sample,
        temperature=temperature,
        num_beams=1 if not do_sample else 1,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
    )
    seq = out[0].tolist()
    if seq and seq[0] == recon_id:
        seq = seq[1:]
    return tokenizer.decode(seq, skip_special_tokens=True)


def init_memory_generator(
    model, peft_config, student_adapter_name: str, k_latent: int = 256
):
    # 1) add student adapter
    model.lm.add_adapter(peft_config, adapter_name=student_adapter_name)
    model.lm.set_adapter(student_adapter_name)

    # 2) collect layer_specs and base weights
    layer_specs = {}
    D_flat = 0
    for name, module in model.lm.base_model.named_modules():
        if hasattr(module, "lora_A") and hasattr(module, "lora_B"):
            if (
                student_adapter_name in module.lora_A
                and student_adapter_name in module.lora_B
            ):
                # shapes
                shape_A = tuple(module.lora_A[student_adapter_name].weight.shape)
                shape_B = tuple(module.lora_B[student_adapter_name].weight.shape)
                layer_specs[name] = (module, shape_A, shape_B)
                D_flat += shape_A[0] * shape_A[1] + shape_B[0] * shape_B[1]

    base_adapter_weights = {}
    for name, (module, _, _) in layer_specs.items():
        subA = module.lora_A[student_adapter_name]
        subB = module.lora_B[student_adapter_name]
        base_adapter_weights[name] = (
            subA.weight.detach().clone(),
            subB.weight.detach().clone(),
        )

    # 3) build generator
    d_hidden = model.lm.config.hidden_size
    generator = MemoryAdapterGenerator(
        d_hidden=d_hidden, k_latent=k_latent, D_flat=D_flat
    ).to(
        device=next(model.parameters()).device,
        dtype=(
            torch.bfloat16
            if next(model.parameters()).dtype == torch.bfloat16
            else torch.float32
        ),
    )
    return generator, layer_specs, base_adapter_weights
