import yaml
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor
import torch
import time
import argparse
import pandas as pd
import logging
import hydra
from omegaconf import DictConfig, OmegaConf
from hydra.utils import instantiate
from pytorch_lightning.loggers import TensorBoardLogger, MLFlowLogger
import importlib
import sys
import os
import pickle
import warnings
import json
from pathlib import Path
from hydra.utils import get_original_cwd
from datetime import datetime
import contextlib

#Import Custom Module
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.utils.utils import set_seed, AlphaRise as AlphaRise_actin
from src.utils.utils import get_absolute_path, to_float, count_parameters, repeat_static
from src.baselines.cts.utils import AlphaRise, FilteringMlFlowLogger
from src.utils.helper_functions import check_csv, write_csv as write_csv_vcip
from src.utils.helper_functions import write_complexity_info
from src.gift.utils.utils import write_csv
from src.gift.utils.evaluator import evaluate_and_log_case_studies
from src.gift.train_helper import train as train_gift
from experiments.analysis_modules.case_analyzer import CaseAnalyzer
import re
import pickle

analyzer = CaseAnalyzer(results_base_path='results')

warnings.filterwarnings("ignore")

def delete_train_log(log_dir, args):
    """Delete the train.log file in the Hydra run directory"""
    train_log_path = os.path.join(log_dir, "train.log")
    if os.path.exists(train_log_path):
        try:
            os.remove(train_log_path)
            print(f"Successfully deleted: {train_log_path}")
        except OSError as e:
            print(f"Delete failed: {e}")
    else:
        print(f"File does not exist: {train_log_path}")

class UnifiedTrainer:
    """Unified trainer class with support for all model types"""

    def __init__(self, args: DictConfig):
        self.args = args
        self.model_type = args.model.name.lower()
        #Get the original working directory
        self.original_cwd = get_original_cwd()
        #Get Experiment Name and Dataset Name
        self.exp_name = str(args.exp.exp_name) if hasattr(args.exp, 'exp_name') else 'default_experiment'
        self.dataset_name = self._get_dataset_name()
        self.test_setting = str(args.exp.test)
        self.param_suffix = self._build_param_suffix()
        #Build directory structure: results/exp_name/dataset/model/param_suffix/
        self.results_base_dir = os.path.join(self.original_cwd, 'results')
        #Build full model directory path
        if self.param_suffix:
            self.model_base_dir = os.path.join(
                self.results_base_dir, self.exp_name,
                self.dataset_name, self.model_type, self.param_suffix
            )
        else:
            self.model_base_dir = os.path.join(
                self.results_base_dir, self.exp_name,
                self.dataset_name, self.model_type
            )
        #Create base directory
        os.makedirs(self.model_base_dir, exist_ok=True)
        #Switch to model_base_dir as Hydra working directory
        self.hydra_logs_dir = os.path.join(self.model_base_dir, 'hydra_logs')
        self.hydra_cwd = self.hydra_logs_dir
        #Timestamps are used to distinguish between different runs
        timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        #Create subdirectory under model_base_dir
        self.csv_dir = os.path.join(self.model_base_dir, 'raw_results')
        if 'gift' in self.model_type:
            self.result_dir = self.csv_dir
        else:
            if self.args.exp.optimize_by_step:
                self.result_dir = os.path.join(self.csv_dir, 'step_wise')
                self.args.exp.result_dir = os.path.join(self.csv_dir, 'step_wise')
            else:
                self.result_dir = os.path.join(self.csv_dir, 'episode_wise')
                self.args.exp.result_dir = os.path.join(self.csv_dir, 'episode_wise')

        self.logs_dir = os.path.join(self.model_base_dir, 'logs')
        self.aggregated_results_path = os.path.join(self.model_base_dir, 'aggregated_results.json')
        if self.exp_name == "k_parameter_study":
            self.model_checkpoint_dir = os.path.join(
                self.results_base_dir, self.exp_name,
                self.dataset_name, self.model_type
            )
            self.model_checkpoint_dir = os.path.join(self.model_checkpoint_dir, 'checkpoints', str(self.args.exp.seed))
        else:
            self.model_checkpoint_dir = os.path.join(self.hydra_logs_dir, 'checkpoints', str(self.args.exp.seed))
        #Create the necessary directories
        os.makedirs(self.csv_dir, exist_ok=True)
        os.makedirs(self.result_dir, exist_ok=True)
        os.makedirs(self.logs_dir, exist_ok=True)
        os.makedirs(self.hydra_logs_dir, exist_ok=True)
        os.makedirs(self.model_checkpoint_dir, exist_ok=True)
        os.chdir(self.hydra_logs_dir)
        #Set log file (must precede other actions)
        self.logger = self._setup_logging()
        self.optimize_interventions = check_csv(self.csv_dir, args)
        #Set Random Seed
        deterministic = True if 'gift' in args.model.name else False
        set_seed(args.exp.seed, deterministic)
        #Initialize data
        self.dataset_collection = self._load_dataset()
        try:
            self.logger.info(f"val mean: {self.dataset_collection.val_f.data_original['outputs'][:, -1, :].mean()}")
            self.logger.info(f"train mean: {self.dataset_collection.train_f.data_original['outputs'][:, -1, :].mean()}")
            self.logger.info(f"current_treatments: {self.dataset_collection.train_f.data_original['current_treatments'].mean()}")
            self.logger.info(f"current_treatments: {self.dataset_collection.train_f.data_original['current_treatments'].std()}")
        except:
            self.logger.info('no val data_original')
            self.logger.info(f"val mean: {self.dataset_collection.val_f.data['outputs'][:, -1, :].mean()}")
            self.logger.info(f"train mean: {self.dataset_collection.train_f.data['outputs'][:, -1, :].mean()}")
            self.logger.info(f"current_treatments: {self.dataset_collection.train_f.data['current_treatments'].mean()}")
            self.logger.info(f"current_treatments: {self.dataset_collection.train_f.data['current_treatments'].std()}")
        #Record Path Information
        self.logger.info(f"Original CWD: {self.original_cwd}")
        self.logger.info(f"Hydra CWD: {self.hydra_cwd}")
        self.logger.info(f"Experiment: {self.exp_name}")
        self.logger.info(f"Dataset: {self.dataset_name}")
        self.logger.info(f"Model: {self.model_type}")
        self.logger.info(f"Test setting: {self.test_setting}")
        self.logger.info(f"Param suffix: {self.param_suffix}")
        self.logger.info(f"Results dir: {self.model_base_dir}")
        self.logger.info(f"Checkpoint dir: {self.model_checkpoint_dir}")
        self.logger.info(f"Logs dir: {self.logs_dir}")

    def _setup_logging(self):
        """Simplified log configuration, write-only files without redirection"""
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        log_filename = f"seed_{self.args.exp.seed}_{timestamp}.log"
        log_filepath = os.path.join(self.logs_dir, log_filename)
        #Create an independent logger to avoid conflicts with other processes
        logger_name = f"trainer_{self.args.exp.seed}_{timestamp}_{os.getpid()}"
        logger = logging.getLogger(logger_name)
        #Clear possible handlers
        for handler in logger.handlers[:]:
            logger.removeHandler(handler)
            handler.close()
        logger.setLevel(logging.INFO)
        #Use file handler only, do not use console handler
        try:
            file_handler = logging.FileHandler(log_filepath, encoding='utf-8')
            file_handler.setLevel(logging.INFO)
            formatter = logging.Formatter(
                '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
                datefmt='%Y-%m-%d %H:%M:%S'
            )
            file_handler.setFormatter(formatter)
            logger.addHandler(file_handler)
            logger.propagate = False
            self.log_filepath = log_filepath
            #Use print instead of logger.info to avoid looping issues
            logger.info(f"[PID:{os.getpid()}] Logging setup complete: {log_filepath}")
        except Exception as e:
            logger.info(f"[PID:{os.getpid()}] Logging setup failed: {e}")
            return None
        return logger

    def _get_dataset_name(self):
        """Get a standardized dataset name"""
        dataset_name = self.args.dataset.name
        if dataset_name == 'tumor_generator':
            #tumor dataset, using the gamma parameter
            coeff = getattr(self.args.dataset, 'coeff', 0)
            return f'tumor/gamma_{coeff}'
        elif dataset_name == 'mimic':
            return 'mimic'
        else:
            #Other datasets
            return dataset_name

    def _build_param_suffix(self):
        """The build parameter suffix is used for directory differentiation, and special processing is performed according to the experimental type."""
        param_parts = []
        #Special treatment according to the name of the experiment
        if self.exp_name == "goal_threshold_study":
            #Gift's goal_threshold study
            if (hasattr(self.args.model, 'her_params') and
                    hasattr(self.args.model.her_params, 'target_hit_ratio') and
                    self.args.model.her_params.target_hit_ratio is not None):
                threshold = self.args.model.her_params.target_hit_ratio
                param_parts.append(f"gt_{threshold}")
        elif self.exp_name == "k_parameter_study":
            #K-parameter study of Baseline
            if hasattr(self.args.exp, 'sample_size') and self.args.exp.sample_size is not None:
                param_parts.append(f"k_{self.args.exp.sample_size}")

        elif self.exp_name == "train_size_study":
            #Training Set Size Study
            train_size = None
            if hasattr(self.args.dataset, 'max_number') and self.args.dataset.max_number is not None:
                train_size = self.args.dataset.max_number
            elif (hasattr(self.args.dataset, 'num_patients') and
                    hasattr(self.args.dataset.num_patients, 'train') and
                    self.args.dataset.num_patients.train is not None):
                train_size = self.args.dataset.num_patients.train
            if train_size is not None:
                param_parts.append(f"size_{train_size}")
        
        #--- Core Modification Points ---
        elif self.exp_name == "ablation_study":
            #Gift ablation experiment
            #Obtain the run name explicitly from the `exp.name` parameter (e.g., 'full_model', 'no_dr')
            #This parameter is set in the command line by run_experiments.py
            if hasattr(self.args.exp, 'name') and self.args.exp.name:
                param_parts.append(self.args.exp.name)
            else:
                #If exp.name is not provided, use an explicit alternate name to facilitate debugging
                param_parts.append("unnamed_ablation_run")
        #--- End of Modification ---

        elif self.exp_name == "main_comparison":
            #Primary comparison experiment without additional parameter suffixes
            pass
        elif self.exp_name == "complexity_study":
            #Comparison of time and space complexity experiments
            #Specific complexity analysis parameters can be added as needed
            if hasattr(self.args.exp, 'complexity_type'):
                param_parts.append(f"complexity_{self.args.exp.complexity_type}")
        elif self.exp_name == "hyperparameter_sensitivity":
            #Hyperparametric sensitivity analysis
            if hasattr(self.args.exp, 'learning_rate'):
                param_parts.append(f"lr_{self.args.exp.learning_rate}")
            if hasattr(self.args.exp, 'batch_size'):
                param_parts.append(f"bs_{self.args.exp.batch_size}")
        else:
            #Other experimental types, using common parameter identification logic
            #Goal_threshold parameter for gift
            if (hasattr(self.args.model, 'her_params') and
                    hasattr(self.args.model.her_params, 'goal_threshold') and
                    self.args.model.her_params.goal_threshold is not None):
                threshold = self.args.model.her_params.goal_threshold
                param_parts.append(f"gt_{threshold}")
            #K-parameters for Baseline
            if hasattr(self.args.exp, 'sample_size') and self.args.exp.sample_size is not None:
                param_parts.append(f"k_{self.args.exp.sample_size}")
            #Other possible parameters
            if hasattr(self.args.exp, 'alpha') and self.args.exp.alpha is not None:
                param_parts.append(f"alpha_{self.args.exp.alpha}")
        
        param_parts.append(f"shift_{self.test_setting}")
        return '_'.join(param_parts) if param_parts else None

    def _load_dataset(self):
        """Unified data loading logic"""
        #Build the data path using the original working directory
        processed_data_base = os.path.join(self.original_cwd, 'data/processed')
        # if 'data_seed' in self.args.dataset:
        #     seed = 10
        #     data_dir = os.path.join(processed_data_base, self.dataset_name, self.model_type)
        #     path = os.path.join(data_dir, f"seed_{seed}.pkl")
        # else:
        #     data_dir = os.path.join(processed_data_base, self.dataset_name, self.model_type)
        #     path = os.path.join(data_dir, f"seed_{self.args.exp.seed}.pkl")
        data_dir = os.path.join(processed_data_base, self.dataset_name, self.model_type)
        path = os.path.join(data_dir, f"seed_{self.args.exp.seed}.pkl")

        #Update Paths in Configuration
        self.args.exp.processed_data_dir = data_dir
        if self.args.exp.get('load_data', True) and os.path.exists(path):
            with open(path, 'rb') as file:
                dataset_collection = pickle.load(file)
            self.logger.info(f"Loaded existing dataset from {path}")
        else:
            self.logger.info(f"Creating new dataset and saving to {path}")
            os.makedirs(data_dir, exist_ok=True)
            dataset_collection = instantiate(self.args.dataset, _recursive_=True)
            #Select the data processing method according to the model type
            if self.model_type in ['crn', 'ct', 'rmsn']:
                self.logger.info("Processing data with encoder method...")
                dataset_collection.process_data_encoder()
            else:
                self.logger.info("Processing data with multi method...")
                dataset_collection.process_data_multi()
            with open(path, 'wb') as file:
                pickle.dump(dataset_collection, file)
            self.logger.info(f"Dataset saved to {path}")

        dataset_collection = to_float(dataset_collection)
        #Handling Static Features
        if self.model_type in ['gift', 'vcip', 'actin']:
            if self.args.dataset.get('static_size', 0) > 0:
                dims = len(dataset_collection.train_f.data['static_features'].shape)
                if dims == 2:
                    dataset_collection = repeat_static(dataset_collection)
                    self.logger.info("Static features repeated for sequence compatibility")
        return dataset_collection

    def _setup_model_dimensions(self):
        """Setting Model Dimension Parameters"""
        self.args.model.dim_outcomes = self.dataset_collection.train_f.data['outputs'].shape[-1]
        self.args.model.dim_treatments = self.dataset_collection.train_f.data['current_treatments'].shape[-1]
        self.args.model.dim_vitals = self.dataset_collection.train_f.data['vitals'].shape[-1] if self.dataset_collection.has_vitals else 0
        self.args.model.dim_static_features = self.dataset_collection.train_f.data['static_features'].shape[-1]
        self.logger.info(f"Model dimensions set:")
        self.logger.info(f" - Outcomes: {self.args.model.dim_outcomes}")
        self.logger.info(f" - Treatments: {self.args.model.dim_treatments}")
        self.logger.info(f" - Vitals: {self.args.model.dim_vitals}")
        self.logger.info(f" - Static features: {self.args.model.dim_static_features}")

    def _setup_logger(self):
        """Logger setup"""
        #TensorBoard logs are kept in the current working directory (model_base_dir)
        logger_board = TensorBoardLogger(save_dir='.', name='tensorboard', version='')
        if self.args.exp.get('logging', False):
            experiment_name = f'{self.exp_name}/{self.dataset_name}/{self.model_type}'
            mlf_logger = MLFlowLogger(
                experiment_name=experiment_name,
                tracking_uri="http://localhost:5000",
            )
            mlf_logger.log_metrics({'ymean': self.dataset_collection.train_f.data['outputs'].mean()})
            return logger_board, mlf_logger
        return logger_board, None

    def train_single_model(self):
        """Training a single model (ACTIN, VCIP, CT, etc.)"""
        self.logger.info(f"Starting training for {self.model_type.upper()} model...")
        #Dynamically import model classes
        if self.model_type == 'ct':
            class_path = self.args.model.multi._target_
        else:
            class_path = self.args.model._target_
        module_path, class_name = class_path.rsplit('.', 1)
        module = importlib.import_module(module_path)
        model_class = getattr(module, class_name)
        self.logger.info(f"Model class: {class_name}")
        self.logger.info(f"Module path: {module_path}")
        if self.model_type == 'vcip' or self.model_type == 'ct':
            model = model_class(self.args, self.dataset_collection)
        else:
            model = model_class(self.dataset_collection, self.args)
        count_parameters(model, logger=self.logger)
        #Set checkpoint path (using current working directory)
        checkpoint_path = os.path.join(self.model_checkpoint_dir, 'model.ckpt')
        if os.path.exists(checkpoint_path) and self.args.exp.get('load_model', False):
            checkpoint = torch.load(checkpoint_path)
            model.load_state_dict(checkpoint)
            self.logger.info(f"Loaded existing model from {checkpoint_path}")
        else:
            self.logger.info("Training new model...")
            #Training Model
            logger_board, mlf_logger = self._setup_logger()
            callbacks = [
                pl.callbacks.ModelCheckpoint(
                    dirpath=self.model_checkpoint_dir,
                    monitor='val_loss',
                    filename='model',
                    save_top_k=1,
                    mode='min'
                ),
            ]
            if self.model_type == 'actin':
                callbacks.append(AlphaRise_actin())
            self.logger.info(f"Setting up trainer with {self.args.exp.max_epochs} epochs...")
            if self.model_type == 'vcip' or self.model_type == 'actin':
                trainer = pl.Trainer(
                    logger=logger_board,
                    max_epochs=self.args.exp.max_epochs,
                    devices=self.args.exp.gpus,
                    callbacks=callbacks,
                    precision=32,
                )
            else:
                trainer = pl.Trainer(
                    logger=logger_board,
                    max_epochs=self.args.exp.max_epochs,
                    gpus=eval(str(self.args.exp.gpus)),
                    callbacks=callbacks,
                    precision=32,
                )
            self.logger.info("Starting training process...")
            trainer.fit(model)
            torch.save(model.state_dict(), checkpoint_path)
            self.logger.info(f"Training completed. Model saved to {checkpoint_path}")
        return model

    def train_encoder_decoder(self):
        """Training encoder-decoder structural model (CRN, EDCT)"""
        self.logger.info(f"Training encoder-decoder model: {self.model_type.upper()}")
        self._setup_model_dimensions()
        #Training encoder
        self.logger.info("Initializing encoder...")
        encoder = instantiate(self.args.model.encoder, self.args, self.dataset_collection, _recursive_=False)
        encoder_path = os.path.join(self.model_checkpoint_dir, 'encoder.ckpt')
        if os.path.exists(encoder_path) and self.args.exp.get('load_model', False):
            checkpoint = torch.load(encoder_path)
            encoder.load_state_dict(checkpoint)
            self.logger.info(f"Loaded existing encoder from {encoder_path}")
        else:
            self.logger.info("Training encoder...")
            logger_board, mlf_logger = self._setup_logger()
            encoder_trainer = pl.Trainer(
                gpus=eval(str(self.args.exp.gpus)),
                logger=mlf_logger,
                max_epochs=self.args.exp.max_epochs,
                callbacks=[AlphaRise(rate=self.args.exp.alpha_rate)],
                terminate_on_nan=True,
                precision=32
            )
            encoder_trainer.fit(encoder)
            torch.save(encoder.state_dict(), encoder_path)
            self.logger.info(f"Encoder training completed. Saved to {encoder_path}")
        #Training decoder (if needed)
        if self.args.model.get('train_decoder', True):
            self.logger.info("Initializing decoder...")
            decoder = instantiate(self.args.model.decoder, self.args, encoder, self.dataset_collection, _recursive_=False)
            decoder_path = os.path.join(self.model_checkpoint_dir, 'decoder.ckpt')
            if os.path.exists(decoder_path) and self.args.exp.get('load_model', False):
                checkpoint = torch.load(decoder_path)
                decoder.load_state_dict(checkpoint)
                self.logger.info(f"Loaded existing decoder from {decoder_path}")
            else:
                self.logger.info("Training decoder...")
                decoder_trainer = pl.Trainer(
                    gpus=eval(str(self.args.exp.gpus)),
                    logger=mlf_logger,
                    max_epochs=self.args.exp.max_epochs,
                    callbacks=[AlphaRise(rate=self.args.exp.alpha_rate)],
                    terminate_on_nan=True,
                    precision=32
                )
                decoder_trainer.fit(decoder)
                torch.save(decoder.state_dict(), decoder_path)
                self.logger.info(f"Decoder training completed. Saved to {decoder_path}")
            return encoder, decoder
        return encoder, None

    def train_rmsn(self):
        """Training the RMSN Model"""
        self.logger.info("Training RMSN model...")
        self._setup_model_dimensions()
        models = {}
        model_components = ['propensity_treatment', 'propensity_history', 'encoder', 'decoder']
        prop_treatment_callbacks, propensity_history_callbacks, encoder_callbacks, decoder_callbacks = [], [], [], []
        gradient_clip_vals = [
            self.args.model.propensity_treatment.max_grad_norm,
            self.args.model.propensity_history.max_grad_norm,
            self.args.model.encoder.max_grad_norm,
            self.args.model.decoder.max_grad_norm,
        ]
        callbacks = [prop_treatment_callbacks, propensity_history_callbacks, encoder_callbacks, decoder_callbacks]
        logger_board, mlf_logger = self._setup_logger()
        for i, component in enumerate(model_components):
            if component == 'decoder' and not self.args.model.get('train_decoder', True):
                continue
            self.logger.info(f"Processing component: {component}")
            component_path = os.path.join(self.model_checkpoint_dir, f'{component}.ckpt')
            if os.path.exists(component_path) and self.args.exp.get('load_model', False):
                self.logger.info(f"Loading existing {component} from {component_path}")
                if component == 'encoder':
                    model = instantiate(self.args.model.encoder, self.args,
                                        models['propensity_treatment'],
                                        models['propensity_history'],
                                        self.dataset_collection, _recursive_=False)
                elif component == 'decoder':
                    model = instantiate(self.args.model.decoder, self.args,
                                        models['encoder'],
                                        self.dataset_collection, _recursive_=False)
                else:
                    model = instantiate(getattr(self.args.model, component),
                                        self.args, self.dataset_collection, _recursive_=False)
                checkpoint = torch.load(component_path)
                model.load_state_dict(checkpoint)
                models[component] = model
            else:
                self.logger.info(f"Training {component}...")
                #Training Components
                if component == 'encoder':
                    model = instantiate(self.args.model.encoder, self.args,
                                        models['propensity_treatment'],
                                        models['propensity_history'],
                                        self.dataset_collection, _recursive_=False)
                elif component == 'decoder':
                    model = instantiate(self.args.model.decoder, self.args,
                                        models['encoder'],
                                        self.dataset_collection, _recursive_=False)
                else:
                    model = instantiate(getattr(self.args.model, component),
                                        self.args, self.dataset_collection, _recursive_=False)
                trainer = pl.Trainer(
                    gpus=eval(str(self.args.exp.gpus)),
                    logger=mlf_logger,
                    max_epochs=self.args.exp.max_epochs,
                    callbacks=callbacks[i],
                    gradient_clip_val=gradient_clip_vals[i],
                    terminate_on_nan=True,
                    precision=32
                )
                trainer.fit(model)
                torch.save(model.state_dict(), component_path)
                models[component] = model
                self.logger.info(f"{component} training completed. Saved to {component_path}")
        return models

    def train_gift(self):
        """Training the gift model"""
        self.logger.info("Training GIFT model...")
        model, metrics, complexity_info = train_gift(self.dataset_collection, self.args, logger=self.logger)
        self.logger.info("GIFT training completed.")
        return model, metrics, complexity_info

    def _save_aggregated_results(self, model, metrics):
        """Save aggregate results to a JSON file"""
        self.logger.info("Saving aggregated results...")
        try:
            #Collect experimental configuration information
            config_info = {
                'exp_name': self.exp_name,
                'dataset_coeff': getattr(self.args.dataset, 'coeff', None),
                'goal_threshold': None,
                'baseline_k': getattr(self.args.exp, 'k', None),
                'train_size': None,
                'max_epochs': self.args.exp.max_epochs,
                'gpus': str(self.args.exp.gpus),
                'test_setting': self.test_setting,
                'learning_rate': getattr(self.args.exp, 'learning_rate', None),
                'batch_size': getattr(self.args.exp, 'batch_size', None),
                'alpha': getattr(self.args.exp, 'alpha', None),
            }
            #Get training set size
            if hasattr(self.args.dataset, 'max_number'):
                config_info['train_size'] = self.args.dataset.max_number
            elif hasattr(self.args.dataset, 'num_patients') and hasattr(self.args.dataset.num_patients, 'train'):
                config_info['train_size'] = self.args.dataset.num_patients.train
            #Try to get goal_threshold
            if hasattr(self.args.model, 'her_params') and hasattr(self.args.model.her_params, 'goal_threshold'):
                config_info['goal_threshold'] = self.args.model.her_params.goal_threshold
            #Collect ablation experiment related configurations
            if self.exp_name == "ablation_study":
                config_info['ablation_config'] = {
                    'use_her': getattr(self.args.model.her_params, 'use_her', True) if hasattr(self.args.model, 'her_params') else True,
                    'use_goal_conditioning': getattr(self.args.model, 'use_goal_conditioning', True),
                    'use_history': getattr(self.args.model, 'use_history', True),
                    'reward_function': getattr(self.args.model, 'reward_function', 'default'),
                    'use_experience_replay': getattr(self.args.model, 'use_experience_replay', True),
                }
            #Gather dataset information
            dataset_info = {
                'train_size': len(self.dataset_collection.train_f.data['outputs']),
                'has_vitals': getattr(self.dataset_collection, 'has_vitals', False),
                'dim_outcomes': self.dataset_collection.train_f.data['outputs'].shape[-1],
                'dim_treatments': self.dataset_collection.train_f.data['current_treatments'].shape[-1],
                'dim_static_features': self.dataset_collection.train_f.data['static_features'].shape[-1],
            }
            #Gather model information
            model_info = {
                'model_type': self.model_type,
                'parameters': count_parameters(model) if model else None,
            }
            #Gather Path Information
            path_info = {
                'original_cwd': self.original_cwd,
                'hydra_cwd': self.hydra_cwd,
                'checkpoint_dir': self.model_checkpoint_dir,
                'csv_dir': self.csv_dir,
                'logs_dir': self.logs_dir,
                'log_file': self.log_filepath,
            }
            #Build aggregate results
            aggregated_data = {
                'experiment_metadata': {
                    'exp_name': self.exp_name,
                    'model_type': self.model_type,
                    'dataset_name': self.dataset_name,
                    'param_suffix': self.param_suffix,
                    'seed': self.args.exp.seed,
                    'timestamp': datetime.now().isoformat(),
                    'config_name': getattr(self.args, 'config_name', None),
                },
                'paths': path_info,
                'configuration': config_info,
                'dataset_info': dataset_info,
                'model_info': model_info,
                'metrics': metrics,
                'summary_statistics': self._calculate_summary_statistics(metrics)
            }
            #Save to JSON file
            with open(self.aggregated_results_path, 'w') as f:
                json.dump(aggregated_data, f, indent=2, default=str)
            self.logger.info(f"Aggregated results saved to {self.aggregated_results_path}")
        except Exception as e:
            self.logger.info(f"Failed to save aggregated results: {e}")
            self.logger.error(f"Failed to save aggregated results: {e}")

    def _calculate_summary_statistics(self, metrics):
        """Calculate summary statistics for metrics"""
        try:
            import numpy as np
            summary = {}
            #Extract success_rate and avg_rmse for all tau
            success_rates = []
            avg_rmses = []
            for tau, tau_metrics in metrics.items():
                if isinstance(tau_metrics, dict):
                    if 'success_rate' in tau_metrics:
                        success_rates.append(tau_metrics['success_rate'])
                    if 'avg_rmse' in tau_metrics:
                        avg_rmses.append(tau_metrics['avg_rmse'])
            if success_rates:
                summary['success_rate'] = {
                    'mean': float(np.mean(success_rates)),
                    'std': float(np.std(success_rates)),
                    'max': float(np.max(success_rates)),
                    'min': float(np.min(success_rates)),
                    'final': success_rates[-1] if success_rates else None, #Value of the last tau
                }
            if avg_rmses:
                summary['avg_rmse'] = {
                    'mean': float(np.mean(avg_rmses)),
                    'std': float(np.std(avg_rmses)),
                    'max': float(np.max(avg_rmses)),
                    'min': float(np.min(avg_rmses)),
                    'final': avg_rmses[-1] if avg_rmses else None, #Value of the last tau
                }
            return summary
        except Exception as e:
            self.logger.info(f"Failed to calculate summary statistics: {e}")
            self.logger.warning(f"Failed to calculate summary statistics: {e}")
            return {}

    def plot_case(self, model):
        PLANNING_HORIZON = 6 
        model_name=self.model_type.upper()
        all_case_study_results = {}
        if model_name == 'GIFT':
            all_case_study_results, used_time = evaluate_and_log_case_studies(
                model,
                dataset_collection=self.dataset_collection,
                config=self.args,
                logger=self.logger,
                max_tau=PLANNING_HORIZON,
                case_study_results=all_case_study_results,  #Incoming dictionary to aggregate results
                model_name=model_name,
            )
        else:
            all_case_study_results, used_time = model.evaluate_and_log_case_studies(
                dataset_collection=self.dataset_collection,
                config=self.args,
                logger=self.logger,
                model_name=model_name,
                max_tau=PLANNING_HORIZON,
                case_study_results=all_case_study_results  #Incoming dictionary to aggregate results
            )
        path = os.path.join(self.result_dir, 'case_study_trajectories.pkl')
        if self.args.exp.seed == 10:
            with open(path, 'wb') as f:
                pickle.dump(all_case_study_results, f)
        # analyzer.plot_outcome_trajectories(all_case_study_results, 10, path)
        return used_time

    def run_training(self):
        """Run the appropriate training process according to the model type"""
        self.logger.info("="*60)
        self.logger.info(f"STARTING TRAINING PROCESS")
        self.logger.info(f"Model: {self.model_type.upper()}")
        self.logger.info(f"Dataset: {self.dataset_name}")
        self.logger.info(f"Experiment: {self.exp_name}")
        self.logger.info(f"Seed: {self.args.exp.seed}")
        self.logger.info(f"Working Directory: {os.getcwd()}")
        self.logger.info("="*60)
        start_time = time.time()
        num_episodes = 2000
        sample = self.args.exp.sample_size
        try:
            if self.model_type == 'gift':
                model, metrics, complexity_info = self.train_gift()
                write_csv(self.csv_dir, metrics, self.args)
                self._save_aggregated_results(model, metrics)
                # self.plot_case(model)
            else:
                #Unify all models except gift
                if self.model_type == 'crn':
                    encoder, decoder = self.train_encoder_decoder()
                    model = decoder if decoder else encoder
                elif self.model_type == 'rmsn':
                    models = self.train_rmsn()
                    model = models['decoder'] if 'decoder' in models else models['encoder']
                    encoder = models['encoder']
                else: #single model of actin, vcip, ct, etc.
                    model = self.train_single_model()
                    encoder = None
                # self.plot_case(model)
                
                train_time = time.time() - start_time
                self.logger.info("Starting evaluation...")
                #Decide how to evaluate and save based on optimize_by_step
                start_time = time.time()
                if self.args.exp.optimize_by_step:
                    metrics = model.evaluate(self.dataset_collection, self.args, num_episodes=num_episodes, sample=sample, logger=self.logger)
                    test_time = time.time() - start_time
                    write_csv(self.csv_dir, metrics, self.args)
                    self._save_aggregated_results(model, metrics)
                else:
                    metrics = model.optimize_interventions(encoder=encoder, num_iterations=sample, logger=self.logger, learning_rate=self.args.exp.action_learning_rate)
                    test_time = time.time() - start_time
                    write_csv_vcip(metrics, self.csv_dir, self.args)

                complexity_info = model.complexity_info
                complexity_info['train_time'] = train_time

            if self.args.exp.exp_name == "main_comparison":
                complexity_info['test_time'] = self.plot_case(model)
            
            # if self.args.exp.seed == 10:
            #     if 'tumor_gamma=4' in self.args.dataset.name:
            #         self.plot_case(model)

            complexity_info['optimize_by_step'] = self.args.exp.optimize_by_step
                
            write_complexity_info(complexity_info, self.csv_dir, self.args)
            #Save CSV results (for compatibility)
            self.logger.info(f"CSV results saved to {self.csv_dir}")
        
            end_time = time.time()
            training_time = end_time - start_time
            self.logger.info("="*60)
            self.logger.info(f"TRAINING COMPLETED SUCCESSFULLY!")
            self.logger.info(f"Model: {self.model_type.upper()}")
            self.logger.info(f"Dataset: {self.dataset_name}")
            self.logger.info(f"Training time: {training_time:.2f} seconds")
            self.logger.info(f"Results saved to: {self.model_base_dir}")
            self.logger.info(f"Log file: {self.log_filepath}")
            self.logger.info("="*60)
            return model, metrics
        except Exception as e:
            self.logger.info(f"❌ TRAINING FAILED: {e}")
            import traceback
            traceback.print_exc()
            raise e

#The main function.
OmegaConf.register_new_resolver("toint", lambda x: int(x), replace=True)
OmegaConf.register_new_resolver("subtract", lambda x, y: x - y, replace=True)
OmegaConf.register_new_resolver("sum", lambda x, y: x + y, replace=True)

@hydra.main(config_name='config.yaml', config_path='../configs/', version_base=None)
def main(args: DictConfig):
    """Unified training master function"""
    OmegaConf.set_struct(args, False)
    #Create Unified Trainer
    trainer = UnifiedTrainer(args)
    delete_train_log(trainer.original_cwd, args)
    #Print configuration information (this will be logged in the log file)
    trainer.logger.info("Configuration:")
    trainer.logger.info(OmegaConf.to_yaml(args, resolve=True))
    #Run training
    model, metrics = trainer.run_training()
    trainer.logger.info("All operations completed!")
    return model, metrics

if __name__ == "__main__":
    main()
