"""Knowledge distillation: train a single student model on ensemble soft labels.

Loads K teacher bundles, predicts soft probs on training spans, trains a
single roberta-large with KL divergence loss against the averaged teacher
probabilities. Goal: approach ensemble quality with one model.
"""
from __future__ import annotations

import argparse
import json
import os
import sys
import time
from pathlib import Path

if any("modernbert" in a.lower() for a in sys.argv):
    os.environ.setdefault("TORCHDYNAMO_DISABLE", "1")
    os.environ.setdefault("TORCH_COMPILE_DISABLE", "1")

import joblib
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.metrics import f1_score
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
)

from analysis.exploration.llm_validation._client import PRIMITIVES
from analysis.exploration.llm_validation.classifier.dbert_pipeline import (
    DistilBertPipeline, DistilBertClassifier,
)
from analysis.exploration.llm_validation.classifier.ensemble import predict_proba
from analysis.exploration.llm_validation.classifier.prepare_dataset import (
    load_silver, report_split, trace_group_key,
)


class DistillDataset(Dataset):
    def __init__(self, rows, soft_labels, tokenizer, max_len=512):
        self.rows = rows
        self.soft = soft_labels  # (N, K) numpy
        self.tok = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.rows)

    def __getitem__(self, i):
        r = self.rows[i]
        a = (r.get("preceding_context") or "").strip() or "<empty>"
        b = (r.get("span_text") or "").strip()
        enc = self.tok(
            a, b, truncation=True, max_length=self.max_len,
            padding=False, return_tensors=None,
        )
        return {
            "input_ids": enc["input_ids"],
            "attention_mask": enc["attention_mask"],
            "soft": self.soft[i],
        }


def collate(batch, pad_id):
    max_len = max(len(b["input_ids"]) for b in batch)
    ids = torch.zeros(len(batch), max_len, dtype=torch.long).fill_(pad_id)
    mask = torch.zeros(len(batch), max_len, dtype=torch.long)
    soft = torch.zeros(len(batch), len(batch[0]["soft"]), dtype=torch.float)
    for i, b in enumerate(batch):
        n = len(b["input_ids"])
        ids[i, :n] = torch.tensor(b["input_ids"])
        mask[i, :n] = torch.tensor(b["attention_mask"])
        soft[i] = torch.tensor(b["soft"])
    return {"input_ids": ids, "attention_mask": mask, "soft": soft}


@torch.no_grad()
def eval_macro(model, loader, device):
    model.eval()
    preds = []
    truths = []
    for batch in loader:
        ids = batch["input_ids"].to(device)
        mask = batch["attention_mask"].to(device)
        logits = model(input_ids=ids, attention_mask=mask).logits
        preds.extend(logits.argmax(dim=-1).cpu().tolist())
        truths.extend(batch["soft"].argmax(dim=-1).tolist())
    return f1_score(truths, preds, average="macro", zero_division=0)


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--silvers", nargs="+", required=True, type=Path)
    ap.add_argument("--v1-splits", required=True, type=Path)
    ap.add_argument("--teachers", nargs="+", required=True, type=Path,
                    help="Teacher bundle paths (their predict_proba is averaged)")
    ap.add_argument("--out", required=True, type=Path)
    ap.add_argument("--student-encoder", default="roberta-large")
    ap.add_argument("--epochs", type=int, default=4)
    ap.add_argument("--batch-size", type=int, default=16)
    ap.add_argument("--lr", type=float, default=1e-5)
    ap.add_argument("--temperature", type=float, default=2.0,
                    help="Distillation temperature (>1 softens teacher distribution)")
    ap.add_argument("--alpha", type=float, default=0.7,
                    help="Weight on KL loss (1-alpha goes to hard-label CE)")
    ap.add_argument("--seed", type=int, default=42)
    args = ap.parse_args()

    args.out.parent.mkdir(parents=True, exist_ok=True)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    # --- Load + dedupe silvers ---
    all_rows = []
    for sp in args.silvers:
        all_rows.extend(load_silver(sp))
    seen = set()
    deduped = []
    for r in all_rows:
        if r["span_id"] in seen:
            continue
        seen.add(r["span_id"])
        deduped.append(r)
    print(f"Total deduped silver rows: {len(deduped)}")

    # --- Reconstruct trace splits from v1 ---
    v1_splits = json.load(open(args.v1_splits))
    # Build span_id->row across ALL silvers so multi-domain splits resolve
    base_by_id = {r["span_id"]: r for r in deduped}
    train_traces, val_traces, test_traces = set(), set(), set()
    for sid in v1_splits["train_ids"]:
        if sid in base_by_id:
            train_traces.add(trace_group_key(base_by_id[sid]))
    for sid in v1_splits["val_ids"]:
        if sid in base_by_id:
            val_traces.add(trace_group_key(base_by_id[sid]))
    for sid in v1_splits["test_ids"]:
        if sid in base_by_id:
            test_traces.add(trace_group_key(base_by_id[sid]))
    train, val, test = [], [], []
    for r in deduped:
        tk = trace_group_key(r)
        if tk in val_traces: val.append(r)
        elif tk in test_traces: test.append(r)
        else: train.append(r)
    report_split("train", train)
    report_split("val", val)
    report_split("test", test)

    label_idx = {l: i for i, l in enumerate(PRIMITIVES)}
    n_classes = len(PRIMITIVES)

    # --- Predict ensemble soft labels on train ---
    print(f"\nPredicting teacher soft labels with {len(args.teachers)} teachers...")
    teacher_probs = None
    for t_path in args.teachers:
        bundle = joblib.load(t_path)
        proba = predict_proba(bundle, train)
        if teacher_probs is None:
            teacher_probs = proba
        else:
            teacher_probs = teacher_probs + proba
        print(f"  loaded {t_path.stem}")
    teacher_probs = teacher_probs / len(args.teachers)
    teacher_probs = teacher_probs.astype(np.float32)
    print(f"Teacher probs shape: {teacher_probs.shape}")

    # --- Set up student ---
    print(f"\nLoading student {args.student_encoder}...")
    tok = AutoTokenizer.from_pretrained(args.student_encoder)
    # Causal LMs (Qwen, Llama) usually have no pad_token; reuse eos for padding.
    if tok.pad_token is None and tok.eos_token is not None:
        tok.pad_token = tok.eos_token
    model_kwargs = {"num_labels": n_classes}
    if "modernbert" in args.student_encoder.lower():
        model_kwargs["attn_implementation"] = "sdpa"
    model = AutoModelForSequenceClassification.from_pretrained(
        args.student_encoder, **model_kwargs,
    )
    if model.config.pad_token_id is None and tok.pad_token_id is not None:
        model.config.pad_token_id = tok.pad_token_id
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    print(f"  device: {device}")

    train_ds = DistillDataset(train, teacher_probs, tok)
    # For val/test we use one-hot from the silver labels, just for tracking
    val_soft = np.eye(n_classes, dtype=np.float32)[
        [label_idx[r["llm_label"]] for r in val]
    ]
    test_soft = np.eye(n_classes, dtype=np.float32)[
        [label_idx[r["llm_label"]] for r in test]
    ]
    val_ds = DistillDataset(val, val_soft, tok)
    test_ds = DistillDataset(test, test_soft, tok)
    pad_id = tok.pad_token_id
    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
                              collate_fn=lambda b: collate(b, pad_id), num_workers=2)
    val_loader = DataLoader(val_ds, batch_size=args.batch_size*2, shuffle=False,
                            collate_fn=lambda b: collate(b, pad_id), num_workers=2)
    test_loader = DataLoader(test_ds, batch_size=args.batch_size*2, shuffle=False,
                             collate_fn=lambda b: collate(b, pad_id), num_workers=2)

    n_train_steps = len(train_loader) * args.epochs
    optim = torch.optim.AdamW(model.parameters(), lr=args.lr)
    sched = get_linear_schedule_with_warmup(
        optim, num_warmup_steps=int(0.1 * n_train_steps),
        num_training_steps=n_train_steps,
    )

    T = args.temperature
    alpha = args.alpha
    best_val = -1.0
    best_epoch = 0
    student_dir = args.out.with_suffix(".hf")
    student_dir.mkdir(parents=True, exist_ok=True)

    for epoch in range(1, args.epochs + 1):
        t0 = time.time()
        model.train()
        loss_sum = 0.0
        n = 0
        for batch in train_loader:
            ids = batch["input_ids"].to(device)
            mask = batch["attention_mask"].to(device)
            soft = batch["soft"].to(device)
            logits = model(input_ids=ids, attention_mask=mask).logits

            # KL distillation loss
            log_p_student = F.log_softmax(logits / T, dim=-1)
            kl = F.kl_div(log_p_student, soft, reduction="batchmean") * (T * T)
            # Hard-label CE on argmax of soft (== silver label here)
            hard_labels = soft.argmax(dim=-1)
            ce = F.cross_entropy(logits, hard_labels)
            loss = alpha * kl + (1 - alpha) * ce

            optim.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optim.step()
            sched.step()
            loss_sum += loss.item() * len(hard_labels)
            n += len(hard_labels)

        train_loss = loss_sum / max(1, n)
        val_f1 = eval_macro(model, val_loader, device)
        dt = time.time() - t0
        print(f"  Epoch {epoch}: train_loss={train_loss:.4f}  val_macro_f1={val_f1:.4f}  ({dt:.1f}s)")
        if val_f1 > best_val:
            best_val = val_f1
            best_epoch = epoch
            model.save_pretrained(student_dir)
            tok.save_pretrained(student_dir)

    print(f"\nBest val macro-F1: {best_val:.4f} at epoch {best_epoch}")
    model = AutoModelForSequenceClassification.from_pretrained(student_dir).to(device)
    test_f1 = eval_macro(model, test_loader, device)
    print(f"Test macro-F1: {test_f1:.4f}")

    pipe = DistilBertPipeline(model_dir=str(student_dir), label_order=list(PRIMITIVES))
    clf = DistilBertClassifier(model_dir=str(student_dir), label_order=list(PRIMITIVES))
    bundle = {
        "feature_pipeline": pipe,
        "classifier": clf,
        "label_order": list(PRIMITIVES),
        "hyperparams": {"distill_T": T, "distill_alpha": alpha,
                        "best_val_macro_f1": best_val, "test_macro_f1": test_f1,
                        "n_teachers": len(args.teachers)},
        "pipeline_type": "distilled",
    }
    joblib.dump(bundle, args.out)
    print(f"Saved bundle -> {args.out}")
    splits_path = args.out.with_suffix(".splits.json")
    with open(splits_path, "w") as f:
        json.dump({
            "train_ids": [r["span_id"] for r in train],
            "val_ids": [r["span_id"] for r in val],
            "test_ids": [r["span_id"] for r in test],
        }, f)


if __name__ == "__main__":
    main()
