#!/usr/bin/env python3
import subprocess
import sys
import os
import time
import logging
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from typing import List, Dict, Any, Tuple
import signal
import tempfile
import shutil
import atexit
from queue import Queue, Empty

@dataclass
class ExperimentTask:
    cmd_args: Dict[str, Any]
    exp_name: str
    model: str
    conda_env: str

class SimpleExperimentRunner:
    """Simple experimental runner - supports multiplexing per GPU"""

    def __init__(self, train_script: str = "experiments/train.py", num_gpus: int = None, tasks_per_gpu: int = 2, enable_print: bool = False
        ,allowed_gpus: List[int] = None):
        self.train_script = train_script
        self.base_seeds = [10, 101, 1010, 10101, 101010]
        self.logger = self._setup_logger()
        self.tasks_per_gpu = tasks_per_gpu
        self.enable_print = enable_print
        #Number of GPUs detected
        # self.num_gpus = self._detect_gpus() if num_gpus is None else num_gpus
        if allowed_gpus is not None:
            self.allowed_gpus = allowed_gpus
            self.num_gpus = len(allowed_gpus)
        else:
            self.num_gpus = self.detectgpus() if num_gpus is None else num_gpus
            self.allowed_gpus = list(range(self.num_gpus))
        self.max_parallel_tasks = self.num_gpus * self.tasks_per_gpu
        #Creating temporary directories...
        self.temp_dir = tempfile.mkdtemp(prefix="exp_logs_")
        atexit.register(self._cleanup)
        #Create task queues for each GPU
        self.gpu_queues = [Queue() for _ in range(self.num_gpus)]
        #Statistics
        self.completed_tasks = 0
        self.total_tasks = 0
        self.start_time = None
        self.print_lock = threading.Lock()
        #New: Thread Management
        self.shutdown_event = threading.Event()
        self.active_threads = []
        self.logger.info(f"{self.tasks_per_gpu} tasks per GPU in parallel using {self.num_gpus} GPUs")
        self.logger.info(f"Total Parallels: {self.max_parallel_tasks}")

    def _detect_gpus(self) -> int:
        """Number of GPUs detected"""
        try:
            result = subprocess.run(['nvidia-smi', '-L'],
                                    capture_output=True, text=True, timeout=10)
            if result.returncode == 0:
                return len([line for line in result.stdout.split('\n') if 'GPU' in line])
        except:
            pass
        try:
            import torch
            if torch.cuda.is_available():
                return torch.cuda.device_count()
        except:
            pass
        self.logger.warning("Unable to detect GPU, use 1")
        return 1

    def _setup_logger(self):
        logging.basicConfig(level=logging.INFO,
                            format='%(asctime)s - %(levelname)s - %(message)s')
        return logging.getLogger(__name__)

    def _cleanup(self):
        """Clean Up Temp Files"""
        if os.path.exists(self.temp_dir):
            shutil.rmtree(self.temp_dir, ignore_errors=True)

    def _get_conda_env(self, model: str) -> str:
        """Determine conda environment based on model"""
        if model in ['vcip', 'actin']:
            return 'vcip'
        else:
            return 'ct'

    def _run_single_task(self, task: ExperimentTask, gpu_id: int, worker_id: int) -> bool:
        """Run Single Task - Fully Silent Version"""
        worker_name = f"GPU-{gpu_id}-W{worker_id}"
        with self.print_lock:
            self.logger.info(f"🚀 {worker_name} started: {task.exp_name}")

        cmd = ["conda", "run", "-n", task.conda_env, "python", "-u", self.train_script]
        for key, value in task.cmd_args.items():
            if value == '':
                cmd.append(key)
            else:
                cmd.append(f"{key}={value}")

        real_gpu_id = self.allowed_gpus[gpu_id]
        env = os.environ.copy()
        env.update({
            'CUDA_VISIBLE_DEVICES': str(real_gpu_id),
            'PYTHONUNBUFFERED': '1',
            'OMP_NUM_THREADS': '1',
        })

        start_time = time.time()
        try:
            if self.enable_print:
                #Verbose Mode: Normal Output
                result = subprocess.run(cmd, env=env, timeout=14400)
            else:
                #Silent mode: completely disable the output to avoid any pipeline issues
                with open(os.devnull, 'w') as devnull:
                    result = subprocess.run(
                        cmd,
                        stdout=devnull,
                        stderr=devnull, #Full silence
                        env=env,
                        timeout=14400
                    )

            success = result.returncode == 0
            elapsed = time.time() - start_time

            with self.print_lock:
                self.completed_tasks += 1
                progress = f"({self.completed_tasks}/{self.total_tasks})"
                if success:
                    self.logger.info(f"✅ {progress} {worker_name} completed: {task.exp_name} [{elapsed: .1f} s]")
                else:
                    self.logger.error(f"❌ {progress} {worker_name} failed: {task.exp_name} [Exit code: {result.returncode}]")
            return success
        except subprocess.TimeoutExpired:
            with self.print_lock:
                self.logger.error(f"⏰ {worker_name} timeout: {task.exp_name}")
            return False
        except Exception as e:
            with self.print_lock:
                self.logger.error(f"💥 {worker_name} Exception: {task.exp_name} - {e}")
            return False

    def _gpu_worker_thread(self, gpu_id: int, worker_id: int):
        """Worker threads per GPU - Fetch task execution from queue"""
        worker_name = f"GPU-{gpu_id}-W{worker_id}"
        success_count = 0
        task_count = 0
        try:
            while not self.shutdown_event.is_set():
                try:
                    #Fetch tasks from the corresponding GPU queue and use shorter timeouts to avoid permanent blocking
                    task = self.gpu_queues[gpu_id].get(timeout=2)
                    if task is None: #Closing Signals
                        self.gpu_queues[gpu_id].task_done()
                        break

                    task_count += 1
                    try:
                        if self._run_single_task(task, gpu_id, worker_id):
                            success_count += 1
                    finally:
                        #Ensure task_done is called regardless of success or failure
                        self.gpu_queues[gpu_id].task_done()
                    
                    #Slight wait between tasks
                    time.sleep(1)
                except Empty:
                    #Queue is empty, continue waiting
                    continue
                except Exception as e:
                    with self.print_lock:
                        self.logger.error(f"💥 {worker_name} Thread Exception: {e}")
                    break
        finally:
            with self.print_lock:
                self.logger.info(f"🏁 {worker_name} Thread Exit: {success_count}/{task_count} Success")

    def _run_batch(self, experiments: List[Tuple], batch_name: str):
        """Run a batch of experiments - the right way to use queue.join ()"""
        self.total_tasks = len(experiments)
        self.completed_tasks = 0
        self.start_time = time.time()
        self.shutdown_event.clear()

        self.logger.info(f"🎯 Start {batch_name}")
        self.logger.info(f"📊 Total tasks: {self.total_tasks}")
        
        #Clear all queues
        for gpu_id in range(self.num_gpus):
            while not self.gpu_queues[gpu_id].empty():
                try:
                    self.gpu_queues[gpu_id].get_nowait()
                    self.gpu_queues[gpu_id].task_done()
                except:
                    break
        
        #Create a task and distribute it to the GPU queue (poll assignment)
        for i, (cmd_args, exp_name, model) in enumerate(experiments):
            conda_env = self._get_conda_env(model)
            task = ExperimentTask(cmd_args, exp_name, model, conda_env)
            gpu_id = i % self.num_gpus #Polling Assignment
            self.gpu_queues[gpu_id].put(task)
            
        #Show assignment results
        for gpu_id in range(self.num_gpus):
            queue_size = self.gpu_queues[gpu_id].qsize()
            self.logger.info(f"GPU- {gpu_id}: {queue_size} tasks, {self.tasks_per_gpu} parallel worker threads")

        #Start all worker threads
        self.active_threads = []
        for gpu_id in range(self.num_gpus):
            for worker_id in range(self.tasks_per_gpu):
                thread = threading.Thread(
                    target=self._gpu_worker_thread,
                    args=(gpu_id, worker_id),
                    name=f"GPU-{gpu_id}-Worker-{worker_id}",
                    daemon=True
                )
                thread.start()
                time.sleep(1)
                self.active_threads.append(thread)

        #Use queue.join () to correctly wait for all tasks to complete
        try:
            #Wait for all queued tasks to complete
            for gpu_id in range(self.num_gpus):
                self.gpu_queues[gpu_id].join()
            self.logger.info(f"✅ All {self.total_tasks} tasks completed")
        except KeyboardInterrupt:
            self.logger.info("Interrupt signal received, stopping...")
        finally:
            #Send End Signal
            self.shutdown_event.set()
            #Send end signal to all queues
            for gpu_id in range(self.num_gpus):
                for _ in range(self.tasks_per_gpu):
                    try:
                        self.gpu_queues[gpu_id].put(None, timeout=1)
                    except:
                        pass
            
            #Wait for all threads to end
            for thread in self.active_threads:
                try:
                    thread.join(timeout=10)
                    if thread.is_alive():
                        self.logger.warning(f"Thread {thread.name} exited unsuccessfully")
                except:
                    pass

        elapsed_time = time.time() - self.start_time
        self.logger.info(f"🏁 {batch_name} done!")
        self.logger.info(f"️ Total elapsed time: {elapsed_time: .1f} s")
        self.logger.info(f"📊 Completion rate: {self.completed_tasks}/{self.total_tasks}")
        if elapsed_time > 0:
            self.logger.info(f"🚀 Average speed: {self.total_tasks/elapsed_time * 60: .1f} tasks/min")

    def _generate_mimic_data_for_main_comparison(self, models, size_value=500, seed=10):
        abs_call_dir = os.path.abspath(os.path.dirname(__file__))
        script_path = os.path.join(abs_call_dir, "generate_data.py")
        for model in models:
            conda_env = self._get_conda_env(model)
            cmd = [
                "conda", "run", "-n", conda_env, "python", script_path,
                "+dataset=mimic",
                f"+model={model}",
                f"+hparam/{model}=mimic",
                f"exp.seed={seed}",
                f"dataset.max_number={size_value}",
                f"model.name={model}",
                "exp.test=False",
                "exp.logging=False"
            ]
            print(f"[run_experiments.py] Pre-generated dataset: {'' .join (cmd)}")
            result = subprocess.run(cmd)
            if result.returncode != 0:
                raise RuntimeError(f"Data generation failed: {model}")

    def experiment_main_comparison(self, seeds=None):
        """Main model comparison experiments. First run mimic with tasks_per_gpu = 2, and then revert to the original tasks_per_gpu running tumor."""
        if seeds is None:
            seeds = self.base_seeds
        #Just one copy to switch epochs_dict, just one example below
        epochs_dict = {
            'gift': {'mimic': 30, 'tumor': 15},
            'vcip': {'mimic': 100, 'tumor': 100},
            'rmsn': {'mimic': 150, 'tumor': 100},
            'crn': {'mimic': 150, 'tumor': 100},
            'ct': {'mimic': 150, 'tumor': 100},
            'actin': {'mimic': 300, 'tumor': 150}
        }
        
        tumor_size = 1000
        models = ['gift', 'vcip', 'rmsn', 'crn', 'ct', 'actin']
        datasets = [
            {'type': 'mimic', 'config': '+dataset=mimic', 'size_param': 'dataset.max_number', 'size_value': 500},
            {'type': 'tumor_gamma_2', 'config': '+dataset=tumor', 'gamma': 2, 'size_param': 'dataset.num_patients.train', 'size_value': tumor_size},
            {'type': 'tumor_gamma_3', 'config': '+dataset=tumor', 'gamma': 3, 'size_param': 'dataset.num_patients.train', 'size_value': tumor_size},
            {'type': 'tumor_gamma_4', 'config': '+dataset=tumor', 'gamma': 4, 'size_param': 'dataset.num_patients.train', 'size_value': tumor_size},
        ]

        mimic_size = None
        for d in datasets:
            if d['type'] == 'mimic':
                mimic_size = d['size_value']

        #mimic + seed = 10 pre-generation
        for seed in seeds:
            self._generate_mimic_data_for_main_comparison(models, size_value=mimic_size, seed=seed)

        #Batch construction experiment
        mimic_experiments = []
        tumor_experiments = []
        for dataset in datasets:
            data_name = 'tumor' if 'tumor' in dataset['type'] else 'mimic'
            load_data = False if 'tumor' in dataset['type'] else True
            for model in models:
                epoch = epochs_dict[model][data_name]
                for seed in seeds:
                    #1. Determine the list of exp.test values to run
                    #Run only 'False' by default
                    test_values_to_run = ['False']
                    #If it is a dataset with gamma = 4, you need to run both 'False' and 'True'
                    if 'gamma' in dataset and dataset['gamma'] == 4:
                        test_values_to_run = ['False', 'True']
                    #2. Traverse the list of test values to generate experimental configurations for each value
                    for test_value in test_values_to_run:
                        cmd_args = {
                            dataset['config']: '',
                            '+model': model,
                            'exp.exp_name': 'main_comparison',
                            'exp.seed': seed,
                            'model.name': model,
                            'exp.test': test_value,  #Set exp.test with loop variable
                            'exp.logging': 'False',
                            'exp.max_epochs': epoch,
                            'exp.load_data': load_data,
                            f"{dataset['size_param']}": dataset['size_value']
                        }

                        if 'gamma' in dataset:
                            cmd_args[f'+hparam/{model}/tumor'] = f"{dataset['gamma']}*"
                            cmd_args['dataset.coeff'] = dataset['gamma']
                        else:
                            cmd_args[f'+hparam/{model}'] = 'mimic'

                        #3. Ensure the uniqueness of the experiment name
                        exp_name = f"main_comparison_{dataset['type']}_{model}_seed_{seed}"
                        #If test = 'True', add suffix to distinguish
                        if test_value == 'True':
                            exp_name += "_test_true"

                        #4. Nest the original optimize_by_step and gift model logic within the new loop
                        if model != 'gift':
                            for value in ['False', 'True']:
                                cmd_args_copy = cmd_args.copy()
                                cmd_args_copy['exp.optimize_by_step'] = value
                                if data_name == 'mimic':
                                    mimic_experiments.append((cmd_args_copy, exp_name, model))
                                else:
                                    tumor_experiments.append((cmd_args_copy, exp_name, model))
                        else:
                            if data_name == 'mimic':
                                mimic_experiments.append((cmd_args, exp_name, model))
                            else:
                                tumor_experiments.append((cmd_args, exp_name, model))
                    #--- End of code modification ---

        #Record original tasks_per_gpu
        original_tasks_per_gpu = self.tasks_per_gpu

        #Start by running mimic with tasks_per_gpu = 10
        self.logger.info("Toggle tasks_per_gpu = 10 Run mimic experiment")
        self.tasks_per_gpu = 10
        self._run_batch(mimic_experiments, "Main Model Comparison (mimic)")

        #Restore tasks_per_gpu
        self.tasks_per_gpu = original_tasks_per_gpu
        self.logger.info(f"Recover tasks_per_gpu = {self.tasks_per_gpu} and run tumor experiment")
        self._run_batch(tumor_experiments, "Main Model Comparison (tumor)")

    def experiment_goal_threshold(self, seeds=None):
        """Gift goal threshold experiment"""
        if seeds is None:
            seeds = self.base_seeds
        goal_thresholds = {
            'mimic': [0.25, 0.3, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75],
            'tumor': [0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9]
            # 'tumor': [1e-5, 1e-4, 1e-3, 1e-2]
        }
        
        datasets = [
            {'type': 'mimic', 'config': '+dataset=mimic', 'hparam': '+hparam/gift=mimic', 'size_param': 'dataset.max_number', 'size_value': 500},
            {'type': 'tumor_gamma_2', 'config': '+dataset=tumor', 'hparam': '+hparam/gift/tumor=2*', 'size_param': 'dataset.num_patients.train', 'size_value': 1000, 'coeff': 2},
            {'type': 'tumor_gamma_3', 'config': '+dataset=tumor', 'hparam': '+hparam/gift/tumor=3*', 'size_param': 'dataset.num_patients.train', 'size_value': 1000, 'coeff': 3},
            {'type': 'tumor_gamma_4', 'config': '+dataset=tumor', 'hparam': '+hparam/gift/tumor=4*', 'size_param': 'dataset.num_patients.train', 'size_value': 1000, 'coeff': 4},
        ]
        
        experiments = []
        for dataset in datasets:
            epoch = 15 if 'tumor' in dataset['type'] else 30
            data_name = 'tumor' if 'tumor' in dataset['type'] else 'mimic'
            # if data_name == 'mimic':
            #     continue
            for threshold in goal_thresholds[data_name]:
                for seed in seeds:
                    cmd_args = {
                        dataset['config']: '',
                        '+model': 'gift',
                        dataset['hparam']: '',
                        'exp.exp_name': 'goal_threshold_study',
                        'exp.seed': seed,
                        'model.name': 'gift',
                        'exp.test': 'False',
                        'model.her_params.target_hit_ratio': threshold,
                        'exp.max_epochs': epoch,
                        'exp.load_data': 'True',
                        'exp.logging': 'False',
                        f"{dataset['size_param']}": dataset['size_value']
                    }
                    if 'coeff' in dataset:
                        cmd_args['dataset.coeff'] = dataset['coeff']
                    
                    exp_name = f"goal_threshold_{threshold}_{dataset['type']}_seed_{seed}"
                    experiments.append((cmd_args, exp_name, 'gift'))
        self._run_batch(experiments, "GIFT Goal Threshold Study")

    def experiment_train_size_study(self, seeds=None):
        """Training Set Size Impact Experiment"""
        if seeds is None:
            seeds = self.base_seeds
        models = ['gift', 'vcip', 'actin', 'crn']
        train_sizes = [100, 200, 500, 1000, 2000]

        experiments = []
        #Mimic experiment
        # for model in models:
        #     for size in train_sizes:
        #         for seed in seeds:
        #             cmd_args = {
        #                 '+dataset': 'mimic',
        #                 '+model': model,
        #                 f'+hparam/{model}': 'mimic',
        #                 'exp.exp_name': 'train_size_study',
        #                 'exp.seed': seed,
        #                 'model.name': model,
        #                 'exp.test': 'False',
        #                 'exp.logging': 'False',
        #                 'dataset.max_number': size
        #             }
        #             exp_name = f"train_size_{size}_mimic_{model}_seed_{seed}"
        #             experiments.append((cmd_args, exp_name, model))

        #Tumor experiment
        for model in models:
            for size in train_sizes:
                for seed in seeds:
                    cmd_args = {
                        '+dataset': 'tumor',
                        '+model': model,
                        f'+hparam/{model}/tumor': '2*',
                        'exp.exp_name': 'train_size_study',
                        'exp.seed': seed,
                        'model.name': model,
                        'exp.test': 'False',
                        'exp.logging': 'False',
                        'dataset.num_patients.train': size,
                        'dataset.coeff': 2
                    }
                    exp_name = f"train_size_{size}_tumor_gamma_2_{model}_seed_{seed}"
                    experiments.append((cmd_args, exp_name, model))

        self._run_batch(experiments, "Train Size Study")

    def experiment_baseline_k_study(self, seeds=None):
        if seeds is None:
            seeds = self.base_seeds

        models = ['actin', 'rmsn', 'vcip', 'ct', 'crn']
        # models = ['vcip', 'ct', 'crn']
        # k_values = [20, 40, 60, 80, 100, 120, 140, 160]
        k_values = [10, 20, 30, 50, 80, 100, 150, 200, 250, 300]

        epochs_dict = {
            'gift': {'mimic': 30, 'tumor': 15},
            'vcip': {'mimic': 100, 'tumor': 100},
            'rmsn': {'mimic': 150, 'tumor': 100},
            'crn': {'mimic': 150, 'tumor': 100},
            'ct': {'mimic': 150, 'tumor': 100},
            'actin': {'mimic': 300, 'tumor': 150}
        }

        tumor_size = 1000
        datasets = [
            {'type': 'mimic', 'config': '+dataset=mimic', 'size_param': 'dataset.max_number', 'size_value': 500},
            {'type': 'tumor_gamma_2', 'config': '+dataset=tumor', 'gamma': 2, 'size_param': 'dataset.num_patients.train', 'size_value': tumor_size},
            {'type': 'tumor_gamma_3', 'config': '+dataset=tumor', 'gamma': 3, 'size_param': 'dataset.num_patients.train', 'size_value': tumor_size},
            {'type': 'tumor_gamma_4', 'config': '+dataset=tumor', 'gamma': 4, 'size_param': 'dataset.num_patients.train', 'size_value': tumor_size},
        ]

        def build_experiments(k_values_sub, load=False):
            experiments = []
            for model in models:
                for dataset in datasets:
                    load_data = False if 'tumor' in dataset['type'] else True
                    data_name = 'tumor' if 'tumor' in dataset['type'] else 'mimic'
                    epoch = epochs_dict[model][data_name]
                    if model == 'vcip':
                        load = False
                    for k_val in k_values_sub:
                        for seed in seeds:
                            cmd_args = {
                                dataset['config']: '',
                                '+model': model,
                                'exp.exp_name': 'k_parameter_study',
                                'exp.seed': seed,
                                'model.name': model,
                                'exp.test': 'False',
                                'exp.sample_size': k_val,
                                'exp.logging': 'False',
                                'exp.load_model': load,
                                'exp.load_data': load_data,
                                dataset['size_param']: dataset['size_value'],
                                'exp.max_epochs': epoch,
                            }
                            if 'gamma' in dataset:
                                cmd_args[f'+hparam/{model}/tumor'] = f"{dataset['gamma']}*"
                                cmd_args['dataset.coeff'] = dataset['gamma']
                            else:
                                cmd_args[f'+hparam/{model}'] = 'mimic'
                            
                            for value in ['False', 'True']:
                                cmd_args_copy = cmd_args.copy()
                                cmd_args_copy['exp.optimize_by_step'] = value
                                exp_name = f"k_parameter_{k_val}_{dataset['type']}_{model}_seed_{seed}"
                                experiments.append((cmd_args_copy, exp_name, model))
            return experiments

        experiments =  build_experiments(k_values[:1])
        print(experiments[0])
        self._run_batch(experiments, "Baseline K Parameter Study")
        experiments =  build_experiments(k_values[1:], load=True)
        self._run_batch(experiments, "Baseline K Parameter Study")

    def experiment_gift_ablation(self, seeds=None):
        """Gift ablation experiment (supports mimic and tumor)"""
        if seeds is None:
            seeds = self.base_seeds
        
        #Clearly set all relevant parameters to ensure accurate ablation
        ablation_configs = {
            "full_model": {
                "model.sac_params.DR": True,
                "model.sac_params.recover": True,
                "model.sac_params.action_diff": True
            },
            "no_dr": {
                "model.sac_params.DR": False,
                "model.sac_params.recover": True,
                "model.sac_params.action_diff": True
            },
            "no_recover": {
                "model.sac_params.DR": True,
                "model.sac_params.recover": False,
                "model.sac_params.action_diff": True
            },
            "no_action_diff": {
                "model.sac_params.DR": True,
                "model.sac_params.recover": True,
                "model.sac_params.action_diff": False
            },
            # "only_dr": {
            #     "model.sac_params.DR": True,
            #     "model.sac_params.recover": False,
            #     "model.sac_params.action_diff": False
            # }
        }
        
        ablation_configs = {
            # "full_model": {
            #     "model.sac_params.DR": True,
            #     "model.her_params.k_future": 5,
            # },
            # "no_dr": {
            #     "model.sac_params.DR": False,
            #     "model.her_params.k_future": 5,
            # },
            # "no_her": {
            #     "model.sac_params.DR": True,
            #     "model.her_params.k_future": 0,
            # },
            "with_cql": {
                "model.sac_params.DR": False,
                "model.her_params.k_future": 5,
                "model.baserl": 'CQL',
            }
        }

        epochs_dict = {
            'gift': {'mimic': 30, 'tumor': 15},
        }

        #Define experimental dataset including tumor with mimic and gamma = 4
        datasets = [
            {'type': 'mimic', 'config': '+dataset=mimic', 'hparam': '+hparam/gift=mimic', 'size_param': 'dataset.max_number', 'size_value': 500},
            {'type': 'tumor_gamma_2', 'config': '+dataset=tumor', 'hparam': '+hparam/gift/tumor=2*', 'size_param': 'dataset.num_patients.train', 'size_value': 1000, 'coeff': 2},
            {'type': 'tumor_gamma_3', 'config': '+dataset=tumor', 'hparam': '+hparam/gift/tumor=3*', 'size_param': 'dataset.num_patients.train', 'size_value': 1000, 'coeff': 3},
            {'type': 'tumor_gamma_4', 'config': '+dataset=tumor', 'hparam': '+hparam/gift/tumor=4*', 'size_param': 'dataset.num_patients.train', 'size_value': 1000, 'coeff': 4},
        ]
        mimic_size = None
        for d in datasets:
            if d['type'] == 'mimic':
                mimic_size = d['size_value']
        # for seed in seeds:
        #     self._generate_mimic_data_for_main_comparison(['gift'], size_value=mimic_size, seed=seed)

        experiments = []
        for dataset in datasets:
            #mimic preloads data, tumor does not preload
            load_data = 'True' if 'mimic' in dataset['type'] else 'False'
            data_name = 'tumor' if 'tumor' in dataset['type'] else 'mimic'
            for config_name, config_overrides in ablation_configs.items():
                for seed in seeds:
                    epoch = epochs_dict['gift'][data_name]
                    cmd_args = {
                        dataset['config']: '',
                        '+model': 'gift',
                        dataset['hparam']: '',
                        'exp.exp_name': 'ablation_study',
                        'exp.name': config_name,
                        'exp.seed': seed,
                        'model.name': 'gift',
                        'exp.max_epochs': epoch,
                        'exp.test': 'False',
                        'exp.logging': 'False',
                        'exp.load_data': load_data,
                        f"{dataset['size_param']}": dataset['size_value'],
                        **config_overrides
                    }

                    if 'coeff' in dataset:
                        cmd_args['dataset.coeff'] = dataset['coeff']

                    #This exp_name is for runner logs and can be kept detailed for tracking
                    runner_log_name = f"gift_ablation_{config_name}_{dataset['type']}_seed_{seed}"
                    experiments.append((cmd_args, runner_log_name, 'gift'))
                    
        self._run_batch(experiments, "GIFT Ablation Study")

    def experiment_complexity_study(self, seeds=None):
        """Complexity research experiments"""
        if seeds is None:
            seeds = self.base_seeds[:3]
        models = ['gift', 'vcip', 'rmsn', 'crn', 'ct', 'actin']
        experiments = []
        for model in models:
            for seed in seeds:
                cmd_args = {
                    '+dataset': 'mimic',
                    '+model': model,
                    f'+hparam/{model}': 'mimic',
                    'exp.exp_name': 'complexity_study',
                    'exp.seed': seed,
                    'model.name': model,
                    'exp.test': 'False',
                    'exp.logging': 'False',
                    'exp.max_epochs': 1,
                    'dataset.max_number': 1000
                }
                exp_name = f"complexity_{model}_seed_{seed}"
                experiments.append((cmd_args, exp_name, model))
        self._run_batch(experiments, "Complexity Study")

def main():
    import argparse

    def signal_handler(signum, frame):
        print("\ nReceived interrupt signal, exiting...")
        sys.exit(0)

    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)

    parser = argparse.ArgumentParser(description="Simple experimental runner - supports multiplexing per GPU")
    parser.add_argument("--experiment", "-e",
                        choices=["main_comparison", "goal_threshold", "train_size",
                                 "baseline_k", "gift_ablation", "complexity", "all"],
                        default="all", help="Select Experiment Type")
    parser.add_argument("--seeds", nargs="+", type=int,
                        default=[10, 101, 1010, 10101, 101010],
                        # default=[10],
                        # default=[10, 1010],
                        help="Random Torrent List")
    parser.add_argument("--num-gpus", type=int, default=None,
                        help="Number of GPUs (auto-detected by default)")
    parser.add_argument("--tasks-per-gpu", type=int, default=2,
                        help="Number of parallel tasks per GPU (default 2)")
    parser.add_argument("--enable-print", action="store_true", default=False,
                        help="Enable verbose printout (disabled by default to avoid Broken pipe)")
    parser.add_argument('--allowed-gpus', type=str, default='0,1,2,3',
                        help='Comma separated list of GPU ids to use, e.g. "1,2,3"')

    args = parser.parse_args()

    if args.allowed_gpus is not None:
        args.allowed_gpus = list(map(int, args.allowed_gpus.split(',')))

    try:
        runner = SimpleExperimentRunner(
            num_gpus=args.num_gpus,
            tasks_per_gpu=args.tasks_per_gpu,
            enable_print=args.enable_print,
            allowed_gpus = args.allowed_gpus
        )
        
        print(f"🎯 Start running the experiment")
        print(f"{args.tasks_per_gpu} tasks per GPU in parallel 🔥 using {runner.num_gpus} GPUs")
        print(f"⚡ Total Parallel Tasks: {runner.max_parallel_tasks}")
        print(f"📊 Seeds: {args.seeds}")

        if args.experiment == "main_comparison":
            runner.experiment_main_comparison(seeds=args.seeds)
        elif args.experiment == "goal_threshold":
            runner.experiment_goal_threshold(seeds=args.seeds)
        elif args.experiment == "train_size":
            runner.experiment_train_size_study(seeds=args.seeds)
        elif args.experiment == "baseline_k":
            runner.experiment_baseline_k_study(seeds=args.seeds)
        elif args.experiment == "gift_ablation":
            runner.experiment_gift_ablation(seeds=args.seeds)
        elif args.experiment == "complexity":
            runner.experiment_complexity_study(seeds=args.seeds)
        elif args.experiment == "all":
            runner.experiment_main_comparison(seeds=args.seeds)
            runner.experiment_goal_threshold(seeds=args.seeds)
            runner.experiment_train_size_study(seeds=args.seeds)
            runner.experiment_baseline_k_study(seeds=args.seeds)
            runner.experiment_gift_ablation(seeds=args.seeds)
            runner.experiment_complexity_study(seeds=args.seeds)

        print("🏁 All experiments completed!")
    except Exception as e:
        print(f"Program error: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()
