"""
eval program
"""

import os
import time
import sys
import copy

import multiprocessing as mp


import argparse
import numpy as np
import torch
import random
import pickle
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
import matplotlib.pylab as plt

from config import gen_args
from data import Dynamics_Dataset
from data import prepare_input, get_scene_info, get_env_group, prepare_input_boundary_free
from models import Model, ChamferLoss
from utils import make_graph, check_gradient, set_seed, AverageMeter, get_lr, Tee, dy_loss
from utils import count_parameters, my_collate, ChamferLoss, capture_scene_image, get_image_to_tensor_balanced
from tqdm import tqdm
import imageio
from chamferdist import ChamferDistance


# args = gen_args()

eval_path = '/home/htxue/data/mit/visual_dynamics/VGPL-Dynamics-Prior/dump_new/FluidPourExtra_0317/pstep2_chamferatio1_files_dy_nHis4_aug0.02_nroll500_lr0.0001_alignstep2_modeonly_shape'



# val
val_rollout = [97, 297, 397, 450, 497]
# val_rollout = [391, 491, 492, 495]
# val_rollout = [397]
# train
# val_rollout = [0, 100, 200, 300]

# pour
eval_target_list = [(1, 3000)]
# granulat
# eval_target_list = [(14, 1000)]


val_frames_per_rollout = 1
more_image = True


args = np.load(eval_path + "/args.npy", allow_pickle=True).item()
# args.add_norm_vector = False
# args.state_dim = 3
args.eval = 1
args.rolling_num = 45
args.augment_ratio = 0


if 'pour' in args.env:
    st_idx_min = 10
if 'granular' in args.env:
    st_idx_min = 25
    args.rolling_num = 30


# args.n_rollout = 100
# args.time_step = 10
### training


# load training data
time_st_load = time.time()

print("DATASET LOADED ---- TIME CMD = ", time.time() - time_st_load)

for eval_epc, eval_iter in tqdm(eval_target_list):
    print(f'Start evaluating epc{eval_epc}')
    args.eval_epoch, args.eval_iter = eval_epc, eval_iter
    # args.interval = 0.15
    set_seed(args.random_seed)

    # os.system('mkdir -p ' + args.dataf)
    # os.system('mkdir -p ' + args.outf)
    #
    # tee = Tee(os.path.join(args.outf, 'val.log'), 'w')
    # create model and train
    use_gpu = torch.cuda.is_available()
    model = Model(args, use_gpu)
    print("model_kp #params: %d" % count_parameters(model))

    #
    # args.eval_epoch = 19
    # args.eval_iter = 25100
    if args.eval_epoch < 0:
        model_name = 'net_best.pth'
    else:
        model_name = 'model/net_epoch_%d_iter_%d.pth' % (args.eval_epoch, args.eval_iter)

    model_path = os.path.join(args.outf, model_name)
    print("Loading network from %s" % model_path)

    if args.stage == 'dy':
        pretrained_dict = torch.load(model_path)
        model_dict = model.state_dict()
        # only load parameters in dynamics_predictor
        pretrained_dict = {
            k: v for k, v in pretrained_dict.items() \
            if 'dynamics_predictor' in k and k in model_dict}
        model.load_state_dict(pretrained_dict, strict=False)
        print(f"Model load from {model_path} --- successfully")

    else:
        AssertionError("Unsupported stage %s, using other evaluation scripts" % args.stage)

    model.eval()


    if use_gpu:
        model = model.cuda()



    if use_gpu:
        model = model.cuda()


    model.eval()

    # log args
    print(args)

    st_epoch = args.resume_epoch if args.resume_epoch > 0 else 0

    crit = ChamferDistance()



    #  -----------------load dataset (GT)------------------------- #
    if args.using_gt:
        p_data_path = os.path.join(args.root, 'dataset/GT_particle_fps/{}/fps_{}/'.format(args.env, args.fps))
    else:
        p_data_path = os.path.join(args.root, 'dataset/particle_fps_new/{}/fps_{}/'.format(args.env, args.fps))


    particle_states = []  # len = n_rollout, item : time_step * N * state_dim
    for rollout_idx in range(500):
        if args.using_gt:
            if args.same_order:
                dump = np.load(p_data_path + "0.npy".format(rollout_idx))[0]  # 200 * 1
                info_path = os.path.join(args.raw_data_path, str(rollout_idx), "info.p")
                info = pickle.load(open(info_path, 'rb'))['particles']
                particle_states.append(info[:, dump, :])
            else:
                dump = np.load(p_data_path + "{}.npy".format(rollout_idx))  # 200 * 1
                info_path = os.path.join(args.raw_data_path, str(rollout_idx), "info.p")
                info = pickle.load(open(info_path, 'rb'))['particles']
                batch_idx = np.array(range(0, dump.shape[0])).repeat(dump.shape[1])
                particle_states.append(info[batch_idx, dump.flatten(), :].reshape(dump.shape[0], dump.shape[1], -1))
        else:
            dump = torch.load(p_data_path + "/{}.bin".format(rollout_idx))
            particle_states.append(dump)

    if 'extra' not in args.env:
        shape_data_path = os.path.join(args.root, f'dataset/box_sampling/{args.env}/sampling_8_interval_{args.interval}/')
        shape_states = []
        shape_states = np.load(shape_data_path + "pkg.npy")  # 500 * 300 * n_s * 3
        n_p, n_s = args.fps, shape_states.shape[2]
    else:
        if args.env == 'pour_extra':
            if args.add_norm_vector or args.boundary_free:
                shape_data_path = os.path.join(args.root,
                                               f'dataset/box_sampling_with_norm/{args.env}/sampling_8_interval_0.15/')
            else:
                shape_data_path = os.path.join(args.root, f'dataset/box_sampling/{args.env}/sampling_8_interval_0.15/')
            shape_states = []
            for i in range(5):
                p = np.load(shape_data_path + f"pkg{i}.npy")
                shape_states += [p[o] for o in range(p.shape[0])]
    #  -----------------load dataset (GT)------------------------- #





    particle_states = [np.stack(item) for item in particle_states] # 500 * 300 * 200 *3



    # memory: B x mem_nlayer x (n_particle + n_shape) x nf_memory
    # for now, only used as a placeholder

    # start rolling validation



    for roll_idx in val_rollout:
        frame_idx_min = args.n_his + st_idx_min

        frame_idx_max = args.time_step // args.frame_jump - args.rolling_num - 1
        print(frame_idx_min, frame_idx_max, val_frames_per_rollout)
        val_frame_st_idx = random.sample(range(frame_idx_min, frame_idx_max + 1), val_frames_per_rollout)\

        o = args.outf + "/" + f"eval/roll_{roll_idx}/epoch{args.eval_epoch}iter{args.eval_iter}/"
        if not os.path.exists(o):
            os.makedirs(o)

        loss_list = []

        for test_patch_idx, st_idx in enumerate(val_frame_st_idx):
            print(f"TEST BEGIN : roll [{roll_idx}] , frame [{st_idx}] -> [{st_idx + args.rolling_num - 1}]")
            savefig_pth = os.path.join(o, f"{st_idx}_{st_idx+args.rolling_num}")
            if not os.path.exists(savefig_pth):
                os.makedirs(savefig_pth)

            for test_rolling_idx, frame_idx in enumerate(range(st_idx, st_idx + args.rolling_num)):

                if test_rolling_idx == 0:
                    data_p_l = []  # len = n_his, using raw particle order, used in all rollings
                    data_s_l = []  # len = n_his, using raw shape    order
                    for i in range(st_idx - args.n_his, st_idx):
                        # print(roll_idx, i, args.frame_jump)
                        if args.add_norm_vector:
                            data_p_l.append(np.concatenate([particle_states[roll_idx][i * args.frame_jump],
                                                            np.zeros(particle_states[roll_idx][i * args.frame_jump].shape)], -1))
                            data_s_l.append(shape_states[roll_idx][i * args.frame_jump])
                        else:
                            data_p_l.append(particle_states[roll_idx][i*args.frame_jump])
                            data_s_l.append(shape_states[roll_idx][i*args.frame_jump])
                    n_p, n_s = data_p_l[0].shape[0], data_s_l[0].shape[0]

                anchor_p = data_p_l[args.align_step + 1]
                # reordering

                state_cur = torch.FloatTensor(np.concatenate([np.stack(data_p_l), np.stack(data_s_l)[:, :, :3]], 1))
                state_cur = state_cur.unsqueeze(0)


                if not args.boundary_free:

                    attr, _, Rr_cur, Rs_cur = prepare_input(state_cur[0, -1, ...].numpy(), n_p, n_s, args)


                else:
                    attr, _, Rr_cur, Rs_cur = prepare_input_boundary_free(state_cur[0, -1, ...].numpy(), n_p, n_s, args, norm=np.stack(data_s_l)[-1, :, 3:])
                    # state_cur dim 3 -> 6 , 1 * 4 * 910 * 3
                    state_cur_expand = []
                    for i in range(state_cur.shape[1]):
                        _, particle_new, _, _ = prepare_input_boundary_free(state_cur[0, i].numpy(), n_p, n_s, args, norm=np.stack(data_s_l)[i, :, 3:])
                        state_cur_expand.append(particle_new)
                    state_cur_expand = torch.stack(state_cur_expand).unsqueeze(0)
                    state_cur = state_cur_expand

                if use_gpu:
                    attr = attr.cuda()
                    Rr_cur = Rr_cur.cuda()
                    Rs_cur = Rs_cur.cuda()
                    state_cur = state_cur.cuda()

                memory_init = model.init_memory(1, n_s+n_p)

                attr = attr.unsqueeze(0)
                Rr_cur = Rr_cur.unsqueeze(0)
                Rs_cur = Rs_cur.unsqueeze(0)

                group_gt = get_env_group(args, [n_p], [n_s], scene_params=np.array([[0, 0, 0]]), use_gpu=use_gpu)



                inputs = [attr, state_cur, Rr_cur, Rs_cur, group_gt]



                pred_pos, pred_motion_norm = model.predict_dynamics(inputs)

                pred_pos = pred_pos[:, :n_p, :]

                # concatenate the state of the shapes
                # pred_pos (unnormalized): B x (n_p + n_s) x state_dim
                gt_pos = torch.from_numpy(particle_states[roll_idx][frame_idx*args.frame_jump]).unsqueeze(0)

                loss = (crit(pred_pos.detach().cpu(), gt_pos) + crit(gt_pos, pred_pos.detach().cpu())) / 2

                gt_ps = np.concatenate([particle_states[roll_idx][frame_idx*args.frame_jump], shape_states[roll_idx][frame_idx*args.frame_jump][:, :3]])

                # print("Chamfer Loss : ", loss)

                loss_list.append(loss.item())

                if args.add_norm_vector:
                    data_p_l = data_p_l[1:] + [np.concatenate([pred_pos[0].detach().cpu().numpy(), np.zeros(pred_pos[0].shape)], -1)]
                    data_s_l = data_s_l[1:] + [shape_states[roll_idx][frame_idx * args.frame_jump]]

                else:

                    data_p_l = data_p_l[1:] + [pred_pos[0].detach().cpu().numpy()]
                    data_s_l = data_s_l[1:] + [shape_states[roll_idx][frame_idx*args.frame_jump]]

                angle = 60


                if args.env in ['pour', 'pour_extra']:
                    capture_scene_image(np.concatenate([gt_ps, pred_pos[0].detach().cpu().numpy()]), angle=angle,
                                        output_pth=savefig_pth + f"/{test_rolling_idx}.png",
                                        color=['b']*n_p + ['g']*n_s + ['r']*n_p)

                    if more_image:
                        capture_scene_image(np.concatenate([gt_ps[n_p:], pred_pos[0].detach().cpu().numpy()]), angle=angle,
                                            output_pth=savefig_pth + f"/{test_rolling_idx}_pred.png",
                                            color=['g'] * n_s + ['r'] * n_p)
                        capture_scene_image(np.concatenate([gt_ps]), angle=angle,
                                            output_pth=savefig_pth + f"/{test_rolling_idx}_gt.png",
                                            color=['b'] * n_p + ['g'] * n_s)
                elif args.env == 'shake':
                    capture_scene_image(
                        np.concatenate([gt_ps, pred_pos[0].detach().cpu().numpy()]),
                        dset='shake',
                        angle=150,
                        output_pth=savefig_pth + f"/{test_rolling_idx}.png",
                        color=['lime'] * 64 + ['b'] * args.fps + ['g'] * 208 + ['orange'] * 64 + ['r'] * args.fps)
                    if more_image:
                        capture_scene_image(
                         gt_ps,
                         dset='shake',
                         angle=150,
                         output_pth=savefig_pth + f"/{test_rolling_idx}_gt.png",
                         color=['lime'] * 64 + ['b'] * args.fps + ['g'] * 208
                        )
                        capture_scene_image(
                            np.concatenate([gt_ps[64+args.fps:, :], pred_pos[0].detach().cpu().numpy()]),
                            dset='shake',
                            angle=150,
                            output_pth=savefig_pth + f"/{test_rolling_idx}_pred.png",
                            color=['g'] * 208 + ['orange'] * 64 + ['r'] * args.fps
                        )
                elif args.env == 'granular_push':
                    capture_scene_image(np.concatenate([gt_ps, pred_pos[0].detach().cpu().numpy()]), angle=100,
                                        output_pth=savefig_pth + f"/{test_rolling_idx}.png",
                                        color=['b'] * n_p + ['g'] * n_s + ['r'] * n_p, dset='granular_push')

                    if more_image:
                        capture_scene_image(np.concatenate([gt_ps[n_p:], pred_pos[0].detach().cpu().numpy()]), angle=100,
                                            output_pth=savefig_pth + f"/{test_rolling_idx}_pred.png",
                                            color=['g'] * n_s + ['r'] * n_p, dset='granular_push', h=30)
                        capture_scene_image(np.concatenate([gt_ps]), angle=100,
                                            output_pth=savefig_pth + f"/{test_rolling_idx}_gt.png",
                                            color=['b'] * n_p + ['g'] * n_s, dset='granular_push', h=30)
                # reorder n_his states using anchor data
            ## gen video:


            image2tensor = get_image_to_tensor_balanced()

            frame = []
            frame_divide = []

            for i in range(len(range(st_idx, st_idx + args.rolling_num))):
                image = imageio.imread(savefig_pth + "/" + str(i) + ".png")
                image[(image == 0).all(-1), :] = 255
                image = image2tensor(image)
                image = image.permute(1, 2, 0)
                frame.append(image.cpu().numpy())


                if more_image:

                    image_pred = imageio.imread(savefig_pth + "/" + str(i) + "_pred.png")
                    image_pred[(image_pred == 0).all(-1), :] = 255
                    image_pred = image2tensor(image_pred)
                    image_pred = image_pred.permute(1, 2, 0)

                    image_gt = imageio.imread(savefig_pth + "/" + str(i) + "_gt.png")
                    image_gt[(image_gt == 0).all(-1), :] = 255
                    image_gt = image2tensor(image_gt)
                    image_gt = image_gt.permute(1, 2, 0)



                    frame_divide.append(np.concatenate([image_pred.cpu().numpy(), image_gt.cpu().numpy()], 1))



            frames = np.stack(frame)
            print(frames.shape)
            imageio.mimwrite(
                savefig_pth + "/" + 'video.mp4', (frames * 255).astype(np.uint8), quality=8
            )
            if more_image:
                frames_dvide= np.stack(frame_divide)
                print(frames_dvide.shape)
                imageio.mimwrite(
                    savefig_pth + "/" + 'video_divide.mp4', (frames_dvide * 255).astype(np.uint8), quality=8
                )

            np.save( savefig_pth + "/" + 'loss.npy', np.array(loss_list))
            os.system('rm -rf {}/*.png'.format(savefig_pth))







