import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
from peft import LoraConfig
import types
from flow import *


from adapters import (
    peft_factory,
    save_adapter_weights,
    extract_adapter_specs,
    apply_adapter_deltas,
)
from transformers import AutoTokenizer
from models import TRANSFORMERS

from functools import partial
from iterator import collate_queries, collate_demonstrations
import math
from metrics import *


from transformers import StoppingCriteria, StoppingCriteriaList


# ===== Special tokens for this setup =====
INST_TOKEN = "<|INST|>"  # teacher sees this + instruction + query
RECON_TOKEN = "<|RECON|>"  # student generates the instruction from adapter


# ----------------- helpers -----------------
def ensure_token(tokenizer, model_lm, tok: str) -> int:
    """Add tok if missing and resize embeddings. Returns token_id."""
    if tok not in tokenizer.get_vocab():
        tokenizer.add_special_tokens({"additional_special_tokens": [tok]})
        try:
            model_lm.resize_token_embeddings(len(tokenizer))
        except AttributeError:
            pass
    return tokenizer.convert_tokens_to_ids(tok)


def kd_loss(student_logits, teacher_logits, mask: torch.Tensor, T: float = 1.0):
    """
    KL( teacher || student ) averaged over masked positions.
    logits: [B, L, V], mask: [B, L] (1=keep, 0=ignore)
    """
    s = torch.log_softmax(student_logits / T, dim=-1)
    t = torch.softmax(teacher_logits / T, dim=-1)
    kl = torch.sum(t * (torch.log(t + 1e-8) - s), dim=-1)  # [B, L]
    kl = (kl * mask).sum() / (mask.sum() + 1e-8)
    return (T**2) * kl


def make_recon_batch(
    tokenizer, target_text_ids: torch.Tensor, device, max_len: int = None
):
    rid = tokenizer.convert_tokens_to_ids(RECON_TOKEN)
    eos = tokenizer.eos_token_id
    if target_text_ids.dim() == 1:
        target_text_ids = target_text_ids.unsqueeze(0)

    # (A) append EOS to target
    tgt = target_text_ids
    if max_len is not None:
        tgt = tgt[:, :max_len]
    if eos is not None:
        eos_col = torch.full((tgt.size(0), 1), eos, dtype=tgt.dtype, device=device)
        tgt = torch.cat([tgt, eos_col], dim=1)

    # (B) teacher forcing: [RECON] + tgt[:-1] → predict tgt
    bos = torch.full((tgt.size(0), 1), rid, dtype=tgt.dtype, device=device)
    inp = torch.cat([bos, tgt[:, :-1]], dim=1)
    labels = tgt.contiguous()
    return inp, labels


@torch.no_grad()
def adapter_to_text(
    model,  # your PEFT‐wrapped model.lm (HF CausalLM)
    tokenizer,
    layer_specs,
    base_adapter_weights,
    deltas,  # dict[name] -> (ΔA, ΔB)
    adapter_name="student",
    max_length=128,
    do_sample=False,
    temperature=1.0,
):
    """Apply adapter deltas, then generate from <|RECON|> to read instruction back."""
    apply_adapter_deltas(layer_specs, deltas, adapter_name, base_adapter_weights)
    model.set_adapter(adapter_name)

    recon_id = tokenizer.convert_tokens_to_ids(RECON_TOKEN)
    input_ids = torch.tensor([[recon_id]], device=next(model.parameters()).device)

    out = model.generate(
        input_ids=input_ids,
        max_new_tokens=max_length,
        do_sample=do_sample,
        temperature=temperature,
        num_beams=1 if not do_sample else 1,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
    )
    seq = out[0].tolist()
    if seq and seq[0] == recon_id:
        seq = seq[1:]
    return tokenizer.decode(seq, skip_special_tokens=True)


class MetaGenerator(nn.Module):

    def __init__(
        self,
        model: nn.Module,
        layer_specs: dict,
        gen_config: LoraConfig,
        input_dim: int,
        hidden_dim: int = 512,
        num_layers: int = 2,
        gen_adapter: str = "generator",
    ):

        super().__init__()
        self.model = model
        self.gen_adapter = gen_adapter
        model.lm.add_adapter(gen_config, adapter_name=gen_adapter)

        self.layer_specs = layer_specs
        total = 0
        self.param_shapes = {}
        for name, (_, shape_A, shape_B) in layer_specs.items():
            sA = shape_A[0] * shape_A[1]
            sB = shape_B[0] * shape_B[1]
            self.param_shapes[name] = (shape_A, shape_B)
            total += sA + sB

        self.encoder = nn.GRU(
            input_dim, hidden_dim, batch_first=True, num_layers=num_layers
        )
        self.projector = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, total),
        )

        nn.init.zeros_(self.projector[-1].weight)
        nn.init.zeros_(self.projector[-1].bias)

    def forward(self, support_inputs: torch.Tensor, *args, **kwargs):
        """
        support_inputs:  token‐IDs for your support set (shape [B, T])
        returns:          dict[name -> (ΔA, ΔB)] for your student adapter
        """
        # 1) switch model into generator‐LoRA mode
        self.model.set_adapter(self.gen_adapter)
        self.model.train()  # so dropout in the LoRA adapter still works

        # 2) run forward to get hidden_states
        out = self.model(
            input_ids=support_inputs, output_hidden_states=True, return_dict=True
        )
        hs = out.hidden_states[-1]  # (B, T, D_model)

        # 3) encode & pool via GRU
        _, h_n = self.encoder(hs)  # h_n: (num_layers, B, hidden_dim)
        last = h_n[-1]  # (B, hidden_dim)

        # 4) project to a flat vector of LoRA‐deltas
        flat = self.projector(last)  # (B, total_params)

        # 5) un‐flatten into each (ΔA, ΔB)
        dim = flat.size(0)
        flat = flat.view(dim, -1)
        ptr = 0
        deltas = {}
        for name, (shape_A, shape_B) in self.param_shapes.items():
            sA = shape_A[0] * shape_A[1]
            sB = shape_B[0] * shape_B[1]
            chunk = flat[:, ptr : ptr + sA + sB]
            A = chunk[:, :sA].view(*shape_A)
            B = chunk[:, sA:].view(*shape_B)
            deltas[name] = (A, B)
            ptr += sA + sB

        return deltas

    def set_trainable_params(self):
        for module, *_ in self.layer_specs.values():
            module.lora_A[self.gen_adapter].weight.requires_grad = True
            module.lora_B[self.gen_adapter].weight.requires_grad = True

        self.encoder.requires_grad_(True)
        self.projector.requires_grad_(True)

    def get_opt_params(self):
        opt_params = []

        for module, *_ in self.layer_specs.values():
            opt_params += [
                module.lora_A[self.gen_adapter].weight,
                module.lora_B[self.gen_adapter].weight,
            ]

        opt_params += list(self.encoder.parameters())
        opt_params += list(self.projector.parameters())

        return opt_params


class AdapterGenerator(nn.Module):
    def __init__(self, layer_specs, input_dim=3072, hidden_dim=512, num_layers=2):
        super().__init__()
        self.layer_specs = layer_specs

        total_params = 0
        self.param_shapes = {}
        for name, (_, shape_A, shape_B) in layer_specs.items():
            # shape_A: (r, in_dim), shape_B: (out_dim, r)
            size_A = shape_A[0] * shape_A[1]
            size_B = shape_B[0] * shape_B[1]
            self.param_shapes[name] = (shape_A, shape_B)
            total_params += size_A + size_B

        self.encoder = nn.GRU(
            input_dim, hidden_dim, batch_first=True, num_layers=num_layers
        )

        self.projector = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, total_params),
        )

        nn.init.zeros_(self.projector[-1].weight)
        nn.init.zeros_(self.projector[-1].bias)

    def forward(self, inputs, model):
        model.eval()
        with torch.no_grad():
            outputs = model.lm.base_model(
                inputs,
                output_hidden_states=True,  # <— tell it to return hidden states
                return_dict=True,
            )
            # last_hidden_state is [B, T, D_model]
            hs = outputs.hidden_states[-1]

        # 2) feed that full sequence into your GRU encoder
        #    hs: (batch, seq_len, D_model) must match your GRU's input_dim
        #    (encoder_out, h_n) = …
        encoder_out, h_n = self.encoder(hs)
        #   – encoder_out: (batch, seq_len, hidden_dim)
        #   – h_n:         (num_layers, batch, hidden_dim)

        # 3a) Option A: project just the last-layer hidden state
        last_layer = h_n[-1]  # (batch, hidden_dim)
        flat = self.projector(last_layer)  # (batch, total_params)

        weights = {}
        ptr = 0
        for name, (shape_A, shape_B) in self.param_shapes.items():
            size_A = shape_A[0] * shape_A[1]
            size_B = shape_B[0] * shape_B[1]
            chunk = flat[:, ptr : ptr + size_A + size_B]
            A_flat = chunk[:, :size_A]
            B_flat = chunk[:, size_A:]
            A = A_flat.view(*shape_A)  # [in_dim, r]
            B = B_flat.view(*shape_B)  # [r, out_dim]
            weights[name] = A, B
            ptr += size_A + size_B

        return weights

    def set_trainable_params(self):
        self.encoder.requires_grad_(True)
        self.projector.requires_grad_(True)

    def get_opt_params(self):
        return self.parameters()


def run_experiment(model, tokenizer, splits, device, args, data_handler, demos):
    test, train, val = splits

    # ensure special tokens (once)
    _inst_id = ensure_token(tokenizer, model.lm, INST_TOKEN)
    _recon_id = ensure_token(tokenizer, model.lm, RECON_TOKEN)

    student_adapter = "student"
    peft_config = peft_factory(args.peft)
    model.lm.add_adapter(peft_config, adapter_name=student_adapter)

    # Prepare generator target specs
    layer_specs = extract_adapter_specs(
        model.lm, peft_config.target_modules, student_adapter
    )

    base_adapter_weights = {}
    for name, (module, shapeA, shapeB) in layer_specs.items():
        subA = module.lora_A[student_adapter]
        subB = module.lora_B[student_adapter]
        base_adapter_weights[name] = (
            subA.weight.detach().clone(),
            subB.weight.detach().clone(),
        )

    # Instruction text (support) — we’ll distill it out + reconstruct it
    support_text = "\n".join(demos)
    support_ids = tokenizer(
        support_text, return_tensors="pt", padding=True, truncation=True
    )["input_ids"].to(device)
    # what encoder sees to produce adapter: instruction only
    instruction_tokens_for_generator = support_ids

    generator = AdapterGenerator(
        layer_specs,
        input_dim=model.lm.config.hidden_size,
        hidden_dim=args.hidden_dim,
        num_layers=2,
    ).to(device=device, dtype=torch.bfloat16)

    # generator = MetaGenerator(
    #     model,
    #     layer_specs,
    #     peft_config,
    #     input_dim=model.lm.config.hidden_size,
    #     hidden_dim=args.hidden_dim,
    #     num_layers=2,
    # ).to(device=device, dtype=torch.bfloat16)

    model.lm.requires_grad_(False)
    generator.set_trainable_params()
    optimizer = torch.optim.AdamW(
        generator.get_opt_params(), lr=args.lr, weight_decay=args.l2
    )

    loss_fn = nn.CrossEntropyLoss(ignore_index=-100)

    # Query loader
    dl = DataLoader(
        [train[i] for i in range(50)],  # train,
        batch_size=args.batch_size,
        collate_fn=partial(
            collate_queries, dh=data_handler, tokenizer=tokenizer, device=device
        ),
    )

    # ===== Training =====
    for epoch in range(args.epochs):
        generator.train()
        total_loss = 0.0
        total_parts = {"kd": 0.0, "ce": 0.0, "recon": 0.0}
        steps = 0

        for batch in tqdm(dl, desc=f"Epoch {epoch+1}"):
            steps += 1
            # 3) Prepare teacher inputs: [<|INST|>] + instruction + query
            #    (teacher runs on base model.lm.base_model: no adapters)
            inst_tok = torch.tensor([[_inst_id]], device=device)
            teacher_ids = torch.cat(
                [
                    inst_tok.repeat(support_ids.size(0), 1),
                    support_ids.repeat(batch.input_ids.size(0), 1),
                    batch.input_ids,
                ],
                dim=1,
            )
            with torch.no_grad():
                t_out = model(teacher_ids)
                # align on query positions (last |q| tokens)
                Lq = batch.input_ids.size(1)
                logits = t_out.logits
                logits_T = logits[:, -Lq:, :]

            # 1) Build adapter from the INSTRUCTION ONLY
            adapter_deltas = generator(
                instruction_tokens_for_generator, model
            )  # dict[name]->(ΔA,ΔB)

            # 2) Apply adapter to student
            model.set_adapter(student_adapter)
            apply_adapter_deltas(
                layer_specs, adapter_deltas, student_adapter, base_adapter_weights
            )

            # 4) Student forward on query only (adapter active)
            q_ids = batch.input_ids
            s_out = model(q_ids)
            logits_S = s_out.logits[:, -Lq:, :]

            # KD loss on query tokens
            mask = (q_ids != tokenizer.pad_token_id).float()
            L_kd = kd_loss(logits_S, logits_T, mask, T=1.0)

            # (Optional) CE vs gold next-token (your current NLL on teacher-forcing)
            labels = q_ids[:, 1:].clone()
            inputs_for_ce = q_ids[:, :-1]
            s_out_ce = model(inputs_for_ce)
            logits_ce = s_out_ce.logits
            L_ce = loss_fn(logits_ce.view(-1, logits_ce.size(-1)), labels.view(-1))

            # 5) Reconstruction loss: <|RECON|> → instruction
            rinp, rlab = make_recon_batch(tokenizer, support_ids, device, max_len=None)
            r_out = model(rinp)
            L_recon = loss_fn(
                r_out.logits.view(-1, r_out.logits.size(-1)), rlab.view(-1)
            )
            if epoch == 0:
                loss = 0.0 * L_kd + 1.0 * L_ce + 1.0 * L_recon
            else:
                loss = 1.0 * L_ce + 0.05 * L_recon

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_parts["kd"] += float(L_kd)
            total_parts["ce"] += float(L_ce)
            total_parts["recon"] += float(L_recon)


        print(
            f"Epoch {epoch+1} — Avg Loss: {total_loss/steps:.4f} | "
            f"KD:{total_parts['kd']/steps:.4f} CE:{total_parts['ce']/steps:.4f} RECON:{total_parts['recon']/steps:.4f}"
        )

        # ===== After epoch: reconstruct the instruction from adapter =====
        generator.eval()
        with torch.no_grad():
            final_deltas = generator(instruction_tokens_for_generator, model)
        recon_text = adapter_to_text(
            model=model.lm,
            tokenizer=tokenizer,
            layer_specs=layer_specs,
            base_adapter_weights=base_adapter_weights,
            deltas=final_deltas,
            adapter_name=student_adapter,
            max_length=256,
        )
        print("[ORIG INSTRUCTION]")
        print(support_text)
        print("[RECONSTRUCTED]")
        print(recon_text)

        updated, skipped = [], []
        for name, param in generator.named_parameters():
            if param.requires_grad:
                if param.grad is None:
                    skipped.append(name)
                else:
                    updated.append((name, param.grad.norm().item()))

        for name, param in model.named_parameters():
            if param.requires_grad:
                if param.grad is None:
                    skipped.append(name)
                else:
                    updated.append((name, param.grad.norm().item()))

        print(f"Will update {len(updated)} params, skip {len(skipped)} params")
        print("Updated parameters and their grad‑norms:")
        for name, gnorm in updated:
            print(f"  {name:50s} grad‑norm={gnorm:.3e}")

        # ===== Evaluation (unchanged, quick exact-match) =====
        test_dl = DataLoader(
            splits[0],
            batch_size=args.batch_size,
            collate_fn=partial(
                collate_queries, dh=data_handler, tokenizer=tokenizer, device=device
            ),
        )
        model.eval()
        correct_count_exact = 0
        total_count = 0
        with torch.inference_mode():
            verbalizers = data_handler.class_labels
            for batch in tqdm(test_dl, desc="Evaluating on Test Set"):
                input_ids = batch.input_ids
                outputs = model(input_ids)
                logits = outputs.logits[:, -1, :].squeeze()
                token_id = torch.argmax(logits)
                pred = tokenizer.decode(token_id)
                label = batch.target[0]
                correct_count_exact += verbalizers[label] == pred
                total_count += 1
        print(f"Exact match: {correct_count_exact}/{total_count}")


def adapter_ft(model, tokenizer, splits, device, args, data_handler):
    peft_config_base = peft_factory(args.peft)
    trained_adapters = []

    test, train, dev = splits

    adapter_name = f"adapter"
    import copy

    peft_config = copy.deepcopy(peft_config_base)
    model.lm.add_adapter(peft_config, adapter_name=adapter_name)
    model.lm.set_adapter(adapter_name)

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.l2)
    criterion = nn.CrossEntropyLoss(ignore_index=-100)

    for epoch in range(args.epochs):
        # logging.info(f"Training epoch: {epoch+1}/{self.args.epochs}")

        model.train()
        loss_agg = 0.0

        dl = DataLoader(
            train,  # [train[i] for i in range(100)],
            batch_size=args.batch_size,
            collate_fn=partial(
                collate_demonstrations,
                dh=data_handler,
                tokenizer=tokenizer,
                device=device,
            ),
        )

        for batch in tqdm(dl, desc=f"Epoch {epoch+1}"):
            # print(batch)
            optimizer.zero_grad()
            input_ids = batch.input_ids
            inputs = input_ids[:, :-1]

            labels = torch.full_like(inputs, -100)
            labels[:, -1] = input_ids[:, -1]

            # labels = input_ids[:, 1:].clone()

            output = model(inputs)
            logits = output.logits
            vocab_size = logits.shape[-1]
            loss = criterion(logits.view(-1, vocab_size), labels.view(-1))
            loss_agg += loss.item()
            loss.backward()
            optimizer.step()

        dl = DataLoader(
            test,
            batch_size=1,
            collate_fn=partial(
                collate_queries,
                **dict(
                    dh=data_handler,
                    tokenizer=tokenizer,
                    device=device,
                ),
            ),
        )

        # model.lm.active_adapters = None
        with torch.inference_mode():
            verbalizers = data_handler.class_labels
            verbalizer_ids = [tokenizer.encode(verb)[-1] for verb in verbalizers]

            correct_count_exact = 0
            total_count = 0
            for batch in tqdm(dl):
                input_ids = batch.input_ids
                label = batch.target[0]

                outputs = model(input_ids)
                logits = outputs.logits[:, -1, :].squeeze()

                token_id = torch.argmax(logits)
                pred = tokenizer.decode(token_id)
                correct_count_exact += verbalizers[label] == pred
                total_count += 1

            print(f"Exact match: {correct_count_exact}/{total_count}")


def evaluate_model(
    model, tokenizer, name, dl, data_handler, demo_key_values=None, demo_ids=None
):
    model.eval()

    prob_list = []
    pred_list = []
    target_list = []
    with torch.inference_mode():
        verbalizers = data_handler.class_labels
        verbalizer_ids = [tokenizer.encode(verb)[-1] for verb in verbalizers]

        correct_count_exact = 0
        total_count = 0
        for batch in tqdm(dl):
            input_ids = batch.input_ids
            if demo_ids is not None:
                input_ids = torch.concat([demo_ids, input_ids], dim=1)
            label = batch.target[0]
            target_list.append(label.cpu())

            outputs = model(input_ids, past_key_values=demo_key_values)
            logits = outputs.logits[:, -1, :].squeeze()
            preds = logits[verbalizer_ids]
            prob_list.append(preds.cpu())
            pred_label = torch.argmax(preds)
            pred_list.append(pred_label.cpu())

            token_id = torch.argmax(logits)
            pred = tokenizer.decode(token_id)
            correct_count_exact += verbalizers[label] == pred
            total_count += 1

        print(f"Exact match: {correct_count_exact}/{total_count}")

    result = Result(name=name)

    y_pred = torch.tensor(pred_list)
    y_true = torch.tensor(target_list)
    y_logits = torch.stack(prob_list)
    result.evaluate(y_true, y_pred)
    result.logits = (y_true, y_logits)

    print(result)

    return result


def icl_eval(
    model,
    device,
    tokenizer,
    data_handler,
    demos,
    dl,
):

    demo_text = "\n".join(demos)

    with torch.inference_mode():
        demo_ids = tokenizer(demo_text, return_tensors="pt")["input_ids"].to(device)
        outputs = model(demo_ids, use_cache=True)
        demo_key_values = outputs.past_key_values

    # logging.info(f"{n_demos}-shot ICL...")
    result_n_shot = evaluate_model(
        tokenizer=tokenizer,
        model=model,
        name="n-shot",
        dl=dl,
        data_handler=data_handler,
        demo_key_values=None,
        demo_ids=demo_ids,
    )

    # logging.info("0-shot...")
    result_zero_shot = evaluate_model(
        model,
        tokenizer,
        name="0-shot",
        dl=dl,
        data_handler=data_handler,
    )

    return result_n_shot, result_zero_shot


# === Main entry point ===
if __name__ == "__main__":
    import argparse
    from models import initialize_model
    from dataloader import DATASETS

    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default="llama-3.2-1b-it")
    parser.add_argument("--num-demos", type=int, default=4)
    parser.add_argument("--data", type=str, default="mmlu-misc")
    parser.add_argument("--peft", type=str, default="lora")
    parser.add_argument("--batch-size", type=int, default=1)
    parser.add_argument("--epochs", type=int, default=3)
    parser.add_argument("--lr", type=float, default=5e-5)
    parser.add_argument("--l2", type=float, default=0.01)
    parser.add_argument("--hidden-dim", type=int, default=512)
    parser.add_argument("--n-flows", type=int, default=4)
    args = parser.parse_args()

    dh = DATASETS[args.data]
    splits = dh.get_dataset_splits()
    test, train, val = splits

    tokenizer = AutoTokenizer.from_pretrained(TRANSFORMERS[args.model])
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"  # safer for causal LM

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    meta = type("Meta", (), {"task_type": "lm"})()
    model = initialize_model(args, meta, device)
    model.to(device)

    model.lm.config.pad_token_id = tokenizer.pad_token_id
    # (optional) also for generation config
    if hasattr(model, "generation_config"):
        model.generation_config.pad_token_id = tokenizer.pad_token_id

    test_dl = DataLoader(
        test,
        batch_size=args.batch_size,
        collate_fn=partial(collate_queries, dh=dh, tokenizer=tokenizer, device=device),
    )

    demos = [dh.make_demonstration(val[i]) for i in range(args.num_demos)]

    # icl_eval(model, device, tokenizer, dh, demos, test_dl)

    # adapter_ft(model, tokenizer, splits, device, args, dh)

    run_experiment(
        model,
        tokenizer,
        splits=splits,
        device=device,
        args=args,
        data_handler=dh,
        demos=demos,
    )
