import faulthandler
import os
import signal
import tempfile
import warnings
import logging
import time
from logging import Logger
from typing import Dict, List, Optional, Type, Union

import numpy as np
import ray
from ray.rllib.algorithms import ppo
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.evaluation import RolloutWorker
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.typing import MultiAgentPolicyConfigDict
from ray.rllib.execution.rollout_ops import (
    standardize_fields,
    synchronous_parallel_sample,
)
from sacred import SETTINGS as sacred_settings
from sacred import Experiment

from occupancy_measures.agents.bc import BC
from occupancy_measures.agents.generate_safe_policy import SafePolicyGenerationAlgorithm
from occupancy_measures.agents.learned_reward_algorithm import LearnedRewardAlgorithm
from occupancy_measures.agents.orpo import ORPO, ORPOPolicy
from occupancy_measures.envs.learned_reward_wrapper import LearnedRewardWrapperConfig
from occupancy_measures.models.glucose_models import normalize_obs
from occupancy_measures.models.reward_model import RewardModelConfig
from occupancy_measures.utils.os_utils import available_cpu_count
from occupancy_measures.utils.training_utils import (  # convert_to_msgpack_checkpoint,
    build_logger_creator,
    load_algorithm_config,
    load_policies_from_checkpoint,
)

from occupancy_measures.experiments.glucose_experiments import create_glucose_config
from occupancy_measures.experiments.pandemic_experiments import create_pandemic_config
from occupancy_measures.experiments.tomato_experiments import create_tomato_config
from occupancy_measures.experiments.traffic_experiments import create_traffic_config
from pandemic_simulator.environment.pandemic_env import PandemicPolicyGymEnv
from occupancy_measures.envs.tomato_environment import create_simple_example,Tomato_Environment
from flow.utils.registry import make_create_env
from bgp.simglucose.envs.simglucose_gym_env import SimglucoseEnv
from occupancy_measures.envs.glucose_true_rew_wrapper import GlucoseWrapper
from utils.exploration_utils import SoftmaxTempExploration, BetaTempExploration

import learn_reward.unique_id_state as unique_id_state

if unique_id_state.state["is_ensemble"]:
    from reward_modeling.reward_wrapper_ensemble import RewardWrapper
else:
    from reward_modeling.reward_wrapper_pbrr import RewardWrapper


import sys
# Redirect both stdout and stderr to sys.stderr (i.e., force all output to SLURM's .error file)
sys.stdout = sys.stderr

# Add new imports for timeout handling
from concurrent.futures import TimeoutError
from functools import partial
import psutil

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


warnings.filterwarnings("ignore", category=DeprecationWarning)

def create_env_pandemic(config):
    base_env = PandemicPolicyGymEnv(config)
    # Access unique_id from the config
    return RewardWrapper(base_env, reward_model=config.get("reward_model", "custom_pandemic_sas"), unique_id=unique_id_state.state["unique_id"], modify_proxy_reward=unique_id_state.state["modify_proxy_reward"])

def create_env_tomato(config):
    base_env = Tomato_Environment(config)
    # Access unique_id from the config
    return RewardWrapper(base_env, reward_model=config.get("reward_model", "custom_tomato"), unique_id=unique_id_state.state["unique_id"], modify_proxy_reward=unique_id_state.state["modify_proxy_reward"])

def create_env_traffic(config,flow_params, reward_specification, reward_fun, reward_scale):
    # print (flow_params.keys())
    # print ("----------")
    # print (flow_params["exp_tag"])

    # flow_params["exp_tag"] = flow_params["exp_tag"]
    create_env, env_name = make_create_env(
        params=flow_params,
        reward_specification=reward_specification,
        reward_fun=reward_fun,
        reward_scale=reward_scale,
    )
    base_env = create_env()


    return RewardWrapper(base_env, reward_model=config.get("reward_model", "custom_traffic_sas"), unique_id=unique_id_state.state["unique_id"], modify_proxy_reward=unique_id_state.state["modify_proxy_reward"])

def create_env_glucose(config):
    # Access unique_id from the config
    base_env = SimglucoseEnv(config)
   
    # return RewardWrapper(base_env, reward_model=config.get("reward_model", "custom_glucose_sas"), unique_id=unique_id_state.state["unique_id"])
    return GlucoseWrapper(RewardWrapper(base_env, reward_model=config.get("reward_model", "custom_glucose_sas"), unique_id=unique_id_state.state["unique_id"],modify_proxy_reward=unique_id_state.state["modify_proxy_reward"]))


ex = Experiment("PBRR", save_git_info=False)
sacred_settings.CONFIG.READ_ONLY_CONFIG = False


faulthandler.register(signal.SIGUSR1)


@ex.config
def env_config():
    env_to_run = "tomato"  # noqa: F841
    experiment_parts = [env_to_run]  # noqa: F841


create_glucose_config(ex,use_custom_rm=True,custom_rm=create_env_glucose)
create_pandemic_config(ex,use_custom_rm=True,custom_rm=create_env_pandemic)
create_tomato_config(ex, use_custom_rm=True,custom_rm=create_env_tomato)
create_traffic_config(ex, use_custom_rm=True,custom_rm=create_env_traffic)



EPS = 1e-9


@ex.config
def common_config(  # noqa: C901
    env_to_run,
    config,
    env_config,
    num_training_iters,
    experiment_parts,
    _log,
):
    num_cpus = available_cpu_count()  # noqa: F841

    exp_algo = "PPO"
    assert exp_algo in [
        "PPO",
        "ORPO",
        "ORPO_HIGH_TEMP_ACT_SAMP",
        "BC",
        "SafePolicyGenerationAlgorithm",
        "RewardAlgorithm",
    ]

    def restore_default_params(config=config, env_to_run=env_to_run):
        env_config_updates: dict
        custom_model_config_updates: dict
        if env_to_run == "pandemic":
            config.rollout_fragment_length = 193
            config.train_batch_size = max(
                config.rollout_fragment_length * config.num_rollout_workers,
                config.rollout_fragment_length,
            )
            config.vf_clip_param = np.inf
            config.num_sgd_iter = 10
            config.grad_clip = None

            custom_model_config_updates = {
                "use_history_for_disc": True,
                "discriminator_state_dim": 0,
                "history_range": (-24, 0),
            }
            config.model["custom_model_config"].update(custom_model_config_updates)
        elif env_to_run == "traffic":
            config.entropy_coeff_schedule = [[0, 0], [1000000.0, 0]]
            config.gamma = 0.999
            config.num_sgd_iter = 5
            config.rollout_fragment_length = 4000
            config.train_batch_size = max(
                config.rollout_fragment_length * config.num_rollout_workers,
                config.rollout_fragment_length,
            )

            env_config_updates = {"reward_scale": 1}
            config.env_config.update(env_config_updates)
        elif env_to_run == "glucose":
            config.env = "glucose_env_multiagent"
            config.entropy_coeff_schedule = [[0, 0.01], [1000000.0, 0.01]]

            custom_model_config_updates = {
                "discriminator_state_dim": 0,
                "use_cgm_for_obs": False,
                "use_history_for_disc": True,
            }
            config.model["custom_model_config"].update(custom_model_config_updates)

            config.lr = 1e-3
            config.num_envs_per_worker = 1
            config.grad_clip = 0.1
            config.vf_clip_param = np.inf
            config.entropy_coeff_schedule = [[0, 0], [1000000.0, 0]]
            config.rollout_fragment_length = 5760
            config.train_batch_size = max(
                config.rollout_fragment_length
                * config.num_rollout_workers
                * config.num_envs_per_worker,
                config.rollout_fragment_length * config.num_envs_per_worker,
            )
            config.sgd_minibatch_size = 1024
            config.kl_target = 0.01
            config.num_sgd_iter = 8

    # Seed
    seed = 0
    config.seed = seed

    # Logging
    save_freq = 25  # noqa: F841
    log_dir = "data/logs"  # noqa: F841
    checkpoint_to_load_current_policy = None  # noqa: F841
    checkpoint_to_load_policies = None  # noqa: F841
    policy_ids_to_load = None  # noqa: F841
    policy_id_to_load_current_policy = None  # noqa: F841
    checkpoint_path = None  # noqa: F841
    experiment_parts.append(exp_algo)
    experiment_parts.append(config.env_config["reward_fun"])
    if "fcnet_hiddens" in config.model:
        model_string = "model_" + "-".join(
            str(width) for width in config.model["fcnet_hiddens"]
        )
        experiment_parts.append(model_string)
    experiment_tag: Optional[str] = None
    if experiment_tag is not None:
        experiment_parts.append(experiment_tag)
    config.metrics_num_episodes_for_smoothing = 1

    # Evaluation
    evaluation_num_workers = (
        4 if exp_algo == "BC" or exp_algo == "RewardAlgorithm" else 0
    )
    evaluation_interval = (
        25 if exp_algo == "BC" or exp_algo == "RewardAlgorithm" else None
    )
    evaluation_duration = max(evaluation_num_workers, 1)
    evaluation_duration_unit = "episodes"
    evaluation_explore = True
    evaluation_sample_timeout_s = 600
    evaluation_config = {
        "input": "sampler",
        "explore": evaluation_explore,
    }
    config._enable_rl_module_api = False
    config._enable_learner_api = False
    config.enable_connectors = False

    AlgorithmClass: Type[Algorithm]

    if exp_algo == "PPO":
        AlgorithmClass = ppo.PPO  # noqa: F841
        use_learned_reward = False
        if use_learned_reward:
            learned_reward_str = "using_learned_reward"
            if "reward_model_width" in config.model["custom_model_config"]:
                learned_reward_str += "_w" + str(
                    config.model["custom_model_config"]["reward_model_width"]
                )
            if "reward_model_depth" in config.model["custom_model_config"]:
                learned_reward_str += "_d" + str(
                    config.model["custom_model_config"]["reward_model_depth"]
                )
            experiment_parts.append(learned_reward_str)
            reward_model_checkpoint = ""
            if env_to_run == "tomato":
                config.env = "tomato_env_multiagent"
            elif env_to_run == "glucose":
                config.env = "glucose_env_multiagent"

            if reward_model_checkpoint == "":
                _log.error(
                    "Please specify a valid checkpoint from which a reward model can be loaded!"
                )
                assert False

            assert (
                config.env_config["reward_fun"] == "proxy"
            ), "The learned reward function replaces the proxy reward."
            wrapper_env_config: LearnedRewardWrapperConfig = {
                "env": config.env,
                "env_config": config.env_config,
                "reward_fn_checkpoint": reward_model_checkpoint,
            }
            config.env_config = wrapper_env_config
            config.env = "learned_reward_wrapper"

    if "ORPO" in exp_algo:
        AlgorithmClass = ORPO  # noqa: F841
        num_safe_policies = 0
        if checkpoint_to_load_policies is not None and num_safe_policies < len(
            checkpoint_to_load_policies
        ):
            num_safe_policies = len(checkpoint_to_load_policies)

        if policy_ids_to_load is not None:
            assert isinstance(policy_ids_to_load, list) and all(
                isinstance(pid, list) for pid in policy_ids_to_load
            ), (
                "'policy_ids_to_load' must be a list of lists where each list contains the particular policy ids"
                "to load from the corresponding checkpoint given within 'checkpoint_to_load_policies'"
            )
            if len(policy_ids_to_load) < num_safe_policies:
                num_policy_ids_given = len(policy_ids_to_load) - num_safe_policies
                _log.warning(
                    f"Only {len(policy_ids_to_load)} of the policies specified in 'checkpoint_to_load_policies'"
                    "will have a specific policy retrieved as specified in the respective list within 'policy_ids_to_load'"
                )
                num_policy_ids_given.extend([None] * num_policy_ids_given)

        discriminator_state_info_key = None
        discriminator_num_sgd_iter = None
        update_safe_policy_freq = None
        action_dist_divergence_coeff = None
        action_dist_divergence_type = "kl"
        train_discriminator_first = True
        num_extra_repeated_safe_policy_batches = 1
        discriminator_reward_clip = float("inf")
        wgan_grad_clip = 0.01
        wgan_grad_penalty_weight = None
        wasserstein_distance_subtract_mean_safe_policy_score = False
        split_om_kl = False
        occupancy_measure_kl_target: List[float] = []
        use_squared_kl_adaptive_coefficient = False

        safe_policy_specific_params: Dict = {}  # noqa: F841

        if env_to_run == "tomato":
            config.env = "tomato_env_multiagent"
            discriminator_reward_clip = 1000
        elif env_to_run == "pandemic":
            weights_string = "weights_" + "_".join(
                str(coef) for coef in config.env_config["proxy_reward_fun"]._weights
            )
            experiment_parts.append(weights_string)
            discriminator_reward_clip = 100
            discriminator_num_sgd_iter = 2
        elif env_to_run == "glucose":
            config.env = "glucose_env_multiagent"
            discriminator_reward_clip = 1e10
        elif env_to_run == "traffic":
            discriminator_reward_clip = 1

        ORPO_updates = {
            "discriminator_state_info_key": discriminator_state_info_key,
            "discriminator_num_sgd_iter": discriminator_num_sgd_iter,
            "update_safe_policy_freq": update_safe_policy_freq,
            "action_dist_divergence_coeff": action_dist_divergence_coeff,
            "action_dist_divergence_type": action_dist_divergence_type,
            "train_discriminator_first": train_discriminator_first,
            "num_extra_repeated_safe_policy_batches": num_extra_repeated_safe_policy_batches,
            "discriminator_reward_clip": discriminator_reward_clip,
            "wgan_grad_clip": wgan_grad_clip,
            "wgan_grad_penalty_weight": wgan_grad_penalty_weight,
            "wasserstein_distance_subtract_mean_safe_policy_score": wasserstein_distance_subtract_mean_safe_policy_score,
            "split_om_kl": split_om_kl,
            "use_squared_kl_adaptive_coefficient": use_squared_kl_adaptive_coefficient,
        }
        config.update_from_dict(ORPO_updates)
        om_divergence_coeffs: List[Union[int, float]] = [0] * num_safe_policies
        om_divergence_type = ["kl"] * num_safe_policies
        assert set(om_divergence_type).issubset(
            set(
                [
                    "kl",
                    "tv",
                    "chi2",
                    "sqrt_chi2",
                    "wasserstein",
                    "safe_policy_confidence",
                ]
            )
        )
        percent_safe_policy = 0.5
        if occupancy_measure_kl_target:
            assert len(occupancy_measure_kl_target) == num_safe_policies
            om_divergence_coeffs = [np.random.uniform(EPS, 1)] * num_safe_policies
            om_divergence_coeffs_str = "_".join(
                f"om-kl-target-{coeff}" for coeff in occupancy_measure_kl_target
            )
        elif action_dist_divergence_coeff is not None and not split_om_kl:
            om_divergence_coeffs_str = (
                f"action-{action_dist_divergence_type}-{action_dist_divergence_coeff}"
            )
        else:
            om_divergence_coeffs_str = "_".join(
                f"{dist}-{coeff}"
                for dist, coeff in zip(om_divergence_type, om_divergence_coeffs)
            )
        if split_om_kl:
            om_divergence_coeffs_str += "_split-om"

        if update_safe_policy_freq is not None:
            om_divergence_coeffs_str += "_update-" + str(update_safe_policy_freq)
        if num_extra_repeated_safe_policy_batches > 1:
            om_divergence_coeffs_str += "_extra_discriminator_training-" + str(
                num_extra_repeated_safe_policy_batches
            )
        experiment_parts.append(om_divergence_coeffs_str)

        for i in range(num_safe_policies):
            if i not in safe_policy_specific_params:
                safe_policy_specific_params[i] = {}
            safe_policy_specific_params[i]["grad_clip"] = config.grad_clip

        use_learned_reward = False
        if use_learned_reward:
            learned_reward_str = "using_learned_reward"
            if "reward_model_width" in config.model["custom_model_config"]:
                learned_reward_str += "_w" + str(
                    config.model["custom_model_config"]["reward_model_width"]
                )
            if "reward_model_depth" in config.model["custom_model_config"]:
                learned_reward_str += "_d" + str(
                    config.model["custom_model_config"]["reward_model_depth"]
                )
            experiment_parts.append(learned_reward_str)
            reward_model_checkpoint = ""
            if reward_model_checkpoint == "":
                _log.error(
                    "Please specify a valid checkpoint from which a reward model can be loaded!"
                )
                assert False
            assert (
                config.env_config["reward_fun"] == "proxy"
            ), "The learned reward function replaces the proxy reward!"
            wrapper_env_config = {
                "env": config.env,
                "env_config": config.env_config,
                "reward_fn_checkpoint": reward_model_checkpoint,
            }
            config.env_config = wrapper_env_config
            config.env = "learned_reward_wrapper"

        if split_om_kl:
            config.model["custom_model_config"]["use_action_for_disc"] = False
            if action_dist_divergence_coeff is not None:
                config.action_dist_divergence_coeff = action_dist_divergence_coeff
            else:
                config.action_dist_divergence_coeff = om_divergence_coeffs[0]
        
        if "HIGH_TEMP_ACT_SAMP" in exp_algo:
            config.update_from_dict({
                "explore": True,
                "exploration_config": {"type": BetaTempExploration, "temperature": 2.0}
            })

        policies, policy_mapping_fn, policies_to_train = create_multiagent(
            config,
            percent_safe_policy,
            num_safe_policies,
            om_divergence_type,
            om_divergence_coeffs,
            occupancy_measure_kl_target,
            safe_policy_specific_params,
            checkpoint_to_load_policies,
            _log=_log,
        )
        config.multi_agent(
            policies=policies,
            policy_mapping_fn=policy_mapping_fn,
            policies_to_train=policies_to_train,
        )

    elif exp_algo == "BC":
        AlgorithmClass = BC  # noqa: F841
        num_rollout_workers = 0
        entropy_coeff = 0
        input = ""
        validation_prop = 0
        action_info_key: Optional[Union[List[str], str]] = []
        restore_default_params(config, env_to_run)
        for config_key in list(config.keys()):
            if config_key not in BC.get_default_config():
                _log.info(f"removing {config_key} from config")
                if config_key == "lambda":
                    config_key = "lambda_"
                delattr(config, config_key)

        use_learned_reward = False
        if use_learned_reward:
            learned_reward_str = "using_learned_reward"
            if "reward_model_width" in config.model["custom_model_config"]:
                learned_reward_str += "_w" + str(
                    config.model["custom_model_config"]["reward_model_width"]
                )
            if "reward_model_depth" in config.model["custom_model_config"]:
                learned_reward_str += "_d" + str(
                    config.model["custom_model_config"]["reward_model_depth"]
                )
            experiment_parts.append(learned_reward_str)
            reward_model_checkpoint = ""
            if env_to_run == "tomato":
                config.env = "tomato_env_multiagent"
            elif env_to_run == "glucose":
                config.env = "glucose_env_multiagent"

            if reward_model_checkpoint == "":
                _log.error(
                    "Please specify a valid checkpoint from which a reward model can be loaded!"
                )
                assert False

            assert (
                config.env_config["reward_fun"] == "proxy"
            ), "The learned reward function replaces the proxy reward."
            wrapper_env_config = {
                "env": config.env,
                "env_config": config.env_config,
                "reward_fn_checkpoint": reward_model_checkpoint,
            }
            config.env_config = wrapper_env_config
            config.env = "learned_reward_wrapper"

        config.update_from_dict(
            {
                "num_rollout_workers": num_rollout_workers,
                "entropy_coeff": entropy_coeff,
                "input": input,
                "validation_prop": validation_prop,
                "evaluation_num_workers": evaluation_num_workers,
                "evaluation_interval": evaluation_interval,
                "evaluation_duration": evaluation_duration,
                "evaluation_duration_unit": evaluation_duration_unit,
                "evaluation_config": evaluation_config,
                "evaluation_sample_timeout_s": evaluation_sample_timeout_s,
                "action_info_key": action_info_key,
                "multiagent": {
                    "policies": {"safe_policy0": PolicySpec()},
                    "policy_mapping_fn": lambda agent_id, **kwargs: "safe_policy0",
                },
            }
        )

    elif exp_algo == "SafePolicyGenerationAlgorithm":
        AlgorithmClass = SafePolicyGenerationAlgorithm  # noqa: F841
        restore_default_params(config, env_to_run)
        for config_key in list(config.keys()):
            if config_key not in SafePolicyGenerationAlgorithm.get_default_config():
                _log.info(f"removing {config_key} from config")
                if config_key == "lambda":
                    config_key = "lambda_"
                delattr(config, config_key)

        config.env_config["use_safe_policy_actions"] = True
        use_learned_reward = False
        if use_learned_reward:
            learned_reward_str = "using_learned_reward"
            if "reward_model_width" in config.model["custom_model_config"]:
                learned_reward_str += "_w" + str(
                    config.model["custom_model_config"]["reward_model_width"]
                )
            if "reward_model_depth" in config.model["custom_model_config"]:
                learned_reward_str += "_d" + str(
                    config.model["custom_model_config"]["reward_model_depth"]
                )
            experiment_parts.append(learned_reward_str)
            reward_model_checkpoint = ""
            if env_to_run == "tomato":
                config.env = "tomato_env_multiagent"
            elif env_to_run == "glucose":
                config.env = "glucose_env_multiagent"

            if reward_model_checkpoint == "":
                _log.error(
                    "Please specify a valid checkpoint from which a reward model can be loaded!"
                )
                assert False

            assert (
                config.env_config["reward_fun"] == "proxy"
            ), "The learned reward function replaces the proxy reward."
            wrapper_env_config = {
                "env": config.env,
                "env_config": config.env_config,
                "reward_fn_checkpoint": reward_model_checkpoint,
            }
            config.env_config = wrapper_env_config
            config.env = "learned_reward_wrapper"

        safe_policy_action_dist_input_info_key = None
        safe_policy_action_log_std = -3
        categorical_eps = 0.9
        config.update_from_dict(
            {
                "safe_policy_action_dist_input_info_key": safe_policy_action_dist_input_info_key,
                "safe_policy_action_log_std": safe_policy_action_log_std,
                "categorical_eps": categorical_eps,
                "multiagent": {
                    "policies": {"safe_policy0": PolicySpec()},
                    "policy_mapping_fn": lambda agent_id, **kwargs: "safe_policy0",
                },
            }
        )

    elif exp_algo == "RewardAlgorithm":
        restore_default_params(config, env_to_run)
        AlgorithmClass = LearnedRewardAlgorithm  # noqa: F841
        for config_key in list(config.keys()):
            if config_key not in LearnedRewardAlgorithm.get_default_config():
                _log.info(f"removing {config_key} from config")
                if config_key == "lambda":
                    config_key = "lambda_"
                delattr(config, config_key)

        num_rollout_workers = 0
        input = ""
        gamma = 1.0
        noise_prob = 0.0
        action_info_key = []
        rew_clip = 50
        obs_normalization_func = None
        if env_to_run == "tomato":
            config.env = "tomato_env_multiagent"
        elif env_to_run == "glucose":
            config.env = "glucose_env_multiagent"
            obs_normalization_func = normalize_obs

        max_seq_len = 20
        reward_model_width = 32
        reward_model_depth = 2

        custom_model_config: RewardModelConfig = {
            "reward_model_depth": reward_model_depth,
            "reward_model_width": reward_model_width,
            "normalize_obs": obs_normalization_func,
        }
        model_config = {
            "max_seq_len": max_seq_len,
            "custom_model": "reward_model",
            "custom_model_config": custom_model_config,
        }
        config.model = model_config

        model_string = (
            "model_" + str(reward_model_width) + "-by-" + str(reward_model_depth)
        )
        experiment_parts.append(model_string)

        config.update_from_dict(
            {
                "num_rollout_workers": num_rollout_workers,
                "input": input,
                "gamma": gamma,
                "noise_prob": noise_prob,
                "evaluation_num_workers": evaluation_num_workers,
                "evaluation_interval": evaluation_interval,
                "evaluation_duration": evaluation_duration,
                "evaluation_duration_unit": evaluation_duration_unit,
                "evaluation_config": evaluation_config,
                "evaluation_sample_timeout_s": evaluation_sample_timeout_s,
                "action_info_key": action_info_key,
                "rew_clip": rew_clip,
                "multiagent": {
                    "policies": {"safe_policy0": PolicySpec()},
                    "policy_mapping_fn": lambda agent_id, **kwargs: "safe_policy0",
                },
            }
        )

    experiment_parts.append(f"seed_{seed}")
    experiment_name = os.path.sep.join(experiment_parts)  # noqa: F841
    _log.info("Saving experiment results to " + experiment_name)

    ray_init_kwargs = {}  # noqa: F841


def create_multiagent(
    config,
    percent_safe_policy,
    num_safe_policies,
    om_divergence_type,
    om_divergence_coeffs,
    occupancy_measure_kl_target,
    safe_policy_specific_params,
    checkpoint_to_load_policies,
    _log: Logger,
):
    policies: MultiAgentPolicyConfigDict = {}
    safe_policy_ids = [f"safe_policy{str(i)}" for i in range(num_safe_policies)]
    config.safe_policy_ids = safe_policy_ids
    config.om_divergence_type = dict(zip(safe_policy_ids, om_divergence_type))
    config.om_divergence_coeffs = dict(zip(safe_policy_ids, om_divergence_coeffs))
    if occupancy_measure_kl_target:
        config.occupancy_measure_kl_target = dict(
            zip(safe_policy_ids, occupancy_measure_kl_target)
        )

    for i in range(len(safe_policy_ids)):
        policy_name = safe_policy_ids[i]
        policy_config = config.copy()
        # checkpoints are loaded into the safe policies for the however many checkpoints are specified, if they are available
        if checkpoint_to_load_policies is not None and i < len(
            checkpoint_to_load_policies
        ):
            policy_config = load_algorithm_config(checkpoint_to_load_policies[i])
            policy_config.update_from_dict(config.copy())
            # Remove discriminator-specific model config from the checkpoint config
            # to avoid overriding the discriminator model config specified for this
            # experiment.
            custom_model_config = policy_config.model["custom_model_config"]
            for key in [
                "discriminator_width",
                "discriminator_depth",
                "discriminator_state_dim",
                "use_action_for_disc",
                "use_history_for_disc",
                "time_dim",
                "history_range",
            ]:
                if key in custom_model_config:
                    del custom_model_config[key]
                    custom_model_config[key] = config.model["custom_model_config"][key]
        if i in safe_policy_specific_params:
            policy_config = Algorithm.merge_algorithm_configs(
                policy_config,
                safe_policy_specific_params[i],
                _allow_unknown_configs=True,
            )
        policy_config.rollouts(num_rollout_workers=0)
        policy_config.offline_data(input_="sampler")
        policy_config.evaluation(evaluation_num_workers=0)
        policy_config = policy_config.update_from_dict({"__policy_id": policy_name})
        policies[policy_name] = PolicySpec(
            policy_class=ORPOPolicy,
            config=policy_config,
        )

    policies["current"] = PolicySpec(
        policy_class=None,
        config=None,
    )

    num_rollout_workers: int = config.num_rollout_workers
    if num_safe_policies >= 1:
        workers_per_safe_policy = (
            percent_safe_policy * num_rollout_workers
        ) / num_safe_policies
        if workers_per_safe_policy < 1:
            raise ValueError(
                "Too few workers for the number of safe policies. "
                "Increase num_rollout_workers/percent_safe_policy or decrease num_safe_policies."
            )
        if workers_per_safe_policy != int(workers_per_safe_policy):
            _log.warning(
                "Safe policies are not evenly divided among workers "
                f"({workers_per_safe_policy:.1f} workers per safe policy)."
            )

    def policy_mapping_fn(
        agent_id,
        episode,
        worker: RolloutWorker,
        percent_safe_policy=percent_safe_policy,
        safe_policy_ids=safe_policy_ids,
        num_rollout_workers=num_rollout_workers,
        **kwargs,
    ):
        worker_index = worker.worker_index - 1
        if (
            num_safe_policies >= 1
            and worker_index < num_rollout_workers * percent_safe_policy
        ):
            safe_policy_index = worker_index % num_safe_policies
            return safe_policy_ids[safe_policy_index]
        return "current"

    policies_to_train = safe_policy_ids + ["current"]

    return policies, policy_mapping_fn, policies_to_train


@ex.automain
def main(
    config,
    log_dir,
    ray_init_kwargs,
    experiment_name,
    AlgorithmClass: Type[Algorithm],
    exp_algo,
    num_training_iters,
    save_freq,
    checkpoint_to_load_policies: Optional[List[str]],
    policy_ids_to_load: Optional[List[str]],
    policy_id_to_load_current_policy: Optional[str],
    checkpoint_to_load_current_policy: Optional[str],
    checkpoint_path: Optional[str],
    num_cpus: int,
    _log: Logger,
):
    temp_dir = tempfile.mkdtemp()
    # Add timeout configuration
    TRAINING_TIMEOUT = 300  # 5 minutes timeout for training step
    
    def log_system_stats():
        """Log system resource usage"""
        process = psutil.Process()
        memory_info = process.memory_info()
        _log.info(f"Memory usage: {memory_info.rss / 1024 / 1024:.2f} MB")
        _log.info(f"CPU usage: {process.cpu_percent()}%")
        _log.info(f"Number of open files: {len(process.open_files())}")

    def cleanup_workers(workers):
        """Cleanup worker resources"""
        try:
            if workers is not None:
                _log.info("Cleaning up workers...")
                workers.stop()
                _log.info("Workers stopped successfully")
        except Exception as e:
            _log.error(f"Error during worker cleanup: {str(e)}")

    def train_with_timeout(algorithm, timeout):
        """Execute training step with timeout"""
        start_time = time.time()
        try:
            _log.info(f"Starting training step at {start_time}")
            result = algorithm.train()
            end_time = time.time()
            _log.info(f"Training step completed in {end_time - start_time:.2f} seconds")
            return result
        except Exception as e:
            end_time = time.time()
            _log.error(f"Training step failed after {end_time - start_time:.2f} seconds: {str(e)}")
            raise

    os.environ["RAY_AIR_NEW_PERSISTENCE_MODE"] = "0"
    ray.init(
        num_cpus=num_cpus,
        ignore_reinit_error=True,
        include_dashboard=False,
        _temp_dir=temp_dir,
        log_to_driver=True,
        **ray_init_kwargs,
    )
    algorithm = AlgorithmClass(
        config,
        logger_creator=build_logger_creator(log_dir, experiment_name),
    )

    if checkpoint_to_load_policies is not None:
        for i in range(len(checkpoint_to_load_policies)):
            if "ORPO" in exp_algo:
                safe_policy_ids = config.safe_policy_ids
                policy_map = lambda policy_id: safe_policy_ids[i]
            else:
                policy_map = lambda policy_id: DEFAULT_POLICY_ID
            _log.info(
                f"loading policy {policy_map(None)} from {checkpoint_to_load_policies[i]}..."
            )
            load_policies_from_checkpoint(
                checkpoint_to_load_policies[i],
                algorithm,
                policy_ids=(
                    policy_ids_to_load[i] if policy_ids_to_load is not None else None
                ),
                policy_mapping_fn=policy_map,
            )
    if checkpoint_to_load_current_policy is not None:
        _log.info(f"loading policy current from {checkpoint_to_load_current_policy}...")
        load_policies_from_checkpoint(
            checkpoint_to_load_current_policy,
            algorithm,
            policy_ids=policy_id_to_load_current_policy,
            policy_mapping_fn=lambda policy_id: (
                "current" if "ORPO" in exp_algo else DEFAULT_POLICY_ID
            ),
        )

    if checkpoint_path is not None:
        _log.info(f"Restoring checkpoint at {checkpoint_path}")
        algorithm.restore(checkpoint_path)

    result = None
    for iteration in range(num_training_iters):
        _log.info(f"Starting training iteration {algorithm.iteration}")
        log_system_stats()  # Log system stats before training
        result = train_with_timeout(algorithm, TRAINING_TIMEOUT)
        # cleanup_workers(algorithm.workers)
        time.sleep(5)

    checkpoint = algorithm.save()
    # convert_to_msgpack_checkpoint(checkpoint, algorithm.logdir, AlgorithmClass)
    _log.info(f"Saved final checkpoint to {checkpoint}")

    if "glucose" in config["env"]:
        max_env_steps = 5760 * 40
    elif "tomato" in config["env"]:
        max_env_steps = 100*20
    elif "pandemic" in config["env"]:
        max_env_steps = 193*80
    elif "traffic" in config["env"]:
        max_env_steps = 4000*20

    if len(config["safe_policy_ids"]) > 0:
        #collect more trajectories if we are using a reference policy to regularize to; synchronous_parallel_sample will split the # of collected trajectories across all policies
        max_env_steps = max_env_steps * 2
    
    if unique_id_state.state["is_ensemble"]:
        #collect even more trajectories if we are using ensemble based methods that select pairs with the highest variance
        max_env_steps = max_env_steps * 2
        if len(config["safe_policy_ids"]) == 0:
            max_env_steps = max_env_steps * 2

    # For glucose env, run evaluation for 40 episodes (5760 steps each)
    eval_batch = synchronous_parallel_sample(
        worker_set=algorithm.workers, max_env_steps=max_env_steps
    )

    algorithm.stop()

    return result,checkpoint,eval_batch

