import json
import random
import torch
import os
import argparse
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import warnings
import numpy as np
import time
import gc

# Suppress all warnings for maximum speed
warnings.filterwarnings("ignore")
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
os.environ['CUDA_LAUNCH_BLOCKING'] = '0'
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'

# Model mapping
RMU_MODELS = {
    "Zephyr-7b": "shengyuanhu/benchmark_wmdp_npo_kl_ckpt_75",
    # "Zephyr-7b": "OPTML-Group/SimNPO-WMDP-zephyr-7b-beta"
#     "Zephyr-7b": "shengyuanhu/wmdp_unlearn_rmu_150_1200_6.5_zephyr",
}
ORIGINAL_MODELS = {
    "Zephyr-7b": "HuggingFaceH4/zephyr-7b-beta",
}

# Default parameters
DEFAULT_K_LIST = [5, 10]
DEFAULT_SEED = 42
DEFAULT_NUM_SAMPLES = 96
DEFAULT_TEMPERATURE = 0.2
DEFAULT_ORIGINAL_TEMPERATURE = 0.0
DEFAULT_BATCH_SIZE = 32
DEFAULT_MAX_TOKENS = 256
WMDP_DATASET_PATH = "/path/to/dataset/bio.json"

def freeze_seeds_for_sampling(seed):
    """Freeze only Python and NumPy random seeds for reproducible data sampling."""
    print(f"🔒 Freezing Python and NumPy random seeds to {seed} for reproducible data sampling...")
    random.seed(seed)
    np.random.seed(seed)
    print(f"✅ Python and NumPy random seeds frozen to {seed} (for data sampling only)")

def get_num_gpus():
    return torch.cuda.device_count()

def load_wmdp_data(dataset_path, num_samples, seed):
    print(f"📂 Loading WMDP data from {dataset_path}")
    freeze_seeds_for_sampling(seed)
    with open(dataset_path, 'r') as f:
        data = json.load(f)
    indices = random.sample(range(len(data)), min(num_samples, len(data)))
    sampled_data = [data[i] for i in indices]
    prompts = []
    for item in sampled_data:
        question = item.get('question', '').strip()
        choices = item.get('choices', [])
        if choices:
            choices_text = '\n'.join([f"{chr(65+i)}. {choice}" for i, choice in enumerate(choices)])
            prompt = f"{question}\n{choices_text}"
        else:
            prompt = question
        prompts.append(prompt)
    print(f"✅ Loaded {len(prompts)} prompts")
    return prompts, indices, sampled_data

def load_transformers_model_optimized(model_name, dtype=torch.bfloat16, device="cuda"):
    print(f"🚀 Loading optimized Transformers model: {model_name}")
    tokenizer = AutoTokenizer.from_pretrained(
        pretrained_model_name_or_path="HuggingFaceH4/zephyr-7b-beta",
        trust_remote_code=True,
        use_fast=True,
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=dtype,
        device_map="auto" if device == "cuda" else None,
        trust_remote_code=True,
        low_cpu_mem_usage=True,
    )
    if device == "cuda":
        model.gradient_checkpointing_disable()
    model.eval()
    print(f"✅ Optimized model loaded successfully: {model_name}")
    return model, tokenizer

def format_prompts_for_transformers(prompts):
    formatted_prompts = []
    for prompt in prompts:
        if '\nA.' in prompt:
            question = prompt.split('\nA.')[0].strip()
        else:
            question = prompt.strip()
        formatted_prompt = f"Question: {question}\nAnswer:"
        formatted_prompts.append(formatted_prompt)
    return formatted_prompts

def generate_batch_responses_unfrozen(model, tokenizer, prompts, temperature, max_tokens, batch_size, top_p, device="cuda"):
    """
    Generate responses in batches with optimization.
    NOTE: The seed is NOT set/reset before generation, so results are stochastic.
    """
    print(f"🚀 Generating {len(prompts)} responses in batches of {batch_size} (seed NOT frozen for generation)")
    formatted_prompts = format_prompts_for_transformers(prompts)
    all_responses = []
    if temperature == 0:
        generation_config = GenerationConfig(
            temperature=temperature,
            repetition_penalty=1.1,
            max_new_tokens=max_tokens,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            use_cache=True,
        )
    else:     
        generation_config = GenerationConfig(
            temperature=temperature,
            top_p=top_p,
            repetition_penalty=1.1,
            max_new_tokens=max_tokens,
            do_sample=temperature > 0,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            use_cache=True,
        )
    
    
    for i in tqdm(range(0, len(formatted_prompts), batch_size), desc="Generating batches"):
        batch_prompts = formatted_prompts[i:i + batch_size]
        try:
            inputs = tokenizer(
                batch_prompts,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=512,
            )
            if device == "cuda":
                inputs = {k: v.to(device) for k, v in inputs.items()}
            # DO NOT set any seed here!
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    generation_config=generation_config,
                    use_cache=True,
                    return_dict_in_generate=False,
                    output_scores=False,
                )
            batch_responses = []
            for j, output in enumerate(outputs):
                input_length = inputs['input_ids'][j].shape[0]
                new_tokens = output[input_length:]
                response = tokenizer.decode(new_tokens, skip_special_tokens=True)
                response = response.replace("Answer:", "").replace("Question:", "").strip()
                batch_responses.append(response)
            all_responses.extend(batch_responses)
        except Exception as e:
            print(f"❌ Batch generation failed: {str(e)[:100]}...")
            all_responses.extend([""] * len(batch_prompts))
        if i % (batch_size * 4) == 0 and device == "cuda":
            torch.cuda.empty_cache()
    success_rate = sum(1 for r in all_responses if r) / len(all_responses) if all_responses else 0
    print(f"📊 Generation complete: {success_rate:.1%} success rate")
    return all_responses

def generate_k_responses_unfrozen(model, tokenizer, prompts, k, temperature, max_tokens, batch_size, top_p, device="cuda"):
    """
    Generate k responses for each prompt with optimization.
    NOTE: The seed is NOT set/reset before generation, so results are stochastic.
    """
    print(f"🎯 Generating {k} responses for {len(prompts)} prompts (seed NOT frozen for generation)")
    all_k_responses = []
    for k_round in range(k):
        print(f"   Round {k_round + 1}/{k}")
        # DO NOT set any seed here!
        responses = generate_batch_responses_unfrozen(
            model, tokenizer, prompts, temperature, max_tokens, batch_size, top_p, device
        )
        all_k_responses.append(responses)
        if k_round % 5 == 0:
            gc.collect()
            if device == "cuda":
                torch.cuda.empty_cache()
    return all_k_responses

def process_model_unfrozen(model_name, rmu_path, original_path, prompts, indices, dataset, k_list, temperature, original_temperature, max_tokens, batch_size, responses_dir, top_p, device="cuda"):
    print(f"\n{'='*70}")
    print(f"🚀 Generating responses for {model_name} (seed NOT frozen for generation)")
    print(f"{'='*70}")
    results = {}
    print(f"\n📋 Loading original model: {original_path}")
    orig_model, orig_tokenizer = load_transformers_model_optimized(original_path, device=device)
    print("🎯 Generating original model responses...")
    orig_responses = {}
    batch_responses = generate_batch_responses_unfrozen(
        orig_model, orig_tokenizer, prompts, 
        temperature=original_temperature, max_tokens=max_tokens, batch_size=batch_size, top_p=top_p, device=device
    )
    for idx, response in enumerate(batch_responses):
        orig_responses[indices[idx]] = {
            "prompt": prompts[idx],
            "response": response,
            "wmdp_idx": indices[idx],
            "wmdp_data": dataset[idx]
        }
    results['original'] = orig_responses
    with open(f"{responses_dir}/original_{model_name}.json", "w") as f:
        json.dump(orig_responses, f, indent=2)
    del orig_model, orig_tokenizer
    gc.collect()
    if device == "cuda":
        torch.cuda.empty_cache()
    print(f"\n📋 Loading RMU model: {rmu_path}")
    rmu_model, rmu_tokenizer = load_transformers_model_optimized(rmu_path, device=device)
    print("🎯 Generating RMU model responses...")
    rmu_responses = {}
    for k in k_list:
        k_start_time = time.time()
        print(f"\n🎯 Processing k={k}...")
        temp = 0.0 if k == 1 else temperature
        k_responses_all = generate_k_responses_unfrozen(
            rmu_model, rmu_tokenizer, prompts, k, temp, max_tokens, batch_size, top_p, device
        )
        print(f"   Organizing {len(prompts)} prompts × {k} responses...")
        for prompt_idx in range(len(prompts)):
            if indices[prompt_idx] not in rmu_responses:
                rmu_responses[indices[prompt_idx]] = {}
            prompt_k_responses = []
            for k_round in range(k):
                if prompt_idx < len(k_responses_all[k_round]):
                    prompt_k_responses.append(k_responses_all[k_round][prompt_idx])
                else:
                    prompt_k_responses.append("")
            rmu_responses[indices[prompt_idx]][k] = {
                "responses": prompt_k_responses,
                "prompt": prompts[prompt_idx],
                "wmdp_idx": indices[prompt_idx],
                "wmdp_data": dataset[prompt_idx]
            }
        k_time = time.time() - k_start_time
        print(f"✅ Completed k={k} in {k_time:.1f}s")
    results['rmu'] = rmu_responses
    with open(f"{responses_dir}/rmu_{model_name}.json", "w") as f:
        json.dump(rmu_responses, f, indent=2)
    del rmu_model, rmu_tokenizer
    gc.collect()
    if device == "cuda":
        torch.cuda.empty_cache()
    return results

def main(args):
    if args.k_list:
        k_list = [int(k.strip()) for k in args.k_list.split(',')]
    else:
        k_list = DEFAULT_K_LIST
    if args.responses_dir is None:
        args.responses_dir = f"Passk/wmdp_rmu_zephyr_responses_seed{args.seed}_temp{args.temperature}_samples{args.num_samples}"
    
    print("🚀 PARAMETERIZED WMDP RESPONSE GENERATION (seed NOT frozen for generation)")
    print("=" * 60)
    print(f"🌱 Seed (for data sampling): {args.seed}")
    print(f"🌡️  Temperature: {args.temperature}")
    print(f"🌡️  Original Temperature: {args.original_temperature}")
    print(f"📁 Output Directory: {args.responses_dir}")
    print(f"📊 Number of Samples: {args.num_samples}")
    print(f"🔢 Batch Size: {args.batch_size}")
    print(f"📝 Max Tokens: {args.max_tokens}")
    print(f"🎯 K Values: {k_list}")
    print(f"💻 Device: {args.device}")
    print("=" * 60)
    start_time = time.time()
    num_gpus = get_num_gpus()
    print(f"🚀 Using {num_gpus} GPUs" if args.device == "cuda" else f"🚀 Using {args.device}")
    os.makedirs(args.responses_dir, exist_ok=True)
    prompts, indices, dataset = load_wmdp_data(WMDP_DATASET_PATH, args.num_samples, args.seed)
    metadata = {
        "prompts": prompts,
        "indices": indices,
        "wmdp_data": dataset,
        "k_list": k_list,
        "num_samples": len(prompts),
        "seed": args.seed,
        "temperature": args.temperature,
        "original_temperature": args.original_temperature,
        "batch_size": args.batch_size,
        "max_tokens": args.max_tokens,
        "device": args.device,
        "dataset_path": WMDP_DATASET_PATH,
        "gpu_info": {
            "num_gpus": num_gpus,
            "device": args.device,
        }
    }
    with open(f"{args.responses_dir}/metadata.json", "w") as f:
        json.dump(metadata, f, indent=2)
    for model_name in RMU_MODELS:
        rmu_path = RMU_MODELS[model_name]
        original_path = ORIGINAL_MODELS[model_name]
        print(f"\n🎯 Starting {model_name}...")
        model_start_time = time.time()
        results = process_model_unfrozen(
            model_name, rmu_path, original_path, 
            prompts, indices, dataset, k_list,
            args.temperature, args.original_temperature,
            args.max_tokens, args.batch_size, args.responses_dir, args.top_p, args.device
        )
        model_time = time.time() - model_start_time
        print(f"✅ {model_name} completed in {model_time:.1f}s")
    total_time = time.time() - start_time
    print(f"\n🎉 All models completed in {total_time:.1f}s")
    print(f"📁 Results saved to: {args.responses_dir}")

if __name__ == "__main__":
    temperature_list = [0.0]
    top_p_list = [0.0]
    for temperature in temperature_list:
        for top_p in top_p_list:
            args = argparse.Namespace(
                seed=42,
                temperature=temperature,
                top_p=top_p,
                original_temperature=0.0,
                responses_dir=f"Passk_unfrozen_npo/seed42_temp{temperature}_top_p{top_p}",
                num_samples=96,
                batch_size=32,
                max_tokens=256,
                k_list="200", 
                device="cuda"
            )
            main(args)