import argparse
import numpy as np
import torch


### build arguments
parser = argparse.ArgumentParser()
parser.add_argument("--expname", type=str,
                    help='experiment name')
parser.add_argument("--basedir", type=str, default='./logs/',
                    help='where to store ckpts and logs')
parser.add_argument("--datadir", type=str, default='./data/llff/fern',
                    help='input data directory')

# training options
parser.add_argument("--netdepth", type=int, default=8,
                    help='layers in network')
parser.add_argument("--netwidth", type=int, default=256,
                    help='channels per layer')
parser.add_argument("--netdepth_fine", type=int, default=8,
                    help='layers in fine network')
parser.add_argument("--netwidth_fine", type=int, default=256,
                    help='channels per layer in fine network')
parser.add_argument("--N_rand", type=int, default=32*32*4,
                    help='batch size (number of random rays per gradient step)')
parser.add_argument("--lrate", type=float, default=5e-4,
                    help='learning rate')
parser.add_argument("--lrate_decay", type=int, default=250,
                    help='exponential learning rate decay (in 1000 steps)')
parser.add_argument("--chunk", type=int, default=1024*32,
                    help='number of rays processed in parallel, decrease if running out of memory')
parser.add_argument("--netchunk", type=int, default=1024*64,
                    help='number of pts sent through network in parallel, decrease if running out of memory')
parser.add_argument("--no_batching", action='store_true',
                    help='only take random rays from 1 image at a time')
parser.add_argument("--no_reload", action='store_true',
                    help='do not reload weights from saved ckpt')
parser.add_argument("--ft_path", type=str, default=None,
                    help='specific weights npy file to reload for coarse network')

# rendering options
parser.add_argument("--N_samples", type=int, default=64,
                    help='number of coarse samples per ray')
parser.add_argument("--N_importance", type=int, default=0,
                    help='number of additional fine samples per ray')
parser.add_argument("--perturb", type=float, default=1.,
                    help='set to 0. for no jitter, 1. for jitter')
parser.add_argument("--use_viewdirs", action='store_true',
                    help='use full 5D input instead of 3D')
parser.add_argument("--i_embed", type=int, default=0,
                    help='set 0 for default positional encoding, -1 for none')
parser.add_argument("--multires", type=int, default=10,
                    help='log2 of max freq for positional encoding (3D location)')
parser.add_argument("--multires_views", type=int, default=4,
                    help='log2 of max freq for positional encoding (2D direction)')
parser.add_argument("--raw_noise_std", type=float, default=0.,
                    help='std dev of noise added to regularize sigma_a output, 1e0 recommended')

parser.add_argument("--render_only", action='store_true',
                    help='do not optimize, reload weights and render out render_poses path')
parser.add_argument("--render_test", action='store_true',
                    help='render the test set instead of render_poses path')
parser.add_argument("--render_factor", type=int, default=0,
                    help='downsampling factor to speed up rendering, set 4 or 8 for fast preview')

# training options
parser.add_argument("--precrop_iters", type=int, default=0,
                    help='number of steps to train on central crops')
parser.add_argument("--precrop_frac", type=float,
                    default=.5, help='fraction of img taken for central crops')

# dataset options
parser.add_argument("--dataset_type", type=str, default='llff',
                    help='options: llff / blender / deepvoxels')
parser.add_argument("--testskip", type=int, default=8,
                    help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels')

## deepvoxels flags
parser.add_argument("--shape", type=str, default='greek',
                    help='options : armchair / cube / greek / vase')

## blender flags
parser.add_argument("--white_bkgd", action='store_true',
                    help='set to render synthetic data on a white bkgd (always use for dvoxels)')
parser.add_argument("--half_res", action='store_true',
                    help='load blender synthetic data at 400x400 instead of 800x800')

## llff flags
parser.add_argument("--factor", type=int, default=8,
                    help='downsample factor for LLFF images')
parser.add_argument("--no_ndc", action='store_true',
                    help='do not use normalized device coordinates (set for non-forward facing scenes)')
parser.add_argument("--lindisp", action='store_true',
                    help='sampling linearly in disparity rather than depth')
parser.add_argument("--spherify", action='store_true',
                    help='set for spherical 360 scenes')
parser.add_argument("--llffhold", type=int, default=8,
                    help='will take every 1/N images as LLFF test set, paper uses 8')

# logging/saving options
parser.add_argument("--i_print",   type=int, default=100,
                    help='frequency of console printout and metric loggin')
parser.add_argument("--i_img",     type=int, default=500,
                    help='frequency of tensorboard image logging')
parser.add_argument("--i_weights", type=int, default=10000,
                    help='frequency of weight ckpt saving')
parser.add_argument("--i_validset", type=int, default=2000,
                    help='fequency of validset saving')
parser.add_argument("--i_testset", type=int, default=2000,
                    help='frequency of testset saving')
parser.add_argument("--i_video",   type=int, default=10000000000,
                    help='frequency of render_poses video saving')


# =============== Options for Yunzhu ==================

# PyFleX options
parser.add_argument("--screenWidth", type=int, default=720)
parser.add_argument("--screenHeight", type=int, default=720)


# training options
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--env", default='FluidPourExtra')
parser.add_argument("--phase", default='ae')

parser.add_argument("--n_rollout", type=int, default=0,
                    help="number of rollout in the dataset")
parser.add_argument("--time_step", type=int, default=300,
                    help="number of time step in each rollout")
parser.add_argument("--n_frames", type=int, default=0,
                    help="number of frames per time step")

parser.add_argument("--prestored", type=int, default=0,
                    help="whether to load prestored embeds during dynamics learning")
parser.add_argument("--n_his", type=int, default=3,
                    help='number of frames used as the current state')
parser.add_argument("--n_roll", type=int, default=2,
                    help='number of frames to predict into the future during training')
parser.add_argument("--n_view", type=int, default=6,
                    help='number of views to sample per time step')

parser.add_argument("--ct_loss", type=int, default=1,
                    help='whether to use time contrastive loss')
parser.add_argument("--nerf_loss", type=int, default=1,
                    help='whether to use nerf reconstruction loss')
parser.add_argument("--auto_loss", type=int, default=0,
                    help='whether to use a conv autoencoder')

# training options for auto-decoding
parser.add_argument("--n_timestep_for_dec", type=int, default=3000,
                    help="the number of time steps for auto-decoding")
parser.add_argument("--n_dec_optim_iter", type=int, default=200,
                    help="the number of optimization step during test time")
parser.add_argument("--log_per_iter_dec", type=int, default=1)
parser.add_argument("--dec_store_st_idx", type=int, default=0,
                    help="the starting idx for storing the optimized embeds")
parser.add_argument("--dec_store_ed_idx", type=int, default=0,
                    help="the ending idx for storing the optimized embds")


parser.add_argument("--nf_hidden", type=int, default=256)
parser.add_argument("--act_dim", type=int, default=0)

parser.add_argument("--train_valid_ratio", type=float, default=0.8)
parser.add_argument("--batch_size", type=int, default=2)
parser.add_argument("--num_workers", type=int, default=1)

parser.add_argument("--log_per_iter", type=int, default=40)
parser.add_argument("--ckp_per_epoch", type=int, default=10)
parser.add_argument("--ckp_per_iter", type=int, default=10000)

parser.add_argument("--outf", default='files')
parser.add_argument("--dataf", default='data')
parser.add_argument("--evalf", default='eval')
parser.add_argument("--mpcf", default='mpc')

parser.add_argument("--resume", type=int, default=0)
parser.add_argument("--resume_epoch", type=int, default=0)
parser.add_argument("--resume_iter", type=int, default=0)
parser.add_argument("--n_epoch", type=int, default=1000)

# evaluation options
parser.add_argument("--eval_skip_frame", type=int, default=1)
parser.add_argument("--eval_epoch", type=int, default=-1,
                    help='the index of checkpoint to load, -1 for the current best checkpoint')
parser.add_argument("--eval_iter", type=int, default=-1,
                    help='the index of checkpoint to load, -1 for the current best checkpoint')

# MPPI options
parser.add_argument("--optim_type", default="mppi")
parser.add_argument("--ctrl_init_idx", type=int, default=100)
parser.add_argument("--n_look_ahead", type=int, default=80)
parser.add_argument("--n_sample", type=int, default=1000)
parser.add_argument("--n_update_iter_init", type=int, default=100)
parser.add_argument("--n_update_iter", type=int, default=10)
parser.add_argument("--n_update_delta", type=int, default=2)
parser.add_argument("--beta_filter", type=float, default=0.7)
parser.add_argument("--reward_base", type=float, default=3.)
parser.add_argument("--reward_weight", type=float, default=0.1)

# MPPI optimize for goal observation
parser.add_argument("--goal_camera_dist_offset", type=float, default=0.)
parser.add_argument("--n_optim_iter_goal", type=int, default=500)
parser.add_argument("--lrate_optim_goal", type=float, default=1e-2)


# new params in baseline for CVPR paper
parser.add_argument("--frame_jump", type=int, default=1)




def gen_args():
    args = parser.parse_args()

    args.data_names = ['viewMatrix', 'projMatrix', 'action', 'scene_params']

    if args.env in ['FluidManip', 'FluidPour']:

        if args.env == 'FluidPour':
            args.n_rollout = 50
            args.n_frames = 100
            args.half_res = True    # The input image is 360 x 360, use 180 for training

        elif args.env == 'FluidManip':
            args.n_rollout = 1000
            args.n_frames = 20
            args.half_res = False # The input image is already 180

        args.time_step = 300
        args.act_dim = 4

        args.n_view_enc = 4
        args.near = 2.
        args.far = 7.

    elif args.env in ['FluidManipClip']:

        args.n_rollout = 1000
        args.n_frames = 20
        args.half_res = False # The input image is already 180

        args.time_step = 300
        args.act_dim = 4

        args.n_view_enc = 4
        args.near = 2.
        args.far = 8.

    elif args.env in ['FluidManipClip_wKuka']:

        args.n_rollout = 1000
        args.n_frames = 20
        args.half_res = False # The input image is already 180

        args.time_step = 300
        args.act_dim = 4

        args.n_view_enc = 4
        args.near = 2.
        args.far = 9.5

    elif args.env in ['FluidManipClip_wKuka_wColor']:

        args.n_rollout = 1000
        args.n_frames = 20
        args.half_res = False # The input image is already 180

        args.time_step = 300
        args.act_dim = 4

        args.n_view_enc = 4
        args.near = 2.
        args.far = 9.5

    elif args.env in ['FluidShake']:

        args.n_rollout = 1000
        args.n_frames = 20
        args.half_res = False # The input image is already 180

        args.time_step = 300
        args.act_dim = 2

        args.n_view_enc = 4
        args.near = 2.
        args.far = 7.

    elif args.env in ['FluidShakeWithIce_1000']:

        args.n_rollout = 1000
        args.n_frames = 20
        args.half_res = False # The input image is already 180

        args.time_step = 300
        args.act_dim = 2

        args.n_view_enc = 4
        args.near = 2.
        args.far = 7.

    elif args.env in ['FluidShakeWithIce_wKuka']:

        args.n_rollout = 1000
        args.n_frames = 20
        args.half_res = False # The input image is already 180

        args.time_step = 300
        args.act_dim = 2

        args.n_view_enc = 4
        args.near = 2.
        args.far = 7.

    elif args.env in ['FluidShakeWithIce_wKuka_wColor_wGripper']:

        args.n_rollout = 1000
        args.n_frames = 20
        args.half_res = False # The input image is already 180

        args.time_step = 300
        args.act_dim = 2

        args.n_view_enc = 4
        args.near = 2.
        args.far = 7.

    elif args.env in ['MassRope']:
        args.n_rollout = 1000
        args.n_frames = 20
        args.half_res = False # The input image is already 180

        args.time_step = 200
        args.act_dim = 3

        args.n_view_enc = 4
        args.near = 2.
        args.far = 6.5

    elif args.env in ['RigidDrop']:
        args.n_rollout = 1000
        args.n_frames = 20
        args.half_res = False # The input image is already 180

        args.time_step = 50
        args.act_dim = 0

        args.n_view_enc = 4

        args.near = 2.
        args.far = 6.

    elif args.env in ['RigidFall']:
        args.n_rollout = 800
        args.n_frames = 20
        args.half_res = False # The input image is already 180

        args.time_step = 80
        args.act_dim = 0

        args.n_view_enc = 4
        args.near = 2.
        args.far = 7.

    elif args.env in ['GranularManip']:

        args.n_rollout = 30
        args.n_frames = 100
        args.half_res = True    # The input image is 360 x 360, use 180 for training

        args.time_step = 600
        args.act_dim = 7

        args.n_view_enc = 4
        args.near = 4.5
        args.far = 12


    elif args.env in ['FluidPour_baseline']:

        args.n_rollout = 500
        args.n_frames = 6
        args.half_res = False  # The input image is already 180

        args.time_step = 300
        args.act_dim = 4

        args.n_view_enc = 4
        args.near = 2.
        args.far = 7.

        data_root = '/home/htxue/datasets/'
        args.dataf = '/home/htxue/datasets/data_FluidPour/'
        # args.storef = args.dataf + '_dec'


    elif args.env in ['FluidPourExtra_baseline']:

        args.n_rollout = 500
        args.n_frames = 6
        args.half_res = False  # The input image is already 180

        args.time_step = 300
        args.act_dim = 4

        args.n_view_enc = 4
        args.near = 2.
        args.far = 7.

        data_root = '/home/htxue/datasets/'
        args.dataf = '/home/htxue/datasets/data_FluidPourExtra/'
        args.white_bkgd = True
        # args.storef = args.dataf + '_dec'

    else:
        raise AssertionError("Unsupported env")



    if args.phase in ['ae', 'dec']:
        args.n_his = 1
        args.n_roll = 0
        args.eval_skip_frame = 5

    elif args.phase == 'nn':
        args.eval_skip_frame = 1

    elif args.phase == 'dy':
        args.eval_skip_frame = 1

    elif args.phase == 'mpc':
        args.eval_skip_frame = 1
        args.action_lower_lim = []
        args.action_upper_lim = []


    if 'baseline' not in args.env:

        data_root = '../data/'
        args.dataf = data_root + args.dataf + '_' + args.env


    args.outf = 'dump/dump_%s/' % args.env + args.outf
    args.evalf = 'dump/dump_%s/' % args.env + args.evalf
    args.mpcf = 'dump/dump_%s/' % args.env + args.mpcf


    ### config for NeRF
    args.dataset_type = 'PyFleX'

    args.no_batching = True
    args.use_viewdirs = True
    args.N_samples = 64
    args.N_importance = 128

    args.precrop_iters = 500
    args.precrop_frac = 0.5

    args.lrate_decay = 500000


    # path to data



    # phase
    args.outf += '_%s' % args.phase
    args.evalf += '_%s' % args.phase
    args.mpcf += '_%s' % args.phase


    # n_his
    args.outf += '_nHis%d' % args.n_his
    args.evalf += '_nHis%d' % args.n_his
    args.mpcf += '_nHis%d' % args.n_his


    if args.nerf_loss != 0 and args.auto_loss != 0:
        raise AssertionError("nerf_loss and auto_loss cannot be 1 at the same time")


    if args.ct_loss:
        args.outf += '_ct'
        args.evalf += '_ct'
        args.mpcf += '_ct'

    if args.nerf_loss:
        args.outf += '_nerf'
        args.evalf += '_nerf'
        args.mpcf += '_nerf'

    if args.auto_loss:
        args.outf += '_auto'
        args.evalf += '_auto'
        args.mpcf += '_auto'


    # evaluation checkpoints
    if args.eval_iter > -1:
        args.evalf += '_dyEpoch_' + str(args.eval_epoch) + '_dyIter_' + str(args.eval_iter)
        args.mpcf += '_dyEpoch_' + str(args.eval_epoch) + '_dyIter_' + str(args.eval_iter)
    else:
        args.evalf += '_dybest'
        args.mpcf += '_dybest'

    return args
