# --- adapter_generator_exp1.py ---

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
from peft import LoraConfig
import types


from adapters import (
    peft_factory,
)
from transformers import AutoTokenizer
from models import TRANSFORMERS

from functools import partial
from iterator import collate_queries, collate_demonstrations
import math
from metrics import *

from typing import Dict, Tuple


from transformers import StoppingCriteria, StoppingCriteriaList
from functional_lora import FunctionalLoRAInjector, generate_layer_specs


from generators import *

Tensor = torch.Tensor
DeltaDict = Dict[str, Tuple[Tensor, Tensor]]  # name -> (A:[r,in], B:[out,r])


# ===== Special tokens for this setup =====
INST_TOKEN = "<|INST|>"  # teacher sees this + instruction + query
RECON_TOKEN = "<|RECON|>"  # student generates the instruction from adapter


# ----------------- helpers -----------------
def ensure_token(tokenizer, model_lm, tok: str) -> int:
    """Add tok if missing and resize embeddings. Returns token_id."""
    if tok not in tokenizer.get_vocab():
        tokenizer.add_special_tokens({"additional_special_tokens": [tok]})
        try:
            model_lm.resize_token_embeddings(len(tokenizer))
        except AttributeError:
            pass
    return tokenizer.convert_tokens_to_ids(tok)


def kd_loss(student_logits, teacher_logits, mask: torch.Tensor, T: float = 1.0):
    """
    KL( teacher || student ) averaged over masked positions.
    logits: [B, L, V], mask: [B, L] (1=keep, 0=ignore)
    """
    s = torch.log_softmax(student_logits / T, dim=-1)
    t = torch.softmax(teacher_logits / T, dim=-1)
    kl = torch.sum(t * (torch.log(t + 1e-8) - s), dim=-1)  # [B, L]
    kl = (kl * mask).sum() / (mask.sum() + 1e-8)
    return (T**2) * kl


def make_recon_batch(
    tokenizer, target_text_ids: torch.Tensor, device, max_len: int = None
):
    rid = tokenizer.convert_tokens_to_ids(RECON_TOKEN)
    eos = tokenizer.eos_token_id
    if target_text_ids.dim() == 1:
        target_text_ids = target_text_ids.unsqueeze(0)

    # (A) append EOS to target
    tgt = target_text_ids
    if max_len is not None:
        tgt = tgt[:, :max_len]
    if eos is not None:
        eos_col = torch.full((tgt.size(0), 1), eos, dtype=tgt.dtype, device=device)
        tgt = torch.cat([tgt, eos_col], dim=1)

    # (B) teacher forcing: [RECON] + tgt[:-1] → predict tgt
    bos = torch.full((tgt.size(0), 1), rid, dtype=tgt.dtype, device=device)
    inp = torch.cat([bos, tgt[:, :-1]], dim=1)
    labels = tgt.contiguous()
    return inp, labels


@torch.no_grad()
def adapter_to_text(
    model_lm,  # HF CausalLM (e.g., model.lm)
    tokenizer,
    deltas,  # dict[name] -> (A:[r,in], B:[out,r]) from your generator
    injector=None,  # FunctionalLoRAInjector or None
    *,  # keyword-only below
    recon_token=RECON_TOKEN,
    max_length=128,
    do_sample=False,
    temperature=1.0,
    num_beams=1,
    # Fallback-only args (ignored when injector is used):
    layer_specs=None,
    base_adapter_weights=None,
    adapter_name="student",
    apply_adapter_deltas=None,  # callable(layer_specs, deltas, adapter_name, base_adapter_weights)
):
    """
    If `injector` is provided:
        - Temporarily enable functional LoRA with `deltas`, generate, then disable.
    Else (fallback):
        - Apply deltas in-place to a PEFT adapter via apply_adapter_deltas, set adapter, then generate.
    """
    device = next(model_lm.parameters()).device

    # make sure recon token exists
    if recon_token not in tokenizer.get_vocab():
        tokenizer.add_special_tokens({"additional_special_tokens": [recon_token]})
        try:
            model_lm.resize_token_embeddings(len(tokenizer))
        except AttributeError:
            pass

    recon_id = tokenizer.convert_tokens_to_ids(recon_token)
    input_ids = torch.tensor([[recon_id]], device=device)

    # Ensure pad id is set for generate
    pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id

    if injector is not None:
        # Functional path: no PEFT adapter touching, grads irrelevant (no_grad)
        injector.set_deltas(deltas)
        injector.enable()
        try:
            out = model_lm.generate(
                input_ids=input_ids,
                max_new_tokens=max_length,
                do_sample=do_sample,
                temperature=temperature,
                num_beams=num_beams if not do_sample else 1,
                eos_token_id=tokenizer.eos_token_id,
                pad_token_id=pad_id,
            )
        finally:
            injector.disable()
    else:
        # Fallback path: use your in-place writer + PEFT adapter
        assert (
            apply_adapter_deltas is not None
        ), "apply_adapter_deltas must be provided for fallback path"
        assert (
            layer_specs is not None and base_adapter_weights is not None
        ), "layer_specs/base_adapter_weights required"
        apply_adapter_deltas(layer_specs, deltas, adapter_name, base_adapter_weights)
        # activate the adapter on the PEFT model
        try:
            model_lm.set_adapter(adapter_name)
        except AttributeError:
            # some wrappers expose it on the container
            if hasattr(model_lm, "peft_config"):
                pass  # already fine
            else:
                raise
        out = model_lm.generate(
            input_ids=input_ids,
            max_new_tokens=max_length,
            do_sample=do_sample,
            temperature=temperature,
            num_beams=num_beams if not do_sample else 1,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=pad_id,
        )

    seq = out[0].tolist()
    if seq and seq[0] == recon_id:
        seq = seq[1:]
    return tokenizer.decode(seq, skip_special_tokens=True)


# ---------- Adapter merge & linearity helpers ----------
def merge_deltas(
    d1: DeltaDict,
    d2: DeltaDict,
    alpha: float = 1.0,
    beta: float = 1.0,
    normalize: bool = False,
) -> DeltaDict:
    """
    Linear merge: Δ = alpha*Δ1 + beta*Δ2 (optionally Fro-norm normalize each layer).
    Assumes both dicts have identical keys and shapes.
    """
    out: DeltaDict = {}
    for k in d1.keys():
        A1, B1 = d1[k]
        A2, B2 = d2[k]
        if normalize:
            n1 = (A1.pow(2).sum() + B1.pow(2).sum()).sqrt().clamp_min(1e-8)
            n2 = (A2.pow(2).sum() + B2.pow(2).sum()).sqrt().clamp_min(1e-8)
        else:
            n1 = n2 = 1.0
        out[k] = (alpha * A1 / n1 + beta * A2 / n2, alpha * B1 / n1 + beta * B2 / n2)
    return out


def linearity_loss_deltas(
    d_ab: DeltaDict, d_a: DeltaDict, d_b: DeltaDict
) -> torch.Tensor:
    """
    Penalize deviation from additive composition: Δ(AB) ≈ Δ(A)+Δ(B).
    Returns a scalar on the correct device/dtype.
    """
    num = torch.zeros(
        (),
        device=next(iter(d_ab.values()))[0].device,
        dtype=next(iter(d_ab.values()))[0].dtype,
    )
    den = torch.zeros(
        (),
        device=next(iter(d_ab.values()))[0].device,
        dtype=next(iter(d_ab.values()))[0].dtype,
    )
    for k in d_ab.keys():
        Aab, Bab = d_ab[k]
        Aa, Ba = d_a[k]
        Ab, Bb = d_b[k]
        num = num + (Aab - (Aa + Ab)).pow(2).sum() + (Bab - (Ba + Bb)).pow(2).sum()
        den = den + (Aab.pow(2).sum() + Bab.pow(2).sum()).clamp_min(1e-8)
    return (num / den).clamp_min(0.0)


def token_f1(ref: str, hyp: str) -> float:
    """Whitespace token F1 for reconstruction quality."""
    ref_t = ref.split()
    hyp_t = hyp.split()
    if not ref_t and not hyp_t:
        return 1.0
    from collections import Counter

    cr, ch = Counter(ref_t), Counter(hyp_t)
    tp = sum(min(ch[w], cr.get(w, 0)) for w in ch)
    prec = tp / max(1, len(hyp_t))
    rec = tp / max(1, len(ref_t))
    return 0.0 if (prec + rec) == 0 else 2 * prec * rec / (prec + rec)


# ----------------- helpers (kept minimal) -----------------
def ensure_token(tokenizer, model_lm, tok: str) -> int:
    """Add tok if missing and resize embeddings. Returns token_id."""
    if tok not in tokenizer.get_vocab():
        tokenizer.add_special_tokens({"additional_special_tokens": [tok]})
        try:
            model_lm.resize_token_embeddings(len(tokenizer))
        except AttributeError:
            pass
    return tokenizer.convert_tokens_to_ids(tok)


def kd_loss(student_logits, teacher_logits, mask: torch.Tensor, T: float = 1.0):
    """
    KL( teacher || student ) averaged over masked positions.
    logits: [B, L, V], mask: [B, L] (1=keep, 0=ignore)
    """
    s = torch.log_softmax(student_logits / T, dim=-1)
    t = torch.softmax(teacher_logits / T, dim=-1)
    kl = torch.sum(t * (torch.log(t + 1e-8) - s), dim=-1)  # [B, L]
    kl = (kl * mask).sum() / (mask.sum() + 1e-8)
    return (T**2) * kl


def make_recon_batch(
    tokenizer, target_text_ids: torch.Tensor, device, max_len: int = None
):
    rid = tokenizer.convert_tokens_to_ids(RECON_TOKEN)
    eos = tokenizer.eos_token_id
    if target_text_ids.dim() == 1:
        target_text_ids = target_text_ids.unsqueeze(0)

    tgt = target_text_ids
    if max_len is not None:
        tgt = tgt[:, :max_len]
    if eos is not None:
        eos_col = torch.full((tgt.size(0), 1), eos, dtype=tgt.dtype, device=device)
        tgt = torch.cat([tgt, eos_col], dim=1)

    bos = torch.full((tgt.size(0), 1), rid, dtype=tgt.dtype, device=device)
    inp = torch.cat([bos, tgt[:, :-1]], dim=1)
    labels = tgt.contiguous()
    return inp, labels


import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from functools import partial
from tqdm import tqdm



# --- simple delta merge ---
def merge_deltas(
    d1: DeltaDict,
    d2: DeltaDict,
    alpha: float = 1.0,
    beta: float = 1.0,
    normalize: bool = False,
) -> DeltaDict:
    out: DeltaDict = {}
    for k in d1.keys():
        A1, B1 = d1[k]
        A2, B2 = d2[k]
        if normalize:
            n1 = (A1.pow(2).sum() + B1.pow(2).sum()).sqrt().clamp_min(1e-8)
            n2 = (A2.pow(2).sum() + B2.pow(2).sum()).sqrt().clamp_min(1e-8)
        else:
            n1 = n2 = 1.0
        out[k] = (alpha * A1 / n1 + beta * A2 / n2, alpha * B1 / n1 + beta * B2 / n2)
    return out


# (optional) param-space linearity, can keep at small weight or 0
def linearity_loss_deltas(
    d_ab: DeltaDict, d_a: DeltaDict, d_b: DeltaDict
) -> torch.Tensor:
    dev = next(iter(d_ab.values()))[0].device
    dty = next(iter(d_ab.values()))[0].dtype
    num = torch.zeros((), device=dev, dtype=dty)
    den = torch.zeros((), device=dev, dtype=dty)
    for k in d_ab.keys():
        Aab, Bab = d_ab[k]
        Aa, Ba = d_a[k]
        Ab, Bb = d_b[k]
        num = num + (Aab - (Aa + Ab)).pow(2).sum() + (Bab - (Ba + Bb)).pow(2).sum()
        den = den + (Aab.pow(2).sum() + Bab.pow(2).sum()).clamp_min(1e-8)
    return (num / den).clamp_min(0.0)


def token_f1(ref: str, hyp: str) -> float:
    ref_t = ref.split()
    hyp_t = hyp.split()
    if not ref_t and not hyp_t:
        return 1.0
    from collections import Counter

    cr, ch = Counter(ref_t), Counter(hyp_t)
    tp = sum(min(ch[w], cr.get(w, 0)) for w in ch)
    prec = tp / max(1, len(hyp_t))
    rec = tp / max(1, len(ref_t))
    return 0.0 if (prec + rec) == 0 else 2 * prec * rec / (prec + rec)


def run_experiment(model, tokenizer, splits, device, args, data_handler, demos):
    test, train, val = splits

    # special tokens
    _inst_id = ensure_token(tokenizer, model.lm, INST_TOKEN)
    _recon_id = ensure_token(tokenizer, model.lm, RECON_TOKEN)

    injector = FunctionalLoRAInjector(scale=1.0)
    _ = peft_factory(args.peft)  # kept for parity, injector path is used

    # target specs (per-component ranks across all layers)
    targets = [
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ]
    layer_specs = generate_layer_specs(
        model,
        targets,
        default_rank=16,
        ranks=[16, 16, 4, 4, 8, 8, 8],
    )

    # ===== build supports =====
    support_text = "\n".join(demos)
    support_ids = tokenizer(
        support_text, return_tensors="pt", padding=True, truncation=True
    )["input_ids"].to(device)
    instruction_tokens_for_generator = support_ids

    mid = max(1, len(demos) // 2)
    demos_a = demos[:mid]
    demos_b = demos[mid:] if mid < len(demos) else demos[:1]
    support_text_a = "\n".join(demos_a)
    support_text_b = "\n".join(demos_b)
    support_text_ab = support_text_a + "\n" + support_text_b
    support_ids_a = tokenizer(
        support_text_a, return_tensors="pt", padding=True, truncation=True
    )["input_ids"].to(device)
    support_ids_b = tokenizer(
        support_text_b, return_tensors="pt", padding=True, truncation=True
    )["input_ids"].to(device)
    support_ids_ab = tokenizer(
        support_text_ab, return_tensors="pt", padding=True, truncation=True
    )["input_ids"].to(device)

    # ===== generator =====
    generator = LinearAdapterGeneratorSmall(
        layer_specs=layer_specs,
        input_dim=model.lm.config.hidden_size,
        a_std=1e-3,
    ).to(device, dtype=torch.bfloat16)

    model.lm.requires_grad_(False)
    generator.set_trainable_params()
    optimizer = torch.optim.AdamW(
        generator.get_opt_params(), lr=args.lr, weight_decay=args.l2
    )
    ce = nn.CrossEntropyLoss(ignore_index=-100)

    # loaders & injector
    dl = DataLoader(
        train,
        # [train[i] for i in range(50)],
        batch_size=args.batch_size,
        collate_fn=partial(
            collate_queries, dh=data_handler, tokenizer=tokenizer, device=device
        ),
    )
    injector.attach(layer_specs)

    # weights
    lambda_lin_params = getattr(
        args, "lin_weight", 0.0
    )  # param-space linearity (optional)
    lambda_comp_recon = getattr(args, "comp_recon", 1.0)  # CE: (A+B) decodes AB
    lambda_comp_kd = getattr(
        args, "comp_kd", 0.5
    )  # KD: teacher(AB ctx + q) vs student(q | A+B)
    lambda_kd_main = 0.1
    lambda_ce_main = 1.0
    lambda_recon_main = 0.2

    # ===== training =====
    for epoch in range(args.epochs):
        generator.train()
        agg = {
            "loss": 0.0,
            "kd": 0.0,
            "ce": 0.0,
            "recon": 0.0,
            "comp_rec": 0.0,
            "comp_kd": 0.0,
            "lin": 0.0,
        }
        steps = 0

        for batch in tqdm(dl, desc=f"Epoch {epoch+1}"):
            steps += 1

            # ---------- teacher logits ----------
            inst_tok = torch.tensor([[_inst_id]], device=device)

            # teacher for "full" (original) demos + q (kept from your baseline)
            teacher_ids_full = torch.cat(
                [
                    inst_tok.repeat(support_ids.size(0), 1),
                    support_ids.repeat(batch.input_ids.size(0), 1),
                    batch.input_ids,
                ],
                dim=1,
            )

            # teacher for concatenated AB + q (key for composition KD)
            teacher_ids_ab = torch.cat(
                [
                    inst_tok.repeat(support_ids_ab.size(0), 1),
                    support_ids_ab.repeat(batch.input_ids.size(0), 1),
                    batch.input_ids,
                ],
                dim=1,
            )

            injector.disable()
            with torch.no_grad():
                t_full = model(teacher_ids_full).logits
                t_ab = model(teacher_ids_ab).logits
                Lq = batch.input_ids.size(1)
                logits_T_full = t_full[:, -Lq:, :]
                logits_T_ab = t_ab[:, -Lq:, :]

            # ---------- build adapters ----------
            # base adapter from original instruction (as in your baseline path)
            d_full = generator(instruction_tokens_for_generator, model)

            # adapters for A, B, AB and merged (A+B)
            dA = generator(support_ids_a, model)
            dB = generator(support_ids_b, model)
            dAB = generator(support_ids_ab, model)
            dM = merge_deltas(dA, dB, alpha=1.0, beta=1.0, normalize=False)

            # ---------- student passes ----------
            # (i) baseline KD/CE/RECON using d_full
            injector.set_deltas(d_full)
            injector.enable()
            q_ids = batch.input_ids
            s_out = model(q_ids)
            logits_S = s_out.logits[:, -Lq:, :]
            mask = (q_ids != tokenizer.pad_token_id).float()
            L_kd = kd_loss(logits_S, logits_T_full, mask, T=1.0)

            inputs_for_ce = q_ids[:, :-1]
            s_out_ce = model(inputs_for_ce)
            logits_last = s_out_ce.logits[:, -1, :]
            L_ce = ce(
                logits_last.view(-1, logits_last.size(-1)), batch.target_ids.view(-1)
            )

            rinp_full, rlab_full = make_recon_batch(tokenizer, support_ids, device)
            r_out_full = model(rinp_full)
            L_recon = ce(
                r_out_full.logits.view(-1, r_out_full.logits.size(-1)),
                rlab_full.view(-1),
            )
            injector.clear_deltas()
            injector.disable()

            # (ii) composition in decoded space: with merged adapter (A+B), decode AB under teacher forcing
            rinp_ab, rlab_ab = make_recon_batch(tokenizer, support_ids_ab, device)
            injector.set_deltas(dM)
            injector.enable()
            r_out_merge = model(rinp_ab)
            L_comp_recon = ce(
                r_out_merge.logits.view(-1, r_out_merge.logits.size(-1)),
                rlab_ab.view(-1),
            )

            # (iii) composition KD: student(q | A+B) should match teacher(q | AB ctx)
            s_out_m = model(q_ids)
            logits_S_m = s_out_m.logits[:, -Lq:, :]
            L_comp_kd = kd_loss(logits_S_m, logits_T_ab, mask, T=1.0)
            injector.clear_deltas()
            injector.disable()

            # (iv) (optional) param-space linearity penalty between dAB and (dA+dB)
            L_lin = (
                linearity_loss_deltas(dAB, dA, dB)
                if lambda_lin_params > 0
                else torch.tensor(0.0, device=device)
            )

            # total
            loss = (
                lambda_kd_main * L_kd
                + lambda_ce_main * L_ce
                + lambda_recon_main * L_recon
                + lambda_comp_recon * L_comp_recon
                + lambda_comp_kd * L_comp_kd
                # + lambda_lin_params * L_lin
            )

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            agg["loss"] += float(loss)
            agg["kd"] += float(L_kd)
            agg["ce"] += float(L_ce)
            agg["recon"] += float(L_recon)
            agg["comp_rec"] += float(L_comp_recon)
            agg["comp_kd"] += float(L_comp_kd)
            agg["lin"] += float(L_lin)

        print(
            f"Epoch {epoch+1} — Loss:{agg['loss']/steps:.4f} | "
            f"KD:{agg['kd']/steps:.4f} CE:{agg['ce']/steps:.4f} RECON:{agg['recon']/steps:.4f} | "
            f"COMP_RECON:{agg['comp_rec']/steps:.4f} COMP_KD:{agg['comp_kd']/steps:.4f} "
            f"LIN:{agg['lin']/steps:.4f}"
        )

        # ===== qualitative recon and merge eval =====
        generator.eval()
        with torch.no_grad():
            d_final = generator(instruction_tokens_for_generator, model)
            recon_text = adapter_to_text(
                model.lm,
                tokenizer,
                d_final,
                injector=injector,
                max_length=256,
                do_sample=False,
            )
            dA = generator(support_ids_a, model)
            dB = generator(support_ids_b, model)
            dAB = generator(support_ids_ab, model)
            dM = merge_deltas(dA, dB)

        print("[ORIG INSTRUCTION]\n" + support_text)
        print("[RECONSTRUCTED]\n" + recon_text)

        recon_A = adapter_to_text(
            model.lm, tokenizer, dA, injector=injector, max_length=256, do_sample=False
        )
        recon_B = adapter_to_text(
            model.lm, tokenizer, dB, injector=injector, max_length=256, do_sample=False
        )
        recon_AB = adapter_to_text(
            model.lm, tokenizer, dAB, injector=injector, max_length=256, do_sample=False
        )
        recon_M = adapter_to_text(
            model.lm, tokenizer, dM, injector=injector, max_length=256, do_sample=False
        )

        print("\n[RECON QUALITY]")
        print(f"A   F1: {token_f1(support_text_a,  recon_A):.3f}")
        print(f"B   F1: {token_f1(support_text_b,  recon_B):.3f}")
        print(f"AB  (direct)  F1: {token_f1(support_text_ab, recon_AB):.3f}")
        print(f"A+B (merged)  F1: {token_f1(support_text_ab, recon_M):.3f}")
        print("\n[MERGED ADAPTER RECONSTRUCTION]\n" + recon_M)  # <--- NEW printout
        gap_lin = float(linearity_loss_deltas(dAB, dA, dB).detach().cpu())
        print(f"Δ-linearity (params) AB vs A+B: {gap_lin:.6f}")

        # task eval with merged adapter
        injector.set_deltas(dM)
        injector.enable()
        test_dl = DataLoader(
            splits[0],
            batch_size=args.batch_size,
            collate_fn=partial(
                collate_queries, dh=data_handler, tokenizer=tokenizer, device=device
            ),
        )
        model.eval()
        correct_count_exact = 0
        total_count = 0
        with torch.inference_mode():
            for batch in tqdm(test_dl, desc="Eval merged A+B"):
                input_ids = batch.input_ids
                outputs = model(input_ids)
                logits = outputs.logits[:, -1, :].squeeze()
                token_id = torch.argmax(logits)
                correct_count_exact += batch.target_ids == token_id
                total_count += 1
        print(f"[MERGE] Exact match: {correct_count_exact.item()}/{total_count}")
        injector.disable()
        injector.clear_deltas()



def adapter_ft(model, tokenizer, splits, device, args, data_handler):
    peft_config_base = peft_factory(args.peft)
    trained_adapters = []

    test, train, dev = splits

    adapter_name = f"adapter"
    import copy

    peft_config = copy.deepcopy(peft_config_base)
    model.lm.add_adapter(peft_config, adapter_name=adapter_name)
    model.lm.set_adapter(adapter_name)

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.l2)
    criterion = nn.CrossEntropyLoss(ignore_index=-100)

    for epoch in range(args.epochs):
        # logging.info(f"Training epoch: {epoch+1}/{self.args.epochs}")

        model.train()
        loss_agg = 0.0

        dl = DataLoader(
            train,  # [train[i] for i in range(100)],
            batch_size=args.batch_size,
            collate_fn=partial(
                collate_demonstrations,
                dh=data_handler,
                tokenizer=tokenizer,
                device=device,
            ),
        )

        for batch in tqdm(dl, desc=f"Epoch {epoch+1}"):
            # print(batch)
            optimizer.zero_grad()
            input_ids = batch.input_ids
            inputs = input_ids[:, :-1]

            labels = torch.full_like(inputs, -100)
            labels[:, -1] = input_ids[:, -1]

            # labels = input_ids[:, 1:].clone()

            output = model(inputs)
            logits = output.logits
            vocab_size = logits.shape[-1]
            loss = criterion(logits.view(-1, vocab_size), labels.view(-1))
            loss_agg += loss.item()
            loss.backward()
            optimizer.step()

        dl = DataLoader(
            test,
            batch_size=1,
            collate_fn=partial(
                collate_queries,
                **dict(
                    dh=data_handler,
                    tokenizer=tokenizer,
                    device=device,
                ),
            ),
        )

        # model.lm.active_adapters = None
        with torch.inference_mode():
            verbalizers = data_handler.class_labels
            verbalizer_ids = [tokenizer.encode(verb)[-1] for verb in verbalizers]

            correct_count_exact = 0
            total_count = 0
            for batch in tqdm(dl):
                input_ids = batch.input_ids
                label = batch.target[0]

                outputs = model(input_ids)
                logits = outputs.logits[:, -1, :].squeeze()

                token_id = torch.argmax(logits)
                pred = tokenizer.decode(token_id)
                correct_count_exact += batch.target_ids == token_id
                total_count += 1

            print(f"Exact match: {correct_count_exact.item()}/{total_count}")


def evaluate_model(
    model, tokenizer, name, dl, data_handler, demo_key_values=None, demo_ids=None
):
    model.eval()

    prob_list = []
    pred_list = []
    target_list = []
    with torch.inference_mode():
        verbalizers = data_handler.class_labels
        verbalizer_ids = [tokenizer.encode(verb)[-1] for verb in verbalizers]

        correct_count_exact = 0
        total_count = 0
        for batch in tqdm(dl):
            input_ids = batch.input_ids
            if demo_ids is not None:
                input_ids = torch.concat([demo_ids, input_ids], dim=1)
            label = batch.target[0]
            target_list.append(label.cpu())

            outputs = model(input_ids, past_key_values=demo_key_values)
            logits = outputs.logits[:, -1, :].squeeze()
            preds = logits[verbalizer_ids]
            prob_list.append(preds.cpu())
            pred_label = torch.argmax(preds)
            pred_list.append(pred_label.cpu())

            token_id = torch.argmax(logits)
            pred = tokenizer.decode(token_id)
            correct_count_exact += batch.target_ids == token_id
            total_count += 1

        print(f"Exact match: {correct_count_exact.item()}/{total_count}")

    result = Result(name=name)

    y_pred = torch.tensor(pred_list)
    y_true = torch.tensor(target_list)
    y_logits = torch.stack(prob_list)
    result.evaluate(y_true, y_pred)
    result.logits = (y_true, y_logits)

    print(result)

    return result


def icl_eval(
    model,
    device,
    tokenizer,
    data_handler,
    demos,
    dl,
):

    demo_text = "\n".join(demos)

    with torch.inference_mode():
        demo_ids = tokenizer(demo_text, return_tensors="pt")["input_ids"].to(device)
        outputs = model(demo_ids, use_cache=True)
        demo_key_values = outputs.past_key_values

    # logging.info(f"{n_demos}-shot ICL...")
    result_n_shot = evaluate_model(
        tokenizer=tokenizer,
        model=model,
        name="n-shot",
        dl=dl,
        data_handler=data_handler,
        demo_key_values=None,
        demo_ids=demo_ids,
    )

    # logging.info("0-shot...")
    result_zero_shot = evaluate_model(
        model,
        tokenizer,
        name="0-shot",
        dl=dl,
        data_handler=data_handler,
    )

    return result_n_shot, result_zero_shot


# === Main entry point ===
if __name__ == "__main__":
    import argparse
    from models import initialize_model
    from dataloader import DATASETS

    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default="llama-3.2-1b-it")
    parser.add_argument("--num-demos", type=int, default=4)
    parser.add_argument("--data", type=str, default="mmlu-misc")
    parser.add_argument("--peft", type=str, default="lora")
    parser.add_argument("--batch-size", type=int, default=1)
    parser.add_argument("--epochs", type=int, default=3)
    parser.add_argument("--lr", type=float, default=5e-5)
    parser.add_argument("--l2", type=float, default=0.01)
    parser.add_argument("--hidden-dim", type=int, default=512)
    parser.add_argument("--n-flows", type=int, default=4)
    args = parser.parse_args()

    dh = DATASETS[args.data]
    splits = dh.get_dataset_splits()
    test, train, val = splits

    tokenizer = AutoTokenizer.from_pretrained(TRANSFORMERS[args.model])
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"  # safer for causal LM

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    meta = type("Meta", (), {"task_type": "lm"})()
    model = initialize_model(args, meta, device)
    model.to(device)

    model.lm.config.pad_token_id = tokenizer.pad_token_id
    # (optional) also for generation config
    if hasattr(model, "generation_config"):
        model.generation_config.pad_token_id = tokenizer.pad_token_id

    test_dl = DataLoader(
        test,
        batch_size=args.batch_size,
        collate_fn=partial(collate_queries, dh=dh, tokenizer=tokenizer, device=device),
    )

    demos = [dh.make_demonstration(val[i]) for i in range(args.num_demos)]

    # icl_eval(model, device, tokenizer, dh, demos, test_dl)

    # adapter_ft(model, tokenizer, splits, device, args, dh)

    run_experiment(
        model,
        tokenizer,
        splits=splits,
        device=device,
        args=args,
        data_handler=dh,
        demos=demos,
    )
