import wandb
wandb.login()

sweep_config = {
    'method': 'grid'
    }
metric = {
    'name': 'val_loss',
    'goal': 'minimize'
    }

sweep_config['metric'] = metric
parameters_dict = {
    'num_layers':{
        'value': 2
    },
    'layer_dims':{
        'value': 250
    },
    'batch_size': {
        'values':  [100, 500, 1000]
    },
    'lr': {
        'values': [1e-2, 1e-3, 1e-4, 1e-5]
    },
    'nw':{
        'value': 0
    },
    'kl':{
        'values': [1e-4, 1e-3, 1e-2]
    },
    'latent': {
        'value': 2
    },
    'time':{
        'value': True
    },
    'emg': {
        'value': False
    },
    'save_best':{
        'value': False
    },
    'seed': {
         'values': [0, 1, 2, 3]
      }
}

sweep_config['parameters'] = parameters_dict
parameters_dict.update({
    'epochs': {
        'value': 500},
    'data_artifact': {
        'value': 'data/synthetic_spiral'
    },
    })
sweep_id = wandb.sweep(sweep_config, project="neighbor_vae")
