DMBP_config = {
    "algo": "ql", # "bc", "ql"
    "train_per_epi" : 5000,
    "max_timestep": int(3e5),
    "start_testing": int(0),
    "checkpoint_start": int(2e5),
    "checkpoint_every": int(1e4),

    "gamma": 0.99,
    "tau": 0.005,
    "eta": 1.0,
    "lr_decay": True,
    "max_q_backup": True,
    "step_start_ema": 0, #1000,
    "ema_decay": 0.995,
    "update_ema_every": 5,
    "T": 100,

    "beta_schedule": 'vp',
    "beta_training_mode": 'partial',
    'loss_training_mode': 'no_act2',
    "predict_epsilon": True,
    "data_usage": 1.0,
    'ms': 'offline',
    'gn': 10.0,

    # Long Term Buffer Parameter Definition
    "condition_length": 5,
    "T-scheme": "same",  # "random" or "same"

    "non_markovian_step": 6,

    # Attention Hyperparameters
    "attn_hidden_layer": 2,
    "attn_hidden_dim": 128,
    "attn_embed_dim":  128,

    "lr": 1e-4,

    "batch_size": 512,
    "hidden_size": 128,
    "embed_dim":   128,
    "reward_tune": "no",

    "attack_ratio": 0.3,
    "attack_scale": 1.0,

    "need_training": True,
    "need_eval": True,

    "detect_denoise_loops": 5,
    "detect_denoise_steps": 20,

    "total_epoch": 10,
    "steps_per_epoch": 1000,

    "tn_v_T": 0.3,
    "exp_name": "default",

    "repeat_times": 10,

    "stack_length": 50000,

    # "start_ambient": 1000
}

def update_DMBP_config(env_name, config, args=None):

    if any(env in env_name.lower() for env in ('hopper', 'walker2d', 'halfcheetah')):
        updated_config = {
            "stack_length": 30000,
        }
        config.update(updated_config)

    my_list = ['pen', 'hammer', 'door', 'relocate']
    if any(item in env_name.lower() for item in my_list):
        updated_config = {
            "detect_denoise_loops": 10,
            "detect_denoise_steps": 1,
            "attn_hidden_layer": 5,
            "attn_hidden_dim": 256,
            "attn_embed_dim":  256,
            "hidden_size": 512,
            "embed_dim":   512,
            "total_epoch": 40,
            "stack_length": 5000,
        }
        config.update(updated_config)
    
    if 'kitchen' in env_name.lower():
        updated_config = {
            "detect_denoise_loops": 10,
            "detect_denoise_steps": 1,
            "attn_hidden_layer": 5,    # 2
            "attn_hidden_dim": 256,    # 128
            "attn_embed_dim":  256,    # 128
            "hidden_size": 512,        # 128
            "embed_dim":   512,        # 128
            "ambient_epochs": 10,         # 20
            "naive_epochs":   30,         # 20
            "stack_length": 5000,      # 5000
        }
        config.update(updated_config)
        # if 'complete' in args.dataset:
        #     updated_config = {
        #         "detect_denoise_loops": 10,
        #         "detect_denoise_steps": 1,
        #         "attn_hidden_layer": 2,    # 2
        #         "attn_hidden_dim": 128,    # 128
        #         "attn_embed_dim":  128,    # 128
        #         "hidden_size": 128,        # 128
        #         "embed_dim":   128,        # 128
        #         "total_epoch": 100,         # 20
        #         "stack_length": 5000,      # 5000
        #     }
        #     config.update(updated_config)
        

    updated_config = {
        "load_model_path": args.load_model_path,
        "dataset": args.dataset,
        "dataset_path": args.dataset_path,
    }
    config.update(updated_config)

    if args.denoisng_loops:
        updated_config = {
            "detect_denoise_loops": args.denoisng_loops,
        }
        config.update(updated_config)
    
    if args.denoisng_steps:
        updated_config = {
            "detect_denoise_steps": args.denoisng_steps,
        }
        config.update(updated_config)

    return config