policy_embedding_config = {
    'posterior': {
        'n_layer': 4,
        'n_head': 8,
        'n_embd': 512,
        'dropout': 0,
        'bias': True,
        'cat_size': 16,
        'class_size': 16
    },

    'prior': {
        'n_layer': 2,
        'n_head': 4,
        'n_embd': 256,
        'dropout': 0,
        'bias': True
    },

    'policy': {
        'n_layer': 4,
        'n_head': 8,
        'n_embd': 512,
        'dropout': 0,
        'bias': True,
        'actor_type': 'stochastic',
        'max_action': 1,
        'log_std_bounds': [-2, 5],
        'z_condition': True,
    },

    'kl': {
        'kl_loss_coef': 1,
        'kl_balance_scale': 0.8,
        'use_free_nats': True,
        'free_nats': 1
    },

    'img_size': 256,
    'patch_size': 32,
    'seq_len': 8,
    'use_fourier': False,

    'weight_decay': 1e-4,
    'warmup_steps': 5000,
    'lr': 3e-5,
    'betas': (0.9, 0.9),

    'n_train_steps': 100000,
    'batch_size': 32,
    'gradient_accumulate_every': 1,
}


# policy_embedding_config_256 = {
#     'posterior': {
#         'n_layer': 4,
#         'n_head': 8,
#         'n_embd': 512,
#         'dropout': 0,
#         'bias': True,
#         'cat_size': 16,
#         'class_size': 16
#     },

#     'prior': {
#         'n_layer': 2,
#         'n_head': 4,
#         'n_embd': 256,
#         'dropout': 0,
#         'bias': True
#     },

#     'policy': {
#         'n_layer': 4,
#         'n_head': 8,
#         'n_embd': 512,
#         'dropout': 0,
#         'bias': True,
#         'actor_type': 'stochastic',
#         'max_action': 1,
#         'log_std_bounds': [-2, 5],
#         'z_condition': True,
#     },

#     'kl': {
#         'kl_loss_coef': 1,
#         'kl_balance_scale': 0.8,
#         'use_free_nats': True,
#         'free_nats': 1
#     },

#     'img_size': 256,
#     'patch_size': 32,
#     'seq_len': 8,
#     'use_fourier': False,

#     'weight_decay': 1e-4,
#     'warmup_steps': 5000,
#     'lr': 3e-5,
#     'betas': (0.9, 0.9),

#     'n_train_steps': 100000,
#     'batch_size': 32,
#     'gradient_accumulate_every': 1,
# }


oxe_policy_embedding_config = {
    'posterior': {
        'n_layer': 6,
        'n_head': 12,
        'n_embd': 768,
        'dropout': 0,
        'bias': True,
        'cat_size': 16,
        'class_size': 16
    },

    'prior': {
        'n_layer': 4,
        'n_head': 8,
        'n_embd': 512,
        'dropout': 0,
        'bias': True
    },

    'policy': {
        'n_layer': 6,
        'n_head': 12,
        'n_embd': 768,
        'dropout': 0,
        'bias': True,
        'actor_type': 'stochastic',
        'max_action': 1,
        'log_std_bounds': [-2, 5],
        'z_condition': True,
    },

    'kl': {
        'kl_loss_coef': 1,
        'kl_balance_scale': 0.8,
        'use_free_nats': True,
        'free_nats': 1
    },

    'img_size': 256,
    'patch_size': 32,
    'seq_len': 8,
    'use_fourier': False,

    'weight_decay': 1e-4,
    'warmup_steps': 5000,
    'lr': 3e-5,
    'betas': (0.9, 0.9),

    'n_train_steps': 300000,
    'batch_size': 32,
    'gradient_accumulate_every': 1,
}

