import socket

from diffuser.utils import watch

#------------------------ base ------------------------#

## automatically make experiment names for planning
## by labelling folders with these args

diffusion_args_to_watch = [
    ('prefix', ''),
    ('horizon', 'H'),
    ('n_diffusion_steps', 'T'),
]


plan_args_to_watch = [
    ('prefix', ''),
    ##
    ('horizon', 'H'),
    ('n_diffusion_steps', 'T'),
    ('value_horizon', 'V'),
    ('discount', 'd'),
    ('normalizer', ''),
    ('batch_size', 'b'),
    ##
    ('conditional', 'cond'),
]
base16 = {

    'diffusion': {
        ## model
        'model': 'models.TemporalUnet', # TemporalUnet, ConditionalUnet1D (Not implemented yet)
        'diffusion': 'models.GaussianDiffusion',
        'horizon': 256,
        'n_diffusion_steps': 256,
        'action_weight': 1,
        'loss_weights': None,
        'loss_discount': 1,
        'predict_epsilon': False,
        'dim_mults': (1, 4, 8),
        'renderer': 'utils.Maze2dRenderer',

        ## dataset
        'loader': 'datasets.GoalDataset',
        'termination_penalty': None,
        'normalizer': 'LimitsNormalizer',
        'preprocess_fns': ['maze2d_set_terminals'],
        'clip_denoised': True,
        'use_padding': False,
        'max_path_length': 40000,

        ## serialization
        'logbase': 'logs',
        'prefix': 'diffusion16/',
        'exp_name': watch(diffusion_args_to_watch),

        ## training
        'n_steps_per_epoch': 20000,
        'loss_type': 'l2',
        'n_train_steps': 4e6,
        'batch_size': 16,
        'learning_rate': 1e-4,
        'gradient_accumulate_every': 1,
        'ema_decay': 0.995,
        'save_freq': 50000,
        'sample_freq': 50000,
        'n_saves': 50,
        'save_parallel': False,
        'n_reference': 50,
        'n_samples': 10,
        'bucket': None,
        'device': 'cuda',
    },

    'plan': {
        'batch_size': 1,
        'device': 'cuda',

        ## diffusion model
        'horizon': 256,
        'n_diffusion_steps': 256,
        'normalizer': 'LimitsNormalizer',

        ## serialization
        'vis_freq': 10,
        'logbase': 'logs',
        'prefix': 'plans/diffusion16',
        'exp_name': watch(plan_args_to_watch),
        'suffix': 'test',

        'conditional': False,

        ## loading
        'diffusion_loadpath': 'f:diffusion16/H{horizon}_T{n_diffusion_steps}',
        'diffusion_epoch': 'latest',
    },

}

base32 = {

    'diffusion': {
        ## model
        'model': 'models.TemporalUnet', # TemporalUnet, ConditionalUnet1D (Not implemented yet)
        'diffusion': 'models.GaussianDiffusion',
        'horizon': 256,
        'n_diffusion_steps': 256,
        'action_weight': 1,
        'loss_weights': None,
        'loss_discount': 1,
        'predict_epsilon': False,
        'dim_mults': (1, 4, 8),
        'renderer': 'utils.Maze2dRenderer',

        ## dataset
        'loader': 'datasets.GoalDataset',
        'termination_penalty': None,
        'normalizer': 'LimitsNormalizer',
        'preprocess_fns': ['maze2d_set_terminals'],
        'clip_denoised': True,
        'use_padding': False,
        'max_path_length': 40000,

        ## serialization
        'logbase': 'logs',
        'prefix': 'diffusion32/',
        'exp_name': watch(diffusion_args_to_watch),

        ## training
        'n_steps_per_epoch': 10000,
        'loss_type': 'l2',
        'n_train_steps': 2e6,
        'batch_size': 32,
        'learning_rate': 1.5e-4,
        'gradient_accumulate_every': 1,
        'ema_decay': 0.995,
        'save_freq': 50000,
        'sample_freq': 50000,
        'n_saves': 50,
        'save_parallel': False,
        'n_reference': 50,
        'n_samples': 10,
        'bucket': None,
        'device': 'cuda',
    },

    'plan': {
        'batch_size': 1,
        'device': 'cuda',

        ## diffusion model
        'horizon': 256,
        'n_diffusion_steps': 256,
        'normalizer': 'LimitsNormalizer',

        ## serialization
        'vis_freq': 10,
        'logbase': 'logs',
        'prefix': 'plans/diffusion32',
        'exp_name': watch(plan_args_to_watch),
        'suffix': 'test',

        'conditional': False,

        ## loading
        'diffusion_loadpath': 'f:diffusion32/H{horizon}_T{n_diffusion_steps}',
        'diffusion_epoch': 'latest',
    },

}

base64 = {

    'diffusion': {
        ## model
        'model': 'models.TemporalUnet', # TemporalUnet, ConditionalUnet1D (Not implemented yet)
        'diffusion': 'models.GaussianDiffusion',
        'horizon': 256,
        'n_diffusion_steps': 256,
        'action_weight': 1,
        'loss_weights': None,
        'loss_discount': 1,
        'predict_epsilon': False,
        'dim_mults': (1, 4, 8),
        'renderer': 'utils.Maze2dRenderer',

        ## dataset
        'loader': 'datasets.GoalDataset',
        'termination_penalty': None,
        'normalizer': 'LimitsNormalizer',
        'preprocess_fns': ['maze2d_set_terminals'],
        'clip_denoised': True,
        'use_padding': False,
        'max_path_length': 40000,

        ## serialization
        'logbase': 'logs',
        'prefix': 'diffusion64/',
        'exp_name': watch(diffusion_args_to_watch),

        ## training
        'n_steps_per_epoch': 5000,
        'loss_type': 'l2',
        'n_train_steps': 1e6,
        'batch_size': 64,
        'learning_rate': 2e-4,
        'gradient_accumulate_every': 1,
        'ema_decay': 0.995,
        'save_freq': 50000,
        'sample_freq': 50000,
        'n_saves': 50,
        'save_parallel': False,
        'n_reference': 50,
        'n_samples': 10,
        'bucket': None,
        'device': 'cuda',
    },

    'plan': {
        'batch_size': 1,
        'device': 'cuda',

        ## diffusion model
        'horizon': 256,
        'n_diffusion_steps': 256,
        'normalizer': 'LimitsNormalizer',

        ## serialization
        'vis_freq': 10,
        'logbase': 'logs',
        'prefix': 'plans/diffusion64',
        'exp_name': watch(plan_args_to_watch),
        'suffix': 'test',

        'conditional': False,

        ## loading
        'diffusion_loadpath': 'f:diffusion64/H{horizon}_T{n_diffusion_steps}',
        'diffusion_epoch': 'latest',
    },

}

base128 = {

    'diffusion': {
        ## model
        'model': 'models.TemporalUnet', # TemporalUnet, ConditionalUnet1D (Not implemented yet)
        'diffusion': 'models.GaussianDiffusion',
        'horizon': 256,
        'n_diffusion_steps': 256,
        'action_weight': 1,
        'loss_weights': None,
        'loss_discount': 1,
        'predict_epsilon': False,
        'dim_mults': (1, 4, 8),
        'renderer': 'utils.Maze2dRenderer',

        ## dataset
        'loader': 'datasets.GoalDataset',
        'termination_penalty': None,
        'normalizer': 'LimitsNormalizer',
        'preprocess_fns': ['maze2d_set_terminals'],
        'clip_denoised': True,
        'use_padding': False,
        'max_path_length': 40000,

        ## serialization
        'logbase': 'logs',
        'prefix': 'diffusion128/',
        'exp_name': watch(diffusion_args_to_watch),

        ## training
        'n_steps_per_epoch': 2500,
        'loss_type': 'l2',
        'n_train_steps': 5e5,
        'batch_size': 128,
        'learning_rate': 3e-4,
        'gradient_accumulate_every': 1,
        'ema_decay': 0.995,
        'save_freq': 50000,
        'sample_freq': 50000,
        'n_saves': 50,
        'save_parallel': False,
        'n_reference': 50,
        'n_samples': 10,
        'bucket': None,
        'device': 'cuda',
    },

    'plan': {
        'batch_size': 1,
        'device': 'cuda',

        ## diffusion model
        'horizon': 256,
        'n_diffusion_steps': 256,
        'normalizer': 'LimitsNormalizer',

        ## serialization
        'vis_freq': 10,
        'logbase': 'logs',
        'prefix': 'plans/diffusion128',
        'exp_name': watch(plan_args_to_watch),
        'suffix': 'test',

        'conditional': False,

        ## loading
        'diffusion_loadpath': 'f:diffusion128/H{horizon}_T{n_diffusion_steps}',
        'diffusion_epoch': 'latest',
    },

}

base256 = {

    'diffusion': {
        ## model
        'model': 'models.TemporalUnet', # TemporalUnet, ConditionalUnet1D (Not implemented yet)
        'diffusion': 'models.GaussianDiffusion',
        'horizon': 256,
        'n_diffusion_steps': 256,
        'action_weight': 1,
        'loss_weights': None,
        'loss_discount': 1,
        'predict_epsilon': False,
        'dim_mults': (1, 4, 8),
        'renderer': 'utils.Maze2dRenderer',

        ## dataset
        'loader': 'datasets.GoalDataset',
        'termination_penalty': None,
        'normalizer': 'LimitsNormalizer',
        'preprocess_fns': ['maze2d_set_terminals'],
        'clip_denoised': True,
        'use_padding': False,
        'max_path_length': 40000,

        ## serialization
        'logbase': 'logs',
        'prefix': 'diffusion256/',
        'exp_name': watch(diffusion_args_to_watch),

        ## training
        'n_steps_per_epoch': 1250,
        'loss_type': 'l2',
        'n_train_steps': 2.5e5,
        'batch_size': 256,
        'learning_rate': 4e-4,
        'gradient_accumulate_every': 1,
        'ema_decay': 0.995,
        'save_freq': 50000,
        'sample_freq': 50000,
        'n_saves': 50,
        'save_parallel': False,
        'n_reference': 50,
        'n_samples': 10,
        'bucket': None,
        'device': 'cuda',
    },

    'plan': {
        'batch_size': 1,
        'device': 'cuda',

        ## diffusion model
        'horizon': 256,
        'n_diffusion_steps': 256,
        'normalizer': 'LimitsNormalizer',

        ## serialization
        'vis_freq': 10,
        'logbase': 'logs',
        'prefix': 'plans/diffusion256',
        'exp_name': watch(plan_args_to_watch),
        'suffix': 'test',

        'conditional': False,

        ## loading
        'diffusion_loadpath': 'f:diffusion256/H{horizon}_T{n_diffusion_steps}',
        'diffusion_epoch': 'latest',
    },

}

base512 = {

    'diffusion': {
        ## model
        'model': 'models.TemporalUnet', # TemporalUnet, ConditionalUnet1D (Not implemented yet)
        'diffusion': 'models.GaussianDiffusion',
        'horizon': 256,
        'n_diffusion_steps': 256,
        'action_weight': 1,
        'loss_weights': None,
        'loss_discount': 1,
        'predict_epsilon': False,
        'dim_mults': (1, 4, 8),
        'renderer': 'utils.Maze2dRenderer',

        ## dataset
        'loader': 'datasets.GoalDataset',
        'termination_penalty': None,
        'normalizer': 'LimitsNormalizer',
        'preprocess_fns': ['maze2d_set_terminals'],
        'clip_denoised': True,
        'use_padding': False,
        'max_path_length': 40000,

        ## serialization
        'logbase': 'logs',
        'prefix': 'diffusion512/',
        'exp_name': watch(diffusion_args_to_watch),

        ## training
        'n_steps_per_epoch': 625,
        'loss_type': 'l2',
        'n_train_steps': 1.25e5,
        'batch_size': 512,
        'learning_rate': 5.5e-4,
        'gradient_accumulate_every': 1,
        'ema_decay': 0.995,
        'save_freq': 50000,
        'sample_freq': 50000,
        'n_saves': 50,
        'save_parallel': False,
        'n_reference': 50,
        'n_samples': 10,
        'bucket': None,
        'device': 'cuda',
    },

    'plan': {
        'batch_size': 1,
        'device': 'cuda',

        ## diffusion model
        'horizon': 256,
        'n_diffusion_steps': 256,
        'normalizer': 'LimitsNormalizer',

        ## serialization
        'vis_freq': 10,
        'logbase': 'logs',
        'prefix': 'plans/diffusion512',
        'exp_name': watch(plan_args_to_watch),
        'suffix': 'test',

        'conditional': False,

        ## loading
        'diffusion_loadpath': 'f:diffusion512/H{horizon}_T{n_diffusion_steps}',
        'diffusion_epoch': 'latest',
    },

}

cfm16 = {

    'diffusion': {
        ## model
        'model': 'models.TemporalUnet', # TemporalUnet, ConditionalUnet1D (Not implemented yet)
        'diffusion': 'models.CFM',
        'horizon': 256,
        'n_diffusion_steps': 256,
        'action_weight': 1,
        'loss_weights': None,
        'loss_discount': 1,
        'predict_epsilon': False,
        'dim_mults': (1, 4, 8),
        'renderer': 'utils.Maze2dRenderer',

        ## dataset
        'loader': 'datasets.GoalDataset',
        'termination_penalty': None,
        'normalizer': 'LimitsNormalizer',
        'preprocess_fns': ['maze2d_set_terminals'],
        'clip_denoised': True,
        'use_padding': False,
        'max_path_length': 40000,

        ## serialization
        'logbase': 'logs',
        'prefix': 'cfm16/',
        'exp_name': watch(diffusion_args_to_watch),

        ## training
        'n_steps_per_epoch': 20000, # base: 10000
        'loss_type': 'l2',
        'n_train_steps': 4e6, # base: 2e6
        'batch_size': 16,     # base: 32
        'learning_rate': 1e-4,
        'gradient_accumulate_every': 1, # base: 2
        'ema_decay': 0.995,
        'save_freq': 50000,
        'sample_freq': 50000,
        'n_saves': 50,
        'save_parallel': False,
        'n_reference': 50,
        'n_samples': 10,
        'bucket': None,
        'device': 'cuda',
    },

    'plan': {
        'batch_size': 1,
        'device': 'cuda',

        ## cfm model
        'horizon': 256,
        'n_diffusion_steps': 256,
        'normalizer': 'LimitsNormalizer',

        ## serialization
        'vis_freq': 10,
        'logbase': 'logs',
        'prefix': 'plans/cfm16',
        'exp_name': watch(plan_args_to_watch),
        'suffix': 'test',

        'conditional': False,

        ## loading
        'diffusion_loadpath': 'f:cfm16/H{horizon}_T{n_diffusion_steps}',
        'diffusion_epoch': 'latest',
    },

}

cfm32 = {

    'diffusion': {
        ## model
        'model': 'models.TemporalUnet', # TemporalUnet, ConditionalUnet1D (Not implemented yet)
        'diffusion': 'models.CFM',
        'horizon': 256,
        'n_diffusion_steps': 256,
        'action_weight': 1,
        'loss_weights': None,
        'loss_discount': 1,
        'predict_epsilon': False,
        'dim_mults': (1, 4, 8),
        'renderer': 'utils.Maze2dRenderer',

        ## dataset
        'loader': 'datasets.GoalDataset',
        'termination_penalty': None,
        'normalizer': 'LimitsNormalizer',
        'preprocess_fns': ['maze2d_set_terminals'],
        'clip_denoised': True,
        'use_padding': False,
        'max_path_length': 40000,

        ## serialization
        'logbase': 'logs',
        'prefix': 'cfm32/',
        'exp_name': watch(diffusion_args_to_watch),

        ## training
        'n_steps_per_epoch': 10000, # base: 10000
        'loss_type': 'l2',
        'n_train_steps': 2e6, # base: 2e6
        'batch_size': 32,     # base: 32
        'learning_rate': 1.5e-4,
        'gradient_accumulate_every': 1, # base: 2
        'ema_decay': 0.995,
        'save_freq': 50000,
        'sample_freq': 50000,
        'n_saves': 50,
        'save_parallel': False,
        'n_reference': 50,
        'n_samples': 10,
        'bucket': None,
        'device': 'cuda',
    },

    'plan': {
        'batch_size': 1,
        'device': 'cuda',

        ## cfm model
        'horizon': 256,
        'n_diffusion_steps': 256,
        'normalizer': 'LimitsNormalizer',

        ## serialization
        'vis_freq': 10,
        'logbase': 'logs',
        'prefix': 'plans/cfm32',
        'exp_name': watch(plan_args_to_watch),
        'suffix': 'test',

        'conditional': False,

        ## loading
        'diffusion_loadpath': 'f:cfm32/H{horizon}_T{n_diffusion_steps}',
        'diffusion_epoch': 'latest',
    },

}

cfm64 = {

    'diffusion': {
        ## model
        'model': 'models.TemporalUnet', # TemporalUnet, ConditionalUnet1D (Not implemented yet)
        'diffusion': 'models.CFM',
        'horizon': 256,
        'n_diffusion_steps': 256,
        'action_weight': 1,
        'loss_weights': None,
        'loss_discount': 1,
        'predict_epsilon': False,
        'dim_mults': (1, 4, 8),
        'renderer': 'utils.Maze2dRenderer',

        ## dataset
        'loader': 'datasets.GoalDataset',
        'termination_penalty': None,
        'normalizer': 'LimitsNormalizer',
        'preprocess_fns': ['maze2d_set_terminals'],
        'clip_denoised': True,
        'use_padding': False,
        'max_path_length': 40000,

        ## serialization
        'logbase': 'logs',
        'prefix': 'cfm64/',
        'exp_name': watch(diffusion_args_to_watch),

        ## training
        'n_steps_per_epoch': 5000, # base: 10000
        'loss_type': 'l2',
        'n_train_steps': 1e6, # base: 2e6
        'batch_size': 64,     # base: 32
        'learning_rate': 2e-4,
        'gradient_accumulate_every': 1, # base: 2
        'ema_decay': 0.995,
        'save_freq': 50000,
        'sample_freq': 50000,
        'n_saves': 50,
        'save_parallel': False,
        'n_reference': 50,
        'n_samples': 10,
        'bucket': None,
        'device': 'cuda',
    },

    'plan': {
        'batch_size': 1,
        'device': 'cuda',

        ## cfm model
        'horizon': 256,
        'n_diffusion_steps': 256,
        'normalizer': 'LimitsNormalizer',

        ## serialization
        'vis_freq': 10,
        'logbase': 'logs',
        'prefix': 'plans/cfm64',
        'exp_name': watch(plan_args_to_watch),
        'suffix': 'test',

        'conditional': False,

        ## loading
        'diffusion_loadpath': 'f:cfm64/H{horizon}_T{n_diffusion_steps}',
        'diffusion_epoch': 'latest',
    },

}

cfm128 = {

    'diffusion': {
        ## model
        'model': 'models.TemporalUnet', # TemporalUnet, ConditionalUnet1D (Not implemented yet)
        'diffusion': 'models.CFM',
        'horizon': 256,
        'n_diffusion_steps': 256,
        'action_weight': 1,
        'loss_weights': None,
        'loss_discount': 1,
        'predict_epsilon': False,
        'dim_mults': (1, 4, 8),
        'renderer': 'utils.Maze2dRenderer',

        ## dataset
        'loader': 'datasets.GoalDataset',
        'termination_penalty': None,
        'normalizer': 'LimitsNormalizer',
        'preprocess_fns': ['maze2d_set_terminals'],
        'clip_denoised': True,
        'use_padding': False,
        'max_path_length': 40000,

        ## serialization
        'logbase': 'logs',
        'prefix': 'cfm128/',
        'exp_name': watch(diffusion_args_to_watch),

        ## training
        'n_steps_per_epoch': 2500, # base: 10000
        'loss_type': 'l2',
        'n_train_steps': 5e5, # base: 2e6
        'batch_size': 128,     # base: 32
        'learning_rate': 3e-4,
        'gradient_accumulate_every': 1, # base: 2
        'ema_decay': 0.995,
        'save_freq': 50000,
        'sample_freq': 50000,
        'n_saves': 50,
        'save_parallel': False,
        'n_reference': 50,
        'n_samples': 10,
        'bucket': None,
        'device': 'cuda',
    },

    'plan': {
        'batch_size': 1,
        'device': 'cuda',

        ## cfm model
        'horizon': 256,
        'n_diffusion_steps': 256,
        'normalizer': 'LimitsNormalizer',

        ## serialization
        'vis_freq': 10,
        'logbase': 'logs',
        'prefix': 'plans/cfm128',
        'exp_name': watch(plan_args_to_watch),
        'suffix': 'test',

        'conditional': False,

        ## loading
        'diffusion_loadpath': 'f:cfm128/H{horizon}_T{n_diffusion_steps}',
        'diffusion_epoch': 'latest',
    },

}

cfm256 = {

    'diffusion': {
        ## model
        'model': 'models.TemporalUnet', # TemporalUnet, ConditionalUnet1D (Not implemented yet)
        'diffusion': 'models.CFM',
        'horizon': 256,
        'n_diffusion_steps': 256,
        'action_weight': 1,
        'loss_weights': None,
        'loss_discount': 1,
        'predict_epsilon': False,
        'dim_mults': (1, 4, 8),
        'renderer': 'utils.Maze2dRenderer',

        ## dataset
        'loader': 'datasets.GoalDataset',
        'termination_penalty': None,
        'normalizer': 'LimitsNormalizer',
        'preprocess_fns': ['maze2d_set_terminals'],
        'clip_denoised': True,
        'use_padding': False,
        'max_path_length': 40000,

        ## serialization
        'logbase': 'logs',
        'prefix': 'cfm256/',
        'exp_name': watch(diffusion_args_to_watch),

        ## training
        'n_steps_per_epoch': 1250, # base: 10000
        'loss_type': 'l2',
        'n_train_steps': 2.5e5, # base: 2e6
        'batch_size': 256,     # base: 32
        'learning_rate': 4e-4,
        'gradient_accumulate_every': 1, # base: 2
        'ema_decay': 0.995,
        'save_freq': 50000,
        'sample_freq': 50000,
        'n_saves': 50,
        'save_parallel': False,
        'n_reference': 50,
        'n_samples': 10,
        'bucket': None,
        'device': 'cuda',
    },

    'plan': {
        'batch_size': 1,
        'device': 'cuda',

        ## cfm model
        'horizon': 256,
        'n_diffusion_steps': 256,
        'normalizer': 'LimitsNormalizer',

        ## serialization
        'vis_freq': 10,
        'logbase': 'logs',
        'prefix': 'plans/cfm256',
        'exp_name': watch(plan_args_to_watch),
        'suffix': 'test',

        'conditional': False,

        ## loading
        'diffusion_loadpath': 'f:cfm256/H{horizon}_T{n_diffusion_steps}',
        'diffusion_epoch': 'latest',
    },

}

cfm512 = {

    'diffusion': {
        ## model
        'model': 'models.TemporalUnet', # TemporalUnet, ConditionalUnet1D (Not implemented yet)
        'diffusion': 'models.CFM',
        'horizon': 256,
        'n_diffusion_steps': 256,
        'action_weight': 1,
        'loss_weights': None,
        'loss_discount': 1,
        'predict_epsilon': False,
        'dim_mults': (1, 4, 8),
        'renderer': 'utils.Maze2dRenderer',

        ## dataset
        'loader': 'datasets.GoalDataset',
        'termination_penalty': None,
        'normalizer': 'LimitsNormalizer',
        'preprocess_fns': ['maze2d_set_terminals'],
        'clip_denoised': True,
        'use_padding': False,
        'max_path_length': 40000,

        ## serialization
        'logbase': 'logs',
        'prefix': 'cfm512/',
        'exp_name': watch(diffusion_args_to_watch),

        ## training
        'n_steps_per_epoch': 625, # base: 10000
        'loss_type': 'l2',
        'n_train_steps': 1.25e5, # base: 2e6
        'batch_size': 512,     # base: 32
        'learning_rate': 5.5e-4,
        'gradient_accumulate_every': 1, # base: 2
        'ema_decay': 0.995,
        'save_freq': 50000,
        'sample_freq': 50000,
        'n_saves': 50,
        'save_parallel': False,
        'n_reference': 50,
        'n_samples': 10,
        'bucket': None,
        'device': 'cuda',
    },

    'plan': {
        'batch_size': 1,
        'device': 'cuda',

        ## cfm model
        'horizon': 256,
        'n_diffusion_steps': 256,
        'normalizer': 'LimitsNormalizer',

        ## serialization
        'vis_freq': 10,
        'logbase': 'logs',
        'prefix': 'plans/cfm512',
        'exp_name': watch(plan_args_to_watch),
        'suffix': 'test',

        'conditional': False,

        ## loading
        'diffusion_loadpath': 'f:cfm512/H{horizon}_T{n_diffusion_steps}',
        'diffusion_epoch': 'latest',
    },

}

#------------------------ overrides ------------------------#

'''
    maze2d maze episode steps:
        umaze: 150
        medium: 250
        large: 600
'''

maze2d_umaze_v1 = {
    'diffusion': {
        'horizon': 128,
        'n_diffusion_steps': 64,
    },
    'plan': {
        'horizon': 128,
        'n_diffusion_steps': 64,
    },
}

maze2d_large_v1 = {
    'diffusion': {
        'horizon': 384,
        'n_diffusion_steps': 256,  # 256 debug
    },
    'plan': {
        'horizon': 384,
        'n_diffusion_steps': 256,  # 256 debug

        # Safety enabled
        'safety_enabled': True,  # True, False 
        # (If you want to use naive FM, you also need to change conditional_sample in cfm.py)

        # One Shot Initialization
        'one_shot_enabled': True,  # True, False

        ## CBF for flow matching
        'obstacles': [
            {'order': 2, 'center': (5.6, 4.8), 'radius': 1},
            {'order': 4, 'center': (5.1, 1.8), 'radius': 1},
        ],
        #     'obstacles': [
        #     {'order': 2, 'center': (5.8, 5.0), 'radius': 1},
        #     {'order': 4, 'center': (5.3, 2.0), 'radius': 1},
        # ],
        'cbf_solver': 'closed_form',   # 'qp', 'closed_form'
        'cbf_method': 'relax', # 'robust', 'relax', 'time'
        'robust_term': 0.01,  # 0.01, 0.1
        'relax_threshold': 0.99, # 0.9999

        # set suffix
        'suffix': 'test'
    },
}
