from trajectory.utils import watch

#------------------------ base ------------------------#

logbase = '~/logs/'
gpt_expname = 'vae/vq'

## automatically make experiment names for planning
## by labelling folders with these args
args_to_watch = [
    ('prefix', ''),
    ('plan_freq', 'freq'),
    ('horizon', 'H'),
    ('beam_width', 'beam'),
]

base = {

    'train': {
        'model': "VQTransformer",
        'tag': "experiment",
        'discrete': False,
        'dimension_as_token': False,
        'N': 100,
        'discount': 0.99,
        'n_layer': 4,
        'n_head': 4,

        ## number of epochs for a 1M-size dataset; n_epochs = 1M / dataset_size * n_epochs_ref
        'n_epochs_ref': 20,
        'n_saves': 3,
        'logbase': logbase,
        'device': 'cuda',

        'K': 512,
        'latent_step': 3,
        'n_embd': 128,
        'trajectory_embd': 512,
        'batch_size': 512,
        'learning_rate': 2e-4,
        'lr_decay': False,
        'seed': 42,

        'embd_pdrop': 0.1,
        'resid_pdrop': 0.1,
        'attn_pdrop': 0.1,

        'step': 1,
        'subsampled_sequence_length': 25,
        'termination_penalty': -100,
        'exp_name': gpt_expname,

        'discretizer': 'QuantileDiscretizer',
        'position_weight': 1,
        'action_weight': 5,
        'reward_weight': 1,
        'value_weight': 1,

        'first_action_weight': 0,
        'sum_reward_weight': 0,
        'last_value_weight': 0,

        "normalize": True,
        "normalize_reward": True,
        "max_path_length": 1000,
        "bottleneck": "pooling",
        "disable_goal": False,
        "residual": True,
        "ma_update": True,
    },

    'plan': {
        'discrete': False,
        'logbase': logbase,
        'gpt_loadpath': gpt_expname,
        'gpt_epoch': 'latest',
        'device': 'cuda',
        'renderer': 'Renderer',
        'iql_value': False,

        'plan_freq': 1,
        'horizon': 21,

        "rounds": 2,
        "nb_samples": 4096,

        'beam_width': 128,
        'n_expand': 2,

        'k_obs': 1,
        'k_act': None,
        'cdf_obs': None,
        'cdf_act': 0.6,
        'percentile': 'mean',

        'max_context_transitions': 5,
        'prefix_context': True,
        'prob_threshold': 0.05,
        'prob_weight': 5e2,

        'vis_freq': 200,
        'exp_name': watch(args_to_watch),
        'prefix': 'plans/defaults/',
        'suffix': '0',
        'verbose': True,
        'uniform': False,

        # Planner
        'test_planner': 'sample_prior',
    },

}

#------------------------ locomotion ------------------------#

## for all halfcheetah environments, you can reduce the planning horizon and beam width without
## affecting performance. good for speed and sanity.

halfcheetah_medium_v2 = halfcheetah_medium_replay_v2 = {
    'plan': {
    }
}

halfcheetah_medium_expert_v2 = {
    'plan': {
    },
}

## if you leave the dictionary empty, it will use the base parameters
hopper_medium_expert_v2 = hopper_medium_v2 = walker2d_medium_v2 = {}

## hopper and wlaker2d are a little more sensitive to planning hyperparameters; 
## proceed with caution when reducing the horizon or increasing the planning frequency

hopper_medium_replay_v2 = {
    'train': {
        'n_epochs_ref': 30,
    },
}

walker2d_medium_expert_v2 = {
}

walker2d_medium_replay_v2 = {
    'train': {
        'n_epochs_ref': 30,
    }
}

ant_medium_v2 = ant_medium_replay_v2 = ant_random_v2 = {
    'train': {
    },
}

hammer_cloned_v0 = hammer_human_v0 = human_expert_v0 = {
    'train': {
        "termination_penalty": None,
        "max_path_length": 200,
    },
}

relocate_cloned_v0 = relocate_human_v0 = relocate_expert_v0 = {
    'train': {
        "termination_penalty": None,
        "max_path_length": 200,
    },
}

door_cloned_v0 = door_human_v0 = door_expert_v0 = {
    'train': {
        "termination_penalty": None,
        "max_path_length": 200,
    },
}

pen_cloned_v0 = pen_expert_v0 = pen_human_v0 = {
    'train': {
        "termination_penalty": None,
        "max_path_length": 100,
        'n_epochs_ref': 10,
        'subsampled_sequence_length': 25,

    },
    'plan': {
        'prob_weight': 5e2,
        'horizon': 30,
    }
}

antmaze_large_diverse_v0=antmaze_large_play_v0=antmaze_medium_diverse_v0=antmaze_medium_play_v0=antmaze_umaze_v0 = {
    'train':{
        "disable_goal": False,
        "termination_penalty": None,
        "max_path_length": 1001,
        "normalize": False,
        "normalize_reward": False,
        'lr_decay': False,
        'n_embd': 128,
        'n_layer': 4,
        'K': 8192,
        'latent_step': 3,
        "discount": 0.998,
        'subsampled_sequence_length': 16,
    },
    'plan': {
        'iql_value': False,
        'horizon': 21,
        'vis_freq': 200,
        'renderer': "AntMazeRenderer"
    }
}


antmaze_ultra_diverse_v0=antmaze_ultra_play_v0 = {
'train':{
        "disable_goal": False,
        "termination_penalty": None,
        "max_path_length": 1001,
        "normalize": False,
        "normalize_reward": False,
        'n_epochs_ref': 20,
        'lr_decay': False,
        'K': 8192,
        'n_embd': 128,
        'n_layer': 4,
        'latent_step': 3,
        "discount": 0.998,
        'batch_size': 512,
        'subsampled_sequence_length': 16,
    },
    'plan': {
        'iql_value': False,
        'horizon': 21,
        'vis_freq': 200,
        'renderer': "AntMazeRenderer"
    }
}
