"""Stage v3: fine-tune DistilBERT end-to-end on silver labels.

Reuses v1 trace splits when available (apples-to-apples eval). Trains
DistilBERT-base with class-weighted CE loss, AdamW, lr=2e-5, batch 16,
3 epochs, early-stop on val macro-F1.

Inference will run on GPU; ~5–10 ms/span on a single A100.
"""
from __future__ import annotations

import argparse
import json
import os
import sys
import time
from dataclasses import dataclass
from pathlib import Path

# ModernBERT triggers torch.compile/Triton at runtime which fails to compile
# on this host. Disable proactively if the user passes a ModernBERT encoder.
if any("modernbert" in a.lower() for a in sys.argv):
    os.environ.setdefault("TORCHDYNAMO_DISABLE", "1")
    os.environ.setdefault("TORCH_COMPILE_DISABLE", "1")

import joblib
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.metrics import f1_score
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
)

from analysis.exploration.llm_validation._client import PRIMITIVES
from analysis.exploration.llm_validation.classifier.prepare_dataset import (
    load_silver, report_split, trace_group_key,
)


class SpanDataset(Dataset):
    """BERT-style sentence-pair: (preceding_context, span_text)."""

    def __init__(self, rows, tokenizer, label_idx, max_len=512):
        self.rows = rows
        self.tok = tokenizer
        self.label_idx = label_idx
        self.max_len = max_len

    def __len__(self):
        return len(self.rows)

    def __getitem__(self, i):
        r = self.rows[i]
        a = (r.get("preceding_context") or "").strip() or "<empty>"
        b = (r.get("span_text") or "").strip()
        enc = self.tok(
            a, b,
            truncation=True, max_length=self.max_len,
            padding=False, return_tensors=None,
        )
        return {
            "input_ids": enc["input_ids"],
            "attention_mask": enc["attention_mask"],
            "label": self.label_idx[r["llm_label"]],
        }


def collate(batch, pad_id):
    max_len = max(len(b["input_ids"]) for b in batch)
    ids = torch.zeros(len(batch), max_len, dtype=torch.long).fill_(pad_id)
    mask = torch.zeros(len(batch), max_len, dtype=torch.long)
    labels = torch.zeros(len(batch), dtype=torch.long)
    for i, b in enumerate(batch):
        n = len(b["input_ids"])
        ids[i, :n] = torch.tensor(b["input_ids"])
        mask[i, :n] = torch.tensor(b["attention_mask"])
        labels[i] = b["label"]
    return {"input_ids": ids, "attention_mask": mask, "labels": labels}


from analysis.exploration.llm_validation.classifier.dbert_pipeline import (
    DistilBertPipeline,
    DistilBertClassifier,
)


def evaluate_macro_f1(model, loader, device, n_classes):
    model.eval()
    preds = []
    labels = []
    with torch.no_grad():
        for batch in loader:
            ids = batch["input_ids"].to(device)
            mask = batch["attention_mask"].to(device)
            logits = model(input_ids=ids, attention_mask=mask).logits
            preds.extend(logits.argmax(dim=-1).cpu().tolist())
            labels.extend(batch["labels"].tolist())
    return f1_score(labels, preds, average="macro", zero_division=0)


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--silvers", nargs="+", required=True, type=Path,
                    help="One or more silver JSONL files to combine for training")
    ap.add_argument("--v1-splits", required=True, type=Path,
                    help="classifier_v1.splits.json (provides trace→split assignment)")
    ap.add_argument("--out", required=True, type=Path,
                    help="Joblib bundle path; the underlying HF model goes to --model-dir")
    ap.add_argument("--model-dir", default=None, type=Path,
                    help="Where to save the HuggingFace model. Defaults to <out>.hf/")
    ap.add_argument("--encoder", default="distilbert-base-uncased")
    ap.add_argument("--max-len", type=int, default=512)
    ap.add_argument("--batch-size", type=int, default=16)
    ap.add_argument("--epochs", type=int, default=3)
    ap.add_argument("--lr", type=float, default=2e-5)
    ap.add_argument("--seed", type=int, default=42)
    ap.add_argument(
        "--class-weight", choices=["balanced", "sqrt", "none"], default="balanced",
        help="balanced=N/(K*c_i); sqrt=sqrt(N/c_i); none=uniform",
    )
    ap.add_argument(
        "--label-smoothing", type=float, default=0.0,
        help="Cross-entropy label smoothing (0 = off)",
    )
    args = ap.parse_args()

    if args.model_dir is None:
        args.model_dir = args.out.with_suffix(".hf")
    args.model_dir.mkdir(parents=True, exist_ok=True)
    args.out.parent.mkdir(parents=True, exist_ok=True)

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    # ---- Load and combine silvers ----
    all_rows = []
    for sp in args.silvers:
        rows = load_silver(sp)
        print(f"  Loaded {len(rows)} from {sp}")
        all_rows.extend(rows)
    # Dedupe on span_id (some files may overlap)
    seen = set()
    deduped = []
    for r in all_rows:
        if r["span_id"] in seen:
            continue
        seen.add(r["span_id"])
        deduped.append(r)
    print(f"Total deduped: {len(deduped)}")

    # ---- Reconstruct v1 trace assignments and route rows ----
    v1_splits = json.load(open(args.v1_splits))
    # Build span_id->row across ALL silvers so multi-domain splits
    # (puzzle + math) resolve trace keys for every referenced span.
    base_by_id = {r["span_id"]: r for r in deduped}
    train_traces, val_traces, test_traces = set(), set(), set()
    for sid in v1_splits["train_ids"]:
        if sid in base_by_id:
            train_traces.add(trace_group_key(base_by_id[sid]))
    for sid in v1_splits["val_ids"]:
        if sid in base_by_id:
            val_traces.add(trace_group_key(base_by_id[sid]))
    for sid in v1_splits["test_ids"]:
        if sid in base_by_id:
            test_traces.add(trace_group_key(base_by_id[sid]))

    train, val, test = [], [], []
    for r in deduped:
        tk = trace_group_key(r)
        if tk in val_traces:
            val.append(r)
        elif tk in test_traces:
            test.append(r)
        else:  # train traces or new
            train.append(r)
    report_split("train", train)
    report_split("val",   val)
    report_split("test",  test)

    # ---- Tokenizer + model ----
    label_order = list(PRIMITIVES)
    label_idx = {l: i for i, l in enumerate(label_order)}
    n_classes = len(label_order)

    print(f"\nLoading {args.encoder}...")
    tok = AutoTokenizer.from_pretrained(args.encoder)
    # Causal LMs (Qwen, Llama) usually have no pad_token; reuse eos for padding.
    if tok.pad_token is None and tok.eos_token is not None:
        tok.pad_token = tok.eos_token
    model_kwargs = {"num_labels": n_classes}
    if "modernbert" in args.encoder.lower():
        model_kwargs["attn_implementation"] = "sdpa"
    # NLI- or task-tuned variants have a head with the wrong shape; rebuild it.
    if any(t in args.encoder.lower() for t in ("mnli", "nli", "sst", "tasksource")):
        model_kwargs["ignore_mismatched_sizes"] = True
    model = AutoModelForSequenceClassification.from_pretrained(
        args.encoder, **model_kwargs,
    )
    if model.config.pad_token_id is None and tok.pad_token_id is not None:
        model.config.pad_token_id = tok.pad_token_id
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    print(f"  device: {device}")

    # Class weights from training distribution
    train_labels = np.array([label_idx[r["llm_label"]] for r in train])
    counts = np.bincount(train_labels, minlength=n_classes)
    safe_counts = np.maximum(counts, 1)
    if args.class_weight == "balanced":
        class_weights = len(train_labels) / (n_classes * safe_counts)
    elif args.class_weight == "sqrt":
        class_weights = np.sqrt(len(train_labels) / (n_classes * safe_counts))
    else:  # none
        class_weights = np.ones(n_classes)
    class_weights = class_weights.astype(np.float32)
    cw_tensor = torch.tensor(class_weights).to(device)
    print(f"Class weights ({args.class_weight}): "
          f"{dict(zip(label_order, class_weights.round(3)))}")

    # ---- Datasets / loaders ----
    pad_id = tok.pad_token_id
    train_ds = SpanDataset(train, tok, label_idx, max_len=args.max_len)
    val_ds   = SpanDataset(val,   tok, label_idx, max_len=args.max_len)
    test_ds  = SpanDataset(test,  tok, label_idx, max_len=args.max_len)

    train_loader = DataLoader(
        train_ds, batch_size=args.batch_size, shuffle=True,
        collate_fn=lambda b: collate(b, pad_id),
        num_workers=2,
    )
    val_loader = DataLoader(
        val_ds, batch_size=args.batch_size * 2, shuffle=False,
        collate_fn=lambda b: collate(b, pad_id), num_workers=2,
    )
    test_loader = DataLoader(
        test_ds, batch_size=args.batch_size * 2, shuffle=False,
        collate_fn=lambda b: collate(b, pad_id), num_workers=2,
    )

    # ---- Optimizer + schedule ----
    n_train_steps = len(train_loader) * args.epochs
    optim = torch.optim.AdamW(model.parameters(), lr=args.lr)
    sched = get_linear_schedule_with_warmup(
        optim, num_warmup_steps=int(0.1 * n_train_steps),
        num_training_steps=n_train_steps,
    )

    # ---- Train ----
    best_val = -1.0
    best_epoch = 0
    for epoch in range(1, args.epochs + 1):
        t0 = time.time()
        model.train()
        n = 0
        loss_sum = 0.0
        for batch in train_loader:
            ids = batch["input_ids"].to(device)
            mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            logits = model(input_ids=ids, attention_mask=mask).logits
            loss = F.cross_entropy(
                logits, labels, weight=cw_tensor,
                label_smoothing=args.label_smoothing,
            )
            optim.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optim.step()
            sched.step()
            loss_sum += loss.item() * len(labels)
            n += len(labels)
        train_loss = loss_sum / max(1, n)
        val_f1 = evaluate_macro_f1(model, val_loader, device, n_classes)
        dt = time.time() - t0
        print(f"  Epoch {epoch}: train_loss={train_loss:.4f}  val_macro_f1={val_f1:.4f}  ({dt:.1f}s)")
        if val_f1 > best_val:
            best_val = val_f1
            best_epoch = epoch
            model.save_pretrained(args.model_dir)
            tok.save_pretrained(args.model_dir)
            print(f"    [saved best epoch {epoch}]")

    # ---- Test eval (best checkpoint) ----
    print()
    print(f"Best val macro-F1: {best_val:.4f} at epoch {best_epoch}")
    model = AutoModelForSequenceClassification.from_pretrained(args.model_dir).to(device)
    test_f1 = evaluate_macro_f1(model, test_loader, device, n_classes)
    print(f"Test macro-F1:  {test_f1:.4f}")

    # ---- Save sklearn-compatible bundle ----
    pipeline = DistilBertPipeline(model_dir=str(args.model_dir), label_order=label_order)
    classifier = DistilBertClassifier(model_dir=str(args.model_dir), label_order=label_order,
                                      max_len=args.max_len)
    bundle = {
        "feature_pipeline": pipeline,
        "classifier": classifier,
        "label_order": label_order,
        "hyperparams": {"encoder": args.encoder, "epochs": args.epochs, "lr": args.lr,
                        "best_epoch": best_epoch, "best_val_macro_f1": best_val,
                        "test_macro_f1": test_f1},
        "pipeline_type": "distilbert",
    }
    joblib.dump(bundle, args.out)
    print(f"Saved bundle -> {args.out}")
    print(f"Saved HF model -> {args.model_dir}")

    # Save splits (same as v1)
    splits_path = args.out.with_suffix(".splits.json")
    splits = {
        "train_ids": [r["span_id"] for r in train],
        "val_ids":   [r["span_id"] for r in val],
        "test_ids":  [r["span_id"] for r in test],
    }
    with open(splits_path, "w") as f:
        json.dump(splits, f)
    print(f"Saved splits -> {splits_path}")


if __name__ == "__main__":
    main()
