import torch

from params_proto.neo_proto import ParamsProto, PrefixProto, Proto

class Config(ParamsProto):
    # misc
    seed = 100
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    bucket = '/data/your_name_hdd/dd/weights'
    dataset = 'hopper-medium-expert-v2'

    ## model
    model = 'models.TemporalUnet'
    diffusion = 'models.GaussianInvDynDiffusion'
    posterior_model = 'models.future.FutureTransformer'
    prior_model = 'models.helpers.DiagGaussian'
    # Unitree-go-runnning horizon
    # horizon = 56
    # Gym horizon
    horizon = 100
    n_diffusion_steps = 200
    action_weight = 10
    loss_weights = None
    loss_discount = 1
    predict_epsilon = True
    dim_mults = (1, 4, 8)
    returns_condition = True
    calc_energy=False
    dim=128
    condition_dropout=0.25
    condition_guidance_w = 1.2
    test_ret=0.9
    # Unitree-go-running gaits dimension
    # condition_dim=3
    # Gym return dimension
    condition_dim=1
    renderer = 'utils.MuJoCoRenderer'

    ## dataset---------------
    #loader = 'datasets.SequenceDataset'
    loader = 'datasets.FutureSequenceDataset'
    n_workers=16
    normalizer = 'CDFNormalizer'
    preprocess_fns = []
    clip_denoised = True
    use_padding = True
    # Selection between return and gaits
    include_returns = True
    include_gaits = False
    discount = 0.99
    # Unitree-go-running environment
    #max_path_length=250
    # Gym environment
    max_path_length = 1000
    hidden_dim = 256
    ar_inv = False
    train_only_inv = False
    termination_penalty = -100
    returns_scale = 400.0 # Determined using rewards from the dataset
    z_dim = 16
    po = 0
    occlude_start_idx = -2
    
    ## training
    resume=0
    trained_time='2023-08-27/18-43-16'
    mode = 'new'
    future_mode=1
    token_mode=1
    #mode = 'old'
    z_reg = 0.01
    #trainer = 'utils.NewTraining.NewTrainer'
    trainer = 'Trainer'
    n_steps_per_epoch = 10000
    loss_type = 'l2'
    n_train_steps = 1e6
    batch_size = 32
    learning_rate = 2e-4
    gradient_accumulate_every = 2
    ema_decay = 0.995
    log_freq = 1000
    save_freq = 50000
    sample_freq = 10000
    n_saves = 5
    save_parallel = False
    n_reference = 8
    save_checkpoints = True
    cond_z = 0
    # valid
    valid_freq = 50000
    #eval
    ckpt_step=750000
    num_eval=10
    sample_mode='mu'