import argparse
import numpy as np
import torch

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
### build arguments
parser = argparse.ArgumentParser()
parser.add_argument('--env', default='Cubli')
parser.add_argument('--stage', default='dy', help="dy: dynamics model")
parser.add_argument('--pstep', type=int, default=2)
parser.add_argument('--random_seed', type=int, default=42)

parser.add_argument('--time_step', type=int, default=0)
parser.add_argument('--dt', type=float, default=1. / 60.)
parser.add_argument('--n_instance', type=int, default=4)

parser.add_argument('--nf_relation', type=int, default=150)
parser.add_argument('--nf_particle', type=int, default=150)
parser.add_argument('--nf_pos', type=int, default=150)
parser.add_argument('--nf_memory', type=int, default=150)
parser.add_argument('--mem_nlayer', type=int, default=2)
parser.add_argument('--nf_effect', type=int, default=150)

parser.add_argument('--outf', default='files')
parser.add_argument('--evalf', default='eval')
parser.add_argument('--dataf', default='data')

parser.add_argument('--eval', type=int, default=0)
parser.add_argument('--verbose_data', type=int, default=0)
parser.add_argument('--verbose_model', type=int, default=0)
parser.add_argument('--eps', type=float, default=1e-6)

# for ablation study
parser.add_argument('--neighbor_radius', type=float, default=-1)
parser.add_argument('--neighbor_k', type=float, default=-1)

parser.add_argument('--n_his', type=int, default=2)
parser.add_argument('--sequence_length', type=int, default=3)

parser.add_argument('--shape_state_dim', type=int, default=14)

# object attributes:
parser.add_argument('--attr_dim', type=int, default=0)

# object state:
parser.add_argument('--state_dim', type=int, default=0)

# relation attr:
parser.add_argument('--relation_dim', type=int, default=0)

# physics parameter
parser.add_argument('--physics_param_range', type=float, nargs=2, default=None)

# width and height for storing vision
parser.add_argument('--vis_width', type=int, default=160)
parser.add_argument('--vis_height', type=int, default=120)


'''
train
'''

parser.add_argument('--n_rollout', type=int, default=0)
parser.add_argument('--train_valid_ratio', type=float, default=0.9)
parser.add_argument('--num_workers', type=int, default=10)
parser.add_argument('--log_per_iter', type=int, default=500)
parser.add_argument('--ckp_per_iter', type=int, default=1000)

parser.add_argument('--n_epoch', type=int, default=200)
parser.add_argument('--beta1', type=float, default=0.9)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--optimizer', default='Adam', help='Adam|SGD')
parser.add_argument('--max_grad_norm', type=float, default=1.0)
parser.add_argument('--batch_size', type=int, default=8)

# data generation
parser.add_argument('--gen_data', type=int, default=0)
parser.add_argument('--gen_stat', type=int, default=0)
parser.add_argument('--gen_vision', type=int, default=0)

parser.add_argument('--resume', type=int, default=0)
parser.add_argument('--resume_epoch', type=int, default=0)
parser.add_argument('--resume_iter', type=int, default=0)

# data augmentation
parser.add_argument('--augment_ratio', type=float, default=0.)


'''
eval
'''
parser.add_argument('--eval_epoch', type=int, default=-1, help='pretrained model')
parser.add_argument('--eval_iter', type=int, default=-1, help='pretrained model')
parser.add_argument('--eval_set', default='demo')

# visualization flog
parser.add_argument('--pyflex', type=int, default=1)
parser.add_argument('--vispy', type=int, default=1)


def gen_args():
    args = parser.parse_args()

    args.data_names = ['positions', 'shape_quats', 'scene_params','velocities']

    if args.env == 'Cubli':
        args.env_idx = 3

        args.n_rollout = 200  # The number of training trajectories
        args.time_step = 149

        # object states:
        # [x, y, z]
        args.state_dim = 3

        # object attr:
        # [particle, floor]
        args.attr_dim = 2

        args.neighbor_radius = 0.5
        args.neighbor_k = 20

        suffix = ''
        if args.n_instance == -1:
            args.n_instance = 1
        else:
            suffix += '_nIns_' + str(args.n_instance)

        args.physics_param_range = (-15., -5.)

        args.outf = 'outf_file'
        args.evalf = 'evaf_file'

        args.mean_p = np.array([0,0,0])
        args.std_p = np.array([1 , 1,1])
        args.mean_d = np.array([0,  0, 0])
        args.std_d = np.array([1, 1, 1])


    args.dataf = 'YOUR DATA' + args.dataf + '_' + args.env




    # data augmentation
    if args.augment_ratio > 0:
        args.outf += '_aug%.2f' % args.augment_ratio
        args.evalf += '_aug%.2f' % args.augment_ratio


    # evaluation checkpoints
    if args.stage in ['dy']:
        'your_name_file'


    return args
