from typing import Optional
import numpy as np

def get_iql_train_configs(env_name, algo, train_type, random_seed):
    if env_name == 'hopper-medium-v2':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "hopper"
            level: str = "medium"
            env: str = f"{env_sim_name}-{level}-v2"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_freq: int = int(2e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e4)
            n_episodes: int = 20  # How many episodes run during evaluation
            max_timesteps: int = int(1e6)
            offline_iterations: int = int(300000)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online
            checkpoints_path: Optional[str] = f'checkpoints/IQL_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/IQL_offline-{seed}/{env}/checkpoint_998000.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            # IQL
            actor_dropout: float = 0.0  # Dropout in actor network
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 20000  # Replay buffer size
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            batch_size: int = 256  # Batch size for all networks
            discount: float = 0.99  # Discount factor
            tau: float = 0.001  # Target network update rate
            beta: float = 3.0  # Inverse temperature. Small beta -> BC, big beta -> maximizing Q
            iql_tau: float = 0.7  # Coefficient for asymmetric loss
            expl_noise: float = 0.03  # Std of Gaussian exploration noise
            noise_clip: float = 0.5  # Range to clip noise
            iql_deterministic: bool = True  # Use deterministic actor
            use_off_policy: bool = False  # Use deterministic actor
            use_q_update: bool = True  # Use deterministic actor
            change_lr: bool = True  # Use deterministic actor
            normalize: bool = True  # Normalize states
            normalize_reward: bool = True  # Normalize reward
            vf_lr: float = 3e-4  # V function learning rate
            qf_lr: float = 3e-4  # Critic learning rate
            actor_lr: float = 3e-4  # Actor learning rate
            vae_lr: float = 5e-4  # VAE learning rate
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'hopper-medium-expert-v2':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "hopper"
            level: str = "medium-expert"
            env: str = f"{env_sim_name}-{level}-v2"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_freq: int = int(2e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e4)
            n_episodes: int = 20  # How many episodes run during evaluation
            max_timesteps: int = int(1e6)
            offline_iterations: int = int(300000)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/IQL_{train_type}/{env}'  # Save path
            # load_model: str = f"../offline/checkpoints/{env}/checkpoint_999999.pt"  # Model load file name, "" doesn't load
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/IQL_offline-{seed}/{env}/checkpoint_998000.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            # IQL
            actor_dropout: float = 0.0  # Dropout in actor network
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 20000  # Replay buffer size
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            batch_size: int = 256  # Batch size for all networks
            discount: float = 0.99  # Discount factor
            tau: float = 0.005  # Target network update rate
            beta: float = 6.0  # Inverse temperature. Small beta -> BC, big beta -> maximizing Q
            iql_tau: float = 0.5  # Coefficient for asymmetric loss
            expl_noise: float = 0.03  # Std of Gaussian exploration noise
            noise_clip: float = 0.5  # Range to clip noise
            iql_deterministic: bool = False  # Use deterministic actor
            use_off_policy: bool = False  # Use deterministic actor
            use_q_update: bool = True  # Use deterministic actor
            change_lr: bool = True  # Use deterministic actor
            normalize: bool = True  # Normalize states
            normalize_reward: bool = False  # Normalize reward
            vf_lr: float = 3e-4  # V function learning rate
            qf_lr: float = 3e-4  # Critic learning rate
            actor_lr: float = 3e-4  # Actor learning rate
            vae_lr: float = 5e-4  # VAE learning rate
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'hopper-medium-replay-v2':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "hopper"
            level: str = "medium-replay"
            env: str = f"{env_sim_name}-{level}-v2"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_freq: int = int(2e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e4)
            n_episodes: int = 100  # How many episodes run during evaluation
            max_timesteps: int = int(1e6)
            offline_iterations: int = int(300000)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/IQL_{train_type}/{env}' # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/IQL_offline-{seed}/{env}/checkpoint_998000.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            # IQL
            actor_dropout: float = 0.0  # Dropout in actor network
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 20000  # Replay buffer size
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            batch_size: int = 256  # Batch size for all networks
            discount: float = 0.99  # Discount factor
            tau: float = 0.001  # Target network update rate
            beta: float = 3.0  # Inverse temperature. Small beta -> BC, big beta -> maximizing Q
            iql_tau: float = 0.7  # Coefficient for asymmetric loss
            expl_noise: float = 0.03  # Std of Gaussian exploration noise
            noise_clip: float = 0.5  # Range to clip noise
            iql_deterministic: bool = True  # Use deterministic actor
            use_off_policy: bool = False  # Use deterministic actor
            use_q_update: bool = True  # Use deterministic actor
            change_lr: bool = True  # Use deterministic actor
            normalize: bool = True  # Normalize states
            normalize_reward: bool = True  # Normalize reward
            vf_lr: float = 3e-4  # V function learning rate
            qf_lr: float = 3e-4  # Critic learning rate
            actor_lr: float = 3e-4  # Actor learning rate
            vae_lr: float = 5e-4  # VAE learning rate
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'halfcheetah-medium-v2':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "halfcheetah"
            level: str = "medium"
            env: str = f"{env_sim_name}-{level}-v2"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_freq: int = int(2e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e4)
            n_episodes: int = 100  # How many episodes run during evaluation
            max_timesteps: int = int(1e6)
            offline_iterations: int = int(300000)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/IQL_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/IQL_offline-{seed}/{env}/checkpoint_998000.pt"
            # load_model: str = '' if 'offline' in train_type else f"checkpoints/IQL_incremental-{seed}/{env}/checkpoint_490000.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            # IQL
            actor_dropout: float = 0.0  # Dropout in actor network
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 20_000  # Replay buffer size
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            batch_size: int = 256  # Batch size for all networks
            discount: float = 0.99  # Discount factor
            tau: float = 0.005  # Target network update rate
            beta: float = 3.0  # Inverse temperature. Small beta -> BC, big beta -> maximizing Q
            iql_tau: float = 0.7  # Coefficient for asymmetric loss
            expl_noise: float = 0.03  # Std of Gaussian exploration noise
            noise_clip: float = 0.5  # Range to clip noise
            iql_deterministic: bool = False  # Use deterministic actor
            use_off_policy: bool = False  # Use deterministic actor
            use_q_update: bool = True  # Use deterministic actor
            change_lr: bool = True  # Use deterministic actor
            normalize: bool = True  # Normalize states
            normalize_reward: bool = False  # Normalize reward
            vf_lr: float = 3e-4  # V function learning rate
            qf_lr: float = 3e-4  # Critic learning rate
            actor_lr: float = 3e-4  # Actor learning rate
            vae_lr: float = 5e-4  # VAE learning rate
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'halfcheetah-expert-v2':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "halfcheetah"
            level: str = "expert"
            env: str = f"{env_sim_name}-{level}-v2"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_freq: int = int(2e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e4)
            n_episodes: int = 20  # How many episodes run during evaluation
            max_timesteps: int = int(1e6)
            offline_iterations: int = int(300000)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/IQL_{train_type}/{env}' # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/IQL_offline-{seed}/{env}/checkpoint_998000.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            em_file: str = f"embedding_checkpoints/{env}/checkpoint_99999.pt"  # Model load file name, "" doesn't load
            # IQL
            actor_dropout: float = 0.0  # Dropout in actor network
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 20000  # Replay buffer size
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            batch_size: int = 256  # Batch size for all networks
            discount: float = 0.99  # Discount factor
            tau: float = 0.005  # Target network update rate
            beta: float = 0.5  # Inverse temperature. Small beta -> BC, big beta -> maximizing Q
            iql_tau: float = 0.7  # Coefficient for asymmetric loss
            expl_noise: float = 0.03  # Std of Gaussian exploration noise
            noise_clip: float = 0.5  # Range to clip noise
            iql_deterministic: bool = False  # Use deterministic actor
            use_off_policy: bool = False  # Use deterministic actor
            use_q_update: bool = True  # Use deterministic actor
            change_lr: bool = True  # Use deterministic actor
            normalize: bool = True  # Normalize states
            normalize_reward: bool = False  # Normalize reward
            vf_lr: float = 3e-4  # V function learning rate
            qf_lr: float = 3e-4  # Critic learning rate
            actor_lr: float = 3e-4  # Actor learning rate
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'halfcheetah-medium-expert-v2':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "halfcheetah"
            level: str = "medium-expert"
            env: str = f"{env_sim_name}-{level}-v2"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_freq: int = int(2e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e4)
            n_episodes: int = 100  # How many episodes run during evaluation
            max_timesteps: int = int(1e6)
            offline_iterations: int = int(300000)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/IQL_{train_type}/{env}' # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/IQL_offline-{seed}/{env}/checkpoint_998000.pt"
            # load_model: str = '' if 'offline' in train_type else f"checkpoints/IQL_incremental-{seed}/{env}/checkpoint_490000.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            # vae_file: str = f"../../vae_checkpoints/{env}/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            em_file: str = f"embedding_checkpoints/{env}/checkpoint_99999.pt"  # Model load file name, "" doesn't load
            # IQL
            actor_dropout: float = 0.0  # Dropout in actor network
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 20000  # Replay buffer size
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            batch_size: int = 256  # Batch size for all networks
            discount: float = 0.99  # Discount factor
            tau: float = 0.005  # Target network update rate
            beta: float = 3.0  # Inverse temperature. Small beta -> BC, big beta -> maximizing Q
            iql_tau: float = 0.9  # Coefficient for asymmetric loss
            expl_noise: float = 0.03  # Std of Gaussian exploration noise
            noise_clip: float = 0.5  # Range to clip noise
            iql_deterministic: bool = False  # Use deterministic actor
            use_off_policy: bool = False  # Use deterministic actor
            use_q_update: bool = True  # Use deterministic actor
            change_lr: bool = True  # Use deterministic actor
            normalize: bool = True  # Normalize states
            normalize_reward: bool = False  # Normalize reward
            vf_lr: float = 3e-4  # V function learning rate
            qf_lr: float = 3e-4  # Critic learning rate
            actor_lr: float = 3e-4  # Actor learning rate
            vae_lr: float = 5e-4  # VAE learning rate
            vae_lr: float = 5e-4  # VAE learning rate
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'halfcheetah-medium-replay-v2':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "halfcheetah"
            level: str = "medium-replay"
            env: str = f"{env_sim_name}-{level}-v2"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_freq: int = int(2e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e4)
            n_episodes: int = 20  # How many episodes run during evaluation
            max_timesteps: int = int(1e6)
            offline_iterations: int = int(300000)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/IQL_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/IQL_offline-{seed}/{env}/checkpoint_998000.pt"
            # load_model: str = '' if 'offline' in train_type else f"checkpoints/IQL_incremental-{seed}/{env}/checkpoint_490000.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            em_file: str = f"embedding_checkpoints/{env}/checkpoint_199999.pt"  # Model load file name, "" doesn't load
            # IQL
            actor_dropout: float = 0.0  # Dropout in actor network
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 20_000  # Replay buffer size
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            batch_size: int = 256  # Batch size for all networks
            discount: float = 0.99  # Discount factor
            tau: float = 0.005  # Target network update rate
            beta: float = 0.5  # Inverse temperature. Small beta -> BC, big beta -> maximizing Q
            iql_tau: float = 0.7  # Coefficient for asymmetric loss
            expl_noise: float = 0.03  # Std of Gaussian exploration noise
            noise_clip: float = 0.5  # Range to clip noise
            iql_deterministic: bool = False  # Use deterministic actor
            use_off_policy: bool = False  # Use deterministic actor
            use_q_update: bool = True  # Use deterministic actor
            change_lr: bool = True  # Use deterministic actor
            normalize: bool = True  # Normalize states
            normalize_reward: bool = False  # Normalize reward
            vf_lr: float = 3e-4  # V function learning rate
            qf_lr: float = 3e-4  # Critic learning rate
            actor_lr: float = 3e-4  # Actor learning rate
            vae_lr: float = 5e-4  # VAE learning rate
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'walker2d-medium-v2':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "walker2d"
            level: str = "medium"
            env: str = f"{env_sim_name}-{level}-v2"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_freq: int = int(2e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e4)
            n_episodes: int = 100  # How many episodes run during evaluation
            max_timesteps: int = int(1e6)
            offline_iterations: int = int(300000)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/IQL_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/IQL_offline-{seed}/{env}/checkpoint_998000.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            # IQL
            actor_dropout: float = 0.0  # Dropout in actor network
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 20000  # Replay buffer size
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            batch_size: int = 256  # Batch size for all networks
            discount: float = 0.99  # Discount factor
            tau: float = 0.005  # Target network update rate
            beta: float = 3.0  # Inverse temperature. Small beta -> BC, big beta -> maximizing Q
            iql_tau: float = 0.7  # Coefficient for asymmetric loss
            expl_noise: float = 0.03  # Std of Gaussian exploration noise
            noise_clip: float = 0.5  # Range to clip noise
            iql_deterministic: bool = False  # Use deterministic actor
            use_off_policy: bool = False  # Use deterministic actor
            use_q_update: bool = True  # Use deterministic actor
            change_lr: bool = True  # Use deterministic actor
            normalize: bool = True  # Normalize states
            normalize_reward: bool = False  # Normalize reward
            vf_lr: float = 3e-4  # V function learning rate
            qf_lr: float = 3e-4  # Critic learning rate
            actor_lr: float = 3e-4  # Actor learning rate
            vae_lr: float = 5e-4  # VAE learning rate
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'walker2d-medium-expert-v2':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "walker2d"
            level: str = "medium-expert"
            env: str = f"{env_sim_name}-{level}-v2"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_freq: int = int(2e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e4)
            n_episodes: int = 100  # How many episodes run during evaluation
            max_timesteps: int = int(1e6)
            offline_iterations: int = int(300000)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/IQL_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/IQL_offline-{seed}/{env}/checkpoint_998000.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            # IQL
            actor_dropout: float = 0.0  # Dropout in actor network
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 20000  # Replay buffer size
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            batch_size: int = 256  # Batch size for all networks
            discount: float = 0.99  # Discount factor
            tau: float = 0.005  # Target network update rate
            beta: float = 3.0  # Inverse temperature. Small beta -> BC, big beta -> maximizing Q
            iql_tau: float = 0.7  # Coefficient for asymmetric loss
            expl_noise: float = 0.03  # Std of Gaussian exploration noise
            noise_clip: float = 0.5  # Range to clip noise
            iql_deterministic: bool = False  # Use deterministic actor
            use_off_policy: bool = False  # Use deterministic actor
            use_q_update: bool = True  # Use deterministic actor
            change_lr: bool = True  # Use deterministic actor
            normalize: bool = True  # Normalize states
            normalize_reward: bool = False  # Normalize reward
            vf_lr: float = 3e-4  # V function learning rate
            qf_lr: float = 3e-4  # Critic learning rate
            actor_lr: float = 3e-4  # Actor learning rate
            vae_lr: float = 5e-4  # VAE learning rate
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'walker2d-medium-replay-v2':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "walker2d"
            level: str = "medium-replay"
            env: str = f"{env_sim_name}-{level}-v2"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_freq: int = int(2e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e4)
            n_episodes: int = 100  # How many episodes run during evaluation
            max_timesteps: int = int(1e6)
            offline_iterations: int = int(300000)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/IQL_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/IQL_offline-{seed}/{env}/checkpoint_998000.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            # IQL
            actor_dropout: float = 0.0  # Dropout in actor network
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 20000  # Replay buffer size
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            batch_size: int = 256  # Batch size for all networks
            discount: float = 0.99  # Discount factor
            tau: float = 0.005  # Target network update rate
            beta: float = 3.0  # Inverse temperature. Small beta -> BC, big beta -> maximizing Q
            iql_tau: float = 0.7  # Coefficient for asymmetric loss
            expl_noise: float = 0.03  # Std of Gaussian exploration noise
            noise_clip: float = 0.5  # Range to clip noise
            iql_deterministic: bool = False  # Use deterministic actor
            use_off_policy: bool = False  # Use deterministic actor
            use_q_update: bool = True  # Use deterministic actor
            change_lr: bool = True  # Use deterministic actor
            normalize: bool = True  # Normalize states
            normalize_reward: bool = False  # Normalize reward
            vf_lr: float = 3e-4  # V function learning rate
            qf_lr: float = 3e-4  # Critic learning rate
            actor_lr: float = 3e-4  # Actor learning rate
            vae_lr: float = 5e-4  # VAE learning rate
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'antmaze-large-play-v2':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "antmaze"
            level: str = "large-play"
            env: str = f"{env_sim_name}-{level}-v2"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_freq: int = int(5e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e4)
            n_episodes: int = 100  # How many episodes run during evaluation
            max_timesteps: int = int(4e5)
            offline_iterations: int = int(300000)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/IQL_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/IQL_offline-{seed}/{env}/checkpoint_995000.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            # IQL
            actor_dropout: float = 0.0  # Dropout in actor network
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 20000  # Replay buffer size
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            batch_size: int = 256  # Batch size for all networks
            discount: float = 0.99  # Discount factor
            tau: float = 0.005  # Target network update rate
            beta: float = 10.0  # Inverse temperature. Small beta -> BC, big beta -> maximizing Q
            iql_tau: float = 0.9  # Coefficient for asymmetric loss
            expl_noise: float = 0.03  # Std of Gaussian exploration noise
            noise_clip: float = 0.5  # Range to clip noise
            iql_deterministic: bool = False  # Use deterministic actor
            use_off_policy: bool = False  # Use deterministic actor
            use_q_update: bool = True  # Use deterministic actor
            change_lr: bool = True  # Use deterministic actor
            normalize: bool = True  # Normalize states
            normalize_reward: bool = True  # Normalize reward
            vf_lr: float = 3e-4  # V function learning rate
            qf_lr: float = 3e-4  # Critic learning rate
            actor_lr: float = 3e-4  # Actor learning rate
            vae_lr: float = 5e-4  # VAE learning rate
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'antmaze-large-diverse-v2':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "antmaze"
            level: str = "large-diverse"
            env: str = f"{env_sim_name}-{level}-v2"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_freq: int = int(5e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e4)
            n_episodes: int = 100  # How many episodes run during evaluation
            max_timesteps: int = int(4e5)
            offline_iterations: int = int(300000)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/IQL_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/IQL_offline-{seed}/{env}/checkpoint_995000.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            # IQL
            actor_dropout: float = 0.0  # Dropout in actor network
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 50000  # Replay buffer size
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            batch_size: int = 256  # Batch size for all networks
            discount: float = 0.99  # Discount factor
            tau: float = 0.005  # Target network update rate
            beta: float = 10.0  # Inverse temperature. Small beta -> BC, big beta -> maximizing Q
            iql_tau: float = 0.9  # Coefficient for asymmetric loss
            expl_noise: float = 0.03  # Std of Gaussian exploration noise
            noise_clip: float = 0.5  # Range to clip noise
            iql_deterministic: bool = False  # Use deterministic actor
            use_off_policy: bool = False  # Use deterministic actor
            use_q_update: bool = True  # Use deterministic actor
            change_lr: bool = True  # Use deterministic actor
            normalize: bool = True  # Normalize states
            normalize_reward: bool = True  # Normalize reward
            vf_lr: float = 3e-4  # V function learning rate
            qf_lr: float = 3e-4  # Critic learning rate
            actor_lr: float = 3e-4  # Actor learning rate
            vae_lr: float = 5e-4  # VAE learning rate
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'antmaze-medium-play-v2':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "antmaze"
            level: str = "medium-play"
            env: str = f"{env_sim_name}-{level}-v2"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_freq: int = int(5e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e4)
            n_episodes: int = 100  # How many episodes run during evaluation
            max_timesteps: int = int(4e5)
            offline_iterations: int = int(300000)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/IQL_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/IQL_offline-{seed}/{env}/checkpoint_995000.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            # IQL
            actor_dropout: float = 0.0  # Dropout in actor network
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 50_000  # Replay buffer size
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            batch_size: int = 256  # Batch size for all networks
            discount: float = 0.99  # Discount factor
            tau: float = 0.005  # Target network update rate
            beta: float = 10.0  # Inverse temperature. Small beta -> BC, big beta -> maximizing Q
            iql_tau: float = 0.9  # Coefficient for asymmetric loss
            expl_noise: float = 0.03  # Std of Gaussian exploration noise
            noise_clip: float = 0.5  # Range to clip noise
            iql_deterministic: bool = False  # Use deterministic actor
            use_off_policy: bool = False  # Use deterministic actor
            use_q_update: bool = True  # Use deterministic actor
            change_lr: bool = True  # Use deterministic actor
            normalize: bool = True  # Normalize states
            normalize_reward: bool = True  # Normalize reward
            vf_lr: float = 3e-4  # V function learning rate
            qf_lr: float = 3e-4  # Critic learning rate
            actor_lr: float = 3e-4  # Actor learning rate
            vae_lr: float = 5e-4  # VAE learning rate
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'antmaze-medium-diverse-v2':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "antmaze"
            level: str = "medium-diverse"
            env: str = f"{env_sim_name}-{level}-v2"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_freq: int = int(5e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e4)
            n_episodes: int = 100  # How many episodes run during evaluation
            max_timesteps: int = int(1e6)
            offline_iterations: int = int(300000)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/IQL_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/IQL_offline-{seed}/{env}/checkpoint_995000.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            # IQL
            actor_dropout: float = 0.0  # Dropout in actor network
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 50_000  # Replay buffer size
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            batch_size: int = 256  # Batch size for all networks
            discount: float = 0.99  # Discount factor
            tau: float = 0.005  # Target network update rate
            beta: float = 10.0  # Inverse temperature. Small beta -> BC, big beta -> maximizing Q
            iql_tau: float = 0.9  # Coefficient for asymmetric loss
            expl_noise: float = 0.03  # Std of Gaussian exploration noise
            noise_clip: float = 0.5  # Range to clip noise
            iql_deterministic: bool = False  # Use deterministic actor
            use_off_policy: bool = False  # Use deterministic actor
            use_q_update: bool = True  # Use deterministic actor
            change_lr: bool = True  # Use deterministic actor
            normalize: bool = True  # Normalize states
            normalize_reward: bool = True  # Normalize reward
            vf_lr: float = 3e-4  # V function learning rate
            qf_lr: float = 3e-4  # Critic learning rate
            actor_lr: float = 3e-4  # Actor learning rate
            vae_lr: float = 5e-4  # VAE learning rate
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'antmaze-umaze-v2':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "antmaze"
            level: str = "umaze"
            env: str = f"{env_sim_name}-{level}-v2"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_freq: int = int(5e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e4)
            n_episodes: int = 100  # How many episodes run during evaluation
            max_timesteps: int = int(1e6)
            offline_iterations: int = int(300000)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/IQL_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/IQL_offline-{seed}/{env}/checkpoint_995000.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            # IQL
            actor_dropout: float = 0.0  # Dropout in actor network
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 50000  # Replay buffer size
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            batch_size: int = 256  # Batch size for all networks
            discount: float = 0.99  # Discount factor
            tau: float = 0.005  # Target network update rate
            beta: float = 10.0  # Inverse temperature. Small beta -> BC, big beta -> maximizing Q
            iql_tau: float = 0.9  # Coefficient for asymmetric loss
            expl_noise: float = 0.03  # Std of Gaussian exploration noise
            noise_clip: float = 0.5  # Range to clip noise
            iql_deterministic: bool = False  # Use deterministic actor
            use_off_policy: bool = False  # Use deterministic actor
            use_q_update: bool = True  # Use deterministic actor
            change_lr: bool = True  # Use deterministic actor
            normalize: bool = True  # Normalize states
            normalize_reward: bool = True  # Normalize reward
            vf_lr: float = 3e-4  # V function learning rate
            qf_lr: float = 3e-4  # Critic learning rate
            actor_lr: float = 3e-4  # Actor learning rate
            vae_lr: float = 5e-4  # VAE learning rate
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'antmaze-umaze-diverse-v2':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "antmaze"
            level: str = "umaze-diverse"
            env: str = f"{env_sim_name}-{level}-v2"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_freq: int = int(5e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e4)
            n_episodes: int = 100  # How many episodes run during evaluation
            max_timesteps: int = int(1e6)
            offline_iterations: int = int(300000)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/IQL_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/IQL_offline-{seed}/{env}/checkpoint_995000.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            # IQL
            actor_dropout: float = 0.0  # Dropout in actor network
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 50000  # Replay buffer size
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            batch_size: int = 256  # Batch size for all networks
            discount: float = 0.99  # Discount factor
            tau: float = 0.005  # Target network update rate
            beta: float = 10.0  # Inverse temperature. Small beta -> BC, big beta -> maximizing Q
            iql_tau: float = 0.9  # Coefficient for asymmetric loss
            expl_noise: float = 0.03  # Std of Gaussian exploration noise
            noise_clip: float = 0.5  # Range to clip noise
            iql_deterministic: bool = False  # Use deterministic actor
            use_off_policy: bool = False  # Use deterministic actor
            use_q_update: bool = True  # Use deterministic actor
            change_lr: bool = True  # Use deterministic actor
            normalize: bool = True  # Normalize states
            normalize_reward: bool = True  # Normalize reward
            vf_lr: float = 3e-4  # V function learning rate
            qf_lr: float = 3e-4  # Critic learning rate
            actor_lr: float = 3e-4  # Actor learning rate
            vae_lr: float = 5e-4  # VAE learning rate
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'pen-cloned-v1':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "pen"
            level: str = "cloned"
            env: str = f"{env_sim_name}-{level}-v1"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_freq: int = int(5e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e4)
            n_episodes: int = 10  # How many episodes run during evaluation
            max_timesteps: int = int(5e5)
            offline_iterations: int = int(300000)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/IQL_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/IQL_offline-{seed}/{env}/checkpoint_995000.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_19999.pt"  # Model load file name, "" doesn't load
            # IQL
            actor_dropout: float = 0.1 if "offline" in train_type else 0.0  # Dropout in actor network
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 20000  # Replay buffer size
            vae_hidden_dim: int = 100  # hidden dimension of vae network
            batch_size: int = 256  # Batch size for all networks
            discount: float = 0.99  # Discount factor
            tau: float = 0.005  # Target network update rate
            beta: float = 3.0  # Inverse temperature. Small beta -> BC, big beta -> maximizing Q
            iql_tau: float = 0.8  # Coefficient for asymmetric loss
            expl_noise: float = 0.03  # Std of Gaussian exploration noise
            noise_clip: float = 0.5  # Range to clip noise
            iql_deterministic: bool = False  # Use deterministic actor
            use_off_policy: bool = False  # Use deterministic actor
            use_q_update: bool = True  # Use deterministic actor
            change_lr: bool = True  # Use deterministic actor
            normalize: bool = True  # Normalize states
            normalize_reward: bool = False  # Normalize reward
            vf_lr: float = 3e-4  # V function learning rate
            qf_lr: float = 3e-4  # Critic learning rate
            actor_lr: float = 3e-4  # Actor learning rate
            vae_lr: float = 5e-4  # VAE learning rate
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'relocate-human-v1':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "relocate"
            level: str = "human"
            env: str = f"{env_sim_name}-{level}-v1"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_freq: int = int(5e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e4)
            n_episodes: int = 10  # How many episodes run during evaluation
            max_timesteps: int = int(1e6)
            offline_iterations: int = int(300000)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/IQL_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/IQL_offline-{seed}/{env}/checkpoint_995000.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_19999.pt"  # Model load file name, "" doesn't load
            # IQL
            actor_dropout: float = 0.1 if "offline" in train_type else 0.0  # Dropout in actor network
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 20000  # Replay buffer size
            vae_hidden_dim: int = 100  # hidden dimension of vae network
            batch_size: int = 256  # Batch size for all networks
            discount: float = 0.99  # Discount factor
            tau: float = 0.005  # Target network update rate
            beta: float = 3.0  # Inverse temperature. Small beta -> BC, big beta -> maximizing Q
            iql_tau: float = 0.8  # Coefficient for asymmetric loss
            expl_noise: float = 0.03  # Std of Gaussian exploration noise
            noise_clip: float = 0.5  # Range to clip noise
            iql_deterministic: bool = False  # Use deterministic actor
            use_off_policy: bool = False  # Use deterministic actor
            use_q_update: bool = True  # Use deterministic actor
            change_lr: bool = True  # Use deterministic actor
            normalize: bool = True  # Normalize states
            normalize_reward: bool = False  # Normalize reward
            vf_lr: float = 3e-4  # V function learning rate
            qf_lr: float = 3e-4  # Critic learning rate
            actor_lr: float = 3e-4  # Actor learning rate
            vae_lr: float = 5e-4  # VAE learning rate
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'relocate-cloned-v1':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "relocate"
            level: str = "cloned"
            env: str = f"{env_sim_name}-{level}-v1"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_freq: int = int(5e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e4)
            n_episodes: int = 10  # How many episodes run during evaluation
            max_timesteps: int = int(5e5)
            offline_iterations: int = int(300000)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/IQL_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/IQL_offline-{seed}/{env}/checkpoint_995000.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_19999.pt"  # Model load file name, "" doesn't load
            # IQL
            actor_dropout: float = 0.1 if "offline" in train_type else 0.0  # Dropout in actor network
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 20000  # Replay buffer size
            vae_hidden_dim: int = 100  # hidden dimension of vae network
            batch_size: int = 256  # Batch size for all networks
            discount: float = 0.99  # Discount factor
            tau: float = 0.005  # Target network update rate
            beta: float = 3.0  # Inverse temperature. Small beta -> BC, big beta -> maximizing Q
            iql_tau: float = 0.8  # Coefficient for asymmetric loss
            expl_noise: float = 0.03  # Std of Gaussian exploration noise
            noise_clip: float = 0.5  # Range to clip noise
            iql_deterministic: bool = False  # Use deterministic actor
            use_off_policy: bool = False  # Use deterministic actor
            use_q_update: bool = True  # Use deterministic actor
            change_lr: bool = True  # Use deterministic actor
            normalize: bool = True  # Normalize states
            normalize_reward: bool = False  # Normalize reward
            vf_lr: float = 3e-4  # V function learning rate
            qf_lr: float = 3e-4  # Critic learning rate
            actor_lr: float = 3e-4  # Actor learning rate
            vae_lr: float = 5e-4  # VAE learning rate
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'door-human-v1':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "door"
            level: str = "human"
            env: str = f"{env_sim_name}-{level}-v1"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_freq: int = int(5e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e4)
            n_episodes: int = 10  # How many episodes run during evaluation
            max_timesteps: int = int(1e6)
            offline_iterations: int = int(300000)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/IQL_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/IQL_offline-{seed}/{env}/checkpoint_995000.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_19999.pt"  # Model load file name, "" doesn't load
            # IQL
            actor_dropout: float = 0.1 if "offline" in train_type else 0.0  # Dropout in actor network
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 20000  # Replay buffer size
            vae_hidden_dim: int = 100  # hidden dimension of vae network
            batch_size: int = 256  # Batch size for all networks
            discount: float = 0.99  # Discount factor
            tau: float = 0.005  # Target network update rate
            beta: float = 3.0  # Inverse temperature. Small beta -> BC, big beta -> maximizing Q
            iql_tau: float = 0.8  # Coefficient for asymmetric loss
            expl_noise: float = 0.03  # Std of Gaussian exploration noise
            noise_clip: float = 0.5  # Range to clip noise
            iql_deterministic: bool = False  # Use deterministic actor
            use_off_policy: bool = False  # Use deterministic actor
            use_q_update: bool = True  # Use deterministic actor
            change_lr: bool = True  # Use deterministic actor
            normalize: bool = True  # Normalize states
            normalize_reward: bool = False  # Normalize reward
            vf_lr: float = 3e-4  # V function learning rate
            qf_lr: float = 3e-4  # Critic learning rate
            actor_lr: float = 3e-4  # Actor learning rate
            vae_lr: float = 5e-4  # VAE learning rate
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'door-cloned-v1':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "door"
            level: str = "cloned"
            env: str = f"{env_sim_name}-{level}-v1"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_freq: int = int(5e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e4)
            n_episodes: int = 10  # How many episodes run during evaluation
            max_timesteps: int = int(5e5)
            offline_iterations: int = int(300000)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/IQL_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/IQL_offline-{seed}/{env}/checkpoint_995000.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_19999.pt"  # Model load file name, "" doesn't load
            # IQL
            actor_dropout: float = 0.1 if "offline" in train_type else 0.0  # Dropout in actor network
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 20000  # Replay buffer size
            vae_hidden_dim: int = 100  # hidden dimension of vae network
            batch_size: int = 256  # Batch size for all networks
            discount: float = 0.99  # Discount factor
            tau: float = 0.005  # Target network update rate
            beta: float = 3.0  # Inverse temperature. Small beta -> BC, big beta -> maximizing Q
            iql_tau: float = 0.8  # Coefficient for asymmetric loss
            expl_noise: float = 0.03  # Std of Gaussian exploration noise
            noise_clip: float = 0.5  # Range to clip noise
            iql_deterministic: bool = False  # Use deterministic actor
            use_off_policy: bool = False  # Use deterministic actor
            use_q_update: bool = True  # Use deterministic actor
            change_lr: bool = True  # Use deterministic actor
            normalize: bool = True  # Normalize states
            normalize_reward: bool = False  # Normalize reward
            vf_lr: float = 3e-4  # V function learning rate
            qf_lr: float = 3e-4  # Critic learning rate
            actor_lr: float = 3e-4  # Actor learning rate
            vae_lr: float = 5e-4  # VAE learning rate
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    else:
        raise NotImplementedError
    return TrainConfig


def get_awac_train_configs(env_name, algo, train_type, random_seed):
    if env_name == 'halfcheetah-medium-expert-v2':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "halfcheetah"
            level: str = "medium-expert"
            env: str = f"{env_sim_name}-{level}-v2"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_frequency: int = int(1e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e5)
            n_test_episodes: int = 10  # How many episodes run during evaluation
            offline_iterations: int = int(1e6)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/AWAC_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/AWAC_offline-{seed}/{env}/checkpoint_999999.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            em_file: str = f"embedding_checkpoints/{env}/checkpoint_199999.pt"  # Model load file name, "" doesn't load
            # CQL
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 20000  # Replay buffer size
            normalize_reward: bool = False
            batch_size: int = 256  # Batch size for all networks
            max_timesteps: int = int(1e6)
            discount: float = 0.99  # Discount factor
            awac_lambda: float = 0.3333
            deterministic_torch: bool = True
            hidden_dim: int = 256
            learning_rate: float = 0.0003
            gamma: float = 0.99
            num_train_ops: int = 1000000
            # seed: 42
            tau: float = 0.005
            # test_seed: 69
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            latent_dim: int = 1
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'halfcheetah-medium-replay-v2':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "halfcheetah"
            level: str = "medium-replay"
            env: str = f"{env_sim_name}-{level}-v2"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_frequency: int = int(1e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e5)
            n_test_episodes: int = 10  # How many episodes run during evaluation
            offline_iterations: int = int(1e6)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/AWAC_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/AWAC_offline-{seed}/{env}/checkpoint_999999.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            em_file: str = f"embedding_checkpoints/{env}/checkpoint_199999.pt"  # Model load file name, "" doesn't load
            # CQL
            buffer_size: int = 2_000_000 if 'incremental' not in train_type else 20_000  # Replay buffer size
            normalize_reward: bool = False
            batch_size: int = 256  # Batch size for all networks
            max_timesteps: int = int(1e6)
            discount: float = 0.99  # Discount factor
            awac_lambda: float = 0.3333
            deterministic_torch: bool = True
            vae_lr: float = 5e-4  # VAE learning rate
            hidden_dim: int = 256
            learning_rate: float = 0.0003
            gamma: float = 0.99
            num_train_ops: int = 1000000
            # seed: 42
            tau: float = 0.005
            # test_seed: 69
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            latent_dim: int = 1
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'halfcheetah-medium-v2':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "halfcheetah"
            level: str = "medium"
            env: str = f"{env_sim_name}-{level}-v2"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_frequency: int = int(1e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e5)
            n_test_episodes: int = 10  # How many episodes run during evaluation
            offline_iterations: int = int(1e6)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/AWAC_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/AWAC_offline-{seed}/{env}/checkpoint_999999.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            em_file: str = f"embedding_checkpoints/{env}/checkpoint_199999.pt"  # Model load file name, "" doesn't load
            # CQL
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 50000  # Replay buffer size
            normalize_reward: bool = False
            batch_size: int = 256  # Batch size for all networks
            max_timesteps: int = int(1e6)
            discount: float = 0.99  # Discount factor
            awac_lambda: float = 0.3333
            deterministic_torch: bool = True
            hidden_dim: int = 256
            learning_rate: float = 0.0003
            gamma: float = 0.99
            num_train_ops: int = 1000000
            # seed: 42
            tau: float = 0.005
            # test_seed: 69
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            latent_dim: int = 1
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'hopper-medium-expert-v2':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "hopper"
            level: str = "medium-expert"
            env: str = f"{env_sim_name}-{level}-v2"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_frequency: int = int(1e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e5)
            n_test_episodes: int = 10  # How many episodes run during evaluation
            offline_iterations: int = int(1e6)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/AWAC_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/AWAC_offline-{seed}/{env}/checkpoint_999999.pt"
            # vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            vae_file: str = f"../../vae_checkpoints/{env}/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            em_file: str = f"embedding_checkpoints/{env}/checkpoint_199999.pt"  # Model load file name, "" doesn't load
            # CQL
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 50000  # Replay buffer size
            normalize_reward: bool = False
            batch_size: int = 256  # Batch size for all networks
            max_timesteps: int = int(1e6)
            discount: float = 0.99  # Discount factor
            awac_lambda: float = 0.3333
            deterministic_torch: bool = True
            hidden_dim: int = 256
            learning_rate: float = 0.0003
            gamma: float = 0.99
            num_train_ops: int = 1000000
            # seed: 42
            tau: float = 0.005
            # test_seed: 69
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            latent_dim: int = 1
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'hopper-medium-replay-v2':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "hopper"
            level: str = "medium-replay"
            env: str = f"{env_sim_name}-{level}-v2"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_frequency: int = int(1e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e5)
            n_test_episodes: int = 10  # How many episodes run during evaluation
            offline_iterations: int = int(1e6)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/AWAC_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/AWAC_offline-{seed}/{env}/checkpoint_999999.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            em_file: str = f"embedding_checkpoints/{env}/checkpoint_199999.pt"  # Model load file name, "" doesn't load
            # CQL
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 20000  # Replay buffer size
            normalize_reward: bool = False
            batch_size: int = 256  # Batch size for all networks
            max_timesteps: int = int(1e6)
            discount: float = 0.99  # Discount factor
            awac_lambda: float = 0.3333
            deterministic_torch: bool = True
            hidden_dim: int = 256
            learning_rate: float = 0.0003
            gamma: float = 0.99
            num_train_ops: int = 1000000
            # seed: 42
            tau: float = 0.005
            # test_seed: 69
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            latent_dim: int = 1
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'hopper-medium-v2':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "hopper"
            level: str = "medium"
            env: str = f"{env_sim_name}-{level}-v2"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_frequency: int = int(1e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e5)
            n_test_episodes: int = 10  # How many episodes run during evaluation
            offline_iterations: int = int(1e6)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/AWAC_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/AWAC_offline-{seed}/{env}/checkpoint_999999.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            em_file: str = f"embedding_checkpoints/{env}/checkpoint_199999.pt"  # Model load file name, "" doesn't load
            # CQL
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 20000  # Replay buffer size
            normalize_reward: bool = False
            batch_size: int = 256  # Batch size for all networks
            max_timesteps: int = int(1e6)
            discount: float = 0.99  # Discount factor
            awac_lambda: float = 0.3333
            deterministic_torch: bool = True
            hidden_dim: int = 256
            learning_rate: float = 0.0003
            gamma: float = 0.99
            num_train_ops: int = 1000000
            # seed: 42
            tau: float = 0.005
            # test_seed: 69
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            latent_dim: int = 1
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'walker2d-medium-expert-v2':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "walker2d"
            level: str = "medium-expert"
            env: str = f"{env_sim_name}-{level}-v2"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_frequency: int = int(1e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e5)
            n_test_episodes: int = 10  # How many episodes run during evaluation
            offline_iterations: int = int(1e6)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/AWAC_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/AWAC_offline-{seed}/{env}/checkpoint_999999.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            em_file: str = f"embedding_checkpoints/{env}/checkpoint_199999.pt"  # Model load file name, "" doesn't load
            # CQL
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 20000  # Replay buffer size
            normalize_reward: bool = False
            batch_size: int = 256  # Batch size for all networks
            max_timesteps: int = int(1e6)
            discount: float = 0.99  # Discount factor
            awac_lambda: float = 0.3333
            deterministic_torch: bool = True
            hidden_dim: int = 256
            learning_rate: float = 0.0003
            gamma: float = 0.99
            num_train_ops: int = 1000000
            # seed: 42
            tau: float = 0.005
            # test_seed: 69
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            latent_dim: int = 1
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'walker2d-medium-replay-v2':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "walker2d"
            level: str = "medium-replay"
            env: str = f"{env_sim_name}-{level}-v2"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_frequency: int = int(1e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e5)
            n_test_episodes: int = 10  # How many episodes run during evaluation
            offline_iterations: int = int(1e6)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/AWAC_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/AWAC_offline-{seed}/{env}/checkpoint_999999.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            em_file: str = f"embedding_checkpoints/{env}/checkpoint_199999.pt"  # Model load file name, "" doesn't load
            # CQL
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 20000  # Replay buffer size
            normalize_reward: bool = False
            batch_size: int = 256  # Batch size for all networks
            max_timesteps: int = int(1e6)
            discount: float = 0.99  # Discount factor
            awac_lambda: float = 0.3333
            deterministic_torch: bool = True
            hidden_dim: int = 256
            learning_rate: float = 0.0003
            gamma: float = 0.99
            num_train_ops: int = 1000000
            # seed: 42
            tau: float = 0.005
            # test_seed: 69
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            latent_dim: int = 1
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'walker2d-medium-v2':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "walker2d"
            level: str = "medium"
            env: str = f"{env_sim_name}-{level}-v2"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_frequency: int = int(1e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e5)
            n_test_episodes: int = 10  # How many episodes run during evaluation
            offline_iterations: int = int(1e6)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/AWAC_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/AWAC_offline-{seed}/{env}/checkpoint_999999.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            em_file: str = f"embedding_checkpoints/{env}/checkpoint_199999.pt"  # Model load file name, "" doesn't load
            # CQL
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 20000  # Replay buffer size
            normalize_reward: bool = False
            batch_size: int = 256  # Batch size for all networks
            max_timesteps: int = int(1e6)
            discount: float = 0.99  # Discount factor
            awac_lambda: float = 0.3333
            deterministic_torch: bool = True
            hidden_dim: int = 256
            learning_rate: float = 0.0003
            gamma: float = 0.99
            num_train_ops: int = 1000000
            # seed: 42
            tau: float = 0.005
            # test_seed: 69
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            latent_dim: int = 1
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'antmaze-umaze-v2':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "antmaze"
            level: str = "umaze"
            env: str = f"{env_sim_name}-{level}-v2"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_frequency: int = int(1e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e5)
            n_test_episodes: int = 10  # How many episodes run during evaluation
            offline_iterations: int = int(1e6)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/AWAC_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/AWAC_offline-{seed}/{env}/checkpoint_999999.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            em_file: str = f"embedding_checkpoints/{env}/checkpoint_199999.pt"  # Model load file name, "" doesn't load
            # CQL
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 50_000  # Replay buffer size
            normalize_reward: bool = True
            batch_size: int = 256  # Batch size for all networks
            max_timesteps: int = int(1e6)
            discount: float = 0.99  # Discount factor
            awac_lambda: float = 0.1
            deterministic_torch: bool = False
            hidden_dim: int = 256
            learning_rate: float = 0.0003
            vae_lr: float = 5e-4  # VAE learning rate
            gamma: float = 0.99
            num_train_ops: int = 1000000
            # seed: 42
            tau: float = 0.005
            # test_seed: 69
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            latent_dim: int = 1
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'antmaze-umaze-diverse-v2':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "antmaze"
            level: str = "umaze-diverse"
            env: str = f"{env_sim_name}-{level}-v2"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_frequency: int = int(1e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e5)
            n_test_episodes: int = 10  # How many episodes run during evaluation
            offline_iterations: int = int(1e6)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/AWAC_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/AWAC_offline-{seed}/{env}/checkpoint_999999.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            em_file: str = f"embedding_checkpoints/{env}/checkpoint_199999.pt"  # Model load file name, "" doesn't load
            # CQL
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 50_000  # Replay buffer size
            normalize_reward: bool = True
            batch_size: int = 256  # Batch size for all networks
            max_timesteps: int = int(1e6)
            discount: float = 0.99  # Discount factor
            awac_lambda: float = 0.1
            deterministic_torch: bool = False
            hidden_dim: int = 256
            learning_rate: float = 0.0003
            vae_lr: float = 5e-4  # VAE learning rate
            gamma: float = 0.99
            num_train_ops: int = 1000000
            # seed: 42
            tau: float = 0.005
            # test_seed: 69
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            latent_dim: int = 1
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'antmaze-medium-diverse-v2':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "antmaze"
            level: str = "medium-diverse"
            env: str = f"{env_sim_name}-{level}-v2"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_frequency: int = int(1e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e5)
            n_test_episodes: int = 10  # How many episodes run during evaluation
            offline_iterations: int = int(1e6)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/AWAC_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/AWAC_offline-{seed}/{env}/checkpoint_999999.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            em_file: str = f"embedding_checkpoints/{env}/checkpoint_199999.pt"  # Model load file name, "" doesn't load
            # CQL
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 50_000  # Replay buffer size
            normalize_reward: bool = True
            batch_size: int = 256  # Batch size for all networks
            max_timesteps: int = int(4e5)
            vae_lr: float = 5e-4  # VAE learning rate
            discount: float = 0.99  # Discount factor
            awac_lambda: float = 0.1
            deterministic_torch: bool = False
            hidden_dim: int = 256
            learning_rate: float = 0.0003
            gamma: float = 0.99
            num_train_ops: int = 1000000
            # seed: 42
            tau: float = 0.005
            # test_seed: 69
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            latent_dim: int = 1
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'antmaze-medium-play-v2':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "antmaze"
            level: str = "medium-play"
            env: str = f"{env_sim_name}-{level}-v2"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_frequency: int = int(1e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e5)
            n_test_episodes: int = 10  # How many episodes run during evaluation
            offline_iterations: int = int(1e6)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/AWAC_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/AWAC_offline-{seed}/{env}/checkpoint_999999.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            em_file: str = f"embedding_checkpoints/{env}/checkpoint_199999.pt"  # Model load file name, "" doesn't load
            # CQL
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 50_000  # Replay buffer size
            normalize_reward: bool = True
            batch_size: int = 256  # Batch size for all networks
            max_timesteps: int = int(1e6)
            vae_lr: float = 5e-4  # VAE learning rate
            discount: float = 0.99  # Discount factor
            awac_lambda: float = 0.1
            deterministic_torch: bool = False
            hidden_dim: int = 256
            learning_rate: float = 0.0003
            gamma: float = 0.99
            num_train_ops: int = 1000000
            # seed: 42
            tau: float = 0.005
            # test_seed: 69
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            latent_dim: int = 1
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'antmaze-large-diverse-v2':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "antmaze"
            level: str = "large-diverse"
            env: str = f"{env_sim_name}-{level}-v2"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_frequency: int = int(1e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e5)
            n_test_episodes: int = 10  # How many episodes run during evaluation
            offline_iterations: int = int(1e6)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/AWAC_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/AWAC_offline-{seed}/{env}/checkpoint_999999.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            em_file: str = f"embedding_checkpoints/{env}/checkpoint_199999.pt"  # Model load file name, "" doesn't load
            # CQL
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 50_000  # Replay buffer size
            normalize_reward: bool = True
            batch_size: int = 256  # Batch size for all networks
            max_timesteps: int = int(1e6)
            discount: float = 0.99  # Discount factor
            awac_lambda: float = 0.1
            deterministic_torch: bool = False
            hidden_dim: int = 256
            learning_rate: float = 0.0003
            gamma: float = 0.99
            num_train_ops: int = 1000000
            # seed: 42
            tau: float = 0.005
            # test_seed: 69
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            latent_dim: int = 1
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'antmaze-large-play-v2':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "antmaze"
            level: str = "large-play"
            env: str = f"{env_sim_name}-{level}-v2"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_frequency: int = int(1e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e5)
            n_test_episodes: int = 10  # How many episodes run during evaluation
            offline_iterations: int = int(1e6)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/AWAC_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/AWAC_offline-{seed}/{env}/checkpoint_999999.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            em_file: str = f"embedding_checkpoints/{env}/checkpoint_199999.pt"  # Model load file name, "" doesn't load
            # CQL
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 50_000  # Replay buffer size
            normalize_reward: bool = True
            batch_size: int = 256  # Batch size for all networks
            max_timesteps: int = int(1e6)
            discount: float = 0.99  # Discount factor
            awac_lambda: float = 0.1
            deterministic_torch: bool = False
            hidden_dim: int = 256
            learning_rate: float = 0.0003
            vae_lr: float = 5e-4  # VAE learning rate
            gamma: float = 0.99
            num_train_ops: int = 1000000
            # seed: 42
            tau: float = 0.005
            # test_seed: 69
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            latent_dim: int = 1
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'relocate-cloned-v1':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "relocate"
            level: str = "cloned"
            env: str = f"{env_sim_name}-{level}-v1"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_frequency: int = int(1e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e5)
            n_test_episodes: int = 10  # How many episodes run during evaluation
            offline_iterations: int = int(1e6)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/AWAC_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/AWAC_offline-{seed}/{env}/checkpoint_999999.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_19999.pt"  # Model load file name, "" doesn't load
            em_file: str = f"embedding_checkpoints/{env}/checkpoint_199999.pt"  # Model load file name, "" doesn't load
            # CQL
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 20_000  # Replay buffer size
            normalize_reward: bool = True
            batch_size: int = 256  # Batch size for all networks
            max_timesteps: int = int(1e6)
            discount: float = 0.99  # Discount factor
            awac_lambda: float = 0.1
            deterministic_torch: bool = False
            hidden_dim: int = 256
            learning_rate: float = 0.0003
            vae_lr: float = 5e-4  # VAE learning rate
            gamma: float = 0.99
            num_train_ops: int = 1000000
            # seed: 42
            tau: float = 0.005
            # test_seed: 69
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            latent_dim: int = 1
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'relocate-expert-v1':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "relocate"
            level: str = "expert"
            env: str = f"{env_sim_name}-{level}-v1"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_frequency: int = int(1e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e5)
            n_test_episodes: int = 10  # How many episodes run during evaluation
            offline_iterations: int = int(1e6)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/AWAC_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/AWAC_offline-{seed}/{env}/checkpoint_999999.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            em_file: str = f"embedding_checkpoints/{env}/checkpoint_199999.pt"  # Model load file name, "" doesn't load
            # CQL
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 500_000  # Replay buffer size
            normalize_reward: bool = True
            batch_size: int = 256  # Batch size for all networks
            max_timesteps: int = int(1e6)
            discount: float = 0.99  # Discount factor
            awac_lambda: float = 0.1
            deterministic_torch: bool = False
            hidden_dim: int = 256
            learning_rate: float = 0.0003
            gamma: float = 0.99
            num_train_ops: int = 1000000
            # seed: 42
            tau: float = 0.005
            # test_seed: 69
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            latent_dim: int = 1
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'relocate-human-v1':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "relocate"
            level: str = "human"
            env: str = f"{env_sim_name}-{level}-v1"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_frequency: int = int(1e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e5)
            n_test_episodes: int = 10  # How many episodes run during evaluation
            offline_iterations: int = int(1e6)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/AWAC_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/AWAC_offline-{seed}/{env}/checkpoint_999999.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            em_file: str = f"embedding_checkpoints/{env}/checkpoint_199999.pt"  # Model load file name, "" doesn't load
            # CQL
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 500_000  # Replay buffer size
            normalize_reward: bool = True
            batch_size: int = 256  # Batch size for all networks
            max_timesteps: int = int(1e6)
            discount: float = 0.99  # Discount factor
            awac_lambda: float = 0.1
            deterministic_torch: bool = False
            hidden_dim: int = 256
            learning_rate: float = 0.0003
            gamma: float = 0.99
            num_train_ops: int = 1000000
            # seed: 42
            tau: float = 0.005
            # test_seed: 69
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            latent_dim: int = 1
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'pen-cloned-v1':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "pen"
            level: str = "cloned"
            env: str = f"{env_sim_name}-{level}-v1"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_frequency: int = int(1e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e5)
            n_test_episodes: int = 10  # How many episodes run during evaluation
            offline_iterations: int = int(1e6)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/AWAC_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/AWAC_offline-{seed}/{env}/checkpoint_999999.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_19999.pt"  # Model load file name, "" doesn't load
            em_file: str = f"embedding_checkpoints/{env}/checkpoint_199999.pt"  # Model load file name, "" doesn't load
            # CQL
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 20_000  # Replay buffer size
            normalize_reward: bool = True
            batch_size: int = 256  # Batch size for all networks
            max_timesteps: int = int(1e6)
            discount: float = 0.99  # Discount factor
            awac_lambda: float = 0.1
            deterministic_torch: bool = False
            hidden_dim: int = 256
            learning_rate: float = 0.0003
            vae_lr: float = 5e-4  # VAE learning rate
            gamma: float = 0.99
            num_train_ops: int = 1000000
            # seed: 42
            tau: float = 0.005
            # test_seed: 69
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            latent_dim: int = 1
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'pen-human-v1':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "pen"
            level: str = "human"
            env: str = f"{env_sim_name}-{level}-v1"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_frequency: int = int(1e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e5)
            n_test_episodes: int = 10  # How many episodes run during evaluation
            offline_iterations: int = int(1e6)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/AWAC_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/AWAC_offline-{seed}/{env}/checkpoint_999999.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            em_file: str = f"embedding_checkpoints/{env}/checkpoint_199999.pt"  # Model load file name, "" doesn't load
            # CQL
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 500_000  # Replay buffer size
            normalize_reward: bool = True
            batch_size: int = 256  # Batch size for all networks
            max_timesteps: int = int(1e6)
            discount: float = 0.99  # Discount factor
            awac_lambda: float = 0.1
            deterministic_torch: bool = False
            hidden_dim: int = 256
            learning_rate: float = 0.0003
            gamma: float = 0.99
            num_train_ops: int = 1000000
            # seed: 42
            tau: float = 0.005
            # test_seed: 69
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            latent_dim: int = 1
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'pen-expert-v1':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "pen"
            level: str = "expert"
            env: str = f"{env_sim_name}-{level}-v1"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_frequency: int = int(1e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e5)
            n_test_episodes: int = 10  # How many episodes run during evaluation
            offline_iterations: int = int(1e6)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/AWAC_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/AWAC_offline-{seed}/{env}/checkpoint_999999.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            em_file: str = f"embedding_checkpoints/{env}/checkpoint_199999.pt"  # Model load file name, "" doesn't load
            # CQL
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 500_000  # Replay buffer size
            normalize_reward: bool = True
            batch_size: int = 256  # Batch size for all networks
            max_timesteps: int = int(1e6)
            discount: float = 0.99  # Discount factor
            awac_lambda: float = 0.1
            deterministic_torch: bool = False
            hidden_dim: int = 256
            learning_rate: float = 0.0003
            gamma: float = 0.99
            num_train_ops: int = 1000000
            # seed: 42
            tau: float = 0.005
            # test_seed: 69
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            latent_dim: int = 1
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'door-cloned-v1':
        class TrainConfig: 
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "door"
            level: str = "cloned"
            env: str = f"{env_sim_name}-{level}-v1"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_frequency: int = int(1e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e5)
            n_test_episodes: int = 10  # How many episodes run during evaluation
            offline_iterations: int = int(1e6)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/AWAC_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/AWAC_offline-{seed}/{env}/checkpoint_999999.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_19999.pt"  # Model load file name, "" doesn't load
            em_file: str = f"embedding_checkpoints/{env}/checkpoint_199999.pt"  # Model load file name, "" doesn't load
            # CQL
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 20_000  # Replay buffer size
            normalize_reward: bool = True
            batch_size: int = 256  # Batch size for all networks
            max_timesteps: int = int(1e6)
            discount: float = 0.99  # Discount factor
            awac_lambda: float = 0.1
            deterministic_torch: bool = False
            hidden_dim: int = 256
            learning_rate: float = 0.0003
            vae_lr: float = 5e-4  # VAE learning rate
            gamma: float = 0.99
            num_train_ops: int = 1000000
            # seed: 42
            tau: float = 0.005
            # test_seed: 69
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            latent_dim: int = 1
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'door-human-v1':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "door"
            level: str = "human"
            env: str = f"{env_sim_name}-{level}-v1"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_frequency: int = int(1e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e5)
            n_test_episodes: int = 10  # How many episodes run during evaluation
            offline_iterations: int = int(1e6)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/AWAC_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/AWAC_offline-{seed}/{env}/checkpoint_999999.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            em_file: str = f"embedding_checkpoints/{env}/checkpoint_199999.pt"  # Model load file name, "" doesn't load
            # CQL
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 500_000  # Replay buffer size
            normalize_reward: bool = True
            batch_size: int = 256  # Batch size for all networks
            max_timesteps: int = int(1e6)
            discount: float = 0.99  # Discount factor
            awac_lambda: float = 0.1
            deterministic_torch: bool = False
            hidden_dim: int = 256
            learning_rate: float = 0.0003
            gamma: float = 0.99
            num_train_ops: int = 1000000
            # seed: 42
            tau: float = 0.005
            # test_seed: 69
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            latent_dim: int = 1
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'door-expert-v1':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "door"
            level: str = "expert"
            env: str = f"{env_sim_name}-{level}-v1"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_frequency: int = int(1e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e5)
            n_test_episodes: int = 10  # How many episodes run during evaluation
            offline_iterations: int = int(1e6)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/AWAC_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/AWAC_offline-{seed}/{env}/checkpoint_999999.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            em_file: str = f"embedding_checkpoints/{env}/checkpoint_199999.pt"  # Model load file name, "" doesn't load
            # CQL
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 500_000  # Replay buffer size
            normalize_reward: bool = True
            batch_size: int = 256  # Batch size for all networks
            max_timesteps: int = int(1e6)
            discount: float = 0.99  # Discount factor
            awac_lambda: float = 0.1
            deterministic_torch: bool = False
            hidden_dim: int = 256
            learning_rate: float = 0.0003
            gamma: float = 0.99
            num_train_ops: int = 1000000
            # seed: 42
            tau: float = 0.005
            # test_seed: 69
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            latent_dim: int = 1
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'hammer-cloned-v1':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "hammer"
            level: str = "cloned"
            env: str = f"{env_sim_name}-{level}-v1"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_frequency: int = int(1e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e5)
            n_test_episodes: int = 10  # How many episodes run during evaluation
            offline_iterations: int = int(1e6)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/AWAC_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/AWAC_offline-{seed}/{env}/checkpoint_999999.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            em_file: str = f"embedding_checkpoints/{env}/checkpoint_199999.pt"  # Model load file name, "" doesn't load
            # CQL
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 500_000  # Replay buffer size
            normalize_reward: bool = True
            batch_size: int = 256  # Batch size for all networks
            max_timesteps: int = int(1e6)
            discount: float = 0.99  # Discount factor
            awac_lambda: float = 0.1
            deterministic_torch: bool = False
            hidden_dim: int = 256
            learning_rate: float = 0.0003
            gamma: float = 0.99
            num_train_ops: int = 1000000
            # seed: 42
            tau: float = 0.005
            # test_seed: 69
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            latent_dim: int = 1
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'hammer-human-v1':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "hammer"
            level: str = "human"
            env: str = f"{env_sim_name}-{level}-v1"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_frequency: int = int(1e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e5)
            n_test_episodes: int = 10  # How many episodes run during evaluation
            offline_iterations: int = int(1e6)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/AWAC_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/AWAC_offline-{seed}/{env}/checkpoint_999999.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            em_file: str = f"embedding_checkpoints/{env}/checkpoint_199999.pt"  # Model load file name, "" doesn't load
            # CQL
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 500_000  # Replay buffer size
            normalize_reward: bool = True
            batch_size: int = 256  # Batch size for all networks
            max_timesteps: int = int(1e6)
            discount: float = 0.99  # Discount factor
            awac_lambda: float = 0.1
            deterministic_torch: bool = False
            hidden_dim: int = 256
            learning_rate: float = 0.0003
            gamma: float = 0.99
            num_train_ops: int = 1000000
            # seed: 42
            tau: float = 0.005
            # test_seed: 69
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            latent_dim: int = 1
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    elif env_name == 'hammer-expert-v1':
        class TrainConfig:
            # Experiment
            device: str = "cuda"
            env_sim_name: str = "hammer"
            level: str = "expert"
            env: str = f"{env_sim_name}-{level}-v1"  # OpenAI gym environment name
            seed: int = random_seed  # Sets Gym, PyTorch and Numpy seeds
            eval_seed: int = random_seed  # Eval environment seed
            eval_frequency: int = int(1e3)  # How often (time steps) we evaluate
            save_freq: int = int(1e5)
            n_test_episodes: int = 10  # How many episodes run during evaluation
            offline_iterations: int = int(1e6)  # Number of offline updates
            online_iterations: int = int(2e5+2)  # Number of online updates
            checkpoints_path: Optional[str] = f'checkpoints/AWAC_{train_type}/{env}'  # Save path
            load_model: str = '' if 'offline' in train_type else f"../offline/checkpoints/AWAC_offline-{seed}/{env}/checkpoint_999999.pt"
            vae_file: str = f"../../sas_vae_checkpoints/{env}-0/checkpoint_399999.pt"  # Model load file name, "" doesn't load
            em_file: str = f"embedding_checkpoints/{env}/checkpoint_199999.pt"  # Model load file name, "" doesn't load
            # CQL
            buffer_size: int = 10_000_000 if 'incremental' not in train_type else 500_000  # Replay buffer size
            normalize_reward: bool = True
            batch_size: int = 256  # Batch size for all networks
            max_timesteps: int = int(1e6)
            discount: float = 0.99  # Discount factor
            awac_lambda: float = 0.1
            deterministic_torch: bool = False
            hidden_dim: int = 256
            learning_rate: float = 0.0003
            gamma: float = 0.99
            num_train_ops: int = 1000000
            # seed: 42
            tau: float = 0.005
            # test_seed: 69
            vae_hidden_dim: int = 400  # hidden dimension of vae network
            latent_dim: int = 1
            # Wandb logging
            project: str = f"{algo}"
            group: str = f"{env_sim_name}-{level}"
            name: str = f"{train_type}"
    else:
        raise NotImplementedError

    return TrainConfig
