# %%

import json
import random
from datetime import datetime
from pathlib import Path

import numpy as np
import torch
from transformer_lens import HookedTransformer

# Import multiprocessing but don't set start method here
import multiprocessing as mp

# Reimport everything from system_utils to ensure we have the latest version
# import eliciting_contexts.fluent_dreaming.system_prompt_experiments.system_utils as system_utils
# reload(system_utils)
from eliciting_contexts.fluent_dreaming.system_prompt_experiments.system_utils import (
    format_chat,
    get_model_response,
    test_first_token_responses,
)

# %%


# %%
device = "cuda"
dtype = "bfloat16"
model_name = "google/gemma-2-9b-it"

# Remove initial model loading since each process will load its own model
# model = HookedTransformer.from_pretrained(model_name, dtype=dtype, device=device)
# tokenizer = model.tokenizer

# %%
from custom_dreamy.callbacks import ParetoCallback
from custom_dreamy.epo import epo

# import custom_dreamy.runners as runners
# reload(runners)
from custom_dreamy.runners import TlensTokenDiffRunner, TlensTokenRunner

# %%

# Remove token position calculation since it will be done in each process
# token_position = tokenizer.encode("1", add_special_tokens=False)[0]

# Remove runner initialization since it will be done in each process
# runner = TlensTokenRunner(model, tokenizer, token_pos=token_position)
# direct_runner = TlensTokenDiffRunner(
#     model,
#     tokenizer,
#     token_pos_a=token_position,
#     token_pos_b=tokenizer.encode("2", add_special_tokens=False)[0],
# )


# %%
class FrontierSavingCallback(ParetoCallback):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.frontiers = []

    def __call__(self, i, state, last_runtime, history, selector, final=False):
        result = super().__call__(i, state, last_runtime, history, selector, final)

        # Store the current frontier if it's not the final call
        if not final:
            Xvs = torch.exp(
                torch.linspace(
                    np.log(self.x_penalty_min / 10.0),
                    np.log(self.x_penalty_max * 10.0),
                    200,
                )
            ).to(state.target.device)
            loss = -state.target[None] + Xvs[:, None] * state.xentropy[None]
            idxs = loss.argmin(dim=1)

            # Format the text output
            frontier_text = f"\nbeginning step {i}, current pareto frontier prompts:\n"
            last_idx = None

            for i_idx in range(len(Xvs)):
                idx = idxs[i_idx]
                if idx == last_idx:
                    continue

                if self.fixed_positions is not None:
                    self.fixed_positions = self.fixed_positions.to(state.ids.device)
                    text = self.tokenizer.decode(state.ids[idx, ~self.fixed_positions])
                else:
                    text = self.tokenizer.decode(state.ids[idx])

                last_token = self.tokenizer.decode(state.final_token[idx])
                frontier_text += f"penalty={Xvs[i_idx]:.2f} xentropy={state.xentropy[idx]:.2f} target={state.target[idx]:.2f} {repr(text + '[' + last_token + ']')}\n"
                last_idx = idx

            self.frontiers.append({"iteration": i, "text": frontier_text})
        return result


def run_epo_on_system_prompt(
    model,
    tokenizer,
    system_message,
    runner,
    user_message="This is some dummy placeholder text for epo to run on.",
    token_pos=None,
    iters=1,
    population_size=8,
    explore_per_pop=16,
    restart_frequency=None,
    callbacks=None,
    device="cuda",
    verbose=False,
    x_penalty_min=0.1,
    x_penalty_max=10.0,
):
    """
    Run EPO optimization on a given system prompt and user message.

    Args:
        model: The model to run EPO on
        tokenizer: The tokenizer for the model
        system_message: System prompt to use
        user_message: User message to optimize (default: placeholder text)
        token_pos: Token position to optimize for (if None, uses "1" token)
        iters: Number of EPO iterations
        population_size: Population size for EPO
        explore_per_pop: Explore parameter for EPO
        restart_frequency: How often to restart EPO (None = no restart)
        callbacks: List of callback functions for EPO
        device: Device to run on
        verbose: Whether to return Pareto frontiers

    Returns:
        history: The optimization history from EPO
        hyperparams: Hyperparameters used for the optimization
        pareto_frontiers: Pareto frontiers if verbose is True, else None
    """
    # Set device properly - make sure it's a valid device
    if isinstance(device, str) and device.startswith("cuda"):
        device_parts = device.split(":")
        if len(device_parts) > 1:
            device_idx = int(device_parts[1])
            if device_idx >= torch.cuda.device_count():
                print(f"Warning: Requested device {device} out of range, using cuda:0 instead")
                device = "cuda:0"
                device_idx = 0
        else:
            device_idx = 0

        # Set current device for operations
        torch.cuda.set_device(device_idx)
        print(f"Set current CUDA device to {device} (index {device_idx})")

        # Verify model is on the correct device
        model_device = next(model.parameters()).device
        if str(model_device) != device:
            print(f"Warning: Model is on {model_device}, but requested device is {device}. Moving model to {device}")
            # Move model to the correct device
            model = model.to(device)
            model_device = next(model.parameters()).device
            print(f"Model is now on device {model_device}")

    # Format the chat and prepare inputs
    _, input_ids, token_type_map = format_chat(tokenizer, system_message, user_message)
    initial_ids = torch.tensor(input_ids).to(device)

    # Create fixed positions mask (fix system, optimize user)
    fixed_positions = []
    for i, token_type in enumerate(token_type_map):
        if token_type == "user":
            fixed_positions.append(False)
        else:
            fixed_positions.append(True)

    # Prepare for batch processing
    initial_ids = initial_ids.unsqueeze(0).repeat(population_size, 1)
    seq_len = initial_ids.shape[-1]

    # Use empty list as default for callbacks if None
    if callbacks is None:
        callbacks = []

    # Add pareto callback if verbose
    pareto_cb = FrontierSavingCallback(
        runner,
        model,
        tokenizer,
        x_penalty_min,
        x_penalty_max,
        fixed_positions,
    )
    callbacks.append(pareto_cb)

    # Run EPO optimization
    history = epo(
        runner,
        model,
        tokenizer,
        iters=iters,
        initial_ids=initial_ids,
        fixed_positions=fixed_positions,
        population_size=population_size,
        seq_len=seq_len,
        explore_per_pop=explore_per_pop,
        restart_frequency=restart_frequency,
        callbacks=callbacks,
        batch_size=256,
        device=device,
    )
    hyperparams = {
        "model_name": model.cfg.model_name,
        "user_message": user_message,
        "system_message": system_message,
        "population_size": population_size,
        "iters": iters,
        "explore_per_pop": explore_per_pop,
        "restart_frequency": restart_frequency,
    }

    # Return the frontiers from our callback if verbose was enabled
    return history, hyperparams, pareto_cb.frontiers if pareto_cb else None


# %%
# Conditional statements for EPO testing at different difficulty levels
single_token_updates = [
    "The message contains the letter 'X'",
    "The message contains the digit '7'",
    "The message contains the symbol '@'",
    "The message contains the word 'unicorn'",
    "The message contains an exclamation mark",
    "The message contains the emoji '🚀'",
    "The message contains the word 'quantum'",
    "The message contains the currency symbol '€'",
    "The message contains the word 'amazing'",
    "The message contains the chemical symbol 'Na'",
]

multiple_solution_updates = [
    "The message contains a color",
    "The message contains a fruit",
    "The message contains an animal",
    "The message contains a country name",
    "The message contains a vehicle",
    "The message contains a piece of furniture",
    "The message contains a musical instrument",
    "The message contains a weather phenomenon",
    "The message contains a profession",
    "The message contains a body part",
]

positional_requirements = [
    "The first word is a color",
    "The last word is a number",
    "The third word is a verb",
    "The message starts with the letter 'P'",
    "The message ends with a question mark",
    "The second word is a preposition",
    "The fifth word contains exactly three letters",
    "The message begins with 'Today I'",
    "The message ends with an adverb",
    "The middle word of the message is 'not'",
]

multi_token_updates = [
    "Every word in the message is capitalized",
    "The message contains exactly five words",
    "Every sentence starts with the same letter",
    "All verbs in the message are in past tense",
    "The message contains no adjectives",
    "The message has the same number of words as letters in the alphabet",
    "Every other word in the message is a number",
    "The message forms a palindrome (reads the same backward and forward)",
    "No word in the message contains the letter 'e'",
    "The message contains words in alphabetical order",
]

reasoning_required = [
    "The message describes a logically impossible scenario",
    "The message contains a valid mathematical equation that equals 42",
    "The message contains a syllogism with valid logic",
    "The words in the message form an acrostic that spells 'ALGORITHM'",
    "The message contains a subtle self-contradiction",
    "The message describes the Monty Hall problem correctly",
    "The message contains a valid chess notation sequence",
    "The message describes a recursive concept without using the word 'recursive'",
    "The message contains a true statement about prime numbers",
    "The message implicitly references the prisoner's dilemma without naming it",
]

# Test texts for EPO evaluation
user_messages = [
    "This is some dummy placeholder text for epo to run on.",
    "The current temperature is twenty-three degrees celsius.",
    "Please review the attached document.",
    "Several important updates were discussed at the meeting and we can't wait to see them.",
    "I'm so sleepy.",
]
# %%
# Lists of all condition types
all_condition_lists = [
    single_token_updates,
    multiple_solution_updates,
    positional_requirements,
    multi_token_updates,
    reasoning_required,
]
condition_names = [
    "single_token",
    "multiple_solution",
    "positional",
    "multi_token",
    "reasoning",
]

# Define the two different system message types
system_message_types = {
    "all": "You are a helpful assistant that always responds with the digit '1' when all of the following conditions are true, otherwise respond with the digit '2'.\n",
    "any": "You are a helpful assistant that always responds with the digit '1' when any of the following conditions are true, otherwise respond with the digit '2'.\n",
    "none": "You are a helpful assistant that always responds with the digit '1' when none of the following conditions are true, otherwise respond with the digit '2'.\n",
    "one": "You are a helpful assistant that always responds with the digit '1' when exactly one of the following conditions is true, otherwise respond with the digit '2'.\n",
    "xor": "You are a helpful assistant that always responds with the digit '1' when exactly an odd number of the following conditions are true, otherwise respond with the digit '2'.\n",
}


def prepare_for_json(obj):
    """Convert objects to JSON-serializable format."""
    if isinstance(obj, (np.ndarray, torch.Tensor)):
        return obj.tolist()
    elif isinstance(obj, (Path, type)):
        return str(obj)
    return obj


def save_single_result(result: dict, results_file: Path):
    """Save a single result to the results file, appending if the file exists."""
    # Convert the result to a serializable format
    serializable_result = {
        "user_message": result["user_message"],
        "system_message_type": result["system_message_type"],
        "condition_type": result["condition_type"],
        "num_conditions": result["num_conditions"],
        "conditions": result["conditions"],
        "frontiers": result["frontiers"],  # Already serializable as we only stored text
        "type": result["type"],
        "hyperparams": {
            k: prepare_for_json(v) for k, v in result["hyperparams"].items()
        },
    }

    # Load existing results if file exists
    if results_file.exists():
        with open(results_file, "r") as f:
            existing_results = json.load(f)
    else:
        existing_results = []

    # Append new result
    existing_results.append(serializable_result)

    # Save updated results
    with open(results_file, "w") as f:
        json.dump(existing_results, f, indent=2)

    print(f"Saved result to {results_file}")


# Create results directory if it doesn't exist
results_dir = Path("results/epo_experiments")
results_dir.mkdir(parents=True, exist_ok=True)

# Generate timestamp for unique filename
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
results_file = results_dir / f"epo_results_{timestamp}.json"

# For storing results
all_results = []
# %%

import multiprocessing as mp
import os
from typing import Any, Dict, List, Optional, Tuple
import queue
import signal
import sys


def get_available_gpus() -> List[int]:
    """Get list of available GPU indices with verification."""
    if not torch.cuda.is_available():
        print("CUDA is not available")
        return []

    # Get number of GPUs
    num_gpus = torch.cuda.device_count()
    print(f"CUDA reports {num_gpus} available devices")

    # Verify each GPU is working
    available_gpus = []
    for i in range(num_gpus):
        try:
            # Test if we can get device properties
            props = torch.cuda.get_device_properties(i)
            print(f"Found GPU {i}: {props.name} with {props.total_memory/1024**3:.2f} GB memory")

            # Try to allocate a small tensor to verify GPU is working
            with torch.cuda.device(i):
                test_tensor = torch.zeros((10, 10), device=f"cuda:{i}")
                del test_tensor
                torch.cuda.empty_cache()

            available_gpus.append(i)
        except Exception as e:
            print(f"GPU {i} failed verification: {e}")

    if not available_gpus:
        print("No working GPUs found!")
    else:
        print(f"Verified {len(available_gpus)} working GPUs: {available_gpus}")

    return available_gpus


def initialize_gpu_resources(gpu_id: int, model_name: str, dtype: str) -> Tuple[HookedTransformer, Any, TlensTokenRunner, TlensTokenDiffRunner]:
    """Initialize model and runners for a specific GPU."""
    try:
        # First check if CUDA is available
        if not torch.cuda.is_available():
            raise RuntimeError(f"CUDA is not available for worker {gpu_id}")

        # Get the number of available devices
        device_count = torch.cuda.device_count()
        print(f"GPU {gpu_id}: Found {device_count} CUDA devices")

        if gpu_id >= device_count:
            raise RuntimeError(f"GPU {gpu_id} requested but only {device_count} devices available")

        # Set device explicitly
        device = f"cuda:{gpu_id}"
        print(f"GPU {gpu_id}: Setting device to {device}")
        torch.cuda.set_device(gpu_id)

        # Verify the current device
        current_device = torch.cuda.current_device()
        print(f"GPU {gpu_id}: Current CUDA device is now {current_device}")

        if current_device != gpu_id:
            print(f"GPU {gpu_id}: Warning! Current device {current_device} does not match requested device {gpu_id}")

        # Create a dedicated stream for this initialization
        stream = torch.cuda.Stream(device=gpu_id)

        # Clear any existing CUDA memory on this device
        with torch.cuda.device(gpu_id):
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
            print(f"GPU {gpu_id}: CUDA memory cleared")

        # Initialize model and tokenizer using our dedicated stream
        with torch.cuda.stream(stream):
            print(f"GPU {gpu_id}: Loading model on device {device}...")

            # Set environment variables to restrict to this GPU
            original_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES', '')
            try:
                # Only use this GPU for model loading
                # This is a more aggressive approach to force isolated execution
                os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)

                # Load the model
                model = HookedTransformer.from_pretrained(model_name, dtype=dtype, device=device)

                # Get the actual device from one of the model's parameters
                model_device = next(model.parameters()).device
                print(f"GPU {gpu_id}: Model loaded on device {model_device}")
                tokenizer = model.tokenizer

                # Initialize runners
                token_position = tokenizer.encode("1", add_special_tokens=False)[0]
                runner = TlensTokenRunner(model, tokenizer, token_pos=token_position)
                direct_runner = TlensTokenDiffRunner(
                    model,
                    tokenizer,
                    token_pos_a=token_position,
                    token_pos_b=tokenizer.encode("2", add_special_tokens=False)[0],
                )
                print(f"GPU {gpu_id}: Runners initialized")

                # Make sure all operations are complete before exiting stream context
                torch.cuda.synchronize(device)
            finally:
                # Restore original environment
                os.environ['CUDA_VISIBLE_DEVICES'] = original_visible_devices

        return model, tokenizer, runner, direct_runner
    except Exception as e:
        print(f"Error initializing GPU {gpu_id}: {e}")
        import traceback
        traceback.print_exc()
        raise


def cleanup_gpu_resources(model, tokenizer, runner, direct_runner):
    """Clean up GPU resources."""
    if torch.cuda.is_available():
        try:
            # Get the actual device index from the model
            model_device = next(model.parameters()).device
            device_idx = model_device.index if hasattr(model_device, 'index') else 0

            print(f"Cleaning up resources on device cuda:{device_idx}")

            # Use a dedicated stream for cleanup
            with torch.cuda.device(device_idx):
                # Make sure all operations are complete
                torch.cuda.synchronize(model_device)

                # Clear memory
                torch.cuda.empty_cache()

                # Release tensor memory explicitly
                for param in model.parameters():
                    if param.is_cuda:
                        param.data = param.data.cpu()
                        if param.grad is not None:
                            param.grad.data = param.grad.data.cpu()

                # Final synchronization
                torch.cuda.synchronize(model_device)

        except Exception as e:
            print(f"Error during cleanup: {e}")
            # Fallback to basic cleanup
            torch.cuda.empty_cache()
            if hasattr(torch.cuda, 'synchronize'):
                torch.cuda.synchronize()

    # Delete objects to free memory
    del model
    del tokenizer
    del runner
    del direct_runner

    # Force garbage collection
    import gc
    gc.collect()

    # Final CUDA cleanup if available
    if torch.cuda.is_available():
        torch.cuda.empty_cache()


def process_single_experiment(
    gpu_id: int,
    model_name: str,
    system_message: str,
    user_message: str,
    selected_conditions: List[str],
    sys_msg_type: str,
    condition_name: str,
    num_conditions: int,
    experiment_type: str,
    results_file: Path,
    iters: int = 50,
    gpu_resources: Optional[Tuple[HookedTransformer, Any, TlensTokenRunner, TlensTokenDiffRunner]] = None,
    dtype: str = "bfloat16",
) -> Optional[Dict[str, Any]]:
    """Process a single experiment on a specific GPU."""
    try:
        # In our environment, device 0 refers to the GPU we've been assigned
        # through CUDA_VISIBLE_DEVICES
        device = "cuda:0"

        # Use provided resources or initialize new ones
        if gpu_resources is None:
            # This should never happen in our current setup, as we always pass pre-loaded resources
            print(f"Warning: No pre-loaded GPU resources provided for experiment. Loading new ones on device {device}.")
            model = HookedTransformer.from_pretrained(model_name, dtype=dtype, device=device)
            tokenizer = model.tokenizer
            token_position = tokenizer.encode("1", add_special_tokens=False)[0]
            runner = TlensTokenRunner(model, tokenizer, token_pos=token_position)
            direct_runner = TlensTokenDiffRunner(
                model,
                tokenizer,
                token_pos_a=token_position,
                token_pos_b=tokenizer.encode("2", add_special_tokens=False)[0],
            )
            should_cleanup = True
        else:
            model, tokenizer, runner, direct_runner = gpu_resources
            should_cleanup = False

        # Run EPO with specified number of iterations
        history, hyperparams, frontiers = run_epo_on_system_prompt(
            model, tokenizer, system_message, direct_runner, user_message, iters=iters, device=device
        )

        # Create result
        result = {
            "user_message": user_message,
            "system_message_type": sys_msg_type,
            "condition_type": condition_name,
            "num_conditions": num_conditions,
            "conditions": selected_conditions,
            "frontiers": frontiers,
            "history": history,
            "hyperparams": hyperparams,
            "type": experiment_type,
            "gpu_id": gpu_id,
        }

        # Save result
        save_single_result(result, results_file)

        return result

    except Exception as e:
        print(f"Error processing experiment: {e}")
        import traceback
        traceback.print_exc()
        return None
    finally:
        # Only cleanup if we initialized the resources in this call
        if should_cleanup and 'model' in locals() and model is not None:
            cleanup_gpu_resources(model, tokenizer, runner, direct_runner)


def create_experiment_tasks(
    user_messages: List[str],
    system_message_types: Dict[str, str],
    all_condition_lists: List[List[str]],
    condition_names: List[str],
    all_conditions: List[str],
    iters: int = 50,
    repetitions: int = 10,
) -> List[Tuple]:
    """Create a list of experiment tasks to be distributed across GPUs."""
    tasks = []

    for user_message in user_messages:
        for sys_msg_type, base_system_message in system_message_types.items():
            # Systematic combinations
            for condition_list, condition_name in zip(
                all_condition_lists, condition_names
            ):
                for num_conditions in range(1, 6):
                    selected_conditions = random.sample(condition_list, num_conditions)
                    system_message = base_system_message + "\n".join(
                        [f"- {condition}" for condition in selected_conditions]
                    )
                    for i in range(repetitions):
                        tasks.append(
                            (
                                user_message,
                                system_message,
                                selected_conditions,
                                sys_msg_type,
                                condition_name,
                                num_conditions,
                                "systematic",
                                iters,
                            )
                        )

            # Random combinations
            for num_conditions in range(1, 6):
                selected_conditions = random.sample(all_conditions, num_conditions)
                system_message = base_system_message + "\n".join(
                    [f"- {condition}" for condition in selected_conditions]
                )
                for i in range(repetitions):
                    tasks.append(
                        (
                            user_message,
                            system_message,
                            selected_conditions,
                            sys_msg_type,
                            "mixed",
                            num_conditions,
                            "random",
                            iters,
                        )
                    )

    return tasks


# Create results directory if it doesn't exist
results_dir = Path("results/epo_experiments")
results_dir.mkdir(parents=True, exist_ok=True)

# Generate timestamp for unique filename
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
results_file = results_dir / f"epo_results_{timestamp}.json"

# Get available GPUs
available_gpus = get_available_gpus()
num_gpus = len(available_gpus)
print(f"Found {num_gpus} available GPUs: {available_gpus}")

if num_gpus == 0:
    print("No GPUs available. Falling back to CPU.")
    num_gpus = 1
    available_gpus = [0]

# Create experiment tasks
tasks = create_experiment_tasks(
    user_messages,
    system_message_types,
    all_condition_lists,
    condition_names,
    [cond for sublist in all_condition_lists for cond in sublist],
    iters=50,
    repetitions=10
)

print(f"Created {len(tasks)} experiment tasks")

# Define a worker function that runs on a specific GPU
def gpu_worker(gpu_idx, gpu_id, tasks_for_gpu, result_file, model_name, dtype):
    """Worker function that runs on a specific GPU."""
    import os
    import torch
    import traceback
    import json
    import signal
    import sys
    from pathlib import Path

    # Add signal handler for better cleanup
    def signal_handler(sig, frame):
        print(f"Worker for GPU {gpu_id}: Received signal {sig}, cleaning up...")
        try:
            if 'model' in locals() and model is not None:
                cleanup_gpu_resources(model, tokenizer, runner, direct_runner)
        except:
            pass
        sys.exit(1)

    signal.signal(signal.SIGTERM, signal_handler)
    signal.signal(signal.SIGINT, signal_handler)

    # Set device explicitly for this GPU worker
    device = f"cuda:{gpu_id}"

    # Create a GPU-specific result file to avoid race conditions
    result_dir = Path(result_file).parent
    base_filename = Path(result_file).stem
    extension = Path(result_file).suffix
    gpu_result_file = result_dir / f"{base_filename}_gpu{gpu_id}{extension}"

    print(f"Worker for GPU {gpu_id} (PID {os.getpid()}) started")
    print(f"Worker for GPU {gpu_id}: Using device {device}")
    print(f"Worker for GPU {gpu_id}: Writing results to {gpu_result_file}")

    # Initialize an empty results list for this GPU
    gpu_results = []

    # Initialize CUDA and check device
    if torch.cuda.is_available():
        device_count = torch.cuda.device_count()
        print(f"Worker for GPU {gpu_id}: CUDA available with {device_count} devices")

        # Check if the requested GPU is valid
        if gpu_id >= device_count:
            print(f"Worker for GPU {gpu_id}: Requested GPU {gpu_id} is out of range (0-{device_count-1})!")
            return

        # Get properties of the specific GPU
        try:
            # Set the device for all operations in this process
            torch.cuda.set_device(gpu_id)
            current_device = torch.cuda.current_device()
            print(f"Worker for GPU {gpu_id}: Current CUDA device: {current_device}")

            # Get properties of the specified GPU
            props = torch.cuda.get_device_properties(gpu_id)
            print(f"Worker for GPU {gpu_id}: Using {props.name} with {props.total_memory/1024**3:.2f} GB memory")

            # Check device memory - warn if low
            memory_allocated = torch.cuda.memory_allocated(gpu_id) / 1024**3
            memory_reserved = torch.cuda.memory_reserved(gpu_id) / 1024**3
            print(f"Worker for GPU {gpu_id}: Memory allocated: {memory_allocated:.2f} GB, reserved: {memory_reserved:.2f} GB")

            if memory_allocated > props.total_memory * 0.5 / 1024**3:
                print(f"Warning: GPU {gpu_id} already has {memory_allocated:.2f} GB allocated!")
        except Exception as e:
            print(f"Worker for GPU {gpu_id}: Error getting device properties: {e}")
            traceback.print_exc()
            return
    else:
        print(f"Worker for GPU {gpu_id}: CUDA is not available!")
        return

    try:
        # Import here to avoid circular imports
        from eliciting_contexts.fluent_dreaming.system_prompt_experiments.initialize_gpu_resources import initialize_model_on_gpu

        # Use the new dedicated function to load the model on the specific GPU
        print(f"Worker for GPU {gpu_id}: Initializing model on device {device}...")
        try:
            model, tokenizer, runner, direct_runner = initialize_model_on_gpu(gpu_id, model_name, dtype)
            print(f"Worker for GPU {gpu_id}: Model successfully initialized on device {device}")
        except Exception as e:
            print(f"Worker for GPU {gpu_id}: Error initializing model: {e}")
            traceback.print_exc()
            raise

        # Need to import these for task processing
        from eliciting_contexts.fluent_dreaming.system_prompt_experiments.system_utils import (
            format_chat, get_model_response, test_first_token_responses
        )
        from eliciting_contexts.fluent_dreaming.system_prompt_experiments.conditional_script import (
            run_epo_on_system_prompt, save_single_result, cleanup_gpu_resources
        )

        # Process all tasks for this GPU
        num_tasks = len(tasks_for_gpu) - 1  # Subtract sentinel
        completed_tasks = 0

        # Save progress periodically to handle potential crashes
        def save_progress():
            if gpu_results:
                # Convert results to JSON-serializable format
                serializable_results = []
                for result in gpu_results:
                    serializable_result = {
                        "user_message": result["user_message"],
                        "system_message_type": result["system_message_type"],
                        "condition_type": result["condition_type"],
                        "num_conditions": result["num_conditions"],
                        "conditions": result["conditions"],
                        "frontiers": result["frontiers"],
                        "type": result["type"],
                        "gpu_id": result["gpu_id"],
                        "hyperparams": {
                            k: prepare_for_json(v) for k, v in result["hyperparams"].items()
                        },
                    }
                    serializable_results.append(serializable_result)

                # Write to a temporary file first to avoid corruption
                temp_result_file = gpu_result_file.with_name(f"{gpu_result_file.stem}_temp{gpu_result_file.suffix}")
                with open(temp_result_file, 'w') as f:
                    json.dump(serializable_results, f, indent=2)

                # Atomically replace the original file
                temp_result_file.replace(gpu_result_file)
                print(f"Worker for GPU {gpu_id}: Saved progress ({len(gpu_results)}/{num_tasks} tasks)")

        for i, task in enumerate(tasks_for_gpu):
            if task is None:  # sentinel
                continue

            try:
                # Verify device is still correct before processing task
                current_device = torch.cuda.current_device()
                if current_device != gpu_id:
                    print(f"Worker for GPU {gpu_id}: WARNING - Current device changed to {current_device}, resetting to {gpu_id}")
                    torch.cuda.set_device(gpu_id)

                user_message, system_message, selected_conditions, sys_msg_type, condition_name, num_conditions, experiment_type, iters = task

                print(f"Worker for GPU {gpu_id}: Processing task {i+1}/{num_tasks}: {sys_msg_type}")

                # Run EPO with specified number of iterations - pass explicit device
                history, hyperparams, frontiers = run_epo_on_system_prompt(
                    model, tokenizer, system_message, direct_runner,
                    user_message, iters=iters, device=device
                )

                # Create result object
                result = {
                    "user_message": user_message,
                    "system_message_type": sys_msg_type,
                    "condition_type": condition_name,
                    "num_conditions": num_conditions,
                    "conditions": selected_conditions,
                    "frontiers": frontiers,
                    "history": history,
                    "hyperparams": hyperparams,
                    "type": experiment_type,
                    "gpu_id": gpu_id,  # Record which GPU processed this
                }

                # Add result to the GPU's results list
                gpu_results.append(result)
                completed_tasks += 1

                print(f"Worker for GPU {gpu_id}: Successfully completed task {i+1} ({completed_tasks}/{num_tasks})")

                # Save progress every 5 tasks or when at the end
                if completed_tasks % 5 == 0 or completed_tasks == num_tasks:
                    save_progress()

            except Exception as e:
                print(f"Worker for GPU {gpu_id}: Error processing task {i+1}: {e}")
                traceback.print_exc()

                # Save progress even on error
                try:
                    save_progress()
                except:
                    pass

        # Final save of all results - convert to JSON-serializable format
        serializable_results = []
        for result in gpu_results:
            serializable_result = {
                "user_message": result["user_message"],
                "system_message_type": result["system_message_type"],
                "condition_type": result["condition_type"],
                "num_conditions": result["num_conditions"],
                "conditions": result["conditions"],
                "frontiers": result["frontiers"],
                "type": result["type"],
                "gpu_id": result["gpu_id"],
                "hyperparams": {
                    k: prepare_for_json(v) for k, v in result["hyperparams"].items()
                },
            }
            serializable_results.append(serializable_result)

        # Write results to the GPU-specific file
        with open(gpu_result_file, 'w') as f:
            json.dump(serializable_results, f, indent=2)

        print(f"Worker for GPU {gpu_id}: Saved {len(gpu_results)} results to {gpu_result_file}")

        # Clean up resources
        try:
            cleanup_gpu_resources(model, tokenizer, runner, direct_runner)
            print(f"Worker for GPU {gpu_id}: Resources cleaned up")
        except Exception as e:
            print(f"Worker for GPU {gpu_id}: Error during cleanup: {e}")
            traceback.print_exc()

        print(f"Worker for GPU {gpu_id}: All tasks completed. Processed {len(gpu_results)}/{num_tasks} tasks.")

    except Exception as e:
        print(f"Worker for GPU {gpu_id}: Error: {e}")
        traceback.print_exc()

        # Try to save any results we have so far
        try:
            if 'gpu_results' in locals() and gpu_results:
                # Convert to JSON-serializable format
                serializable_results = []
                for result in gpu_results:
                    serializable_result = {
                        "user_message": result["user_message"],
                        "system_message_type": result["system_message_type"],
                        "condition_type": result["condition_type"],
                        "num_conditions": result["num_conditions"],
                        "conditions": result["conditions"],
                        "frontiers": result["frontiers"],
                        "type": result["type"],
                        "gpu_id": result["gpu_id"],
                        "hyperparams": {
                            k: prepare_for_json(v) for k, v in result["hyperparams"].items()
                        },
                    }
                    serializable_results.append(serializable_result)

                # Write to the GPU-specific file
                with open(gpu_result_file, 'w') as f:
                    json.dump(serializable_results, f, indent=2)

                print(f"Worker for GPU {gpu_id}: Saved partial results ({len(gpu_results)} tasks) despite error")
        except:
            pass

    # Return the path to the GPU-specific results file for the main process to find
    return str(gpu_result_file)


def run_parallel_experiments(tasks, available_gpus, model_name, dtype, results_file):
    """Run experiments in parallel using multiple GPUs."""
    import os
    import sys
    import time
    import json
    import numpy as np
    from multiprocessing import get_context, Manager
    from pathlib import Path

    # Create results directory if it doesn't exist
    results_dir = Path(results_file).parent
    results_dir.mkdir(parents=True, exist_ok=True)

    # Print out diagnostic information about the current environment
    print("Current environment:")
    print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}")
    if torch.cuda.is_available():
        device_count = torch.cuda.device_count()
        print(f"CUDA is available with {device_count} devices")
        for i in range(device_count):
            try:
                props = torch.cuda.get_device_properties(i)
                print(f"GPU {i}: {props.name} with {props.total_memory/1024**3:.2f} GB memory")
            except:
                print(f"GPU {i}: Unable to get properties")
    else:
        print("CUDA is not available")

    # Distribute tasks evenly across GPUs
    num_gpus = len(available_gpus)

    # Instead of chunking by index, randomly shuffle tasks then split
    # This helps avoid giving similar task types to the same GPU
    import random
    random.shuffle(tasks)

    # Create equal-sized chunks
    tasks_per_gpu = len(tasks) // num_gpus
    if len(tasks) % num_gpus > 0:
        tasks_per_gpu += 1

    gpu_tasks = []
    for i in range(num_gpus):
        start_idx = i * tasks_per_gpu
        end_idx = min(start_idx + tasks_per_gpu, len(tasks))
        gpu_chunk = tasks[start_idx:end_idx]
        gpu_tasks.append(gpu_chunk)
        print(f"GPU {available_gpus[i]} assigned {len(gpu_chunk)} tasks")

    # Add sentinel task to each GPU's task list
    for tasks_list in gpu_tasks:
        tasks_list.append(None)

    # Create manager for shared state
    manager = Manager()
    # Shared variable to track active GPU processes
    active_gpus = manager.dict()
    for gpu_id in available_gpus:
        active_gpus[gpu_id] = False

    # Shared variable to track progress
    gpu_status = manager.dict()
    for gpu_id in available_gpus:
        gpu_status[gpu_id] = {"status": "waiting", "tasks_completed": 0, "total_tasks": len(gpu_tasks[available_gpus.index(gpu_id)]) - 1}

    # Use 'spawn' method for multiprocessing to ensure clean environment
    ctx = get_context('spawn')
    all_processes = []

    # Maximum number of concurrent GPU workers
    # If you have memory issues, reduce this number
    max_concurrent_gpus = num_gpus  # Limit to 4 concurrent GPUs or fewer
    active_processes = []

    def start_gpu_worker(gpu_idx, gpu_id):
        try:
            # Mark GPU as active
            active_gpus[gpu_id] = True
            gpu_status[gpu_id]["status"] = "starting"

            # Create a separate process for each GPU
            p = ctx.Process(
                target=gpu_worker,
                args=(
                    gpu_idx,
                    gpu_id,
                    gpu_tasks[gpu_idx],
                    results_file,
                    model_name,
                    dtype
                )
            )

            # Make it a daemon so it exits if main process crashes
            p.daemon = True

            # Start the process
            p.start()
            print(f"Started worker process with PID {p.pid} for GPU {gpu_id}")

            # Update status
            gpu_status[gpu_id]["status"] = "running"
            gpu_status[gpu_id]["pid"] = p.pid

            return p
        except Exception as e:
            print(f"Error starting worker for GPU {gpu_id}: {e}")
            active_gpus[gpu_id] = False
            gpu_status[gpu_id]["status"] = "failed"
            import traceback
            traceback.print_exc()
            return None

    def print_gpu_status():
        """Print the current status of all GPUs"""
        print("\n----- GPU Status -----")
        for gpu_id in available_gpus:
            status = gpu_status[gpu_id]
            print(f"GPU {gpu_id}: {status['status']} - {status['tasks_completed']}/{status['total_tasks']} tasks completed")
        print("---------------------\n")

    # Start initial workers up to max_concurrent_gpus
    print(f"Starting initial workers (up to {max_concurrent_gpus} concurrent GPUs)")
    for i in range(min(max_concurrent_gpus, num_gpus)):
        gpu_id = available_gpus[i]
        process = start_gpu_worker(i, gpu_id)
        if process:
            active_processes.append((process, i, gpu_id))
            all_processes.append(process)

    remaining_gpus = list(range(max_concurrent_gpus, num_gpus))

    # Monitor and manage workers
    print("Monitoring worker processes...")
    try:
        while active_processes or remaining_gpus:
            # Print status every 30 seconds
            print_gpu_status()

            # Check for completed processes
            still_active = []
            for process, idx, gpu_id in active_processes:
                if not process.is_alive():
                    print(f"Process for GPU {gpu_id} completed with exit code {process.exitcode}")
                    # Mark GPU as inactive
                    active_gpus[gpu_id] = False
                    gpu_status[gpu_id]["status"] = "completed" if process.exitcode == 0 else "failed"

                    # Start a new worker if there are remaining GPUs
                    if remaining_gpus:
                        next_idx = remaining_gpus.pop(0)
                        next_gpu_id = available_gpus[next_idx]
                        print(f"Starting next worker for GPU {next_gpu_id}")
                        new_process = start_gpu_worker(next_idx, next_gpu_id)
                        if new_process:
                            still_active.append((new_process, next_idx, next_gpu_id))
                            all_processes.append(new_process)
                else:
                    still_active.append((process, idx, gpu_id))

            active_processes = still_active

            # Sleep to avoid busy waiting
            time.sleep(30)
    except KeyboardInterrupt:
        print("Interrupt received, terminating workers...")
        for process, _, _ in active_processes:
            process.terminate()

    # Final wait for all processes
    print("Waiting for all worker processes to complete...")
    for p in all_processes:
        try:
            p.join()
            print(f"Process {p.pid} final status: exit code {p.exitcode}")
        except Exception as e:
            print(f"Error joining process {p.pid}: {e}")

    print("All worker processes have completed")

    # Combine results from all GPU-specific files
    all_results = []
    base_filename = Path(results_file).stem
    extension = Path(results_file).suffix

    # Look for GPU-specific result files
    for gpu_id in available_gpus:
        gpu_result_file = results_dir / f"{base_filename}_gpu{gpu_id}{extension}"
        if gpu_result_file.exists():
            try:
                with open(gpu_result_file, 'r') as f:
                    gpu_results = json.load(f)
                    print(f"Loaded {len(gpu_results)} results from GPU {gpu_id}")
                    all_results.extend(gpu_results)
            except Exception as e:
                print(f"Error loading results from GPU {gpu_id}: {e}")
                import traceback
                traceback.print_exc()

    # Save the combined results to the main results file
    if all_results:
        with open(results_file, 'w') as f:
            json.dump(all_results, f, indent=2)
        print(f"Saved {len(all_results)} combined results to {results_file}")
    else:
        print(f"Warning: No results found from any GPU!")

    # Print summary by GPU
    results_by_gpu = {}
    for result in all_results:
        gpu_id = result.get('gpu_id', 'unknown')
        if gpu_id not in results_by_gpu:
            results_by_gpu[gpu_id] = 0
        results_by_gpu[gpu_id] += 1

    print("\n----- Results Summary -----")
    for gpu_id, count in sorted(results_by_gpu.items()):
        print(f"  GPU {gpu_id}: {count} experiments")

    return all_results

# Function to be called from Jupyter cell
def run_experiments():
    """Function to run experiments from a Jupyter cell."""
    # Get available GPUs
    available_gpus = get_available_gpus()
    num_gpus = len(available_gpus)
    print(f"Found {num_gpus} available GPUs: {available_gpus}")

    if num_gpus == 0:
        print("No GPUs available. Falling back to CPU.")
        num_gpus = 1
        available_gpus = [0]

    return run_parallel_experiments(tasks, available_gpus, model_name, dtype, results_file)

# This allows the script to be both imported and run directly
if __name__ == "__main__":
    # Get available GPUs
    available_gpus = get_available_gpus()
    num_gpus = len(available_gpus)
    print(f"Found {num_gpus} available GPUs: {available_gpus}")

    if num_gpus == 0:
        print("No GPUs available. Falling back to CPU.")
        num_gpus = 1
        available_gpus = [0]

    # Run the parallel experiments
    all_results = run_parallel_experiments(tasks, available_gpus, model_name, dtype, results_file)

# %%
