import json
import random
import torch
import os
import gc
import time
import argparse
from tqdm import tqdm
from vllm import LLM, SamplingParams
from transformers import AutoModelForCausalLM, AutoTokenizer
import re
import numpy as np
import warnings

# Suppress Transformers warnings
warnings.filterwarnings("ignore", message=".*generation flags are not valid.*")
warnings.filterwarnings("ignore", message=".*truncate to max_length.*")

# Model mapping for RMU and original (HuggingFace) models
RMU_MODELS = {
    "Zephyr-7b": "shengyuanhu/benchmark_wmdp_npo_kl_ckpt_75",
#     "Zephyr-7b": "./npo_models/zephyr_npo-new",
    
}
ORIGINAL_MODELS = {
    "Zephyr-7b": "HuggingFaceH4/zephyr-7b-beta",
}
K_LIST = [200]
NUM_SAMPLES = 96
WMDP_DATASET_PATH = "/path/to/dataset/bio.json"

def freeze_seeds_for_sampling(seed):
    """Freeze only Python and NumPy random seeds for reproducible data sampling."""
    print(f"🔒 Freezing Python and NumPy random seeds to {seed} for reproducible data sampling...")
    random.seed(seed)
    np.random.seed(seed)
    print(f"✅ Python and NumPy random seeds frozen to {seed} (for data sampling only)")

def check_gpu_availability():
    if not torch.cuda.is_available():
        print("❌ NO GPU AVAILABLE - This will be very slow!")
        return False, 0, 0
    device_count = torch.cuda.device_count()
    device = torch.cuda.current_device()
    total_memory = torch.cuda.get_device_properties(device).total_memory / (1024**3)
    torch.cuda.empty_cache()
    allocated = torch.cuda.memory_allocated() / (1024**3)
    free_memory = total_memory - allocated
    print(f"🎮 GPU Status: {device_count} GPUs available")
    print(f"   Device: {torch.cuda.get_device_name()}")
    print(f"   Total Memory: {total_memory:.1f} GB")
    print(f"   Free Memory: {free_memory:.1f} GB")
    return True, free_memory, total_memory

def aggressive_gpu_cleanup():
    print("🧹 GPU Memory Cleanup...")
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        torch.cuda.ipc_collect()
    for obj in gc.get_objects():
        if torch.is_tensor(obj) and obj.is_cuda:
            del obj
    gc.collect()
    torch.cuda.empty_cache()
    time.sleep(1)

def load_wmdp_data(dataset_path, num_samples=NUM_SAMPLES, seed=42):
    print(f"📂 Loading WMDP data from {dataset_path}")
    freeze_seeds_for_sampling(seed)
    with open(dataset_path, 'r') as f:
        data = json.load(f)
    indices = random.sample(range(len(data)), min(num_samples, len(data)))
    sampled_data = [data[i] for i in indices]
    prompts = []
    for item in sampled_data:
        question = item.get('question', '').strip()
        choices = item.get('choices', [])
        if len(choices) >= 4:
            prompt = f"{question}\nA. {choices[0]}\nB. {choices[1]}\nC. {choices[2]}\nD. {choices[3]}\nAnswer:"
        else:
            choices_text = '\n'.join([f"{chr(65+i)}. {choice}" for i, choice in enumerate(choices)])
            prompt = f"{question}\n{choices_text}\nAnswer:"
        prompts.append(prompt)
    print(f"✅ Loaded {len(prompts)} prompts")
    return prompts, indices, sampled_data

def load_vllm_model_optimized(model_path):
    print(f"🚀 Loading vLLM model (optimized): {model_path}")
    if not torch.cuda.is_available():
        raise RuntimeError("GPU required for optimized loading but not available")
    is_large_model = any(size in model_path for size in ["34B", "33B", "30B", "70B"])
    is_medium_model = any(size in model_path for size in ["13B", "14B", "15B", "20B"])
    if is_large_model:
        gpu_memory_util = 0.85
        max_num_seqs = 64
        max_model_len = 2048
    elif is_medium_model:
        gpu_memory_util = 0.90
        max_num_seqs = 128
        max_model_len = 2048
    else:
        gpu_memory_util = 0.95
        max_num_seqs = 256
        max_model_len = 2048
    try:
        model = LLM(
            model=model_path,
            tensor_parallel_size=1,
            gpu_memory_utilization=gpu_memory_util,
            max_num_seqs=max_num_seqs,
            max_model_len=max_model_len,
            dtype="bfloat16",
            trust_remote_code=True,
            enforce_eager=True,
            enable_prefix_caching=True,
            swap_space=2,
        )
        print(f"✅ vLLM model loaded successfully!")
        print(f"   GPU Memory Util: {gpu_memory_util}")
        print(f"   Max Sequences: {max_num_seqs}")
        print(f"   Max Model Length: {max_model_len}")
        return model, model.get_tokenizer(), "vllm"
    except Exception as e:
        print(f"❌ vLLM loading failed: {str(e)[:200]}...")
        if "CUDA out of memory" in str(e) or "OutOfMemoryError" in str(e):
            print("🔄 Retrying with more conservative memory settings...")
            aggressive_gpu_cleanup()
            try:
                model = LLM(
                    model=model_path,
                    tensor_parallel_size=1,
                    gpu_memory_utilization=0.70,
                    max_num_seqs=32,
                    max_model_len=1024,
                    dtype="bfloat16",
                    trust_remote_code=True,
                    enforce_eager=True,
                    enable_prefix_caching=True,
                    swap_space=2,
                )
                print(f"✅ vLLM model loaded successfully (conservative settings)!")
                return model, model.get_tokenizer(), "vllm"
            except Exception as e2:
                print(f"❌ vLLM loading failed again: {str(e2)[:200]}...")
                return None, None, "failed"
        return None, None, "failed"

def load_transformers_model_gpu_optimized(model_path, fallback_tokenizer_path=None):
    print(f"🚀 Loading Transformers model (GPU optimized): {model_path}")
    
    # Try to load the tokenizer from the model path first
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=True)
        print(f"✅ Tokenizer loaded from model path: {model_path}")
    except Exception as e:
        print(f"❌ Failed to load tokenizer from {model_path}: {str(e)[:100]}...")
        if fallback_tokenizer_path:
            print(f"🔄 Trying fallback tokenizer: {fallback_tokenizer_path}")
            try:
                tokenizer = AutoTokenizer.from_pretrained(fallback_tokenizer_path, trust_remote_code=True, use_fast=True)
                print(f"✅ Fallback tokenizer loaded successfully: {fallback_tokenizer_path}")
            except Exception as e2:
                print(f"❌ Fallback tokenizer also failed: {str(e2)[:100]}...")
                raise e2
        else:
            raise e
    
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
        low_cpu_mem_usage=True,
    )
    model.gradient_checkpointing_disable()
    model.eval()
    print(f"✅ Transformers model loaded successfully: {model_path}")
    return model, tokenizer, "gpu"

def extract_answer_choice_improved(response, choices=None):
    """
    Improved answer extraction function that can handle both direct letter answers
    and explanatory responses by mapping content to choices.
    
    Args:
        response: The model's response text
        choices: List of choice texts [choice1, choice2, choice3, choice4]
    
    Returns:
        Answer letter (A, B, C, D) or None if no match found
    """
    if not response:
        return None
    
    # First try to find direct letter answers (A, B, C, D)
    match = re.search(r"\b([A-D])\b", response)
    if match:
        return match.group(1)
    
    # If no direct letter found and choices are provided, try content matching
    if choices and len(choices) >= 4:
        response_lower = response.lower().strip()
        
        # Remove common prefixes and suffixes that might interfere
        response_clean = re.sub(r'^(answer|choice|option|the answer is|correct answer|answer:|choice:|option:)\s*', '', response_lower)
        response_clean = re.sub(r'\s*(\.|,|;|:|$)', '', response_clean)
        
        # Try to find exact matches or partial matches with choices
        for i, choice in enumerate(choices):
            choice_lower = choice.lower().strip()
            
            # Exact match
            if response_clean == choice_lower:
                return chr(65 + i)  # Convert 0,1,2,3 to A,B,C,D
            
            # Check if response contains the choice text
            if choice_lower in response_clean or response_clean in choice_lower:
                return chr(65 + i)
            
            # Check for partial matches (e.g., "HP1.441" in "HP1.441. Explanation")
            choice_words = choice_lower.split()
            response_words = response_clean.split()
            
            # If any word from choice appears in response, consider it a match
            for word in choice_words:
                if len(word) > 2 and word in response_clean:  # Only consider words longer than 2 chars
                    return chr(65 + i)
    
    return None

def generate_vllm_answers_batch(model, prompts, temperature=0.0, max_tokens=10, choices_list=None):
    print(f"🚀 Generating answers with vLLM (batch size: {len(prompts)}) (seed NOT frozen for generation)")
    # NOTE: The seed is NOT set/reset before generation, so results are stochastic.
    try:
        sampling_params = SamplingParams(
            temperature=temperature,
            max_tokens=max_tokens,
            n=1,
            stop=["\n"],
        )
        outputs = model.generate(prompts, sampling_params)
        answers = []
        responses = []
        for i, output in enumerate(outputs):
            text = output.outputs[0].text.strip()
            responses.append(text)
            # Pass choices if available for improved answer extraction
            choices = choices_list[i] if choices_list else None
            answer = extract_answer_choice_improved(text, choices)
            answers.append(answer)
        return answers, responses
    except Exception as e:
        print(f"❌ vLLM batch generation failed: {str(e)[:100]}...")
        return None, None

def generate_transformers_answers_batch(model, tokenizer, prompts, indices, dataset, device_type="gpu", temperature=0.0, max_tokens=10, batch_size=16):
    print(f"🚀 Generating answers with Transformers (batch size: {len(prompts)}) (seed NOT frozen for generation)")
    # NOTE: The seed is NOT set/reset before generation, so results are stochastic.
    all_answers = []
    all_responses = []
    for i in tqdm(range(0, len(prompts), batch_size), desc="Generating batches"):
        batch_prompts = prompts[i:i + batch_size]
        batch_indices = indices[i:i + batch_size]
        try:
            inputs = tokenizer(
                batch_prompts,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=512,
            )
            if device_type == "gpu":
                inputs = {k: v.cuda() for k, v in inputs.items()}
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    temperature=temperature,
                    repetition_penalty=1.1,
                    max_new_tokens=max_tokens,
                    do_sample=temperature > 0,
                    pad_token_id=tokenizer.eos_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                    use_cache=True,
                    return_dict_in_generate=False,
                    output_scores=False,
                )
            batch_responses = []
            batch_answers = []
            for j, output in enumerate(outputs):
                input_length = inputs['input_ids'][j].shape[0]
                new_tokens = output[input_length:]
                response = tokenizer.decode(new_tokens, skip_special_tokens=True)
                response = response.replace("Answer:", "").replace("Question:", "").strip()
                batch_responses.append(response)
                
                # Get choices for this specific prompt
                prompt_idx = i + j
                if prompt_idx < len(dataset):
                    choices = dataset[prompt_idx].get('choices', [])
                    answer = extract_answer_choice_improved(response, choices)
                else:
                    answer = extract_answer_choice_improved(response)
                
                batch_answers.append(answer)
            all_responses.extend(batch_responses)
            all_answers.extend(batch_answers)
        except Exception as e:
            print(f"❌ Batch generation failed: {str(e)[:100]}...")
            all_responses.extend([""] * len(batch_prompts))
            all_answers.extend([None] * len(batch_prompts))
        if i % (batch_size * 4) == 0 and device_type == "gpu":
            torch.cuda.empty_cache()
    return all_answers, all_responses

def process_model_optimized(model_name, rmu_path, original_path, prompts, indices, dataset, temperature, original_temperature):
    print(f"\n{'='*70}")
    print(f"🚀 Generating answers for {model_name} (seed NOT frozen for generation)")
    print(f"{'='*70}")
    results = {}
    print(f"\n📋 Loading original model: {original_path}")
    orig_model, orig_tokenizer, orig_type = load_transformers_model_gpu_optimized(original_path)
    print("🎯 Generating original model answers...")
    orig_answers, orig_responses = generate_transformers_answers_batch(
        orig_model, orig_tokenizer, prompts, indices, dataset, device_type=orig_type, temperature=original_temperature, max_tokens=10, batch_size=16
    )
    orig_results = {}
    for idx, (answer, response) in enumerate(zip(orig_answers, orig_responses)):
        orig_results[indices[idx]] = {
            "prompt": prompts[idx],
            "answer": answer,
            "response": response,
            "wmdp_idx": indices[idx],
            "wmdp_data": dataset[idx]
        }
    results['original'] = orig_results
    del orig_model, orig_tokenizer
    gc.collect()
    torch.cuda.empty_cache()
    print(f"\n📋 Loading RMU model: {rmu_path}")
    rmu_model, rmu_tokenizer, rmu_type = load_transformers_model_gpu_optimized(rmu_path, fallback_tokenizer_path=original_path)
    print("🎯 Generating RMU model answers...")
    rmu_results = {idx: {k: {"answers": [], "responses": [], "prompt": prompts[i], "wmdp_idx": idx, "wmdp_data": dataset[i]} for k in K_LIST} for i, idx in enumerate(indices)}
    for k in K_LIST:
        print(f"\n🎯 Processing k={k}...")
        temp = 0.0 if k == 1 else temperature
        for round_num in range(k):
            print(f"   Round {round_num + 1}/{k}...")
            round_answers, round_responses = generate_transformers_answers_batch(
                rmu_model, rmu_tokenizer, prompts, indices, dataset, device_type=rmu_type, temperature=temp, max_tokens=10, batch_size=16
            )
            if round_answers is None:
                print(f"   ❌ Round {round_num + 1} failed, using empty answers")
                round_answers = [None] * len(prompts)
                round_responses = [""] * len(prompts)
            for idx, (answer, response) in enumerate(zip(round_answers, round_responses)):
                rmu_results[indices[idx]][k]["answers"].append(answer)
                rmu_results[indices[idx]][k]["responses"].append(response)
        total_for_k = len(prompts) * k
        successful_for_k = sum(1 for idx in range(len(prompts)) for answer in rmu_results[indices[idx]][k]["answers"] if answer is not None and answer in ['A', 'B', 'C', 'D'])
        success_rate_k = successful_for_k / total_for_k if total_for_k > 0 else 0
        print(f"✅ Completed k={k}: {success_rate_k:.1%} success rate")
    results['rmu'] = rmu_results
    del rmu_model, rmu_tokenizer
    gc.collect()
    torch.cuda.empty_cache()
    return results

def main():
    DEFAULT_SEED = 42
    DEFAULT_TEMPERATURE = 0.8
    DEFAULT_ORIGINAL_TEMPERATURE = 0.0
    DEFAULT_NUM_SAMPLES = NUM_SAMPLES
    DEFAULT_ANSWERS_DIR = None
    parser = argparse.ArgumentParser(description='Generate WMDP answers with configurable parameters (seed NOT frozen for generation) - NPO IMPROVED VERSION')
    parser.add_argument('--seed', type=int, default=DEFAULT_SEED, help='Random seed for data sampling')
    parser.add_argument('--temperature', type=float, default=DEFAULT_TEMPERATURE, help='Temperature for k>1 cases')
    parser.add_argument('--original_temperature', type=float, default=DEFAULT_ORIGINAL_TEMPERATURE, help='Temperature for original models')
    parser.add_argument('--answers_dir', type=str, default=DEFAULT_ANSWERS_DIR, help='Output directory for answers')
    parser.add_argument('--num_samples', type=int, default=DEFAULT_NUM_SAMPLES, help='Number of samples to process')
    args = parser.parse_args()
    if args.answers_dir is None:
        args.answers_dir = f"Passk_unfrozen_npo/wmdp_npo_zephyr_answers_temp{args.temperature}_new"
    print("🚀 OPTIMIZED WMDP ANSWER GENERATION - NPO IMPROVED VERSION (seed NOT frozen for generation)")
    print("=" * 60)
    print(f"🌱 Seed (for data sampling): {args.seed}")
    print(f"🌡️  Temperature: {args.temperature}")
    print(f"🌡️  Original Temperature: {args.original_temperature}")
    print(f"📁 Output Directory: {args.answers_dir}")
    print(f"📊 Number of Samples: {args.num_samples}")
    print("=" * 60)
    gpu_available, free_memory, total_memory = check_gpu_availability()
    if not gpu_available:
        print("⚠️  WARNING: No GPU detected. This will be extremely slow!")
        response = input("Continue anyway? (y/N): ")
        if response.lower() != 'y':
            print("Exiting...")
            return
    os.makedirs(args.answers_dir, exist_ok=True)
    prompts, indices, dataset = load_wmdp_data(WMDP_DATASET_PATH, num_samples=args.num_samples, seed=args.seed)
    metadata = {
        "prompts": prompts,
        "indices": indices,
        "wmdp_data": dataset,
        "k_list": K_LIST,
        "num_samples": len(prompts),
        "seed": args.seed,
        "temperature": args.temperature,
        "original_temperature": args.original_temperature,
        "dataset_path": WMDP_DATASET_PATH,
        "gpu_info": {
            "available": gpu_available,
            "free_memory_gb": free_memory,
            "total_memory_gb": total_memory,
        }
    }
    with open(f"{args.answers_dir}/metadata.json", "w") as f:
        json.dump(metadata, f, indent=2)
    for model_name in RMU_MODELS:
        rmu_path = RMU_MODELS[model_name]
        original_path = ORIGINAL_MODELS[model_name]
        print(f"\n🎯 Starting {model_name}...")
        start_time = time.time()
        results = process_model_optimized(
            model_name, rmu_path, original_path, 
            prompts, indices, dataset,
            args.temperature, args.original_temperature
        )
        with open(f"{args.answers_dir}/original_{model_name}.json", "w") as f:
            json.dump(results['original'], f, indent=2)
        with open(f"{args.answers_dir}/rmu_{model_name}.json", "w") as f:
            json.dump(results['rmu'], f, indent=2)
        processing_time = time.time() - start_time
        print(f"✅ {model_name} completed in {processing_time:.1f} seconds")
        total_answers = sum(len(results['rmu'][idx][k]["answers"]) for idx in results['rmu'] for k in results['rmu'][idx])
        successful_answers = sum(1 for idx in results['rmu'] for k in results['rmu'][idx] for answer in results['rmu'][idx][k]["answers"] if answer is not None and answer in ['A', 'B', 'C', 'D'])
        success_rate = successful_answers / total_answers if total_answers > 0 else 0
        print(f"📊 FINAL STATS for {model_name}:")
        print(f"   Total answers: {total_answers}")
        print(f"   Successful answers: {successful_answers}")
        print(f"   Success rate: {success_rate:.1%}")
        print(f"   Processing time: {processing_time:.1f}s")
        print(f"   Avg time per answer: {processing_time/total_answers:.3f}s")

if __name__ == "__main__":
    main() 