#!/usr/bin/env python3
import argparse, os, json
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel

def load_jsonl(p):
    rows=[]
    with open(p,"r",encoding="utf-8") as f:
        for ln in f: rows.append(json.loads(ln))
    return rows

class JsonlDataset(Dataset):
    def __init__(self, rows, tok, max_len=128):
        self.rows = rows
        self.tok = tok
        self.max_len = max_len
    def __len__(self): return len(self.rows)
    def __getitem__(self, i):
        r = self.rows[i]
        enc = self.tok(
            r["text"],
            truncation=True,
            padding="max_length",
            max_length=self.max_len,
            return_tensors="pt",
        )
        item = {k:v.squeeze(0) for k,v in enc.items()}
        item["text_id"] = r["text_id"]
        item["dataset"] = r.get("dataset","")
        item["true"] = r["label"]
        return item

class SimpleRobertaHead(torch.nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.dropout = torch.nn.Dropout(0.1)
        self.head = torch.nn.Linear(hidden_size, 2)
    def forward(self, cls):
        return self.head(self.dropout(cls))

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--test", required=True)
    ap.add_argument("--model-dir", default="outputs/roberta_cls")
    ap.add_argument("--out", default="outputs/preds_c.jsonl")
    ap.add_argument("--thresh", type=float, default=0.5)
    ap.add_argument("--batch-size", type=int, default=2)     # memory-safe default
    ap.add_argument("--max-len", type=int, default=128)      # match training
    ap.add_argument("--device", choices=["cuda","cpu","auto"], default="auto")
    ap.add_argument("--fp16", action="store_true", help="half precision on CUDA")
    args = ap.parse_args()

    # pick device
    if args.device == "auto":
        device = "cuda" if torch.cuda.is_available() else "cpu"
    else:
        device = args.device

    # load encoder + tokenizer + meta + head
    tok = AutoTokenizer.from_pretrained(args.model_dir, use_fast=True)
    enc = AutoModel.from_pretrained(args.model_dir, use_safetensors=True)
    if device == "cuda":
        dtype = torch.float16 if args.fp16 else torch.float32
        enc = enc.to(device=device, dtype=dtype)
    else:
        enc = enc.to(device)

    hidden = enc.config.hidden_size
    head = SimpleRobertaHead(hidden).to(device)
    # robust load of head state dict (support both key styles)
    sd = torch.load(os.path.join(args.model_dir, "cls_head.pt"), map_location=device)
    if "weight" in sd and "head.weight" not in sd:
        sd = {"head.weight": sd["weight"], "head.bias": sd["bias"]}
    head.load_state_dict(sd, strict=False)

    enc.eval(); head.eval(); torch.set_grad_enabled(False)

    rows = load_jsonl(args.test)
    ds   = JsonlDataset(rows, tok, max_len=args.max_len)
    dl   = DataLoader(ds, batch_size=args.batch_size, shuffle=False)

    os.makedirs("outputs", exist_ok=True)
    use_amp = (device == "cuda" and args.fp16)

    with open(args.out,"w",encoding="utf-8") as f:
        for batch in dl:
            input_ids = batch["input_ids"].to(device, non_blocking=True)
            attn      = batch["attention_mask"].to(device, non_blocking=True)

            if use_amp:
                with torch.autocast(device_type="cuda", dtype=torch.float16):
                    out = enc(input_ids=input_ids, attention_mask=attn)
            else:
                out = enc(input_ids=input_ids, attention_mask=attn)

            cls = out.last_hidden_state[:,0,:]
            logits = head(cls)
            prob = torch.softmax(logits, dim=-1)[:,1]  # FLAGGED

            for i in range(prob.shape[0]):
                p = float(prob[i].detach().cpu())
                pred = "FLAGGED" if p>=args.thresh else "NOT FLAGGED"
                f.write(json.dumps({
                    "text_id": batch["text_id"][i],
                    "dataset": batch["dataset"][i],
                    "true": batch["true"][i],
                    "pred": pred,
                    "p_flagged": p
                })+"\n")

            # tiny cleanup to avoid CUDA fragmentation
            del out, cls, logits, prob
            if device == "cuda":
                torch.cuda.empty_cache()

    print(f"[DONE] wrote {args.out}")

if __name__ == "__main__":
    main()
