import itertools
import os
import random
import numpy as np
import torch
import torch.nn.functional as F
import torch.multiprocessing as mp
from typing import List, Tuple, Optional, Dict, Any
import csv
import pickle
from datetime import datetime
import time
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
)

import fsspec
import hydra
import lightning as L
import omegaconf
import rich.syntax
import rich.tree

import dataloader
import diffusion
import utils

# Import FK steering functions from main_fkd
from main_fkd import (
    compute_entropy_reward, 
    compute_potential_scores, 
    resample_particles, 
    fk_steering_mdlm,
    _load_from_checkpoint
)


def set_seeds_for_sampling(seed_value: int):
    """Set seeds for reproducible sampling."""
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)
    np.random.seed(seed_value)


def aggressive_memory_cleanup(gpu_id: int = None):
    """Perform aggressive memory cleanup."""
    import gc
    
    # Force garbage collection
    gc.collect()
    
    # Clear CUDA cache
    if torch.cuda.is_available():
        if gpu_id is not None:
            torch.cuda.set_device(gpu_id)
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    
    # Additional garbage collection
    gc.collect()


def load_model_on_gpu(config, tokenizer, gpu_id: int):
    """Load model on specific GPU."""
    torch.cuda.set_device(gpu_id)
    model = _load_from_checkpoint(config=config, tokenizer=tokenizer)
    model = model.to(f'cuda:{gpu_id}')
    model.gen_ppl_metric.reset()
    
    if config.eval.disable_ema:
        model.ema = None
    
    # Prepare model for sampling
    if model.ema:
        model.ema.store(itertools.chain(
            model.backbone.parameters(),
            model.noise.parameters()))
        model.ema.copy_to(itertools.chain(
            model.backbone.parameters(),
            model.noise.parameters()))
    
    model.backbone.eval()
    model.noise.eval()
    
    return model


def generate_global_seeds(total_runs: int, base_seed: int, num_particles: int) -> List[List[int]]:
    """Generate global seed mapping for all runs."""
    np.random.seed(base_seed)
    # Generate a unique base seed for each run
    run_seeds = [base_seed + i * 10000 for i in range(total_runs)]
    
    # For each run, generate seeds for each particle
    all_run_seeds = []
    for run_seed in run_seeds:
        np.random.seed(run_seed)
        particle_seeds = [run_seed + j for j in range(num_particles)]  # Use actual num_particles
        all_run_seeds.append(particle_seeds)
    
    return all_run_seeds


def best_of_n_sampling_gpu(
    model,
    num_particles: int,
    num_steps: int,
    shared_seeds: List[int],
    eps: float = 1e-5
) -> Tuple[torch.Tensor, float, List[float], List[torch.Tensor]]:
    """
    Best-of-n sampling using the same seeds as FK steering.
    Optimized for single GPU execution.
    """
    device = model.device
    seq_len = model.config.model.length
    
    all_samples = []
    all_u_denoise = []
    
    # Generate samples independently using shared seeds
    for i, seed in enumerate(shared_seeds):
        # Set seed for reproducible generation
        set_seeds_for_sampling(seed)
        
        # Initialize single particle from prior
        particle = model._sample_prior(1, seq_len).to(device)
        
        timesteps = torch.linspace(1, eps, num_steps + 1, device=device)
        dt = (1 - eps) / num_steps
        
        # Track entropy for this particle
        entropy_trajectory = []
        p_x0_cache = None
        
        # Generation loop
        for step in range(num_steps):
            t_val = timesteps[step]
            t = t_val * torch.ones(1, 1, device=device)
            
            # Compute entropy for current state
            current_entropy = model._compute_mask_conditional_entropy(particle, t)
            entropy_trajectory.append(current_entropy.item())
            
            # Generate next state
            if model.sampler == 'ddpm':
                particle = model._ddpm_update(particle, t, dt)
            elif model.sampler == 'ddpm_cache':
                p_x0_cache, particle_next = model._ddpm_caching_update(
                    particle, t, dt, p_x0=p_x0_cache)
                if (not torch.allclose(particle_next, particle)
                    or model.time_conditioning):
                    p_x0_cache = None
                particle = particle_next
            else:
                particle = model._analytic_update(particle, t, dt)
        
        # Final denoising step
        if model.config.sampling.noise_removal:
            t_final = timesteps[-1] * torch.ones(particle.shape[0], 1, device=device)
            final_entropy = model._compute_mask_conditional_entropy(particle, t_final)
            entropy_trajectory.append(final_entropy.item())
            
            if model.sampler == 'analytic':
                particle = model._denoiser_update(particle, t_final)
            else:
                unet_conditioning = model.noise(t_final)[0]
                particle = model.forward(particle, unet_conditioning).argmax(dim=-1)
        
        # Compute U_denoise (total denoising uncertainty)
        time_points = timesteps[:len(entropy_trajectory)].cpu().numpy()
        u_denoise = abs(np.trapz(entropy_trajectory, x=time_points))
        
        all_samples.append(particle)
        all_u_denoise.append(u_denoise)
    
    # Select best sample (lowest U_denoise)
    best_idx = np.argmin(all_u_denoise)
    best_sample = all_samples[best_idx]
    best_u_denoise = all_u_denoise[best_idx]
    
    return best_sample, best_u_denoise, all_u_denoise, all_samples


def fk_steering_mdlm_with_seeds_gpu(
    model,
    num_particles: int,
    num_steps: int,
    resample_interval: int,
    lambda_weight: float,
    potential_type: str,
    shared_seeds: List[int],
    eps: float = 1e-5,
    memory_efficient: bool = True
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
    """
    FK steering algorithm using shared seeds for initialization.
    Optimized for single GPU execution.
    """
    device = model.device
    seq_len = model.config.model.length
    
    # Initialize particles using shared seeds
    particles = []
    for seed in shared_seeds:
        set_seeds_for_sampling(seed)
        particle = model._sample_prior(1, seq_len).to(device)
        particles.append(particle)
    
    particles = torch.cat(particles, dim=0)
    
    # Initialize detailed tracking
    detailed_logs = {
        'particle_entropies': np.zeros((num_steps + 1, num_particles)),
        'particle_rewards': np.zeros((num_steps + 1, num_particles)),
        'particle_ids': np.arange(num_particles)[None, :].repeat(num_steps + 1, axis=0),
        'resampling_steps': [],
        'resampling_info': [],
        'timesteps': [],
        'particle_lineage': [list(range(num_particles))],
    }
    
    timesteps = torch.linspace(1, eps, num_steps + 1, device=device)
    dt = (1 - eps) / num_steps
    p_x0_cache = None
    
    # FK steering algorithm
    for step in range(num_steps):
        t_val = timesteps[step]
        t = t_val * torch.ones(num_particles, 1, device=device)
        
        detailed_logs['timesteps'].append(t_val.item())
        
        # Compute entropy and rewards for current state (individual processing to avoid batch size issues)
        current_entropies = []
        current_rewards = []
        with torch.no_grad():
            for i in range(num_particles):
                particle = particles[i:i+1]
                t_particle = t[i:i+1]
                
                entropy = model._compute_mask_conditional_entropy(particle, t_particle)
                reward = compute_entropy_reward(model, particle, t_particle)
                
                current_entropies.append(entropy)
                current_rewards.append(reward)
        
        current_entropies = torch.stack(current_entropies).squeeze(-1)  # Remove last dimension
        current_rewards = torch.stack(current_rewards).squeeze(-1)    # Remove last dimension
        
        detailed_logs['particle_entropies'][step] = current_entropies.detach().cpu().numpy()
        detailed_logs['particle_rewards'][step] = current_rewards.detach().cpu().numpy()
        
        # Generate next state for all particles
        new_particles = []
        for i in range(num_particles):
            particle = particles[i:i+1]
            t_particle = t[i:i+1]
            
            if model.sampler == 'ddpm':
                new_particle = model._ddpm_update(particle, t_particle, dt)
            elif model.sampler == 'ddpm_cache':
                new_particle = model._ddpm_update(particle, t_particle, dt)
            else:
                new_particle = model._analytic_update(particle, t_particle, dt)
            
            new_particles.append(new_particle)
        
        particles = torch.cat(new_particles, dim=0)
        
        # Resampling logic
        if (step + 1) % resample_interval == 0 and step < num_steps - 1:
            t_resample = timesteps[step + 1] * torch.ones(num_particles, 1, device=device)
            
            # Compute entropy and rewards for resampling (individual processing)
            resample_entropies = []
            resample_rewards = []
            with torch.no_grad():
                for i in range(num_particles):
                    particle = particles[i:i+1]
                    t_particle = t_resample[i:i+1]
                    
                    entropy = model._compute_mask_conditional_entropy(particle, t_particle)
                    reward = compute_entropy_reward(model, particle, t_particle)
                    
                    resample_entropies.append(entropy)
                    resample_rewards.append(reward)
            
            resample_entropies = torch.stack(resample_entropies).squeeze(-1)  # Remove last dimension
            resample_rewards = torch.stack(resample_rewards).squeeze(-1)    # Remove last dimension
            
            potential_scores = compute_potential_scores(
                resample_rewards, lambda_weight, potential_type
            )
            
            pre_resample_entropies = resample_entropies.detach().cpu().numpy()
            pre_resample_rewards = resample_rewards.detach().cpu().numpy()
            
            probabilities = F.softmax(potential_scores, dim=0)
            selected_indices = torch.multinomial(probabilities, num_particles, replacement=True)
            particles = particles[selected_indices]
            
            current_lineage = detailed_logs['particle_lineage'][-1]
            new_lineage = [current_lineage[idx.item()] for idx in selected_indices]
            detailed_logs['particle_lineage'].append(new_lineage)
            
            detailed_logs['resampling_steps'].append(step + 1)
            detailed_logs['resampling_info'].append({
                'selected_indices': selected_indices.cpu().numpy(),
                'probabilities': probabilities.detach().cpu().numpy(),
                'pre_resample_entropies': pre_resample_entropies,
                'pre_resample_rewards': pre_resample_rewards,
                'particle_lineage': new_lineage.copy()
            })
        else:
            if step == 0 or (step + 1) % resample_interval != 0:
                detailed_logs['particle_lineage'].append(detailed_logs['particle_lineage'][-1].copy())
    
    detailed_logs['timesteps'].append(eps)
    
    # Final denoising step
    if model.config.sampling.noise_removal:
        t_final = timesteps[-1] * torch.ones(particles.shape[0], 1, device=device)
        
        # Compute final entropy and rewards (individual processing)
        final_entropies = []
        final_rewards = []
        with torch.no_grad():
            for i in range(num_particles):
                particle = particles[i:i+1]
                t_particle = t_final[i:i+1]
                
                entropy = model._compute_mask_conditional_entropy(particle, t_particle)
                reward = compute_entropy_reward(model, particle, t_particle)
                
                final_entropies.append(entropy)
                final_rewards.append(reward)
        
        final_entropies = torch.stack(final_entropies).squeeze(-1)  # Remove last dimension
        final_rewards = torch.stack(final_rewards).squeeze(-1)    # Remove last dimension
        
        detailed_logs['particle_entropies'][num_steps] = final_entropies.detach().cpu().numpy()
        detailed_logs['particle_rewards'][num_steps] = final_rewards.detach().cpu().numpy()
        
        final_particles = []
        for i in range(num_particles):
            particle = particles[i:i+1]
            t_particle = t_final[i:i+1]
            
            if model.sampler == 'analytic':
                denoised = model._denoiser_update(particle, t_particle)
            else:
                unet_conditioning = model.noise(t_particle)[0]
                denoised = model.forward(particle, unet_conditioning).argmax(dim=-1)
            
            final_particles.append(denoised)
        
        particles = torch.cat(final_particles, dim=0)
    else:
        t_final = eps * torch.ones(particles.shape[0], 1, device=device)
        
        # Compute final entropy and rewards (individual processing)
        final_entropies = []
        final_rewards = []
        with torch.no_grad():
            for i in range(num_particles):
                particle = particles[i:i+1]
                t_particle = t_final[i:i+1]
                
                entropy = model._compute_mask_conditional_entropy(particle, t_particle)
                reward = compute_entropy_reward(model, particle, t_particle)
                
                final_entropies.append(entropy)
                final_rewards.append(reward)
        
        final_entropies = torch.stack(final_entropies).squeeze(-1)  # Remove last dimension
        final_rewards = torch.stack(final_rewards).squeeze(-1)    # Remove last dimension
        
        detailed_logs['particle_entropies'][num_steps] = final_entropies.detach().cpu().numpy()
        detailed_logs['particle_rewards'][num_steps] = final_rewards.detach().cpu().numpy()
    
    # Select best sample based on final entropy
    t_final = eps * torch.ones(particles.shape[0], 1, device=device)
    
    # Compute final rewards for selection (individual processing)
    final_rewards_for_selection = []
    with torch.no_grad():
        for i in range(num_particles):
            particle = particles[i:i+1]
            t_particle = t_final[i:i+1]
            reward = compute_entropy_reward(model, particle, t_particle)
            final_rewards_for_selection.append(reward)
    
    final_rewards = torch.stack(final_rewards_for_selection).squeeze(-1)  # Remove last dimension
    best_idx = torch.argmax(final_rewards)
    best_sample = particles[best_idx:best_idx+1]
    
    # Add summary statistics to logs
    detailed_logs['final_rewards'] = final_rewards.detach().cpu().numpy()
    detailed_logs['best_particle_idx'] = best_idx.item()
    detailed_logs['num_particles'] = num_particles
    detailed_logs['num_steps'] = num_steps
    detailed_logs['resample_interval'] = resample_interval
    detailed_logs['lambda_weight'] = lambda_weight
    
    return best_sample, particles, detailed_logs


def compute_single_generative_perplexity(
    model, 
    text_sample: str, 
    eval_model_name: str = "gpt2-large"
) -> float:
    """
    Compute generative perplexity for a single text sample.
    """
    
    # Load evaluation model and tokenizer
    eval_model = AutoModelForCausalLM.from_pretrained(
        eval_model_name).eval().to(model.device)
    eval_tokenizer = AutoTokenizer.from_pretrained(eval_model_name)
    
    if eval_tokenizer.pad_token is None:
        eval_tokenizer.pad_token = eval_tokenizer.eos_token
        eval_tokenizer.pad_token_id = eval_tokenizer.eos_token_id
    
    # Tokenize the sample
    tokenized = eval_tokenizer(
        text_sample,
        return_tensors='pt',
        return_attention_mask=True,
        truncation=True,
        padding=True,
        max_length=model.config.model.length
    )
    
    input_ids = tokenized['input_ids'].to(model.device)
    attention_mask = tokenized['attention_mask'].to(model.device)
    
    with torch.no_grad():
        logits = eval_model(input_ids, attention_mask=attention_mask)[0]
        logits = logits.transpose(-1, -2)
        
        # Compute negative log likelihood
        nlls = F.cross_entropy(
            logits[..., :-1],
            input_ids[..., 1:],
            reduction='none'
        )
        
        # Create mask for valid tokens (not padding)
        first_eos = (input_ids == eval_tokenizer.eos_token_id).cumsum(-1) == 1
        token_mask = (input_ids != eval_tokenizer.eos_token_id)
        valid_mask = first_eos[..., 1:] + token_mask[..., 1:]
        
        # Compute perplexity
        total_nll = (nlls * valid_mask).sum()
        total_tokens = valid_mask.sum()
        
        if total_tokens > 0:
            avg_nll = total_nll / total_tokens
            perplexity = torch.exp(avg_nll).item()
        else:
            perplexity = float('inf')
    
    return perplexity


def gpu_worker_best_of_n(
    gpu_id: int,
    config: omegaconf.DictConfig,
    tokenizer,
    run_assignments: List[Tuple[int, List[int]]],
    results_queue: mp.Queue
):
    """
    Worker function for Best-of-N sampling on a single GPU.
    
    Args:
        gpu_id: GPU device ID
        config: experiment configuration
        tokenizer: tokenizer object
        run_assignments: list of (run_id, shared_seeds) tuples assigned to this GPU
        results_queue: queue to store results
    """
    try:
        # Set GPU
        torch.cuda.set_device(gpu_id)
        
        # Aggressive memory cleanup before starting
        aggressive_memory_cleanup(gpu_id)
        
        # Load model on this GPU
        model = load_model_on_gpu(config, tokenizer, gpu_id)
        
        print(f"GPU {gpu_id}: Loaded model, processing {len(run_assignments)} runs")
        
        # Process assigned runs
        for run_id, shared_seeds in run_assignments:
            print(f"GPU {gpu_id}: Processing Best-of-N run {run_id}")
            
            # Run Best-of-N sampling
            bestn_sample, bestn_u_denoise, all_bestn_u_denoise, all_bestn_samples = best_of_n_sampling_gpu(
                model=model,
                num_particles=config.fk_steering.num_particles,
                num_steps=config.sampling.steps,
                shared_seeds=shared_seeds
            )
            
            # Decode sample to text
            bestn_text = model.tokenizer.decode(bestn_sample[0])
            
            # Random inference: randomly select one from all samples
            random.seed(shared_seeds[0])  # Use first seed for consistent randomness
            random_idx = random.randint(0, len(all_bestn_samples) - 1)
            random_sample = all_bestn_samples[random_idx]
            random_text = model.tokenizer.decode(random_sample[0])
            random_u_denoise = all_bestn_u_denoise[random_idx]
            
            # Compute generative perplexity
            bestn_ppl = compute_single_generative_perplexity(
                model, bestn_text, config.eval.gen_ppl_eval_model_name_or_path
            )
            
            # Compute metrics for random inference sample
            random_ppl = compute_single_generative_perplexity(
                model, random_text, config.eval.gen_ppl_eval_model_name_or_path
            )
            
            # Store results
            result = {
                'run_id': run_id,
                'gpu_id': gpu_id,
                'method': 'best_of_n',
                'sample': bestn_text,
                'U_denoise': bestn_u_denoise,
                'perplexity': bestn_ppl,
                'all_u_denoise': all_bestn_u_denoise,
                'shared_seeds': shared_seeds,
                # Random inference results
                'random_sample': random_text,
                'random_U_denoise': random_u_denoise,
                'random_perplexity': random_ppl,
                'random_selected_idx': random_idx
            }
            
            results_queue.put(result)
            print(f"GPU {gpu_id}: Completed Best-of-N run {run_id}")
            print(f"  Best-of-N: U_denoise: {bestn_u_denoise:.4f}, PPL: {bestn_ppl:.4f}")
            print(f"  Random (idx {random_idx}): U_denoise: {random_u_denoise:.4f}, PPL: {random_ppl:.4f}")
            
            # Aggressive memory cleanup after each run
            aggressive_memory_cleanup(gpu_id)
        
        # Clean up model and clear cache before finishing
        del model
        aggressive_memory_cleanup(gpu_id)
        print(f"GPU {gpu_id}: Completed all Best-of-N runs, cleaned up memory")
        
    except Exception as e:
        print(f"GPU {gpu_id} Best-of-N Error: {e}")
        import traceback
        traceback.print_exc()
        # Clean up on error
        aggressive_memory_cleanup(gpu_id)


def gpu_worker_fk_steering(
    gpu_id: int,
    config: omegaconf.DictConfig,
    tokenizer,
    run_assignments: List[Tuple[int, List[int]]],
    results_queue: mp.Queue
):
    """
    Worker function for FK steering on a single GPU.
    
    Args:
        gpu_id: GPU device ID
        config: experiment configuration
        tokenizer: tokenizer object
        run_assignments: list of (run_id, shared_seeds) tuples assigned to this GPU
        results_queue: queue to store results
    """
    try:
        # Set GPU
        torch.cuda.set_device(gpu_id)
        
        # Aggressive memory cleanup before starting (critical for FK steering)
        aggressive_memory_cleanup(gpu_id)
        
        # Wait a bit to ensure memory is fully cleared
        import time
        time.sleep(2)
        
        # Additional aggressive cleanup
        aggressive_memory_cleanup(gpu_id)
        
        # Load model on this GPU
        model = load_model_on_gpu(config, tokenizer, gpu_id)
        
        print(f"GPU {gpu_id}: Loaded model for FK steering, processing {len(run_assignments)} runs")
        
        # Process assigned runs
        for run_id, shared_seeds in run_assignments:
            print(f"GPU {gpu_id}: Processing FK steering run {run_id}")
            
            # Aggressive memory cleanup before each run
            aggressive_memory_cleanup(gpu_id)
            
            # Run FK steering
            fkd_sample, all_fkd_samples, fkd_detailed_logs = fk_steering_mdlm_with_seeds_gpu(
                model=model,
                num_particles=config.fk_steering.num_particles,
                num_steps=config.sampling.steps,
                resample_interval=config.fk_steering.resample_interval,
                lambda_weight=config.fk_steering.lambda_weight,
                potential_type=config.fk_steering.potential_type,
                shared_seeds=shared_seeds
            )
            
            # Compute FK steering U_denoise from detailed logs
            entropy_array = fkd_detailed_logs['particle_entropies']
            timesteps = np.array(fkd_detailed_logs['timesteps'])
            
            # Get the entropy trajectory of the best particle
            best_particle_idx = fkd_detailed_logs['best_particle_idx']
            fkd_entropy_trajectory = entropy_array[:, best_particle_idx]
            fkd_u_denoise = abs(np.trapz(fkd_entropy_trajectory, x=timesteps))
            
            # Decode sample to text
            fkd_text = model.tokenizer.decode(fkd_sample[0])
            
            # Compute generative perplexity
            fkd_ppl = compute_single_generative_perplexity(
                model, fkd_text, config.eval.gen_ppl_eval_model_name_or_path
            )
            
            
            # Store results
            result = {
                'run_id': run_id,
                'gpu_id': gpu_id,
                'method': 'fk_steering',
                'sample': fkd_text,
                'U_denoise': fkd_u_denoise,
                'perplexity': fkd_ppl,
                'detailed_logs': fkd_detailed_logs,
                'shared_seeds': shared_seeds
            }
            
            results_queue.put(result)
            print(f"GPU {gpu_id}: Completed FK steering run {run_id}, U_denoise: {fkd_u_denoise:.4f}, PPL: {fkd_ppl:.4f}")
            
            # Aggressive memory cleanup after each run
            aggressive_memory_cleanup(gpu_id)
        
        # Clean up model and clear cache before finishing
        del model
        aggressive_memory_cleanup(gpu_id)
        print(f"GPU {gpu_id}: Completed all FK steering runs, cleaned up memory")
        
    except Exception as e:
        print(f"GPU {gpu_id} FK steering Error: {e}")
        import traceback
        traceback.print_exc()
        # Clean up on error
        aggressive_memory_cleanup(gpu_id)


def distribute_runs_to_gpus(total_runs: int, num_gpus: int = 8) -> Dict[int, List[int]]:
    """Distribute runs evenly across GPUs."""
    runs_per_gpu = total_runs // num_gpus
    extra_runs = total_runs % num_gpus
    
    distribution = {}
    current_run = 0
    
    for gpu_id in range(num_gpus):
        num_runs_this_gpu = runs_per_gpu + (1 if gpu_id < extra_runs else 0)
        distribution[gpu_id] = list(range(current_run, current_run + num_runs_this_gpu))
        current_run += num_runs_this_gpu
    
    return distribution


def run_parallel_evaluation(
    config: omegaconf.DictConfig,
    tokenizer,
    all_run_seeds: List[List[int]],
    num_gpus: int = 8
) -> Tuple[List[dict], List[dict]]:
    """
    Run parallel evaluation for both Best-of-N and FK steering.
    
    Returns:
        bestn_results: list of Best-of-N results
        fkd_results: list of FK steering results
    """
    total_runs = len(all_run_seeds)
    
    # Distribute runs to GPUs
    gpu_distribution = distribute_runs_to_gpus(total_runs, num_gpus)
    
    print(f"Distributing {total_runs} runs across {num_gpus} GPUs:")
    for gpu_id, run_ids in gpu_distribution.items():
        print(f"  GPU {gpu_id}: runs {run_ids}")
    
    # Phase 1: Best-of-N evaluation
    print(f"\n=== Phase 1: Best-of-N Evaluation ===")
    start_time = time.time()
    
    bestn_results_queue = mp.Queue()
    bestn_processes = []
    
    for gpu_id, run_ids in gpu_distribution.items():
        run_assignments = [(run_id, all_run_seeds[run_id]) for run_id in run_ids]
        
        p = mp.Process(
            target=gpu_worker_best_of_n,
            args=(gpu_id, config, tokenizer, run_assignments, bestn_results_queue)
        )
        p.start()
        bestn_processes.append(p)
    
    # Collect Best-of-N results
    bestn_results = []
    for _ in range(total_runs):
        result = bestn_results_queue.get()
        bestn_results.append(result)
    
    # Wait for all processes to complete
    for p in bestn_processes:
        p.join()
    
    bestn_time = time.time() - start_time
    print(f"Best-of-N evaluation completed in {bestn_time:.2f} seconds")
    
    # Wait additional time to ensure all GPU memory is freed
    print("Waiting 5 seconds for GPU memory cleanup...")
    time.sleep(5)
    
    # Phase 2: FK steering evaluation
    print(f"\n=== Phase 2: FK Steering Evaluation ===")
    start_time = time.time()
    
    fkd_results_queue = mp.Queue()
    fkd_processes = []
    
    for gpu_id, run_ids in gpu_distribution.items():
        run_assignments = [(run_id, all_run_seeds[run_id]) for run_id in run_ids]
        
        p = mp.Process(
            target=gpu_worker_fk_steering,
            args=(gpu_id, config, tokenizer, run_assignments, fkd_results_queue)
        )
        p.start()
        fkd_processes.append(p)
    
    # Collect FK steering results
    fkd_results = []
    for _ in range(total_runs):
        result = fkd_results_queue.get()
        fkd_results.append(result)
    
    # Wait for all processes to complete
    for p in fkd_processes:
        p.join()
    
    fkd_time = time.time() - start_time
    print(f"FK steering evaluation completed in {fkd_time:.2f} seconds")
    print(f"Total evaluation time: {bestn_time + fkd_time:.2f} seconds")
    
    # Sort results by run_id
    bestn_results.sort(key=lambda x: x['run_id'])
    fkd_results.sort(key=lambda x: x['run_id'])
    
    return bestn_results, fkd_results


def combine_results(bestn_results: List[dict], fkd_results: List[dict]) -> List[dict]:
    """Combine Best-of-N and FK steering results."""
    combined_results = []
    
    # Create lookup for FK results
    fkd_lookup = {result['run_id']: result for result in fkd_results}
    
    for bestn_result in bestn_results:
        run_id = bestn_result['run_id']
        fkd_result = fkd_lookup[run_id]
        
        # Combine results for this run
        combined = {
            'run_id': run_id,
            'shared_seeds': bestn_result['shared_seeds'],
            'bestn_sample': bestn_result['sample'],
            'fkd_sample': fkd_result['sample'],
            'random_sample': bestn_result['random_sample'],
            'bestn_U_denoise': bestn_result['U_denoise'],
            'fkd_U_denoise': fkd_result['U_denoise'],
            'random_U_denoise': bestn_result['random_U_denoise'],
            'bestn_ppl': bestn_result['perplexity'],
            'fkd_ppl': fkd_result['perplexity'],
            'random_ppl': bestn_result['random_perplexity'],
            'random_selected_idx': bestn_result['random_selected_idx'],
            'all_bestn_u_denoise': bestn_result['all_u_denoise'],
            'fkd_detailed_logs': fkd_result['detailed_logs']
        }
        
        combined_results.append(combined)
    
    return combined_results


def save_results_to_csv(all_results: List[dict], output_dir: str):
    """Save comparison results to CSV file."""
    os.makedirs(output_dir, exist_ok=True)
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    csv_file = os.path.join(output_dir, f"fk_vs_bestn_comparison_{timestamp}.csv")
    
    fieldnames = [
        'run_id', 'shared_seeds', 'random_selected_idx',
        'bestn_sample', 'fkd_sample', 'random_sample',
        'bestn_U_denoise', 'fkd_U_denoise', 'random_U_denoise',
        'bestn_ppl', 'fkd_ppl', 'random_ppl',
        'bestn_vs_random_u_denoise', 'fkd_vs_random_u_denoise',
        'bestn_vs_random_ppl', 'fkd_vs_random_ppl',
        'fkd_vs_bestn_u_denoise', 'fkd_vs_bestn_ppl'
    ]
    
    with open(csv_file, 'w', newline='', encoding='utf-8') as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        
        for result in all_results:
            # Compute comparisons vs random
            bestn_vs_random_u_denoise = result['random_U_denoise'] - result['bestn_U_denoise']  # Lower is better
            fkd_vs_random_u_denoise = result['random_U_denoise'] - result['fkd_U_denoise']
            bestn_vs_random_ppl = result['random_ppl'] - result['bestn_ppl']  # Lower is better
            fkd_vs_random_ppl = result['random_ppl'] - result['fkd_ppl']
            
            # Compute FK vs Best-of-N (original comparisons)
            fkd_vs_bestn_u_denoise = result['bestn_U_denoise'] - result['fkd_U_denoise']
            fkd_vs_bestn_ppl = result['bestn_ppl'] - result['fkd_ppl']
            
            row = {
                'run_id': result['run_id'],
                'shared_seeds': str(result['shared_seeds']),
                'random_selected_idx': result['random_selected_idx'],
                'bestn_sample': result['bestn_sample'],
                'fkd_sample': result['fkd_sample'],
                'random_sample': result['random_sample'],
                'bestn_U_denoise': f"{result['bestn_U_denoise']:.6f}",
                'fkd_U_denoise': f"{result['fkd_U_denoise']:.6f}",
                'random_U_denoise': f"{result['random_U_denoise']:.6f}",
                'bestn_ppl': f"{result['bestn_ppl']:.6f}",
                'fkd_ppl': f"{result['fkd_ppl']:.6f}",
                'random_ppl': f"{result['random_ppl']:.6f}",
                'bestn_vs_random_u_denoise': f"{bestn_vs_random_u_denoise:.6f}",
                'fkd_vs_random_u_denoise': f"{fkd_vs_random_u_denoise:.6f}",
                'bestn_vs_random_ppl': f"{bestn_vs_random_ppl:.6f}",
                'fkd_vs_random_ppl': f"{fkd_vs_random_ppl:.6f}",
                'fkd_vs_bestn_u_denoise': f"{fkd_vs_bestn_u_denoise:.6f}",
                'fkd_vs_bestn_ppl': f"{fkd_vs_bestn_ppl:.6f}"
            }
            writer.writerow(row)
    
    print(f"\nResults saved to: {csv_file}")
    
    # Also save detailed logs
    logs_file = os.path.join(output_dir, f"detailed_logs_{timestamp}.pkl")
    with open(logs_file, 'wb') as f:
        pickle.dump(all_results, f)
    print(f"Detailed logs saved to: {logs_file}")


def print_summary_statistics(all_results: List[dict]):
    """Print summary statistics of the three-way comparison."""
    print(f"\n=== Summary Statistics (N={len(all_results)}) ===")
    
    # Extract metrics for all three methods
    bestn_u_denoise = [r['bestn_U_denoise'] for r in all_results]
    fkd_u_denoise = [r['fkd_U_denoise'] for r in all_results]
    random_u_denoise = [r['random_U_denoise'] for r in all_results]
    
    bestn_ppl = [r['bestn_ppl'] for r in all_results]
    fkd_ppl = [r['fkd_ppl'] for r in all_results]
    random_ppl = [r['random_ppl'] for r in all_results]
    
    # Compute improvements vs random baseline
    bestn_vs_random_u_denoise = [r - b for r, b in zip(random_u_denoise, bestn_u_denoise)]
    fkd_vs_random_u_denoise = [r - f for r, f in zip(random_u_denoise, fkd_u_denoise)]
    bestn_vs_random_ppl = [r - b for r, b in zip(random_ppl, bestn_ppl)]
    fkd_vs_random_ppl = [r - f for r, f in zip(random_ppl, fkd_ppl)]
    
    # Compute FK vs Best-of-N
    fkd_vs_bestn_u_denoise = [b - f for b, f in zip(bestn_u_denoise, fkd_u_denoise)]
    fkd_vs_bestn_ppl = [b - f for b, f in zip(bestn_ppl, fkd_ppl)]
    
    print(f"\n{'='*60}")
    print(f"U_denoise (lower is better):")
    print(f"  Random:       {np.mean(random_u_denoise):.4f} ± {np.std(random_u_denoise):.4f}")
    print(f"  Best-of-n:    {np.mean(bestn_u_denoise):.4f} ± {np.std(bestn_u_denoise):.4f}")
    print(f"  FK steering:  {np.mean(fkd_u_denoise):.4f} ± {np.std(fkd_u_denoise):.4f}")
    print(f"  Best-of-n vs Random: {np.mean(bestn_vs_random_u_denoise):.4f} ± {np.std(bestn_vs_random_u_denoise):.4f} (wins: {sum(1 for x in bestn_vs_random_u_denoise if x > 0)}/{len(all_results)})")
    print(f"  FK vs Random:        {np.mean(fkd_vs_random_u_denoise):.4f} ± {np.std(fkd_vs_random_u_denoise):.4f} (wins: {sum(1 for x in fkd_vs_random_u_denoise if x > 0)}/{len(all_results)})")
    print(f"  FK vs Best-of-n:     {np.mean(fkd_vs_bestn_u_denoise):.4f} ± {np.std(fkd_vs_bestn_u_denoise):.4f} (wins: {sum(1 for x in fkd_vs_bestn_u_denoise if x > 0)}/{len(all_results)})")
    
    print(f"\n{'='*60}")
    print(f"Generative Perplexity (lower is better):")
    print(f"  Random:       {np.mean(random_ppl):.4f} ± {np.std(random_ppl):.4f}")
    print(f"  Best-of-n:    {np.mean(bestn_ppl):.4f} ± {np.std(bestn_ppl):.4f}")
    print(f"  FK steering:  {np.mean(fkd_ppl):.4f} ± {np.std(fkd_ppl):.4f}")
    print(f"  Best-of-n vs Random: {np.mean(bestn_vs_random_ppl):.4f} ± {np.std(bestn_vs_random_ppl):.4f} (wins: {sum(1 for x in bestn_vs_random_ppl if x > 0)}/{len(all_results)})")
    print(f"  FK vs Random:        {np.mean(fkd_vs_random_ppl):.4f} ± {np.std(fkd_vs_random_ppl):.4f} (wins: {sum(1 for x in fkd_vs_random_ppl if x > 0)}/{len(all_results)})")
    print(f"  FK vs Best-of-n:     {np.mean(fkd_vs_bestn_ppl):.4f} ± {np.std(fkd_vs_bestn_ppl):.4f} (wins: {sum(1 for x in fkd_vs_bestn_ppl if x > 0)}/{len(all_results)})")
    
    print(f"\n{'='*60}")
    print(f"Summary:")
    print(f"  Methods ranked by performance (best to worst):")
    
    # Calculate overall ranking based on wins vs random
    bestn_total_wins = (sum(1 for x in bestn_vs_random_u_denoise if x > 0) + 
                        sum(1 for x in bestn_vs_random_ppl if x > 0))
    
    fkd_total_wins = (sum(1 for x in fkd_vs_random_u_denoise if x > 0) + 
                      sum(1 for x in fkd_vs_random_ppl if x > 0))
    
    print(f"  Best-of-n total wins vs Random: {bestn_total_wins}/{2*len(all_results)}")
    print(f"  FK steering total wins vs Random: {fkd_total_wins}/{2*len(all_results)}")
    
    if fkd_total_wins > bestn_total_wins:
        print(f"  Overall best: FK steering")
    elif bestn_total_wins > fkd_total_wins:
        print(f"  Overall best: Best-of-n")
    else:
        print(f"  Overall: Tie between FK steering and Best-of-n")


@hydra.main(version_base=None, config_path='configs', config_name='config_fk')
def main(config):
    """Main entry point for optimized FK steering vs Best-of-n comparison."""
    
    # Set multiprocessing start method
    mp.set_start_method('spawn', force=True)
    
    # Ensure FK steering configuration exists
    if not hasattr(config, 'fk_steering') or config.fk_steering is None:
        print("Warning: FK steering configuration not found. Using default values.")
        config.fk_steering = omegaconf.DictConfig({
            'num_particles': 8,
            'resample_interval': 50,
            'lambda_weight': 5.0,
            'potential_type': 'max'
        })
    
    # Validate FK steering parameters
    fk_config = config.fk_steering
    if fk_config.num_particles < 1:
        raise ValueError(f"num_particles must be >= 1, got {fk_config.num_particles}")
    if fk_config.resample_interval < 1:
        raise ValueError(f"resample_interval must be >= 1, got {fk_config.resample_interval}")
    if fk_config.lambda_weight <= 0:
        raise ValueError(f"lambda_weight must be > 0, got {fk_config.lambda_weight}")
    
    # Set number of comparison runs
    num_comparison_runs = config.sampling.num_sample_batches
    
    # Validate that num_comparison_runs is multiple of 8
    if num_comparison_runs % 8 != 0:
        raise ValueError(f"num_sample_batches must be a multiple of 8, got {num_comparison_runs}")
    
    print(f"=== Three-way Comparison: FK Steering vs Best-of-n vs Random Inference ===")
    print(f"Configuration:")
    print(f"  num_particles: {fk_config.num_particles}")
    print(f"  resample_interval: {fk_config.resample_interval}")
    print(f"  lambda_weight: {fk_config.lambda_weight}")
    print(f"  potential_type: {fk_config.potential_type}")
    print(f"  num_steps: {config.sampling.steps}")
    print(f"  num_comparison_runs: {num_comparison_runs}")
    print(f"  eval_model: {config.eval.gen_ppl_eval_model_name_or_path}")
    print(f"  num_gpus: 8")
    print(f"  Methods: FK Steering, Best-of-N, Random Inference")
    print(f"  Metrics: U_denoise, Perplexity")
    
    L.seed_everything(config.seed)
    
    logger = utils.get_logger(__name__)
    tokenizer = dataloader.get_tokenizer(config)
    
    # Generate global seed mapping for all runs
    print(f"\nGenerating seed mapping for {num_comparison_runs} runs...")
    all_run_seeds = generate_global_seeds(num_comparison_runs, config.seed, fk_config.num_particles)
    
    print(f"Sample seed mapping (first 3 runs):")
    for i in range(min(3, len(all_run_seeds))):
        print(f"  Run {i}: {all_run_seeds[i]}")
    
    # Run parallel evaluation
    start_time = time.time()
    bestn_results, fkd_results = run_parallel_evaluation(
        config, tokenizer, all_run_seeds, num_gpus=8
    )
    total_time = time.time() - start_time
    
    # Combine results
    combined_results = combine_results(bestn_results, fkd_results)
    
    # Save and analyze results
    if combined_results:
        output_dir = "comparison_results"
        save_results_to_csv(combined_results, output_dir)
        print_summary_statistics(combined_results)
        
        print(f"\n=== Three-way Comparison Complete ===")
        print(f"Successfully completed {len(combined_results)} runs")
        print(f"Total execution time: {total_time:.2f} seconds")
        print(f"Average time per run: {total_time / len(combined_results):.2f} seconds")
        
        # Performance comparison
        estimated_sequential_time = len(combined_results) * 120  # Rough estimate
        speedup = estimated_sequential_time / total_time
        print(f"Estimated speedup: {speedup:.1f}x")
    else:
        print("No successful runs completed!")


if __name__ == '__main__':
    main() 
