import socket

from diffuser.utils import watch

#------------------------ base ------------------------#

## automatically make experiment names for planning
## by labelling folders with these args

args_to_watch = [
    ('prefix', ''),
    ('horizon', 'H'),
    ('n_diffusion_steps', 'T'),
    ## value kwargs
    ('discount', 'd'),
]

logbase = 'logs'

base = {
    'mo_diffusion': {
        ## model
        'model': 'models.TemporalUnet',
        'diffusion': 'models.MOGaussianDiffusion',
        # 'horizon': 32,
        # 'n_diffusion_steps': 20,
        'action_weight': 10,
        'loss_weights': None,
        'loss_discount': 0.99,
        'predict_epsilon': False,
        'dim_mults': (1, 4, 8),
        'dim': 64,
        # 'dim_mults': (1, 2, 4), # for Humanoid env
        # 'dim': 256,
        'attention': False,
        'renderer': 'utils.MuJoCoRenderer',

        ## dataset # overridded by D4MORL config
        # 'loader': 'datasets.TrajectoryDataset',
        'normalizer': 'GaussianNormalizer',
        'preprocess_fns': [],
        'clip_denoised': False,
        'use_padding': True,
        'max_path_length': 1000,

        ## serialization
        'logbase': logbase,
        'prefix': 'diffusion/defaults',
        'exp_name': watch(args_to_watch),

        ## training
        'loss_type': 'l1',
        'n_steps_per_epoch': 10000,
        'n_train_steps': 1e6,     # overridded by D4MORL config
        # 'batch_size': 32,       # overridded by D4MORL config
        # 'learning_rate': 2e-4,  # overridded by D4MORL config
        'gradient_accumulate_every': 2,
        'ema_decay': 0.995,
        'save_freq': 1e8,         # overridded by D4MORL config
        'sample_freq': 20000,
        'n_saves': 5,
        'save_parallel': False,
        'n_reference': 8,
        'bucket': None,
        'device': 'cuda',
        'seed': None,
    }
}


#------------------------ overrides ------------------------#
### please refer to the parameter list for default parameters for every env
