import argparse
import numpy as np
import torch


### build arguments
parser = argparse.ArgumentParser()
parser.add_argument('--env', default='RigidFall')
parser.add_argument('--stage', default='dy', help="dy: dynamics model")
parser.add_argument('--pstep', type=int, default=2)
parser.add_argument('--random_seed', type=int, default=42)

parser.add_argument('--time_step', type=int, default=0)
parser.add_argument('--dt', type=float, default=1. / 60.)
parser.add_argument('--n_instance', type=int, default=-1)

parser.add_argument('--nf_relation', type=int, default=150)
parser.add_argument('--nf_particle', type=int, default=150)
parser.add_argument('--nf_pos', type=int, default=150)
parser.add_argument('--nf_memory', type=int, default=150)
parser.add_argument('--mem_nlayer', type=int, default=2)
parser.add_argument('--nf_effect', type=int, default=150)

parser.add_argument('--outf', default='files')
parser.add_argument('--evalf', default='eval')
parser.add_argument('--dataf', default='data')

parser.add_argument('--eval', type=int, default=0)
parser.add_argument('--verbose_data', type=int, default=0)
parser.add_argument('--verbose_model', type=int, default=0)
parser.add_argument('--eps', type=float, default=1e-6)

# for ablation study
parser.add_argument('--neighbor_radius', type=float, default=-1)
parser.add_argument('--neighbor_k', type=float, default=-1)

# use a flexible number of frames for each training iteration
parser.add_argument('--n_his', type=int, default=4)
parser.add_argument('--sequence_length', type=int, default=5)

# shape state:
# [x, y, z, x_last, y_last, z_last, quat(4), quat_last(4)]
parser.add_argument('--shape_state_dim', type=int, default=14)

# object attributes:
parser.add_argument('--attr_dim', type=int, default=0)

# object state:
parser.add_argument('--state_dim', type=int, default=0)

# relation attr:
parser.add_argument('--relation_dim', type=int, default=0)

# physics parameter
parser.add_argument('--physics_param_range', type=float, nargs=2, default=None)

# width and height for storing vision
parser.add_argument('--vis_width', type=int, default=160)
parser.add_argument('--vis_height', type=int, default=120)
parser.add_argument('--date', type=str, default="0929")


'''
train
'''

parser.add_argument('--n_rollout', type=int, default=0)
parser.add_argument('--train_valid_ratio', type=float, default=0.8)
parser.add_argument('--num_workers', type=int, default=10)
parser.add_argument('--log_per_iter', type=int, default=50)
parser.add_argument('--ckp_per_iter', type=int, default=1000)

parser.add_argument('--n_epoch', type=int, default=1000)
parser.add_argument('--beta1', type=float, default=0.9)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--optimizer', default='Adam', help='Adam|SGD')
parser.add_argument('--max_grad_norm', type=float, default=1.0)
parser.add_argument('--batch_size', type=int, default=16)

# data generation
parser.add_argument('--gen_data', type=int, default=0)
parser.add_argument('--gen_stat', type=int, default=0)
parser.add_argument('--gen_vision', type=int, default=0)

parser.add_argument('--resume', type=int, default=0)
parser.add_argument('--resume_epoch', type=int, default=-1)
parser.add_argument('--resume_iter', type=int, default=-1)

# data augmentation
parser.add_argument('--augment_ratio', type=float, default=0.02)


'''
eval
'''
parser.add_argument('--eval_epoch', type=int, default=-1, help='pretrained model')
parser.add_argument('--eval_iter', type=int, default=-1, help='pretrained model')
parser.add_argument('--eval_set', default='demo')
parser.add_argument('--rolling_num', default=5, type=int)

# visualization flog
parser.add_argument('--pyflex', type=int, default=1)
parser.add_argument('--vispy', type=int, default=1)


'''
new in our work
'''
parser.add_argument('--fps', type=int, default=300)
parser.add_argument('--using_gt', type=int, default=0)
parser.add_argument('--root', type=str, default='./')
parser.add_argument('--raw_data_path', type=str,
                    default='')
parser.add_argument('--trans_invar', type=bool, default=True)
parser.add_argument('--mode', type=str, default='only_shape')
parser.add_argument('--align_step', type=int, default=-2)
parser.add_argument('--vis', type=int, default=50)
parser.add_argument('--dist_cp', type=float, default=0) # neighbour distance bewtween container and particle
parser.add_argument('--frame_jump', type=int, default=5)
parser.add_argument('--chamfer_ratio', type=float, default=1)
parser.add_argument('--squared_chamfer', type=int, default=0)
parser.add_argument('--interval', type=float, default=0.1)
parser.add_argument('--same_order', type=int, default=0)
parser.add_argument('--coalition_loss', action="store_true")
parser.add_argument('--coalition_bar', default=0.6, type=float)
parser.add_argument('--emd_loss_ratio', default=0.0, type=float)

# multistep training
parser.add_argument('--train_extra_steps', action="store_true", help="in training process, using loss accumulated in multi-rolling")
parser.add_argument('--train_extra_steps_num', type=int, default=1)
parser.add_argument('--time_watch', action='store_true', help='ouput the time consumption')


# add norm vector to container
parser.add_argument('--add_norm_vector', action='store_true', help='add norm vector information to the container')

# using boundary free mode
parser.add_argument("--boundary_free", action='store_true', help='open the boundary free mode')
parser.add_argument("--coal_weight", type=float, default=1)


"""
args.mode:
 - wo residual : net(t-3, t-2, t-1, 0) -> t
 - emd residual : net(d(t-3), d(t-2), d(t-1), 0), d(t-1) = emd_(t)(t-1) - p(t-1)
 - container max : ?
"""



def gen_args():
    args = parser.parse_args()

    args.data_names = ['positions', 'shape_quats', 'scene_params']

    if args.env == 'RigidFall':
        args.env_idx = 3

        args.n_rollout = 5000
        args.time_step = 121

        # object states:
        # [x, y, z]
        args.state_dim = 3

        # object attr:
        # [particle, floor]
        args.attr_dim = 2

        args.neighbor_radius = 0.08
        args.neighbor_k = 20

        suffix = ''
        if args.n_instance == -1:
            args.n_instance = 3
        else:
            suffix += '_nIns_' + str(args.n_instance)

        args.physics_param_range = (-15., -5.)

        args.outf = 'dump/dump_RigidFall/' + args.outf + '_' + args.stage + suffix
        args.evalf = 'dump/dump_RigidFall/' + args.evalf + '_' + args.stage + suffix

        args.mean_p = np.array([0.14778039, 0.15373468, 0.10396217])
        args.std_p = np.array([0.27770899, 0.13548609, 0.15006677])
        args.mean_d = np.array([-1.91248869e-05, -2.05043765e-03, 2.10580908e-05])
        args.std_d = np.array([0.00468072, 0.00703023, 0.00304786])

    elif args.env == 'MassRope':
        args.env_idx = 9

        args.n_rollout = 3000
        args.time_step = 201

        # object states:
        # [x, y, z]
        args.state_dim = 3

        # object attr:
        # [particle, pin]
        args.attr_dim = 2

        args.neighbor_radius = 0.25
        args.neighbor_k = -1

        suffix = ''
        if args.n_instance == -1:
            args.n_instance = 2
        else:
            suffix += '_nIns_' + str(args.n_instance)

        args.physics_param_range = (0.25, 1.2)

        args.outf = 'dump/dump_MassRope/' + args.outf + '_' + args.stage + suffix
        args.evalf = 'dump/dump_MassRope/' + args.evalf + '_' + args.stage + suffix

        args.mean_p = np.array([0.06443707, 1.09444374, 0.04942945])
        args.std_p = np.array([0.45214754, 0.29002383, 0.41175843])
        args.mean_d = np.array([-0.00097918, -0.00033966, -0.00080952])
        args.std_d = np.array([0.02086366, 0.0145161, 0.01856096])

    elif args.env == 'pour':
        args.raw_data_path = '/home/htxue/datasets/data_FluidPour/'
        args.env_idx = 3

        args.n_rollout = args.n_rollout
        args.time_step = 300

        # object states:
        # [x, y, z]
        args.state_dim = 3

        # object attr:
        # [particle, shape]
        args.attr_dim = 2

        # args.neighbor_radius = 0.05
        args.dist_cp = args.neighbor_radius * 3 if args.dist_cp < 0 else args.dist_cp
        args.neighbor_k = 20

        suffix = ''
        if args.n_instance == -1:
            args.n_instance = 1
        else:
            suffix += '_nIns_' + str(args.n_instance)

        args.physics_param_range = (-15., -5.)

        if args.using_gt:
            pref = f'dump_new/FluidPour_gt_{args.date}/'
            if args.same_order:
                pref = f'dump_new/FluidPour_gt_{args.date}_same_order/'
        else:
            pref = f'dump_new/FluidPour_{args.date}/'

        pref += f"pstep{args.pstep}_"


        if args.squared_chamfer:
            pref += f"SQchamfer_"

        if args.frame_jump != 5:
            pref += f"framejump{args.frame_jump}_"

        if args.nf_effect != 150:
            pref += f"effect{args.nf_effect}_"
        if args.chamfer_ratio != 0:
            pref += f"chamferatio{args.chamfer_ratio}_"

        args.outf = pref + args.outf + '_' + args.stage + suffix
        args.evalf = pref + args.evalf + '_' + args.stage + suffix

        args.mean_p = np.array([0, 0, 0])
        args.std_p = np.array([1, 1, 1])
        args.mean_d = np.array([0, 0, 0])
        args.std_d = np.array([1, 1, 1])


    elif args.env == 'shake':
        args.raw_data_path = '/home/htxue/datasets/data_FluidShake/'
        args.env_idx = 3

        args.n_rollout = args.n_rollout
        args.time_step = 300

        # object states:
        # [x, y, z]
        args.state_dim = 3

        # object attr:
        # [particle_fluid, particle_cube, shape]
        args.attr_dim = 3

        args.neighbor_radius = args.neighbor_radius
        args.dist_cp = args.neighbor_radius * 3 if args.dist_cp < 0 else args.dist_cp
        args.neighbor_k = 20

        suffix = ''

        args.n_instance = 2  # fluid cube

        if args.n_instance == -1:
            args.n_instance = 1
        else:
            suffix += '_nIns_' + str(args.n_instance)

        args.physics_param_range = (-15., -5.)

        if args.using_gt:
            pref = f'dump_new/FluidShake_gt_{args.date}/'
            if args.same_order:
                pref = f'dump_new/FluidShake_gt_{args.date}_same_order/'
        else:
            pref = f'dump_new/FluidShake_{args.date}/'


        if args.frame_jump != 5:
            pref += f"framejump{args.frame_jump}_"

        if args.nf_effect != 150:
            pref += f"effect{args.nf_effect}_"
        if args.chamfer_ratio != 0:
            pref += f"chamferatio{args.chamfer_ratio}_"

        pref += f"pstep{args.pstep}_"

        args.outf = pref + args.outf + '_' + args.stage + suffix
        args.evalf = pref + args.evalf + '_' + args.stage + suffix


        args.mean_p = np.array([0, 0, 0])
        args.std_p = np.array([1, 1, 1])
        args.mean_d = np.array([0, 0, 0])
        args.std_d = np.array([1, 1, 1])

    elif args.env == 'pour_extra':
        args.raw_data_path = '/home/htxue/datasets/data_FluidPourExtra/'
        args.env_idx = 3

        args.n_rollout = args.n_rollout
        args.time_step = 300

        # object states:
        # [x, y, z]
        args.state_dim = 3

        # object attr:
        # [particle, shape]
        args.attr_dim = 2

        args.neighbor_radius = args.neighbor_radius
        args.dist_cp = args.neighbor_radius * 3 if args.dist_cp < 0 else args.dist_cp
        args.neighbor_k = 20

        suffix = ''
        if args.n_instance == -1:
            args.n_instance = 1
        else:
            suffix += '_nIns_' + str(args.n_instance)

        args.physics_param_range = (-15., -5.)

        if args.using_gt:
            pref = 'dump_new/FluidPourExtra_gt/'
        else:
            pref = f'dump_new/FluidPourExtra_{args.date}/'

        pref += f"pstep{args.pstep}_"

        if args.squared_chamfer:
            pref += f"SQchamfer_"

        if args.frame_jump != 5:
            pref += f"framejump{args.frame_jump}_"

        if args.nf_effect != 150:
            pref += f"effect{args.nf_effect}_"
        if args.chamfer_ratio != 0:
            pref += f"chamferatio{args.chamfer_ratio}_"

        args.outf = pref + args.outf + '_' + args.stage + suffix
        args.evalf = pref + args.evalf + '_' + args.stage + suffix

        args.mean_p = np.array([0, 0, 0])
        args.std_p = np.array([1, 1, 1])
        args.mean_d = np.array([0, 0, 0])
        args.std_d = np.array([1, 1, 1])

    elif args.env == 'granular_push':
        args.raw_data_path = '/home/htxue/datasets/data_GranularPushExtra/'
        args.env_idx = 4

        args.n_rollout = args.n_rollout
        args.time_step = 300

        # object states:
        # [x, y, z]
        args.state_dim = 3

        # object attr:
        # [particle, shape]
        args.attr_dim = 2

        args.neighbor_radius = args.neighbor_radius
        args.dist_cp = args.neighbor_radius * 3 if args.dist_cp < 0 else args.dist_cp
        args.neighbor_k = 20

        suffix = ''
        if args.n_instance == -1:
            args.n_instance = 1
        else:
            suffix += '_nIns_' + str(args.n_instance)

        args.physics_param_range = (-15., -5.)

        if args.using_gt:
            pref = 'dump_new/GranularPush_gt/'
        else:
            pref = f'dump_new/GranularPush_{args.date}/'

        pref += f"pstep{args.pstep}_"

        if args.squared_chamfer:
            pref += f"SQchamfer_"

        if args.frame_jump != 5:
            pref += f"framejump{args.frame_jump}_"

        if args.nf_effect != 150:
            pref += f"effect{args.nf_effect}_"
        if args.chamfer_ratio != 0:
            pref += f"chamferatio{args.chamfer_ratio}_"

        args.outf = pref + args.outf + '_' + args.stage + suffix
        args.evalf = pref + args.evalf + '_' + args.stage + suffix

        args.mean_p = np.array([0, 0, 0])
        args.std_p = np.array([1, 1, 1])
        args.mean_d = np.array([0, 0, 0])
        args.std_d = np.array([1, 1, 1])

    elif args.env == 'shake_extra':
        args.raw_data_path = '/home/htxue/datasets/data_FluidShakeExtra_new/'
        args.env_idx = 5

        args.n_rollout = args.n_rollout
        args.time_step = 300

        # object states:
        # [x, y, z]
        args.state_dim = 3

        # object attr:
        # [particle, shape, boxes]
        args.attr_dim = 3

        args.neighbor_radius = args.neighbor_radius
        args.dist_cp = args.neighbor_radius * 3 if args.dist_cp < 0 else args.dist_cp
        args.neighbor_k = 20

        suffix = ''

        args.n_instance = 4

        if args.n_instance == -1:
            args.n_instance = 1
        else:
            suffix += '_nIns_' + str(args.n_instance)

        args.physics_param_range = (-15., -5.)

        if args.using_gt:
            pref = 'dump_new/FluidShakeExtra_gt/'
        else:
            pref = f'dump_new/FluidShakeExtra_{args.date}/'

        pref += f"pstep{args.pstep}_"

        if args.squared_chamfer:
            pref += f"SQchamfer_"

        if args.frame_jump != 5:
            pref += f"framejump{args.frame_jump}_"

        if args.nf_effect != 150:
            pref += f"effect{args.nf_effect}_"
        if args.chamfer_ratio != 0:
            pref += f"chamferatio{args.chamfer_ratio}_"

        args.outf = pref + args.outf + '_' + args.stage + suffix
        args.evalf = pref + args.evalf + '_' + args.stage + suffix

        args.mean_p = np.array([0, 0, 0])
        args.std_p = np.array([1, 1, 1])
        args.mean_d = np.array([0, 0, 0])
        args.std_d = np.array([1, 1, 1])

    else:
        raise AssertionError("Unsupported env")


    if args.add_norm_vector or args.boundary_free:
        args.state_dim += 3



    # path to data
    args.dataf = 'data/' + args.dataf + '_' + args.env


    # n_his
    args.outf += '_nHis%d' % args.n_his
    args.evalf += '_nHis%d' % args.n_his


    # data augmentation
    if args.augment_ratio > 0:
        args.outf += '_aug%.2f' % args.augment_ratio
        args.evalf += '_aug%.2f' % args.augment_ratio


    # evaluation checkpoints
    if args.stage in ['dy']:
        if args.eval_epoch > -1:
            args.evalf += '_dyEpoch_' + str(args.eval_epoch)
            args.evalf += '_dyIter_' + str(args.eval_iter)
        else:
            args.evalf += '_dyEpoch_best'

        args.evalf += '_%s' % args.eval_set


    return args
