from mage.utils import watch

#------------------------ base ------------------------#

logbase = 'logs/'
gpt_expname = 'vae/vq'

args_to_watch = [
    ('prefix', ''),
    ('plan_freq', 'freq'),
    ('horizon', 'H'),
    ('beam_width', 'beam'),
]

base = {

    'train': {
        'model': "VQTransformer",
        'tag': "experiment",
        'state_conditional': True,
        'N': 100,
        'discount': 0.99,
        'n_layer': 8,
        'n_head': 4,

        'n_epochs_ref': 50,
        'n_saves': 3,
        'logbase': logbase,
        'device': 'cuda:2',

        'K': 512,
        'n_embd': 128,
        'trajectory_embd': 512,
        'batch_size': 256,
        'learning_rate': 1e-4,
        'lr_decay': False,
        'seed': 45,

        'embd_pdrop': 0.1,
        'resid_pdrop': 0.1,
        'attn_pdrop': 0.1,

        'step': 1,    
        'latent_step': 3,
        'subsampled_sequence_length': 25,
        'history_horizon': 0,
        'horizon': 24,
        'return_scale': 1200,
        'termination_penalty': -100,
        'exp_name': gpt_expname,

        'position_weight': 1,
        'action_weight': 5,
        'reward_weight': 1,
        'value_weight': 1,

        'current_obs_weight': 1,
        'current_action_weight': 1,
        'next_obs_weight': 1,
        'next_action_weight': 0.25,

        'first_action_weight': 0,
        'sum_reward_weight': 0,
        'last_value_weight': 0,
        'suffix': '',

        "normalize": True,
        "normalize_reward": True,
        "max_path_length": 1000,
        "bottleneck": "pooling",
        "masking": "uniform",
        "disable_goal": False,
        "residual": True,
        "ma_update": True,

        'use_action': False,
        'rtg': 1.5,
    },

    'plan': {
        'n_epochs_ref': 50,
        'discrete': False,
        'logbase': logbase,
        'gpt_loadpath': gpt_expname,
        'gpt_epoch': 'latest',
        'device': 'cuda',
        'renderer': 'Renderer',
        'suffix': '0',
        'return_scale': 1200,
        'n_layer': 8,
        'K': 512,
        'n_embd': 128,
        'trajectory_embd': 512,

        'plan_freq': 1,
        'horizon': 32,

        "rounds": 2,
        "nb_samples": 4096,

        'beam_width': 64,
        'n_expand': 4,

        'prob_threshold': 0.05,
        'prob_weight': 5e2,

        'vis_freq': 200,
        'exp_name': watch(args_to_watch),
        'verbose': True,
        'uniform': False,

        'tag': "experiment",
        'seed': 42,
        "normalize": True,
        'use_action': False,

        # Planner
        'test_planner': 'beam_prior',
        'rtg': 1.5,
    },

}

#------------------------ locomotion ------------------------#

hammer_cloned_v0 = hammer_human_v0 = human_expert_v0 = {
    'train': {
        "termination_penalty": None,
        "max_path_length": 200,
        'n_epochs_ref': 10,
        'subsampled_sequence_length': 25,
    },
    'plan': {
        'horizon': 24,
    }
}

relocate_cloned_v0 = relocate_human_v0 = relocate_expert_v0 = {
    'train': {
        "termination_penalty": None,
        "max_path_length": 200,
        'n_epochs_ref': 10,
        'subsampled_sequence_length': 25,
    },
    'plan': {
        'horizon': 24,
    }
}

door_cloned_v0 = door_human_v0 = door_expert_v0 = {
    'train': {
        "termination_penalty": None,
        "max_path_length": 200,
        'n_epochs_ref': 10,
        'subsampled_sequence_length': 25,
    },
    'plan': {
        'horizon': 24,
    }
}

pen_cloned_v0 = pen_expert_v0 = pen_human_v0 = {
    'train': {
        "termination_penalty": None,
        "max_path_length": 100,
        'n_epochs_ref': 10,
        'subsampled_sequence_length': 25,
    },
    'plan': {
        'prob_weight': 5e2,
        'horizon': 24,
    }
}

antmaze_large_diverse_v0=antmaze_large_play_v0=antmaze_medium_diverse_v0=antmaze_medium_play_v0=antmaze_umaze_v0 = antmaze_umaze_diverse_v0 ={
    'train':{
        "disable_goal": False,
        "termination_penalty": None,
        "max_path_length": 1001,
        "normalize": False,
        "normalize_reward": False,
        'lr_decay': False,
        'K': 4096,
        "discount": 0.998,
        'subsampled_sequence_length': 16,
        'learning_rate': 2e-4,
    },
    'plan': {
        'iql_value': False,
        'horizon': 15,
        'vis_freq': 200,
        'renderer': "AntMazeRenderer"
    }
}