import os
import torch
import argparse
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel, AutoModelForSequenceClassification
import torch.nn.functional as F

INPUT_DIM_MAP = {
    "mistralai/Ministral-8B-Instruct-2410": 4096,
    "meta-llama/Llama-2-7b-hf": 4096,
    "Qwen/Qwen2-1.5B": 1536,
    "EleutherAI/pythia-6.9b": 4096,   
    "tiiuae/falcon-7b-instruct": 4096,                 
}

@torch.no_grad()
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, required=True)
    parser.add_argument("--split", type=str, required=True)
    parser.add_argument("--layer_idx", type=int, default=20)
    parser.add_argument("--num_shards", type=int, default=1)
    parser.add_argument("--shard_id", type=int, default=0)
    parser.add_argument("--batch_size", type=int, default=12)
    parser.add_argument("--root_dir", type=str, required=True)
    args = parser.parse_args()

    device = torch.device(f"cuda:{args.shard_id % torch.cuda.device_count()}")
    torch.cuda.set_device(device)

    save_dir = f"{args.root_dir}/beavertails_llm_generated/{args.model_name.replace('/', '_')}"
    os.makedirs(save_dir, exist_ok=True)

    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"

    generator = AutoModelForCausalLM.from_pretrained(
        args.model_name,
        torch_dtype=torch.float32 if args.model_name not in ["openai/gpt-oss-20b", "tiiuae/falcon-7b-instruct"] else torch.bfloat16,
        trust_remote_code=True
    ).eval().to(device)
    
    # Load final token classifier: cardiffnlp/twitter-roberta-base-offensive
    clf_model_name = "cardiffnlp/twitter-roberta-base-offensive"
    classifier = AutoModelForSequenceClassification.from_pretrained(clf_model_name).eval().to(device)

    dataset = load_dataset("PKU-Alignment/BeaverTails", split=args.split)        
    prompts = [ex["prompt"] for ex in dataset]
    total_len = len(prompts)
    per_shard = (total_len + args.num_shards - 1) // args.num_shards
    start_idx = args.shard_id * per_shard
    end_idx = min(start_idx + per_shard, total_len)
    prompt_subset = prompts[start_idx:end_idx]

    batch_size = args.batch_size
    max_new_tokens = 128
    max_length = 512

    embeddings, masks, labels, scores, all_prompts, all_responses, all_scores = [], [], [], [], [], [], []
    
    clf_tokenizer = AutoTokenizer.from_pretrained(clf_model_name)

    for batch_start in tqdm(range(0, len(prompt_subset), batch_size), desc=f"Shard {args.shard_id}"):        
        batch_prompts = prompt_subset[batch_start:batch_start + batch_size]

        # ----- Generation (no gradients) -----
        input_enc = tokenizer(
            batch_prompts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length
        ).to(device)

        with torch.no_grad():
            outputs = generator.generate(
                input_ids=input_enc.input_ids,
                attention_mask=input_enc.attention_mask,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                pad_token_id=tokenizer.eos_token_id,
                return_dict_in_generate=True,
                output_scores=False,
                use_cache=False if args.model_name in ["tiiuae/falcon-7b-instruct"] else True
            )

        sequences = outputs.sequences
        decoded = tokenizer.batch_decode(sequences, skip_special_tokens=True)
        responses = [dec[len(prompt):].strip() for dec, prompt in zip(decoded, batch_prompts)]
        full_texts = [p + "\n" + r for p, r in zip(batch_prompts, responses)]

        # free gen intermediates before the heavy featurizer pass
        del outputs, sequences, decoded
        torch.cuda.empty_cache()

        # ----- Featurizer pass (same model), no KV cache, shortest padding -----
        enc = tokenizer(
            full_texts,
            return_tensors="pt",
            padding="longest",          # <---- key change (avoid padding to 512)
            truncation=True,
            max_length=max_length
        ).to(device)

        with torch.no_grad():
            # temporarily disable KV cache for this forward to cut memory
            orig_use_cache = generator.config.use_cache
            generator.config.use_cache = False
            feats = generator(                     # same model
                input_ids=enc.input_ids,
                attention_mask=enc.attention_mask,
                output_hidden_states=True,
            )
            generator.config.use_cache = orig_use_cache

        h = feats.hidden_states[args.layer_idx]  # (B, T, D)

        for i in range(enc.input_ids.size(0)):
            input_id = enc.input_ids[i]
            h_i = h[i]
            valid_len = enc.attention_mask[i].sum().item()  # robust valid length
            h_i = h_i[enc.attention_mask[i].bool()].to(torch.float32).cpu()
            
            prompt_len = tokenizer(batch_prompts[i], return_tensors="pt", truncation=True).input_ids.shape[1]
            response_mask = torch.zeros(valid_len, dtype=torch.long)
            if prompt_len < valid_len:
                response_mask[prompt_len:] = 1

            if response_mask.sum() == 0:
                continue

            # classifier stays as-is
            with torch.no_grad():
                clf_input = clf_tokenizer(
                    responses[i],
                    return_tensors="pt",
                    truncation=True,
                    padding=True,
                    max_length=512
                ).to(device)
                logits = classifier(**clf_input).logits
                score = torch.softmax(logits, dim=-1)[0][1].item()
            label = int(score < 0.5)

            embeddings.append(h_i[response_mask.bool()])
            masks.append(response_mask)
            labels.append(label)
            all_responses.append(responses[i])
            all_prompts.append(batch_prompts[i])
            all_scores.append(score)

            del h_i

        del enc, feats, h, input_enc
        torch.cuda.empty_cache()

    torch.save(all_prompts, os.path.join(save_dir, f"{args.split}_prompts_shard{args.shard_id}.pt"))
    torch.save(all_responses, os.path.join(save_dir, f"{args.split}_responses_shard{args.shard_id}.pt"))
    torch.save(embeddings, os.path.join(save_dir, f"{args.split}_embeddings_shard{args.shard_id}.pt"))
    torch.save(masks, os.path.join(save_dir, f"{args.split}_masks_shard{args.shard_id}.pt"))
    torch.save(labels, os.path.join(save_dir, f"{args.split}_labels_shard{args.shard_id}.pt"))
    torch.save(scores, os.path.join(save_dir, f"{args.split}_logits_shard{args.shard_id}.pt"))

    print(f"[Shard {args.shard_id}] Saved outputs to {save_dir}")

if __name__ == "__main__":
    main()
