import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.utils import logging
import numpy as np
import copy
import os
import argparse
from accelerate import Accelerator
import time
import torch.multiprocessing as mp
import threading
from queue import Queue
from concurrent.futures import ThreadPoolExecutor
import math
import gc
import json
import sys

logging.set_verbosity_error()
torch.backends.cuda.matmul.allow_tf32 = True

parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default='')
parser.add_argument('--hf_cache_dir', type=str, default='')
parser.add_argument('--mixed_precision', type=str, default='fp16', choices=['no', 'fp16', 'bf16'], help='Mixed precision mode')
parser.add_argument('--gpu_threads', type=int, default=4, help='Number of parallel threads per GPU')
parser.add_argument('--verbose', action='store_true', help='Print verbose logs')
parser.add_argument('--aggressive_gc', action='store_true', help='Perform aggressive garbage collection')
parser.add_argument('--eval_interval', type=int, default=20, help='Interval for evaluating best/worst models')
parser.add_argument('--weight_sample_interval', type=int, default=10, help='Sample interval for weight tracking')
parser.add_argument('--data_sample', type=int, default=1000, help='Number of data samples to use for training')
parser.add_argument('--reward_type', type=str, default='grpo_reward', help='Type of reward function to use')
parser.add_argument('--resume', type=str, default=None, help='Path to a checkpoint directory to resume training from')
parser.add_argument('--start_iteration', type=int, default=0, help='Starting iteration number when resuming training')
args = parser.parse_args()

if args.reward_type == 'grpo_reward':
    from countdown_task import reward_function
    print(f"Using original GRPO reward function")
else:
    raise ValueError(f"Invalid reward type: {args.reward_type}, please choose from grpo_reward or toy_reward")


NUM_ITERATIONS = 1000              # Number of ES iterations (generations)
POPULATION_SIZE = 30              # Population size (number of mutants per iteration)
SIGMA = 0.001                     # Standard deviation for weight perturbations (choose carefully)
ALPHA = 0.0005                    # Learning rate
max_new_tokens = 1024
do_sample = False


# --- Load Dataset from JSON File ---
data_path = os.path.join(os.path.dirname(__file__), 'data/countdown/countdown_samples.json')
with open(data_path, 'r') as f:
    data_json = json.load(f)

# Convert the JSON data to the format used in the original dataset
dataset = []
for item in data_json:
    context = item['context']
    target = item['target']
    dataset.append((context, target))

# Use a subset of the dataset for training
dataset = dataset[:args.data_sample]

def force_memory_cleanup():
    """Force aggressive memory cleanup"""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
        torch.cuda.synchronize()

def evaluate_model(model, tokenizer, input_text, target_text, accelerator, seed_idx=None, thread_id=None, verbose=False, return_text=False):
    """
    Generate a response from the model given an input (single or batch) and compute rewards.
    """
    if verbose:
        print(f"Process {accelerator.process_index} Thread {thread_id} evaluating seed {seed_idx}")
    
    # Handle both single input and batch input
    is_batch = isinstance(input_text, list)
    input_texts = input_text if is_batch else [input_text]
    target_texts = target_text if is_batch else [target_text]
    

    tokenized_inputs = tokenizer(input_texts, return_tensors="pt", padding=True, padding_side="left")
    input_ids = tokenized_inputs["input_ids"].to(accelerator.device)
    attention_mask = tokenized_inputs["attention_mask"].to(accelerator.device)

    with torch.inference_mode():
        unwrapped_model = accelerator.unwrap_model(model)
        outputs = unwrapped_model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, do_sample=do_sample)
        if torch.cuda.is_available():
            torch.cuda.synchronize(accelerator.device)
    
    # Decode batch outputs
    generated_texts = []
    for i in range(len(input_texts)):
        try:
            generated_text = tokenizer.decode(outputs[i], skip_special_tokens=True)
        except TypeError:
            tokens = tokenizer.convert_ids_to_tokens(outputs[i], skip_special_tokens=True)
            filtered = [t for t in tokens if t is not None]
            generated_text = tokenizer.convert_tokens_to_string(filtered)
        generated_texts.append(generated_text)
    
    del input_ids, outputs
    torch.cuda.empty_cache()
    
    if verbose and accelerator.is_main_process and thread_id == 0:
        print(generated_texts[0] if generated_texts else "")
    
    # Compute rewards for batch texts
    rewards = []
    for i, (gen_text, tgt_text, inp_text) in enumerate(zip(generated_texts, target_texts, input_texts)):
        numbers = None
        target = None
        if "[" in inp_text and "]" in inp_text:
            start_idx = inp_text.find("[")
            end_idx = inp_text.find("]")
            if start_idx != -1 and end_idx != -1:
                numbers_str = inp_text[start_idx+1:end_idx]
                numbers = [int(n) for n in numbers_str.split() if n.isdigit()]
        
        if tgt_text.isdigit():
            target = int(tgt_text)
        
        # Extract only the model's response part
        model_response = gen_text
        if "assistant:" in gen_text:
            model_response = gen_text.split("assistant:")[-1].strip()
        
        # Use reward_function from countdown_task.py
        reward_result = reward_function(model_response, numbers, target)
        reward = reward_result["reward"]
        rewards.append(reward)
    
    if verbose and accelerator.is_main_process and thread_id == 0:
        print(rewards[0] if rewards else 0)
    
    # Return format based on input type
    if not is_batch:
        if return_text:
            return rewards[0], generated_texts[0]
        return rewards[0]
    else:
        if return_text:
            return rewards, generated_texts
        return rewards

def process_seed(seed_args):
    """Function to process a single seed, used for thread pool"""
    seed_idx, seed, model, tokenizer, accelerator, thread_id, verbose = seed_args
    
    if verbose:
        print(f"Process {accelerator.process_index} Thread {thread_id} processing seed {seed_idx} (value: {seed})")
    
    total_reward = 0.0
    original_model = accelerator.unwrap_model(model)
    for name, param in original_model.named_parameters():
        gen = torch.Generator(device=param.device)

        gen.manual_seed(int(seed))

        noise = torch.randn(
            param.shape,
            generator=gen,
            device=param.device,
            dtype=param.dtype
        )
        param.data.add_(SIGMA * noise)
    
    # Ensure weights are fully loaded before evaluation
    if torch.cuda.is_available():
        torch.cuda.synchronize(accelerator.device)
    
    # Evaluate all prompts with perturbed weights in batch
    input_texts = [input_text for input_text, _ in dataset]
    target_texts = [target_text for _, target_text in dataset]
    rewards = evaluate_model(model, tokenizer, input_texts, target_texts, accelerator, 
                           seed_idx=seed_idx, thread_id=thread_id, verbose=verbose)
    total_reward = sum(rewards)
    
    # Restore original weights (direct inplace modification)
    for name, param in original_model.named_parameters():
        gen = torch.Generator(device=param.device)

        gen.manual_seed(int(seed))

        noise = torch.randn(
            param.shape,
            generator=gen,
            device=param.device,
            dtype=param.dtype
        )
        param.data.add_(-SIGMA * noise)
    
    if torch.cuda.is_available():
        torch.cuda.synchronize(accelerator.device)
    
    average_reward = total_reward / len(dataset)
    
    force_memory_cleanup()
    
    if verbose:
        print(f"Process {accelerator.process_index} Thread {thread_id} completed seed {seed_idx} with reward {average_reward:.4f}")
    
    return seed_idx, average_reward


# --- Main Evolution Strategies Loop ---
def main():
    # Initialize accelerator with proper configuration - using fp16 mixed precision
    accelerator = Accelerator(mixed_precision=args.mixed_precision)
    
    if accelerator.is_main_process:
        print(f"Total processes: {accelerator.num_processes}, GPU threads per process: {args.gpu_threads}")
        print(f"Using mixed precision: {args.mixed_precision}")
        print(f"Population size: {POPULATION_SIZE}, Iterations: {NUM_ITERATIONS}")
        print(f"Sigma: {SIGMA}, Alpha: {ALPHA}")
        print(f"Aggressive GC: {args.aggressive_gc}")
    
    # Determine model path and starting iteration
    model_path = args.resume if args.resume else args.model_name
    base_model_name = model_path.split('/')[-1] if '/' in model_path else model_path
    start_iteration = args.start_iteration if args.resume else 0
    
    if accelerator.is_main_process:
        if args.resume:
            print(f"Resuming training from checkpoint: {model_path}")
            print(f"Starting from iteration: {start_iteration}")
        else:
            print(f"Starting new training with model: {model_path}")
    
    hf_cache_dir = args.hf_cache_dir
    
    # Load model on main process first then sync
    model_list = []
    for model_index in range(args.gpu_threads):
        try:
            if accelerator.is_main_process and model_index == 0:
                
            model = AutoModelForCausalLM.from_pretrained(
                model_path,
                cache_dir=hf_cache_dir,
                device_map={"": accelerator.process_index},
                torch_dtype=torch.float16 if args.mixed_precision == 'fp16' else (torch.bfloat16 if args.mixed_precision == 'bf16' else torch.float32),
                # attn_implementation="flash_attention_2"
            )
            model_list.append(model)
            
            if accelerator.is_main_process and model_index == 0:
                print(f"Model {model_index+1}/{args.gpu_threads} loaded successfully")
                
        except Exception as e:
            error_msg = f"Error loading model from {model_path}: {e}"
            if accelerator.is_main_process:
                print(error_msg)
            raise RuntimeError(error_msg)
    
    
    tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=False, cache_dir=hf_cache_dir)

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        print(f"Set tokenizer.pad_token = tokenizer.eos_token ({tokenizer.eos_token})")
    
    if accelerator.is_main_process:
        print("Model loaded successfully")
    
    # Prepare model with accelerator
    for model in model_list:
        model.eval()  # Turn off dropout, etc.
        model = accelerator.prepare(model)
    
    force_memory_cleanup()
    
    # Record total training start time
    training_start_time = time.time()
    
    for iteration in range(start_iteration, NUM_ITERATIONS):
        # Record iteration start time
        iter_start_time = time.time()
        
        # Force garbage collection
        force_memory_cleanup()
        
        if args.verbose:
            print(f"Process {accelerator.process_index} starting iteration {iteration + 1}/{NUM_ITERATIONS}")
        
        # Generate seeds on main process only
        if accelerator.is_main_process:
            if args.verbose:
                print(f"Main process {accelerator.process_index} generating seeds")
            seeds = np.random.randint(0, 2**31, size=POPULATION_SIZE, dtype=np.int64).tolist()
            seeds_tensor = torch.tensor(seeds, device=accelerator.device)
        else:
            if args.verbose:
                print(f"Worker process {accelerator.process_index} waiting for seeds")
            seeds_tensor = torch.zeros(POPULATION_SIZE, dtype=torch.long, device=accelerator.device)
            
        # Broadcast seeds from main process to all processes
        torch.distributed.broadcast(seeds_tensor, src=0)
        seeds = seeds_tensor.cpu().tolist()  # Convert back to list for all processes
        
        if args.verbose:
            print(f"Process {accelerator.process_index} received seeds")
        
        # Assign seeds to each process for processing
        local_seeds = []
        for seed_idx, seed in enumerate(seeds):
            # Simple task assignment: assign seeds by process ID
            if seed_idx % accelerator.num_processes == accelerator.process_index:
                local_seeds.append((seed_idx, seed))
                
        if args.verbose:
            print(f"Process {accelerator.process_index} assigned {len(local_seeds)} seeds: {[idx for idx, _ in local_seeds]}")
        
        # Process seeds in smaller batches to reduce memory pressure
        local_rewards = []
        batch_size = max(1, min(args.gpu_threads, len(local_seeds)))
        
        for batch_start in range(0, len(local_seeds), batch_size):
            batch_end = min(batch_start + batch_size, len(local_seeds))
            batch_seeds = local_seeds[batch_start:batch_end]
            
            with ThreadPoolExecutor(max_workers=len(batch_seeds)) as executor:
                # Prepare thread arguments
                thread_args = []
                for thread_id, (seed_idx, seed) in enumerate(batch_seeds):
                    # Pass verbose flag as argument to process_seed function
                    thread_args.append((seed_idx, seed, model_list[thread_id], tokenizer, accelerator, thread_id, args.verbose))
                
                # Execute in parallel and collect results
                results = list(executor.map(process_seed, thread_args))
                local_rewards.extend(results)
            
            # Clean up between batches
            force_memory_cleanup()
        
        # Collect rewards from all processes
        all_rewards = torch.zeros(POPULATION_SIZE, device=accelerator.device)
        
        # Fill in locally computed rewards
        for seed_idx, reward in local_rewards:
            all_rewards[seed_idx] = reward
            
        # Aggregate rewards from all processes (each process will get the full reward list)
        torch.distributed.all_reduce(all_rewards, op=torch.distributed.ReduceOp.SUM)
        
        # Convert aggregated rewards back to Python list
        rewards = all_rewards.cpu().tolist()
        # Clean up no longer needed tensor
        del all_rewards
        force_memory_cleanup()
        
        # Convert rewards to a tensor and normalize.
        rewards_tensor = np.array(rewards, dtype=np.float32)
        rewards_normalized = (rewards_tensor - rewards_tensor.mean()) / (rewards_tensor.std() + 1e-8)
        

        if args.verbose:
            print(f"Process {accelerator.process_index} updating model weights")
        original_model = accelerator.unwrap_model(model_list[0])
        for name, param in original_model.named_parameters():
            gen = torch.Generator(device=param.device)
            update = torch.zeros_like(param)
            for seed_idx in range(POPULATION_SIZE):
                r_norm = rewards_normalized[seed_idx]
                seed = seeds[seed_idx]
                gen.manual_seed(int(seed))
                
                noise = torch.randn(
                    param.shape,
                    generator=gen,
                    device=param.device,
                    dtype=param.dtype
                )
                noise.mul_(float(r_norm))
                update.add_(noise)
                del noise
            update.div_(POPULATION_SIZE)
            param.data.add_(ALPHA * update)
            torch.cuda.empty_cache()

        for model_idx in range(1, len(model_list)):
            original_model_tmp = accelerator.unwrap_model(model_list[model_idx])
            for name, param in original_model_tmp.named_parameters():
                param.data.copy_(original_model.get_parameter(name).data.clone())
        
        # Synchronize to ensure weight updates are complete
        if torch.cuda.is_available():
            torch.cuda.synchronize(accelerator.device)
        
        force_memory_cleanup()
        
        iter_time = time.time() - iter_start_time
        
        mean_reward = rewards_tensor.mean().item()
        min_reward = rewards_tensor.min().item()
        max_reward = rewards_tensor.max().item()
        
        del rewards_tensor, rewards_normalized
        force_memory_cleanup()
        
        if accelerator.is_main_process:
            print(f"Iteration {iteration + 1}/{NUM_ITERATIONS}, Time: {iter_time:.2f}s, Mean: {mean_reward:.2f}, Min: {min_reward:.2f}, Max: {max_reward:.2f}")
            print(f"GPU Memory: {torch.cuda.memory_allocated() / 1024**2:.2f}MB allocated, {torch.cuda.max_memory_allocated() / 1024**2:.2f}MB peak")
        
        # Save checkpoint every 100 iterations
        if (iteration + 1) % 100 == 0 and accelerator.is_main_process:
            checkpoint_dir = f"iter{iteration + 1}_{base_model_name}_pop{POPULATION_SIZE}_{args.mixed_precision}_{args.data_sample}samples_{args.reward_type}_{max_new_tokens}tokens_countdown"
            print(f"Saving checkpoint at iteration {iteration + 1} to {checkpoint_dir}...")
            try:
                original_model.save_pretrained(checkpoint_dir)
                tokenizer.save_pretrained(checkpoint_dir)
                
                # Save training configuration for resuming
                config_path = os.path.join(checkpoint_dir, "training_config.json")
                with open(config_path, "w") as f:
                    json.dump({
                        "iteration": iteration + 1,
                        "population_size": POPULATION_SIZE,
                        "sigma": SIGMA,
                        "alpha": ALPHA,
                        "model_name": args.model_name,
                        "mixed_precision": args.mixed_precision,
                        "reward_type": args.reward_type,
                        "max_new_tokens": max_new_tokens,
                        "data_sample": args.data_sample,
                    }, f, indent=2)
                
                print(f"Checkpoint saved successfully.")
            except Exception as e:
                print(f"Error saving checkpoint: {e}")
                print("Continuing training without saving checkpoint.")
    
    total_time = time.time() - training_start_time
    

    # Save the fine-tuned model weights.
    if accelerator.is_main_process:
        print(f"Training completed in {total_time:.2f}s ({total_time/60:.2f} minutes)")
        
        # Calculate total iterations completed
        total_iterations_completed = NUM_ITERATIONS - start_iteration
        
        # Include starting iteration in the name when resuming
        if args.resume:
            resume_info = f"_resumed_from_{start_iteration}_total_{total_iterations_completed}"
        else:
            resume_info = f"_total_{total_iterations_completed}"
            
        save_dir = f"ft_{base_model_name}_es_pop{POPULATION_SIZE}_iter{NUM_ITERATIONS}{resume_info}_{args.mixed_precision}_{args.data_sample}samples_{args.reward_type}_{max_new_tokens}tokens_countdown"
        print(f"Saving model to {save_dir}...")
        try:
            original_model.save_pretrained(save_dir)
            tokenizer.save_pretrained(save_dir)
            
            # Save final training configuration
            config_path = os.path.join(save_dir, "training_config.json")
            with open(config_path, "w") as f:
                json.dump({
                    "initial_model": args.model_name,
                    "resumed_from": args.resume if args.resume else None,
                    "start_iteration": start_iteration,
                    "end_iteration": NUM_ITERATIONS,
                    "total_iterations_completed": total_iterations_completed,
                    "population_size": POPULATION_SIZE,
                    "sigma": SIGMA,
                    "alpha": ALPHA,
                    "mixed_precision": args.mixed_precision,
                    "reward_type": args.reward_type,
                    "max_new_tokens": max_new_tokens,
                    "data_sample": args.data_sample,
                    "training_time_seconds": total_time,
                }, f, indent=2)
                
            print(f"Model saved successfully.")
        except Exception as e:
            print(f"Error saving final model: {e}")
            # Try to save to a backup location
            backup_dir = f"backup_model_{int(time.time())}"
            print(f"Attempting to save to backup location: {backup_dir}")
            try:
                original_model.save_pretrained(backup_dir)
                print(f"Model saved to backup location: {backup_dir}")
            except Exception as e2:
                print(f"Failed to save to backup location: {e2}")
                print("Training completed but model could not be saved.")

if __name__ == "__main__":
    os.environ["PYTHONWARNINGS"] = "ignore"
    mp.set_start_method('spawn', force=True)
    main()