import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
from peft import LoraConfig
import types


from adapters import (
    peft_factory,
)
from transformers import AutoTokenizer
from models import TRANSFORMERS

from functools import partial
from iterator import collate_queries, collate_demonstrations
import math
from metrics import *

from typing import Dict, Tuple


from transformers import StoppingCriteria, StoppingCriteriaList
from functional_lora import FunctionalLoRAInjector, generate_layer_specs


from generators import *

Tensor = torch.Tensor
DeltaDict = Dict[str, Tuple[Tensor, Tensor]]  # name -> (A:[r,in], B:[out,r])


# ===== Special tokens for this setup =====
INST_TOKEN = "<|INST|>"  # teacher sees this + instruction + query
RECON_TOKEN = "<|RECON|>"  # student generates the instruction from adapter


# ---------- Adapter merge & linearity helpers ----------
def merge_deltas(
    d1: DeltaDict,
    d2: DeltaDict,
    alpha: float = 1.0,
    beta: float = 1.0,
    normalize: bool = False,
) -> DeltaDict:
    """
    Linear merge: Δ = alpha*Δ1 + beta*Δ2 (optionally Fro-norm normalize each layer).
    Assumes both dicts have identical keys and shapes.
    """
    out: DeltaDict = {}
    for k in d1.keys():
        A1, B1 = d1[k]
        A2, B2 = d2[k]
        if normalize:
            n1 = (A1.pow(2).sum() + B1.pow(2).sum()).sqrt().clamp_min(1e-8)
            n2 = (A2.pow(2).sum() + B2.pow(2).sum()).sqrt().clamp_min(1e-8)
        else:
            n1 = n2 = 1.0
        out[k] = (alpha * A1 / n1 + beta * A2 / n2, alpha * B1 / n1 + beta * B2 / n2)
    return out


def deltas_l2(maybe_tuple):
    """Convenience for norms/averages."""
    if isinstance(maybe_tuple, tuple):
        A, B = maybe_tuple
        return (A.pow(2).sum() + B.pow(2).sum()).sqrt()
    s = 0.0
    for A, B in maybe_tuple.values():
        s = s + (A.pow(2).sum() + B.pow(2).sum())
    return s.sqrt()


def linearity_loss_deltas(
    d_ab: DeltaDict, d_a: DeltaDict, d_b: DeltaDict, reduction: str = "mean"
) -> torch.Tensor:
    """
    Penalize deviation from additive composition: Δ(A+B) ≈ Δ(A)+Δ(B).
    Returns a scalar.
    """
    num = []
    den = []
    for k in d_ab.keys():
        Aab, Bab = d_ab[k]
        Aa, Ba = d_a[k]
        Ab, Bb = d_b[k]
        num.append((Aab - (Aa + Ab)).pow(2).sum() + (Bab - (Ba + Bb)).pow(2).sum())
        den.append((Aab.pow(2).sum() + Bab.pow(2).sum()).clamp_min(1e-8))
    num = torch.stack([x if x.ndim == 0 else x.sum() for x in num]).sum()
    den = torch.stack([x if x.ndim == 0 else x.sum() for x in den]).sum()
    val = (num / den).clamp_min(0.0)
    return val if reduction == "sum" else val  # already scalar


def token_f1(ref: str, hyp: str) -> float:
    """Simple whitespace token F1 for reconstruction quality."""
    ref_t = ref.split()
    hyp_t = hyp.split()
    if not ref_t and not hyp_t:
        return 1.0
    ref_set = {}
    for w in ref_t:
        ref_set[w] = ref_set.get(w, 0) + 1
    tp = 0
    hyp_set = {}
    for w in hyp_t:
        hyp_set[w] = hyp_set.get(w, 0) + 1
    for w, c in hyp_set.items():
        tp += min(c, ref_set.get(w, 0))
    prec = tp / max(1, len(hyp_t))
    rec = tp / max(1, len(ref_t))
    if prec + rec == 0:
        return 0.0
    return 2 * prec * rec / (prec + rec)


# ----------------- helpers -----------------
def ensure_token(tokenizer, model_lm, tok: str) -> int:
    """Add tok if missing and resize embeddings. Returns token_id."""
    if tok not in tokenizer.get_vocab():
        tokenizer.add_special_tokens({"additional_special_tokens": [tok]})
        try:
            model_lm.resize_token_embeddings(len(tokenizer))
        except AttributeError:
            pass
    return tokenizer.convert_tokens_to_ids(tok)


def kd_loss(student_logits, teacher_logits, mask: torch.Tensor, T: float = 1.0):
    """
    KL( teacher || student ) averaged over masked positions.
    logits: [B, L, V], mask: [B, L] (1=keep, 0=ignore)
    """
    s = torch.log_softmax(student_logits / T, dim=-1)
    t = torch.softmax(teacher_logits / T, dim=-1)
    kl = torch.sum(t * (torch.log(t + 1e-8) - s), dim=-1)  # [B, L]
    kl = (kl * mask).sum() / (mask.sum() + 1e-8)
    return (T**2) * kl


def make_recon_batch(
    tokenizer, target_text_ids: torch.Tensor, device, max_len: int = None
):
    rid = tokenizer.convert_tokens_to_ids(RECON_TOKEN)
    eos = tokenizer.eos_token_id
    if target_text_ids.dim() == 1:
        target_text_ids = target_text_ids.unsqueeze(0)

    # (A) append EOS to target
    tgt = target_text_ids
    if max_len is not None:
        tgt = tgt[:, :max_len]
    if eos is not None:
        eos_col = torch.full((tgt.size(0), 1), eos, dtype=tgt.dtype, device=device)
        tgt = torch.cat([tgt, eos_col], dim=1)

    # (B) teacher forcing: [RECON] + tgt[:-1] → predict tgt
    bos = torch.full((tgt.size(0), 1), rid, dtype=tgt.dtype, device=device)
    inp = torch.cat([bos, tgt[:, :-1]], dim=1)
    labels = tgt.contiguous()
    return inp, labels


@torch.no_grad()
def adapter_to_text(
    model_lm,  # HF CausalLM (e.g., model.lm)
    tokenizer,
    deltas,  # dict[name] -> (A:[r,in], B:[out,r]) from your generator
    injector=None,  # FunctionalLoRAInjector or None
    *,  # keyword-only below
    recon_token=RECON_TOKEN,
    max_length=128,
    do_sample=False,
    temperature=1.0,
    num_beams=1,
    # Fallback-only args (ignored when injector is used):
    layer_specs=None,
    base_adapter_weights=None,
    adapter_name="student",
    apply_adapter_deltas=None,  # callable(layer_specs, deltas, adapter_name, base_adapter_weights)
):
    """
    If `injector` is provided:
        - Temporarily enable functional LoRA with `deltas`, generate, then disable.
    Else (fallback):
        - Apply deltas in-place to a PEFT adapter via apply_adapter_deltas, set adapter, then generate.
    """
    device = next(model_lm.parameters()).device

    # make sure recon token exists
    if recon_token not in tokenizer.get_vocab():
        tokenizer.add_special_tokens({"additional_special_tokens": [recon_token]})
        try:
            model_lm.resize_token_embeddings(len(tokenizer))
        except AttributeError:
            pass

    recon_id = tokenizer.convert_tokens_to_ids(recon_token)
    input_ids = torch.tensor([[recon_id]], device=device)

    # Ensure pad id is set for generate
    pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id

    if injector is not None:
        # Functional path: no PEFT adapter touching, grads irrelevant (no_grad)
        injector.set_deltas(deltas)
        injector.enable()
        try:
            out = model_lm.generate(
                input_ids=input_ids,
                max_new_tokens=max_length,
                do_sample=do_sample,
                temperature=temperature,
                num_beams=num_beams if not do_sample else 1,
                eos_token_id=tokenizer.eos_token_id,
                pad_token_id=pad_id,
            )
        finally:
            injector.disable()
    else:
        # Fallback path: use your in-place writer + PEFT adapter
        assert (
            apply_adapter_deltas is not None
        ), "apply_adapter_deltas must be provided for fallback path"
        assert (
            layer_specs is not None and base_adapter_weights is not None
        ), "layer_specs/base_adapter_weights required"
        apply_adapter_deltas(layer_specs, deltas, adapter_name, base_adapter_weights)
        # activate the adapter on the PEFT model
        try:
            model_lm.set_adapter(adapter_name)
        except AttributeError:
            # some wrappers expose it on the container
            if hasattr(model_lm, "peft_config"):
                pass  # already fine
            else:
                raise
        out = model_lm.generate(
            input_ids=input_ids,
            max_new_tokens=max_length,
            do_sample=do_sample,
            temperature=temperature,
            num_beams=num_beams if not do_sample else 1,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=pad_id,
        )

    seq = out[0].tolist()
    if seq and seq[0] == recon_id:
        seq = seq[1:]
    return tokenizer.decode(seq, skip_special_tokens=True)


def run_experiment(model, tokenizer, splits, device, args, data_handler, demos):
    test, train, val = splits

    # ensure special tokens (once)
    _inst_id = ensure_token(tokenizer, model.lm, INST_TOKEN)
    _recon_id = ensure_token(tokenizer, model.lm, RECON_TOKEN)

    student_adapter = "student"

    injector = FunctionalLoRAInjector(scale=1.0)
    peft_config = peft_factory(args.peft)
    # model.lm.add_adapter(peft_config, adapter_name=student_adapter)

    # Prepare generator target specs
    targets = [
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "up_proj",
        "down_proj",
    ]
    # layer_specs = generate_layer_specs(model.lm, targets, rank=16)
    layer_specs = generate_layer_specs(
        model,
        targets,
        default_rank=16,
        ranks=[16, 16, 4, 4, 8, 8, 8],
    )

    # Instruction text (support) — we’ll distill it out + reconstruct it
    support_text = "\n".join(demos)
    support_ids = tokenizer(
        support_text, return_tensors="pt", padding=True, truncation=True
    )["input_ids"].to(device)
    # what encoder sees to produce adapter: instruction only
    instruction_tokens_for_generator = support_ids

    # generator = AdapterGenerator(
    #     layer_specs,
    #     input_dim=model.lm.config.hidden_size,
    #     hidden_dim=args.hidden_dim,
    #     num_layers=1,
    # ).to(device=device, dtype=torch.bfloat16)

    generator = LinearAdapterGeneratorSmall(
        layer_specs=layer_specs,
        input_dim=model.lm.config.hidden_size,
        a_std=1e-3,  # or 1e-4 if you run fp32; use 1e-3 for bf16/fp16 to avoid underflow
    ).to(device, dtype=torch.bfloat16)

    # generator = MetaGenerator(
    #     model,
    #     layer_specs,
    #     peft_config,
    #     input_dim=model.lm.config.hidden_size,
    #     hidden_dim=args.hidden_dim,
    #     num_layers=2,
    # ).to(device=device, dtype=torch.bfloat16)

    model.lm.requires_grad_(False)
    generator.set_trainable_params()
    optimizer = torch.optim.AdamW(
        generator.get_opt_params(), lr=args.lr, weight_decay=args.l2
    )

    loss_fn = nn.CrossEntropyLoss(ignore_index=-100)

    # Query loader
    dl = DataLoader(
        [train[i] for i in range(50)],  # train,
        batch_size=args.batch_size,
        collate_fn=partial(
            collate_queries, dh=data_handler, tokenizer=tokenizer, device=device
        ),
    )

    injector.attach(layer_specs)

    # ===== Training =====
    for epoch in range(args.epochs):
        generator.train()
        total_loss = 0.0
        total_parts = {"kd": 0.0, "ce": 0.0, "recon": 0.0}
        steps = 0

        for batch in tqdm(dl, desc=f"Epoch {epoch+1}"):
            steps += 1
            # 3) Prepare teacher inputs: [<|INST|>] + instruction + query
            #    (teacher runs on base model.lm.base_model: no adapters)
            inst_tok = torch.tensor([[_inst_id]], device=device)
            teacher_ids = torch.cat(
                [
                    inst_tok.repeat(support_ids.size(0), 1),
                    support_ids.repeat(batch.input_ids.size(0), 1),
                    batch.input_ids,
                ],
                dim=1,
            )

            injector.disable()
            with torch.no_grad():
                t_out = model(teacher_ids)
                # align on query positions (last |q| tokens)
                Lq = batch.input_ids.size(1)
                logits = t_out.logits
                logits_T = logits[:, -Lq:, :]

            # 1) Build adapter from the INSTRUCTION ONLY
            adapter_deltas = generator(
                instruction_tokens_for_generator, model
            )  # dict[name]->(ΔA,ΔB)

            # 2) Apply adapter to student
            injector.set_deltas(adapter_deltas)
            injector.enable()

            # 4) Student forward on query only (adapter active)
            q_ids = batch.input_ids
            s_out = model(q_ids)
            logits_S = s_out.logits[:, -Lq:, :]

            # KD loss on query tokens
            mask = (q_ids != tokenizer.pad_token_id).float()
            L_kd = kd_loss(logits_S, logits_T, mask, T=1.0)

            # (Optional) CE vs gold next-token (your current NLL on teacher-forcing)
            inputs_for_ce = q_ids[:, :-1]
            s_out_ce = model(inputs_for_ce)
            logits_ce = s_out_ce.logits[:, -1:, :]
            L_ce = loss_fn(logits_ce.view(1, -1), batch.target_ids)

            # 5) Reconstruction loss: <|RECON|> → instruction
            rinp, rlab = make_recon_batch(tokenizer, support_ids, device, max_len=None)
            r_out = model(rinp)
            L_recon = loss_fn(
                r_out.logits.view(-1, r_out.logits.size(-1)), rlab.view(-1)
            )
            loss = 0.1 * L_kd + 1.0 * L_ce + 0.2 * L_recon
            # if epoch == 0:
            #     loss = 0.0 * L_kd + 1.0 * L_ce + 0.0 * L_recon
            # else:
            #     loss = 0.0 * L_kd + 1.0 * L_ce + 0.0 * L_recon

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_parts["kd"] += float(L_kd)
            total_parts["ce"] += float(L_ce)
            total_parts["recon"] += float(L_recon)

            injector.clear_deltas()

        print(
            f"Epoch {epoch+1} — Avg Loss: {total_loss/steps:.4f} | "
            f"KD:{total_parts['kd']/steps:.4f} CE:{total_parts['ce']/steps:.4f} RECON:{total_parts['recon']/steps:.4f}"
        )

        # ===== After epoch: reconstruct the instruction from adapter =====
        generator.eval()
        with torch.no_grad():
            final_deltas = generator(instruction_tokens_for_generator, model)
        recon_text = adapter_to_text(
            model_lm=model.lm,
            tokenizer=tokenizer,
            deltas=final_deltas,  # from your generator
            injector=injector,  # FunctionalLoRAInjector instance
            max_length=256,
            do_sample=False,
        )
        print("[ORIG INSTRUCTION]")
        print(support_text)
        print("[RECONSTRUCTED]")
        print(recon_text)

        updated, skipped = [], []
        for name, param in generator.named_parameters():
            if param.requires_grad:
                if param.grad is None:
                    skipped.append(name)
                else:
                    updated.append((name, param.grad.norm().item()))

        for name, param in model.named_parameters():
            if param.requires_grad:
                if param.grad is None:
                    skipped.append(name)
                else:
                    updated.append((name, param.grad.norm().item()))

        print(f"Will update {len(updated)} params, skip {len(skipped)} params")
        print("Updated parameters and their grad‑norms:")
        for name, gnorm in updated:
            print(f"  {name:50s} grad‑norm={gnorm:.3e}")

        # ===== Evaluation (unchanged, quick exact-match) =====
        injector.set_deltas(final_deltas)
        injector.enable()

        test_dl = DataLoader(
            splits[0],
            batch_size=args.batch_size,
            collate_fn=partial(
                collate_queries, dh=data_handler, tokenizer=tokenizer, device=device
            ),
        )
        model.eval()
        correct_count_exact = 0
        total_count = 0
        with torch.inference_mode():
            verbalizers = data_handler.class_labels
            for batch in tqdm(test_dl, desc="Evaluating on Test Set"):
                input_ids = batch.input_ids
                outputs = model(input_ids)
                logits = outputs.logits[:, -1, :].squeeze()
                token_id = torch.argmax(logits)
                pred = tokenizer.decode(token_id)
                label = batch.target[0]
                correct_count_exact += batch.target_ids == token_id
                total_count += 1
        print(f"Exact match: {correct_count_exact.item()}/{total_count}")


def adapter_ft(model, tokenizer, splits, device, args, data_handler):
    peft_config_base = peft_factory(args.peft)
    trained_adapters = []

    test, train, dev = splits

    adapter_name = f"adapter"
    import copy

    peft_config = copy.deepcopy(peft_config_base)
    model.lm.add_adapter(peft_config, adapter_name=adapter_name)
    model.lm.set_adapter(adapter_name)

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.l2)
    criterion = nn.CrossEntropyLoss(ignore_index=-100)

    for epoch in range(args.epochs):
        # logging.info(f"Training epoch: {epoch+1}/{self.args.epochs}")

        model.train()
        loss_agg = 0.0

        dl = DataLoader(
            train,  # [train[i] for i in range(100)],
            batch_size=args.batch_size,
            collate_fn=partial(
                collate_demonstrations,
                dh=data_handler,
                tokenizer=tokenizer,
                device=device,
            ),
        )

        for batch in tqdm(dl, desc=f"Epoch {epoch+1}"):
            # print(batch)
            optimizer.zero_grad()
            input_ids = batch.input_ids
            inputs = input_ids[:, :-1]

            labels = torch.full_like(inputs, -100)
            labels[:, -1] = input_ids[:, -1]

            # labels = input_ids[:, 1:].clone()

            output = model(inputs)
            logits = output.logits
            vocab_size = logits.shape[-1]
            loss = criterion(logits.view(-1, vocab_size), labels.view(-1))
            loss_agg += loss.item()
            loss.backward()
            optimizer.step()

        dl = DataLoader(
            test,
            batch_size=1,
            collate_fn=partial(
                collate_queries,
                **dict(
                    dh=data_handler,
                    tokenizer=tokenizer,
                    device=device,
                ),
            ),
        )

        # model.lm.active_adapters = None
        with torch.inference_mode():
            verbalizers = data_handler.class_labels
            verbalizer_ids = [tokenizer.encode(verb)[-1] for verb in verbalizers]

            correct_count_exact = 0
            total_count = 0
            for batch in tqdm(dl):
                input_ids = batch.input_ids
                label = batch.target[0]

                outputs = model(input_ids)
                logits = outputs.logits[:, -1, :].squeeze()

                token_id = torch.argmax(logits)
                pred = tokenizer.decode(token_id)
                correct_count_exact += batch.target_ids == token_id
                total_count += 1

            print(f"Exact match: {correct_count_exact.item()}/{total_count}")


def evaluate_model(
    model, tokenizer, name, dl, data_handler, demo_key_values=None, demo_ids=None
):
    model.eval()

    prob_list = []
    pred_list = []
    target_list = []
    with torch.inference_mode():
        verbalizers = data_handler.class_labels
        verbalizer_ids = [tokenizer.encode(verb)[-1] for verb in verbalizers]

        correct_count_exact = 0
        total_count = 0
        for batch in tqdm(dl):
            input_ids = batch.input_ids
            if demo_ids is not None:
                input_ids = torch.concat([demo_ids, input_ids], dim=1)
            label = batch.target[0]
            target_list.append(label.cpu())

            outputs = model(input_ids, past_key_values=demo_key_values)
            logits = outputs.logits[:, -1, :].squeeze()
            preds = logits[verbalizer_ids]
            prob_list.append(preds.cpu())
            pred_label = torch.argmax(preds)
            pred_list.append(pred_label.cpu())

            token_id = torch.argmax(logits)
            pred = tokenizer.decode(token_id)
            correct_count_exact += batch.target_ids == token_id
            total_count += 1

        print(f"Exact match: {correct_count_exact.item()}/{total_count}")

    result = Result(name=name)

    y_pred = torch.tensor(pred_list)
    y_true = torch.tensor(target_list)
    y_logits = torch.stack(prob_list)
    result.evaluate(y_true, y_pred)
    result.logits = (y_true, y_logits)

    print(result)

    return result


def icl_eval(
    model,
    device,
    tokenizer,
    data_handler,
    demos,
    dl,
):

    demo_text = "\n".join(demos)

    with torch.inference_mode():
        demo_ids = tokenizer(demo_text, return_tensors="pt")["input_ids"].to(device)
        outputs = model(demo_ids, use_cache=True)
        demo_key_values = outputs.past_key_values

    # logging.info(f"{n_demos}-shot ICL...")
    result_n_shot = evaluate_model(
        tokenizer=tokenizer,
        model=model,
        name="n-shot",
        dl=dl,
        data_handler=data_handler,
        demo_key_values=None,
        demo_ids=demo_ids,
    )

    # logging.info("0-shot...")
    result_zero_shot = evaluate_model(
        model,
        tokenizer,
        name="0-shot",
        dl=dl,
        data_handler=data_handler,
    )

    return result_n_shot, result_zero_shot


# === Main entry point ===
if __name__ == "__main__":
    import argparse
    from models import initialize_model
    from dataloader import DATASETS

    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default="llama-3.2-1b-it")
    parser.add_argument("--num-demos", type=int, default=4)
    parser.add_argument("--data", type=str, default="mmlu-misc")
    parser.add_argument("--peft", type=str, default="lora")
    parser.add_argument("--batch-size", type=int, default=1)
    parser.add_argument("--epochs", type=int, default=3)
    parser.add_argument("--lr", type=float, default=5e-5)
    parser.add_argument("--l2", type=float, default=0.01)
    parser.add_argument("--hidden-dim", type=int, default=512)
    parser.add_argument("--n-flows", type=int, default=4)
    args = parser.parse_args()

    dh = DATASETS[args.data]
    splits = dh.get_dataset_splits()
    test, train, val = splits

    tokenizer = AutoTokenizer.from_pretrained(TRANSFORMERS[args.model])
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"  # safer for causal LM

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    meta = type("Meta", (), {"task_type": "lm"})()
    model = initialize_model(args, meta, device)
    model.to(device)

    model.lm.config.pad_token_id = tokenizer.pad_token_id
    # (optional) also for generation config
    if hasattr(model, "generation_config"):
        model.generation_config.pad_token_id = tokenizer.pad_token_id

    test_dl = DataLoader(
        test,
        batch_size=args.batch_size,
        collate_fn=partial(collate_queries, dh=dh, tokenizer=tokenizer, device=device),
    )

    demos = [dh.make_demonstration(val[i]) for i in range(args.num_demos)]

    # icl_eval(model, device, tokenizer, dh, demos, test_dl)

    # adapter_ft(model, tokenizer, splits, device, args, dh)

    run_experiment(
        model,
        tokenizer,
        splits=splits,
        device=device,
        args=args,
        data_handler=dh,
        demos=demos,
    )
