#!/usr/bin/env python3
import argparse, os, json, math
from dataclasses import dataclass
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
from transformers.optimization import get_linear_schedule_with_warmup

# ---------- data ----------
def load_jsonl(p):
    rows=[]
    with open(p,"r",encoding="utf-8") as f:
        for ln in f: rows.append(json.loads(ln))
    return rows

class JsonlDataset(Dataset):
    def __init__(self, rows, tok, max_len=256):
        self.rows = rows
        self.tok = tok
        self.max_len = max_len
    def __len__(self): return len(self.rows)
    def __getitem__(self, i):
        r = self.rows[i]
        enc = self.tok(r["text"], truncation=True, padding="max_length", max_length=self.max_len, return_tensors="pt")
        item = {k:v.squeeze(0) for k,v in enc.items()}
        y = 1 if r["label"]=="FLAGGED" else 0
        item["label"] = torch.tensor(y, dtype=torch.long)
        return item

# ---------- model head ----------
class SimpleHead(torch.nn.Module):
    def __init__(self, hidden, drop=0.1):
        super().__init__()
        self.dropout = torch.nn.Dropout(drop)
        self.head = torch.nn.Linear(hidden, 2)
    def forward(self, cls):
        return self.head(self.dropout(cls))

# ---------- utils ----------
def class_weights_from_labels(int_labels):
    # weight minority higher; returns tensor [w_notflagged, w_flagged]
    import numpy as np
    n = len(int_labels)
    pos = sum(int_labels)
    neg = n - pos
    # avoid div by zero
    w_pos = n/(2*max(1,pos))
    w_neg = n/(2*max(1,neg))
    return torch.tensor([w_neg, w_pos], dtype=torch.float)

def evaluate(enc, head, dl, device):
    enc.eval(); head.eval()
    y_true=[]; y_pred=[]
    with torch.no_grad():
        for batch in dl:
            input_ids = batch["input_ids"].to(device)
            attn      = batch["attention_mask"].to(device)
            out = enc(input_ids=input_ids, attention_mask=attn)
            cls = out.last_hidden_state[:,0,:]
            logits = head(cls)
            pred = torch.argmax(logits, dim=-1)
            y_pred.extend(pred.cpu().tolist())
            y_true.extend(batch["label"].cpu().tolist())
    from sklearn.metrics import accuracy_score, f1_score, classification_report
    acc = accuracy_score(y_true, y_pred)
    f1m = f1_score(y_true, y_pred, average="macro")
    rep = classification_report(y_true, y_pred, target_names=["NOT FLAGGED","FLAGGED"])
    return acc, f1m, rep

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--train", required=True)
    ap.add_argument("--test",  required=True)
    ap.add_argument("--model", default="roberta-base")
    ap.add_argument("--epochs", type=int, default=6)
    ap.add_argument("--batch-size", type=int, default=16)
    ap.add_argument("--lr", type=float, default=2e-5)
    ap.add_argument("--max-len", type=int, default=256)
    ap.add_argument("--grad-accum", type=int, default=2)
    ap.add_argument("--fp16", action="store_true", help="enable AMP")
    ap.add_argument("--grad-checkpoint", dest="grad_checkpoint", action="store_true", help="enable gradient checkpointing")
    ap.add_argument("--weighted-loss", dest="weighted_loss", action="store_true", help="use class-weighted CE")
    ap.add_argument("--out-dir", default="outputs/roberta_cls")
    args = ap.parse_args()

    os.makedirs(args.out_dir, exist_ok=True)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # data
    tr_rows = load_jsonl(args.train)
    te_rows = load_jsonl(args.test)
    print("---------------------------------------------")
    print(f"{len(tr_rows)} rows loaded from {args.train}")
    print("---------------------------------------------")
    print(f"{len(te_rows)} rows loaded from {args.test}")

    tok = AutoTokenizer.from_pretrained(args.model, use_fast=True)

    tr_ds = JsonlDataset(tr_rows, tok, max_len=args.max_len)
    te_ds = JsonlDataset(te_rows, tok, max_len=args.max_len)
    tr_dl = DataLoader(tr_ds, batch_size=args.batch_size, shuffle=True,  num_workers=0, pin_memory=(device=="cuda"))
    te_dl = DataLoader(te_ds, batch_size=32,               shuffle=False, num_workers=0, pin_memory=(device=="cuda"))

    # encoder (safetensors only)
    enc = AutoModel.from_pretrained(args.model, use_safetensors=True)
    if getattr(args, "grad_checkpoint", False) and hasattr(enc, "gradient_checkpointing_enable"):
        # required by HF when using checkpointing
        if hasattr(enc.config, "use_cache"):
            enc.config.use_cache = False
        enc.gradient_checkpointing_enable()
    enc.to(device)

    hidden = enc.config.hidden_size
    head = SimpleHead(hidden).to(device)

    # loss
    if getattr(args, "weighted_loss", False):
        y_int = [1 if r["label"]=="FLAGGED" else 0 for r in tr_rows]
        w = class_weights_from_labels(y_int).to(device)
        crit = torch.nn.CrossEntropyLoss(weight=w)
    else:
        crit = torch.nn.CrossEntropyLoss()

    # optim / sched
    params = list(enc.parameters()) + list(head.parameters())
    optim = torch.optim.AdamW(params, lr=args.lr)
    total_steps = args.epochs * math.ceil(len(tr_dl) / args.grad_accum)
    sched = get_linear_schedule_with_warmup(optim, num_warmup_steps=max(1, int(0.06*total_steps)), num_training_steps=total_steps)

    scaler = torch.cuda.amp.GradScaler(enabled=(args.fp16 and device=="cuda"))

    best_f1 = -1.0
    enc.train(); head.train()

    step_global = 0
    for ep in range(1, args.epochs+1):
        running = 0.0
        optim.zero_grad(set_to_none=True)
        for i, batch in enumerate(tr_dl):
            input_ids = batch["input_ids"].to(device)
            attn      = batch["attention_mask"].to(device)
            labels    = batch["label"].to(device)

            with torch.cuda.amp.autocast(enabled=(args.fp16 and device=="cuda")):
                out = enc(input_ids=input_ids, attention_mask=attn)
                cls = out.last_hidden_state[:,0,:]
                logits = head(cls)
                loss = crit(logits, labels)
                loss = loss / args.grad_accum

            scaler.scale(loss).backward()
            running += loss.item()

            if (i+1) % args.grad_accum == 0:
                # optional: grad clip after unscale
                scaler.unscale_(optim)
                torch.nn.utils.clip_grad_norm_(params, 1.0)

                scaler.step(optim)
                scaler.update()
                optim.zero_grad(set_to_none=True)
                sched.step()
                step_global += 1

        # ---- eval each epoch ----
        acc, f1m, rep = evaluate(enc, head, te_dl, device)
        print(f"\n[EVAL] acc={acc:.4f}  f1_macro={f1m:.4f}")
        print(rep)

        if f1m > best_f1:
            best_f1 = f1m
            # save encoder + tokenizer + head
            tok.save_pretrained(args.out_dir)
            enc.save_pretrained(args.out_dir, safe_serialization=True)  # safetensors
            torch.save(head.state_dict(), os.path.join(args.out_dir, "cls_head.pt"))
            print(f"[SAVE] New best f1_macro={best_f1:.4f} → {args.out_dir}")

    # final inference dump (to keep previous behavior)
    enc.eval(); head.eval()
    with torch.no_grad(), open(os.path.join("outputs","preds_c.jsonl"),"w",encoding="utf-8") as f:
        for batch in te_dl:
            input_ids = batch["input_ids"].to(device)
            attn      = batch["attention_mask"].to(device)
            out = enc(input_ids=input_ids, attention_mask=attn)
            cls = out.last_hidden_state[:,0,:]
            logits = head(cls)
            prob = torch.softmax(logits, dim=-1)[:,1]
            for i in range(prob.shape[0]):
                p = float(prob[i].cpu())
                pred = "FLAGGED" if p>=0.5 else "NOT FLAGGED"
                f.write(json.dumps({"pred":pred, "p_flagged":p})+"\n")
    print("[DONE] Wrote outputs/preds_c.jsonl")

if __name__ == "__main__":
    main()
