from absl import flags
from ml_collections import config_flags


args = flags.FLAGS

flags.DEFINE_string('env_name', 'halfcheetah-expert-v2', 'Environment name.')
flags.DEFINE_string('f', None,None)
flags.DEFINE_string('v_update', 'expectile_loss','Value update function.[expectile_loss, rkl_loss]')
flags.DEFINE_string('save_dir', './saved_models/', 'Tensorboard logging dir.')
flags.DEFINE_string('exp_name', 'test', 'Epoch logging dir.')
flags.DEFINE_integer('seed', 42, 'Random seed.')
flags.DEFINE_integer('eval_episodes', 10, 'Number of episodes used for evaluation.')
flags.DEFINE_integer('eval_interval', 10000, 'Eval interval.')
flags.DEFINE_integer('batch_size', 1024, 'Mini batch size.')
flags.DEFINE_boolean('double', True, 'Use double q-learning')
flags.DEFINE_integer('max_steps', int(1e6), 'Number of training steps.')
flags.DEFINE_integer('sample_random_times', 0, 'Number of random actions to add to smooth dataset')
flags.DEFINE_boolean('grad_pen', False, 'Add a gradient penalty to critic network')
flags.DEFINE_float('lambda_gp', 1, 'Gradient penalty coefficient')
flags.DEFINE_float('max_clip', 7., 'Loss clip value')
flags.DEFINE_float('alpha', 1.0, 'Alpha for RKL loss')
flags.DEFINE_integer('num_v_updates', 1, 'Number of value updates per iter')
flags.DEFINE_boolean('log_loss', False, 'Use log gumbel loss')
flags.DEFINE_boolean('noise', False, 'Add noise to actions')
flags.DEFINE_float('noise_std', 0.1, 'Noise std for actions')
flags.DEFINE_boolean('state_norm', False, 'Normalize states')
flags.DEFINE_boolean('use_wandb', False, 'Use wandb')
flags.DEFINE_boolean('update_Q_inference', False, 'Update Q inference')

# general paramters
flags.DEFINE_float('actor_lr', 3e-4, 'Actor learning rate')
flags.DEFINE_float('value_lr', 3e-4, 'Value learning rate')
flags.DEFINE_float('critic_lr', 3e-4, 'Critic learning rate')
flags.DEFINE_float('disc_lr', 1e-4, 'Discriminator learning rate')

flags.DEFINE_float('weight_decay', 0.001, 'Weight decay')
flags.DEFINE_float('actor_temperature', 1.0, 'Actor temperature')
flags.DEFINE_float('dropout_rate', 0.0, 'Dropout rate')
flags.DEFINE_float('tau', 0.001, 'Tau for soft target updates')
flags.DEFINE_float('expectile', 0.7, 'Expectile for IQ loss')
flags.DEFINE_boolean('layernorm', False, 'Use layernorm')
flags.DEFINE_float('discount', 0.99, 'Discount factor')

flags.DEFINE_integer('hidden_size', 256, 'Hidden size')
flags.DEFINE_integer('num_layers', 3, 'Number of layers')

# offline IL only
flags.DEFINE_integer('expert_dataset_size', None, 'Expert dataset size')

flags.DEFINE_integer('num_disc_train', int(1e5), 'Number of discriminator train steps')
flags.DEFINE_float('reward_gap', 5.0, 'Reward gap for good dataset')
flags.DEFINE_float('scale_mix', 1.0, 'Scale for mix dataset')
flags.DEFINE_float('clip_threshold', 0.6, 'Clip threshold for bad_disc')

flags.DEFINE_list('bad_name_list', None, 'List of bad dataset names')
flags.DEFINE_list('bad_size_list', None, 'List of bad dataset sizes')

flags.DEFINE_list('mixed_name_list', None, 'List of mixed dataset names')
flags.DEFINE_list('mixed_size_list', None, 'List of mixed dataset sizes')
flags.DEFINE_list('is_good_list', None, 'List of good dataset names')
flags.DEFINE_list('is_bad_list', None, 'List of bad dataset names')

