DMBP_config = {
    "algo": "ql", # "bc", "ql"
    "train_per_epi" : 5000,
    "max_timestep": int(3e5),
    "start_testing": int(0),
    "checkpoint_start": int(2e5),
    "checkpoint_every": int(1e4),

    "gamma": 0.99,
    "tau": 0.005,
    "eta": 1.0,
    "lr_decay": True,
    "max_q_backup": True,
    "step_start_ema": 0, #1000,
    "ema_decay": 0.995,
    "update_ema_every": 5,
    "T": 100,

    "beta_schedule": 'vp', #'self-defined2',
    "beta_training_mode": 'all',
    'loss_training_mode': 'no_act2',    # 'normal' or 'noise' or 'no_act' or 'no_act2'
    "predict_epsilon": True,
    "data_usage": 1.0,
    'ms': 'offline',
    'gn': 10.0,

    # Long Term Buffer Parameter Definition
    "condition_length": 5,
    "T-scheme": "same",  # "random" or "same"

    "non_markovian_step": 6,

    # Attention Hyperparameters
    "attn_hidden_layer": 2,
    "attn_hidden_dim": 128,
    "attn_embed_dim":  128,

    "lr": 1e-4,

    "batch_size": 512,
    "hidden_size": 128,
    "embed_dim":   128,
    "reward_tune": "no",

    "attack_ratio": 0.3,
    "attack_scale": 1.0,

    "need_training": True,
    "need_eval": True,

    "detect_denoise_loops": 1,
    "detect_denoise_steps": 20,

    "total_epoch": 10,
    "steps_per_epoch": 1000,

    "tn_v_T": 0.3,
    "exp_name": "default",

    "repeat_times": 10,

    "stack_length": 50000,
    
    "attack_element": "transition",
}

def update_DMBP_config(env_name, config, args=None):

    # Not pretty sure
    config["beta_training_mode"] = "partial"
    
    updated_config = {"condition_length": 5}
    config.update(updated_config)
    
    if 'halfcheetah' in env_name.lower():
        updated_config = {
            "detect_denoise_loops": 5,
            "detect_denoise_steps": 20,
            "lr": 1e-4,
            "total_epoch": 9,
            "steps_per_epoch": 1000,
        }
        config.update(updated_config)
    if 'hopper' in env_name.lower():
        updated_config = {
            "detect_denoise_loops": 5,
            "detect_denoise_steps": 20,
            "lr": 1e-4,
            "total_epoch": 9,
            "steps_per_epoch": 1000,
        }
        config.update(updated_config)
    if 'walker2d' in env_name.lower():
        updated_config = {
            "detect_denoise_loops": 5,
            "detect_denoise_steps": 20,
            "lr": 1e-4,
            "total_epoch": 9,
            "steps_per_epoch": 1000,
        }
        config.update(updated_config)

    # Here
    my_list = ['pen', 'hammer', 'door', 'relocate']
    if any(item in env_name.lower() for item in my_list):
        updated_config = {
            "detect_denoise_loops": 10, # 5 # 50?
            "detect_denoise_steps": 1, # 2
            "attn_hidden_layer": 5,    # 2
            "total_epoch": 10,         # 20
            "stack_length": 1000,      # 5000
            "steps_per_epoch": 1000,
            "attn_hidden_dim": 256,    # 128
            "attn_embed_dim":  256,    # 128
            "hidden_size": 512,        # 128
            "embed_dim":   512,        # 128
            "lr": 1e-4,
            "batch_size": 512,
        }
        config.update(updated_config)

    if 'kitchen' in env_name.lower():
        updated_config = {
            "detect_denoise_loops": 2,
            "detect_denoise_steps": 25,
            "attn_hidden_layer": 2,    # 2
            "attn_hidden_dim": 256,    # 128
            "attn_embed_dim":  256,    # 128
            "hidden_size": 256,        # 128
            "embed_dim":   256,        # 128
            "total_epoch": 40,         # 20
            "stack_length": 1000,      # 5000
            "steps_per_epoch": 1000,      # 1000
            "lr": 1e-4,
        }
        config.update(updated_config)
        if args.dataset == "mixed":
            updated_config = {
                "attn_hidden_layer": 2,    # 2
                "attn_hidden_dim": 256,    # 128
                "attn_embed_dim":  256,    # 128
                "hidden_size": 512,        # 128
                "embed_dim":   512,        # 128
                "total_epoch": 40,         # 20
                "stack_length": 1000,      # 5000
                "steps_per_epoch": 1000,      # 1000
                "lr": 1e-4,
            }
            config.update(updated_config)
        if args.dataset == "partial":
            updated_config = {
                "attn_hidden_layer": 2,    # 2
                "attn_hidden_dim": 128,    # 128
                "attn_embed_dim":  128,    # 128
                "hidden_size": 128,        # 128
                "embed_dim":   128,        # 128
                "total_epoch": 40,         # 20
                "stack_length": 1000,      # 5000
                "steps_per_epoch": 1000,      # 1000
                "lr": 1e-4,
            }
            config.update(updated_config)

    updated_config = {
        "load_model_path": args.load_model_path,
        "dataset": args.dataset,
        "dataset_path": args.dataset_path,
        "attack_element": args.attack_element,
        "detect_denoise_loops": args.dn_loops,
        "detect_denoise_steps": args.dn_steps,
    }

    config.update(updated_config)

    return config
