import numpy as np
import os
import torch
import argparse
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))

@torch.no_grad()
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, required=True)
    parser.add_argument("--root_dir", type=str, required=True)
    parser.add_argument("--dataset_name", type=str, required=True, choices=['beavertails', 'real_toxicity', 'ultrasafety'])
    parser.add_argument("--split", type=str, required=True)

    # performance knobs (no sharding)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--prefix_batch_size", type=int, default=512)
    parser.add_argument("--compile", action="store_true")
    parser.add_argument("--clf_max_length", type=int, default=512)
    parser.add_argument("--prefix_stride", type=int, default=1)

    args = parser.parse_args()

    # ------------------ device & matmul modes ------------------
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)
    torch.backends.cuda.matmul.allow_tf32 = True

    # ------------------ tokenizers & model ------------------
    generator_tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=True)
    if generator_tokenizer.pad_token is None:
        generator_tokenizer.pad_token = generator_tokenizer.eos_token

    clf_model_name = "cardiffnlp/twitter-roberta-base-offensive"
    classifier = AutoModelForSequenceClassification.from_pretrained(clf_model_name).eval().to(device)
    clf_tokenizer = AutoTokenizer.from_pretrained(clf_model_name, use_fast=True)

    if args.compile:
        try:
            classifier = torch.compile(classifier, mode="max-autotune")
        except Exception:
            pass

    use_amp = True
    amp_dtype = torch.float32

    # ------------------ IO paths (original naming) ------------------
    data_dir = f"{args.root_dir}/{args.dataset_name}_llm_generated/{args.model_name.replace('/', '_')}"
    os.makedirs(data_dir, exist_ok=True)

    responses_path = os.path.join(data_dir, f"{args.split}_responses.pt")
    prompts_path   = os.path.join(data_dir, f"{args.split}_prompts.pt")
    masks_path     = os.path.join(data_dir, f"{args.split}_masks.pt")
    labels_path    = os.path.join(data_dir, f"{args.split}_is_safes.pt")

    if not (os.path.exists(responses_path) and os.path.exists(prompts_path) and os.path.exists(masks_path)):
        print(f"Data files not found in {data_dir} for split {args.split}")
        return

    responses = torch.load(responses_path)
    prompts   = torch.load(prompts_path)
    masks     = torch.load(masks_path)
    labels    = torch.load(labels_path)

    N = len(responses)
    all_prefix_scores = []
    all_final_scores = []
    label_matches = np.zeros(N)
    label_indices = np.full(N, np.nan)

    batch_size = args.batch_size
    stride = max(1, args.prefix_stride)

    for i in tqdm(range(0, N, batch_size), desc=f"Processing {args.split}"):
        bi = range(i, min(i + batch_size, N))
        batch_prompts   = [prompts[j]   for j in bi]
        batch_responses = [responses[j] for j in bi]
        batch_masks     = [masks[j]     for j in bi]

        # ---- Fast prefix construction via offset mapping ----
        all_prefixes_in_batch = []
        num_prefixes_per_response = []

        for prompt, response, mask in zip(batch_prompts, batch_responses, batch_masks):
            num_response_tokens = int(mask.sum().item())
            if num_response_tokens == 0 or len(response) == 0:
                num_prefixes_per_response.append(0)
                continue

            enc = generator_tokenizer(
                response,
                return_offsets_mapping=True,
                add_special_tokens=False
            )
            offsets = enc["offset_mapping"]
            if len(offsets) > num_response_tokens:
                offsets = offsets[:num_response_tokens]

            resp_prefixes = []
            for t in range(1, len(offsets) + 1, stride):
                end_char = offsets[t - 1][1]
                prefix = response[:end_char] or "."
                resp_prefixes.append(prefix)

            num_prefixes_per_response.append(len(resp_prefixes))
            all_prefixes_in_batch.extend(resp_prefixes)

        if not all_prefixes_in_batch:
            all_prefix_scores.extend([[] for _ in bi])
            continue

        # ---- Classify prefixes in large micro-batches ----
        final_scores_flat = []
        for k in range(0, len(all_prefixes_in_batch), args.prefix_batch_size):
            chunk = all_prefixes_in_batch[k : k + args.prefix_batch_size]
            enc = clf_tokenizer(
                chunk,
                return_tensors="pt",
                truncation=True,
                max_length=args.clf_max_length,
                padding=True
            )
            enc = {k2: v.to(device, non_blocking=True) for k2, v in enc.items()}
            with torch.cuda.amp.autocast(enabled=use_amp, dtype=amp_dtype):
                logits = classifier(**enc).logits

            # p(safe) = 1 - sigmoid(logit_bad - logit_safe)
            logit_safe = logits[:, 0]
            logit_bad  = logits[:, 1]
            p_bad = torch.sigmoid(logit_bad - logit_safe)
            p_safe = (1.0 - p_bad).detach().float().cpu().tolist()
            final_scores_flat.extend(p_safe)

        # ---- Scatter back, compute unsafe index ----
        pos = 0
        for j, npfx in enumerate(num_prefixes_per_response):
            scores = final_scores_flat[pos : pos + npfx]
            all_prefix_scores.append(scores)
            pos += npfx

            global_idx = i + j
            if scores:
                unsafe_idx = next((idx for idx, s in enumerate(scores) if s < 0.5), np.nan)
                label_indices[global_idx] = unsafe_idx
                if np.isnan(unsafe_idx):
                    label_matches[global_idx] = (labels[global_idx] >= 0.5)
                else:
                    label_matches[global_idx] = (scores[int(unsafe_idx)] < 0.5)
            all_final_scores.append(scores[-1] if scores else 1.0)

    # ------------------ Save with original filenames ------------------
    save_path = os.path.join(data_dir, f"{args.split}_prefix_scores.pt")
    torch.save(all_prefix_scores, save_path)
    torch.save(all_final_scores, os.path.join(data_dir, f"{args.split}_final_scores.pt"))

    label_indices_save_path = os.path.join(data_dir, f"{args.split}_unsafe_indices.pt")
    torch.save(label_indices, label_indices_save_path)

    print(f"Saved prefix scores to {save_path}", flush=True)
    print(f"Saved label indices to {label_indices_save_path}", flush=True)
    print(f"Label matches: {label_matches.mean():.6f}", flush=True)
    finite = np.isfinite(label_indices)
    if finite.any():
        print(f"Average unsafe index: {np.nanmean(label_indices):.2f} +/- {np.nanstd(label_indices):.2f}", flush=True)

if __name__ == "__main__":
    main()
