import os, sys
import cv2
import numpy as np
import imageio
import json
import random
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm, trange

import matplotlib.pyplot as plt

import pdb
from progressbar import ProgressBar

from nerf_helpers import *
from config import gen_args
from data import DynamicsDataset
from models import DynamicsModel
from utils import set_seed, count_parameters, get_lr, AverageMeter
from utils import rand_float, rand_int, visualize_point_cloud, ChamferLoss

from physics_engine import FluidManipEngine, FluidShakeWithIceEngine
from physics_engine import FluidManipFullEngine, FluidShakeWithIceFullEngine


class Planner(object):

    def __init__(self, args):
        self.args = args
        self.n_his = args.n_his

    def trajectory_optimization(
        self,
        state_cur,      # [n_his, state_dim]
        state_goal,     # [state_dim]
        model_dy,       # the learned dynamics model
        act_seq,        # [n_his + n_look_ahead - 1, action_dim]
        n_sample,
        n_look_ahead,
        n_update_iter,
        action_lower_lim,
        action_upper_lim,
        action_lower_delta_lim,
        action_upper_delta_lim,
        use_gpu,
        reward_scale_factor=1.):

        for i in range(n_update_iter):

            act_seqs = self.sample_action_sequences(
                act_seq, n_sample,
                action_lower_lim, action_upper_lim,
                action_lower_delta_lim, action_upper_delta_lim)

            state_seqs = self.model_rollout(
                state_cur, model_dy, act_seqs, n_look_ahead, use_gpu)

            reward_seqs = reward_scale_factor * self.evaluate_traj(state_seqs, state_goal)

            print('update_iter %d/%d, max: %.4f, mean: %.4f, std: %.4f' % (
                i, n_update_iter, np.max(reward_seqs), np.mean(reward_seqs), np.std(reward_seqs)))

            act_seq = self.optimize_action(act_seqs, reward_seqs)

        # act_seq: [n_his + n_look_ahead - 1, action_dim]
        return act_seq

    def sample_action_sequences(
        self,
        init_act_seq,   # [n_his + n_look_ahead - 1, action_dim]
        n_sample,       # number of action tarjs to sample
        action_lower_lim,
        action_upper_lim,
        action_lower_delta_lim,
        action_upper_delta_lim,
        noise_type='normal'):

        action_dim = init_act_seq.shape[-1]
        beta_filter = self.args.beta_filter

        # act_seqs: [n_sample, N, action_dim]
        # act_seqs_delta: [n_sample, N - 1, action_dim]
        act_seqs = np.stack([init_act_seq] * n_sample)
        act_seqs_delta = np.stack([init_act_seq[1:] - init_act_seq[:-1]] * n_sample)

        # [n_sample, action_dim]
        act_residual = np.zeros([n_sample, action_dim])

        # only add noise to future actions
        # init_act_seq[:(n_his - 1)] are past actions
        # The action we are optimizing for the current timestep is act_seq[n_his - 1]

        # actions that go as input to the dynamics network
        for i in range(self.n_his - 2, init_act_seq.shape[0] - 1):

            if noise_type == "normal":

                if self.args.env in ['FluidManipClip', 'FluidManipClip_wKuka_wColor']:
                    # [n_sample, action_dim]
                    sigma_pos = 0.002
                    noise_pos = np.random.normal(0, sigma_pos, (n_sample, 3))    # position

                    sigma_angle = 0.01
                    noise_angle = np.random.normal(0, sigma_angle, (n_sample, 1))  # angle

                    noise_sample = np.concatenate([noise_pos, noise_angle], -1)

                elif self.args.env in ['FluidShakeWithIce_1000', 'FluidShakeWithIce_wKuka_wColor_wGripper']:
                    # [n_sample, action_dim]
                    sigma_pos = 0.005
                    noise_sample = np.random.normal(0, sigma_pos, (n_sample, 2))

            else:
                raise ValueError("unknown noise type: %s" % (noise_type))

            act_residual = beta_filter * noise_sample + act_residual * (1. - beta_filter)
            act_seqs_delta[:, i] += act_residual


            # clip delta lim
            act_seqs_delta[:, i] = np.clip(
                act_seqs_delta[:, i], action_lower_delta_lim, action_upper_delta_lim)

            act_seqs[:, i + 1] = act_seqs[:, i] + act_seqs_delta[:, i]

            # clip absolute lim
            act_seqs[:, i + 1] = np.clip(
                act_seqs[:, i + 1], action_lower_lim, action_upper_lim)


        # print(act_seqs[:5])
        # time.sleep(100)

        '''
        print(init_act_seq[:, 3])
        print(np.mean(act_seqs[:, :, 3], 0))
        time.sleep(10)
        '''

        # act_seqs: [n_sample, -1, action_dim]
        return act_seqs

    def model_rollout(
        self,
        state_cur,      # [1, n_his, state_dim]
        model_dy,       # the learned dynamics model
        act_seqs_np,    # [n_sample, -1, action_dim]
        n_look_ahead,
        use_gpu):

        # state_cur = torch.tensor(state_cur_np, device=device).float()
        _, n_his, state_dim = state_cur.shape

        act_seqs = torch.FloatTensor(act_seqs_np).float()
        if use_gpu:
            act_seqs = act_seqs.cuda()
        n_sample = act_seqs.shape[0]

        # states_cur: [n_sample, n_his, state_dim]
        states_cur = state_cur.expand([n_sample, -1, -1])

        states_pred_list = []
        assert n_look_ahead == act_seqs.shape[1] - n_his + 1

        for i in range(min(n_look_ahead, act_seqs.shape[1] - n_his + 1)):
            act_cur = act_seqs[:, i:i+n_his]

            # states_pred: [n_sample, state_dim]
            states_pred = model_dy.dynamics_prediction(states_cur, act_cur)
            states_cur = torch.cat([states_cur[:, 1:], states_pred[:, None]], 1)

            states_pred_list.append(states_pred)

        # states_pred_tensor: [n_sample, n_look_ahead, state_dim]
        states_pred_tensor = torch.stack(states_pred_list, axis=1)

        return states_pred_tensor.data.cpu().numpy()

    def evaluate_traj(
        self,
        state_seqs,     # [n_sample, n_look_ahead, state_dim]
        state_goal,     # [state_dim]
    ):

        reward_seqs = -np.sum((state_seqs[:, -1] - state_goal)**2, 1)

        # reward_seqs: [n_sample]
        return reward_seqs

    def optimize_action_CEM(    # Cross Entropy Method (CEM)
        self,
        act_seqs,       # [n_sample, -1, action_dim]
        reward_seqs     # [n_sample]
    ):

        idx = np.argsort(reward_seqs)
        # [-1, action_dim]
        return np.mean(act_seqs[idx[-5:], :, :], 0)

    def optimize_action(   # Model-Predictive Path Integral (MPPI)
        self,
        act_seqs,       # [n_sample, -1, action_dim]
        reward_seqs     # [n_sample]
    ):
        # reward_base = self.args.reward_base
        reward_base = np.mean(reward_seqs)
        reward_weight = self.args.reward_weight

        # [n_sample, 1, 1]
        reward_seqs_exp = np.exp(reward_weight * (reward_seqs - reward_base))
        reward_seqs_exp = reward_seqs_exp.reshape(-1, 1, 1)

        # [-1, action_dim]
        eps = 1e-8
        act_seq = (reward_seqs_exp * act_seqs).sum(0) / (np.sum(reward_seqs_exp) + eps)

        # [-1, action_dim]
        return act_seq



def calc_state_embed(
    args, model, imgs, poses, engine=None,
    optim_embed=False, state_embed_gt=None):

    '''
    test_dir = 'test'
    os.system('mkdir -p ' + test_dir)
    cv2.imwrite(os.path.join(test_dir, 'gt_%d.png' % count), imgs[0][..., ::-1])
    '''

    '''
    img_show = np.concatenate(
        [imgs[0], imgs[1], imgs[2], imgs[3]], 1)
    cv2.imshow('image', img_show.astype(np.uint8))
    cv2.waitKey(0)
    cv2.destroyAllWindows()
    '''

    # calculating the embeddings for the current state
    imgs = (np.array(imgs) / 255.).astype(np.float32)

    H, W, _ = imgs[0].shape
    if args.half_res:
        imgs = np.array([
            cv2.resize(img, (W // 2, H // 2), interpolation=cv2.INTER_AREA)
            for img in imgs])
    else:
        imgs = np.array(imgs)


    poses = torch.FloatTensor(poses)[None, None, ...].cuda()
    imgs = torch.FloatTensor(imgs).permute(0, 3, 1, 2)[None, None, ...].cuda()

    with torch.set_grad_enabled(False):
        img_embeds = model.encode_img(imgs)

        if args.nerf_loss == 1 or args.auto_loss == 1:
            B, N, n_view_enc = img_embeds.size()[:3]
            mask_nodes = torch.ones(1, N, n_view_enc).cuda()
            state_embed = model.encode_state(
                img_embeds,
                poses,
                mask_nodes)

        elif args.nerf_loss == 0 and args.auto_loss == 0 and args.ct_loss == 1:
            state_embed = torch.mean(img_embeds, 2)
            state_embed_norm = state_embed.norm(dim=2, keepdim=True) + 1e-8
            state_embed = state_embed / state_embed_norm

        else:
            raise AssertionError("Unsupported loss combination")


    # print(state_embed)


    '''
    ### Debugging reconstruction

    # adjust the camera pose
    cam_dis = 5.
    cam_height = 1.5

    rad = np.deg2rad(20.)
    camPos = np.array([np.sin(rad) * cam_dis, cam_height, np.cos(rad) * cam_dis])
    camAngle = np.array([rad, np.deg2rad(0.), 0.])
    img, viewMatrix, projMatrix = engine.render_img(camPos, camAngle)
    pose_ori = torch.FloatTensor(np.linalg.inv(np.transpose(viewMatrix)))

    focal = projMatrix[0, 0]
    H, W = 180, 180
    focal = focal * 0.5 * W
    hwf = torch.FloatTensor([H, W, focal])


    camera_info = {
        'hwf': hwf,
        'poses': pose_ori[None, None, ...].cuda(),
        'near': args.near,
        'far': args.far}

    rgb, extras = model.render_imgs(state_embed[:, 0], camera_info)

    # store the data
    rgb_np = rgb.data.cpu().numpy().clip(0., 1.) * 255
    rgb_np = rgb_np.astype(np.uint8)[..., ::-1][0, 0]

    test_dir = 'test'
    os.system('mkdir -p ' + test_dir)
    cv2.imwrite(os.path.join(test_dir, 'rec_%d.png' % count), rgb_np)
    '''

    if optim_embed:

        state_embed = Variable(torch.FloatTensor(
            state_embed.detach().data.cpu().numpy()).cuda(), requires_grad=True)

        optimizer_emb = torch.optim.Adam(
            [state_embed], lr=args.lrate_optim_goal, betas=(0.9, 0.999))


        projMatrix = engine.get_projMatrix()
        focal = projMatrix[0, 0]
        H, W = 180, 180
        focal = focal * 0.5 * W
        hwf = torch.FloatTensor([H, W, focal])

        camera_info = {
            'hwf': hwf,
            'poses': poses[:, 0] if args.nerf_loss == 1 else poses,
            'near': args.near + args.goal_camera_dist_offset,
            'far': args.far + args.goal_camera_dist_offset}


        # render the target image
        print("Render the target image ...")

        # imgs = torch.FloatTensor(imgs).permute(0, 3, 1, 2)[None, None, ...].cuda()
        rgb_np = imgs[0, 0, 0].permute(1, 2, 0).data.cpu().numpy() * 255.
        rgb_np = rgb_np.astype(np.uint8)[..., ::-1]
        cv2.imwrite(os.path.join(args.mpcf, 'optim_target.png'), rgb_np)


        # render the image before optimization
        print("Render the image before optimization ...")
        with torch.set_grad_enabled(False):
            if args.nerf_loss == 1:
                rgb, _ = model.render_imgs(state_embed[:, 0], camera_info)
                rgb_np = rgb.data.cpu().numpy().clip(0., 1.) * 255
                rgb_np = rgb_np.astype(np.uint8)[..., ::-1][0, 0]

            elif args.auto_loss == 1:
                rgb = model.decode_img(state_embed, camera_info)
                rgb_np = rgb[0, 0, 0].permute(1, 2, 0).data.cpu().numpy().clip(0., 1.)
                rgb_np = (rgb_np * 255).astype(np.uint8)[..., ::-1]

        # store the data
        cv2.imwrite(os.path.join(args.mpcf, 'optim_before.png'), rgb_np)

        # print(state_embed[:, 0])


        for iter_optim in range(args.n_optim_iter_goal):

            with torch.set_grad_enabled(True):
                if args.nerf_loss:
                    rgb, extras, target_s = model.render_rays(
                        state_embed, camera_info, imgs)

                elif args.auto_loss:
                    rgb = model.decode_img(state_embed, camera_info)
                    target_s = imgs

            img2mse = lambda x, y : torch.mean((x - y) ** 2)
            img_loss = img2mse(rgb, target_s)

            optimizer_emb.zero_grad()
            img_loss.backward()
            optimizer_emb.step()

            state_embed_norm = state_embed.norm(2, keepdim=True) + 1e-8
            state_embed.data.div_(state_embed_norm)

            if iter_optim % 5 == 0:
                print('[%d/%d] Loss: %.6f, Dist: %.6f' % (
                    iter_optim, args.n_optim_iter_goal, img_loss.item(),
                    img2mse(state_embed, state_embed_gt).item()))


        # render the image after optimization
        # print(state_embed[:, 0])

        print("Render the image after optimization ...")
        with torch.set_grad_enabled(False):
            if args.nerf_loss == 1:
                rgb, _ = model.render_imgs(state_embed[:, 0], camera_info)
                rgb_np = rgb.data.cpu().numpy().clip(0., 1.) * 255
                rgb_np = rgb_np.astype(np.uint8)[..., ::-1][0, 0]

            elif args.auto_loss == 1:
                rgb = model.decode_img(state_embed, camera_info)
                rgb_np = rgb[0, 0, 0].permute(1, 2, 0).data.cpu().numpy().clip(0., 1.)
                rgb_np = (rgb_np * 255).astype(np.uint8)[..., ::-1]

        cv2.imwrite(os.path.join(args.mpcf, 'optim_after.png'), rgb_np)


    return state_embed.detach(), img_embeds



def eval():

    args = gen_args()

    # set_seed(args.seed)
    # set_seed((round(time.time() * 1000)) % 2**32)

    # candidate seed: 42, 17
    set_seed(42)

    use_gpu = torch.cuda.is_available()

    train_view = 0
    inter_view = 0
    extra_view = 1

    assert train_view + inter_view + extra_view == 1

    if extra_view:
        args.goal_camera_dist_offset = -1.5

    if args.env == 'FluidManipClip':
        engine = FluidManipEngine(args)

        cam_dis = 5.
        cam_height = 1.5
        cam_pitch_angle = 0.
        cam_pitch_angle_offset = -20.
        cam_yaw_offset = 0.

    elif args.env == 'FluidShakeWithIce_1000':
        engine = FluidShakeWithIceEngine(args)

        cam_dis = 3.5
        cam_height = 1.6
        cam_pitch_angle = -22.
        cam_pitch_angle_offset = -10.
        cam_yaw_offset = 180.

    elif args.env == 'FluidManipClip_wKuka_wColor':
        engine = FluidManipFullEngine(args)

        table_height = 1.2
        cam_dis = 5.
        cam_height = 1.5 + table_height
        cam_pitch_angle = 0.
        cam_pitch_angle_offset = -20.
        cam_yaw_offset = 0.

    elif args.env == 'FluidShakeWithIce_wKuka_wColor_wGripper':
        engine = FluidShakeWithIceFullEngine(args)

        table_height = 0.1
        cam_dis = 3.5
        cam_height = 1.6 + table_height
        cam_pitch_angle = -22.
        cam_pitch_angle_offset = -10.
        cam_yaw_offset = 180.

    else:
        raise AssertionError("Unsupported env: %s" % args.env)




    ### make log dir
    os.system('mkdir -p ' + args.mpcf)

    log_fout = open(os.path.join(args.mpcf, 'log.txt'), 'w')

    print(args)
    print(args, file=log_fout)


    ### create model
    model = DynamicsModel(args)
    print("model #params: %d" % count_parameters(model))

    # resume training of a saved model (if given)
    if args.eval_epoch == -1:
        model_path = os.path.join(args.outf.replace('mpc', 'dy'), 'net_best.pth')

        '''
        pretrained_dict = torch.load(model_path)
        model_dict = model.state_dict()

        # only load parameters in encoder-decoder
        pretrained_dict = {
            k: v for k, v in pretrained_dict.items() \
            if ('img_encoder' in k or 'tf_encoder' in k or 'decoder' in k) and k in model_dict}
        model.load_state_dict(pretrained_dict, strict=False)
        '''

    else:
        model_path = os.path.join(args.outf.replace('mpc', 'dy'), 'net_epoch_%d_iter_%d.pth' % (
            args.eval_epoch, args.eval_iter))

    print("Loading saved ckp from %s" % model_path)
    pretrained_dict = torch.load(model_path)
    model.load_state_dict(pretrained_dict)

    if args.nerf_loss == 1:
        # !!! need to double checkout these parameters
        near = args.near
        far = args.far

        bds_dict = {
            'near' : near,
            'far' : far,
        }
        model.decoder.render_kwargs_train.update(bds_dict)
        model.decoder.render_kwargs_test.update(bds_dict)

    if use_gpu:
        model = model.cuda()

    model.train(False)

    # Misc
    img2mse = lambda x, y : torch.mean((x - y) ** 2)
    mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.]))
    to8b = lambda x : (255 * np.clip(x, 0, 1)).astype(np.uint8)
    chamfer_loss = ChamferLoss()


    ### rollout to generate the goal state

    engine.init()
    scene_params = engine.scene_params.copy()
    context = engine.context.copy()


    '''
    video_path = os.path.join(args.mpcf, 'viz_eval_mpc_goal.avi')
    fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')
    out = cv2.VideoWriter(video_path, fourcc, 30, (360, 360))

    out_dir = os.path.join(args.mpcf, 'mpc_goal')
    os.system('mkdir -p ' + out_dir)
    '''

    actions_gt = []

    # adjust the camera pose
    rad = np.deg2rad(22.5 + cam_yaw_offset)
    camPos_ref = np.array([np.sin(rad) * cam_dis, cam_height, np.cos(rad) * cam_dis])
    camAngle_ref = np.array([rad, np.deg2rad(cam_pitch_angle), 0.])


    for i in range(args.ctrl_init_idx + args.n_look_ahead):

        if args.env in ['FluidManipClip', 'FluidManipClip_wKuka_wColor']:
            n_stay_still = 40
            nominal_angle = 65
            max_angle = 110

            pourer_lim_x = engine.pourer_lim_x
            pourer_lim_z = engine.pourer_lim_z

            if i < n_stay_still:
                angle_cur = 0.
                pourer_angle_delta = 0.
                pourer_pos_delta = np.zeros(3)
                pourer_pos = engine.pourer_pos.copy()
            else:
                # pourer x position
                scale = 0.002
                pourer_pos_delta[0] += rand_float(-scale, scale) - (pourer_pos[0] - np.sum(pourer_lim_x) / 2.) * scale
                pourer_pos_delta[0] = np.clip(pourer_pos_delta[0], -0.01, 0.01)
                pourer_pos[0] += pourer_pos_delta[0]
                pourer_pos[0] = np.clip(pourer_pos[0], pourer_lim_x[0], pourer_lim_x[1])

                # pourer z position
                scale = 0.003
                pourer_pos_delta[2] += rand_float(-scale, scale) - (pourer_pos[2] - np.sum(pourer_lim_z) / 2.) * scale
                pourer_pos_delta[2] = np.clip(pourer_pos_delta[2], -0.01, 0.01)
                pourer_pos[2] += pourer_pos_delta[2]
                pourer_pos[2] = np.clip(pourer_pos[2], pourer_lim_z[0], pourer_lim_z[1])

                # pourer angle
                scale = 0.3
                angle_idx_cur = i - n_stay_still
                pourer_angle_delta += rand_float(-scale, scale) - (angle_cur - nominal_angle) * 0.002
                pourer_angle_delta = np.clip(pourer_angle_delta, -1.2, 1.2)
                angle_cur += pourer_angle_delta

            pourer_angle = np.deg2rad(max(-angle_cur, -max_angle))

        elif args.env in ['FluidShakeWithIce_1000', 'FluidShakeWithIce_wKuka_wColor_wGripper']:

            if i == 0:
                dx_box = 0.
                dz_box = 0.
                scale = 0.09
                dt = 1. / 60.

                x_box = engine.x_box
                z_box = engine.z_box

            x_box += dx_box * dt
            dx_box += rand_float(-scale, scale) - x_box * scale

            z_box += dz_box * dt
            dz_box += rand_float(-scale, scale) - z_box * scale



        # render
        # img, _, _ = engine.render_img(camPos_ref, camAngle_ref)
        # out.write(img.astype(np.uint8))
        # cv2.imwrite(os.path.join(out_dir, "%d.png" % i), img.astype(np.uint8))


        if i == args.ctrl_init_idx + args.n_look_ahead - 1:

            img, _, _ = engine.render_img()
            cv2.imwrite(os.path.join(args.mpcf, 'viz_mpc_goal.png'), img)

            imgs = []
            poses = []

            cam_dis_adjust = cam_dis + \
                    args.goal_camera_dist_offset * np.cos(np.deg2rad(-cam_pitch_angle))
            cam_height_adjust = cam_height + \
                    args.goal_camera_dist_offset * np.sin(np.deg2rad(-cam_pitch_angle))

            if extra_view == 1:
                cam_height_adjust += 0.4

            # for j in range(args.n_view_enc):
            for j in range(1):
                if train_view == 1:
                    rad_adjust = np.deg2rad(rand_int(0, 20) * 9. + 4.5 + cam_yaw_offset)
                    # rad_adjust = np.deg2rad(4 * 9. + 4.5 + cam_yaw_offset)
                elif inter_view == 1 or extra_view == 1:
                    rad_adjust = np.deg2rad(rand_float(4.5, 180. - 4.5) + cam_yaw_offset)
                    # rad_adjust = np.deg2rad(160 + cam_yaw_offset)
                    # rad_adjust = np.deg2rad(120 + cam_yaw_offset)
                    # rad_adjust = np.deg2rad(80 + cam_yaw_offset)

                # rad_adjust = np.deg2rad(180. - 22.5 + cam_yaw_offset)
                # rad_adjust = np.deg2rad(22.5 + cam_yaw_offset)
                # rad = np.deg2rad(j * 45 + 22.5 + cam_yaw_offset)
                camPos = np.array([
                    np.sin(rad_adjust) * cam_dis_adjust,
                    cam_height_adjust,
                    np.cos(rad_adjust) * cam_dis_adjust])
                camAngle = np.array([
                    rad_adjust,
                    np.deg2rad(cam_pitch_angle + (cam_pitch_angle_offset if extra_view == 1 else 0.)),
                    0.])

                img, viewMatrix, projMatrix = engine.render_img(
                    camPos, camAngle, width=180, height=180, BGR2RGB=True)

                cv2.imwrite(os.path.join(args.mpcf, 'viz_mpc_goal_%d.png' % j), img[..., ::-1])

                focal = projMatrix[0, 0]

                imgs.append(img)
                poses.append(np.linalg.inv(np.transpose(viewMatrix)))


            imgs_gt = []
            poses_gt = []

            for j in range(args.n_view_enc):
                # for j in range(1):
                # rad = np.deg2rad(rand_float(0., 180.))
                rad = np.deg2rad(j * 45 + 22.5 + cam_yaw_offset)
                camPos = np.array([np.sin(rad) * cam_dis, cam_height, np.cos(rad) * cam_dis])
                camAngle = np.array([rad, np.deg2rad(cam_pitch_angle), 0.])

                img, viewMatrix, projMatrix = engine.render_img(
                    camPos, camAngle, width=180, height=180, BGR2RGB=True)

                cv2.imwrite(os.path.join(args.mpcf, 'viz_mpc_goal_gt_%d.png' % j), img[..., ::-1])

                focal = projMatrix[0, 0]

                imgs_gt.append(img)
                poses_gt.append(np.linalg.inv(np.transpose(viewMatrix)))


        else:
            # apply action
            if args.env in ['FluidManipClip', 'FluidManipClip_wKuka_wColor']:
                action = np.concatenate([pourer_pos, np.array([pourer_angle])])
            elif args.env in ['FluidShakeWithIce_1000', 'FluidShakeWithIce_wKuka_wColor_wGripper']:
                action = np.array([x_box, z_box])

            actions_gt.append(action)

            engine.set_action(action)
            engine.step()


    # out.release()


    '''
    ### rollout the simulator to reproduce the goal image

    video_path = os.path.join(args.mpcf, 'viz_eval_repo.avi')
    fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')
    out = cv2.VideoWriter(video_path, fourcc, 30, (360, 360))

    engine.init(scene_params, context)
    for i in range(args.ctrl_init_idx + args.n_look_ahead - 1):

        img, _, _ = engine.render_img(camPos_ref, camAngle_ref)
        out.write(img.astype(np.uint8))

        engine.set_action(actions_gt[i])
        engine.step()

        if i == args.ctrl_init_idx + args.n_look_ahead - 2:
            img, _, _ = engine.render_img()
            cv2.imwrite(os.path.join(args.mpcf, 'viz_mpc_goal_repo.png'), img)

    out.release()
    exit(0)
    '''

    particles_goal = engine.get_state() # particles_goal: n_particles x 3
    # visualize_point_cloud(particles_goal, store_path='test.avi')

    actions_gt = np.array(actions_gt)
    assert actions_gt.shape[0] == args.ctrl_init_idx + args.n_look_ahead - 1

    state_embed_gt, _ = calc_state_embed(
        args, model, imgs_gt, poses_gt)

    if args.ct_loss == 1 and args.nerf_loss == 1 and extra_view == 1:
        optim_embed = True
    else:
        optim_embed = False
    # optim_embed = False
    state_embed_goal, _ = calc_state_embed(
        args, model, imgs, poses, engine, optim_embed=optim_embed,
        state_embed_gt=state_embed_gt)


    print('state_embed_goal.shape', state_embed_goal.size())
    print('actions_gt.shape', actions_gt.shape)



    ### generate initial state_embed
    state_embeds = []
    actions = []

    engine.init(scene_params, context)

    for i in range(args.ctrl_init_idx - args.n_his):
        actions.append(actions_gt[i])
        engine.set_action(actions_gt[i])
        engine.step()

    for i in range(args.ctrl_init_idx - args.n_his, args.ctrl_init_idx):
        imgs = []
        poses = []

        # for j in range(args.n_view_enc):
        for j in range(1):
            # rad = np.deg2rad(rand_float(0., 180.))
            rad = np.deg2rad(j * 45 + 22.5 + cam_yaw_offset)
            # rad = np.deg2rad(180. - 22.5 + cam_yaw_offset)
            camPos = np.array([np.sin(rad) * cam_dis, cam_height, np.cos(rad) * cam_dis])
            camAngle = np.array([rad, np.deg2rad(cam_pitch_angle), 0.])

            img, viewMatrix, projMatrix = engine.render_img(
                camPos, camAngle, width=180, height=180, BGR2RGB=True)

            imgs.append(img)
            poses.append(np.linalg.inv(np.transpose(viewMatrix)))

        state_embed, _ = calc_state_embed(args, model, imgs, poses)

        state_embeds.append(state_embed)

        if i < args.ctrl_init_idx - 1:
            engine.set_action(actions_gt[i])
            engine.step()
            actions.append(actions_gt[i])


    """
    # used for visualizing the distance between the current embedding and the goal embedding

    video_path = os.path.join(args.mpcf, 'viz_eval_embed_dist.avi')
    fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')
    out = cv2.VideoWriter(video_path, fourcc, 30, (900, 360))

    imgs_record = []
    img_embeds_record = []

    for i in range(args.ctrl_init_idx - args.n_his, args.ctrl_init_idx + args.n_look_ahead - 1):
        imgs = []
        poses = []

        for j in range(args.n_view_enc):
            # rad = np.deg2rad(rand_float(0., 180.))
            rad = np.deg2rad(j * 45 + 22.5 + cam_yaw_offset)
            camPos = np.array([np.sin(rad) * cam_dis, cam_height, np.cos(rad) * cam_dis])
            camAngle = np.array([rad, np.deg2rad(cam_pitch_angle), 0.])

            img, viewMatrix, projMatrix = engine.render_img(
                camPos, camAngle, width=180, height=180, BGR2RGB=True)

            '''
            print(img.shape)
            cv2.imwrite('test.png', img)
            cv2.imshow("image", img)
            cv2.waitKey(0)
            cv2.destroyAllWindows()
            '''

            imgs.append(img)
            poses.append(np.linalg.inv(np.transpose(viewMatrix)))

            if j == 0:
                img, _, _ = engine.render_img(camPos, camAngle)
                imgs_record.append(img)

        with torch.set_grad_enabled(False):
            # state_embed: B x N x nf_hidden (B = N = 1)
            # img_embed: B x N x n_view x nf_hidden (B = N = 1)
            state_embed, img_embed = calc_state_embed(args, model, imgs, poses)
            img_embeds_record.append(img_embed.data.cpu().numpy()[0, 0])

        state_embeds.append(state_embed)

        engine.set_action(actions_gt[i])
        engine.step()
        actions.append(actions_gt[i])


    anchor = img_embeds_record[-1][0:1]
    positive = img_embeds_record[-1][1:].reshape(-1, args.nf_hidden)
    negative = np.array(img_embeds_record[:-1]).reshape(-1, args.nf_hidden)
    print(anchor.shape)
    print(positive.shape)
    print(negative.shape)

    print(np.mean((anchor - positive)**2, -1))
    print(np.mean((anchor - negative)**2, -1).reshape(-1, 4))


    for i in range(len(state_embeds)):
        dist = torch.mean((state_embeds[i] - state_embeds[-1])**2).item()
        print(i, dist)

        fig, ax  = plt.subplots(1, 1, figsize=(1.8, 3.6), dpi=200)
        ax.bar([0], [dist], width=0.5)
        plt.ylim([0., 0.8])
        plt.xlim([-0.5, 0.5])
        plt.xlabel('Distance to Goal', fontsize=12)
        plt.tight_layout(pad=0)

        fig.canvas.draw()
        frame = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
        frame = frame.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
        frame = cv2.resize(frame, (180, 360), interpolation=cv2.INTER_AREA)

        frame = np.concatenate([imgs_record[-1], imgs_record[i], frame], 1).astype(np.uint8)

        out.write(frame)

        plt.close()

    out.release()

    exit(0)
    """


    """
    ### rollout and render using the ground truth action sequence

    video_path = os.path.join(args.mpcf, 'viz_eval_mpc_rollout.avi')
    fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')
    out = cv2.VideoWriter(video_path, fourcc, 30, (180, 180))

    out_dir = os.path.join(args.mpcf, 'mpc_rollout')
    os.system('mkdir -p ' + out_dir)


    H, W = 180, 180
    focal = focal * 0.5 * W
    hwf = torch.FloatTensor([H, W, focal])
    print(hwf)

    # adjust the camera pose
    rad = np.deg2rad(20. + cam_yaw_offset)
    camPos = np.array([np.sin(rad) * cam_dis, cam_height, np.cos(rad) * cam_dis])
    camAngle = np.array([rad, np.deg2rad(cam_pitch_angle), 0.])
    img, viewMatrix, projMatrix = engine.render_img(camPos, camAngle)
    pose_ori = torch.FloatTensor(np.linalg.inv(np.transpose(viewMatrix)))

    print(len(state_embeds))
    print(len(actions_gt))
    state_cur = torch.cat(state_embeds[:args.n_his], 1)

    st_idx = args.ctrl_init_idx
    ed_idx = args.ctrl_init_idx + args.n_look_ahead

    for i in range(st_idx, ed_idx):

        '''
        engine.set_action(actions_gt[i - 1])
        engine.step()
        img, _, _ = engine.render_img()

        if i == ed_idx - 1:
            cv2.imwrite('test.png', img)
        '''

        action_cur = actions_gt[i-args.n_his:i]
        action_cur = torch.FloatTensor(action_cur)[None, ...].cuda()

        with torch.set_grad_enabled(False):
            state_pred = model.dynamics_prediction(state_cur, action_cur)
            state_cur = torch.cat([state_cur[:, 1:], state_pred[:, None]], 1)

            loss_dy_cur = F.mse_loss(state_pred, state_embeds[args.n_his + i - st_idx][0])
            print('%d / %d: %.6f' % (i, ed_idx, loss_dy_cur.item()))

            camera_info = {
                'hwf': hwf,
                'poses': pose_ori[None, None, ...].cuda(),
                'near': near,
                'far': far}

            rgb, extras = model.render_imgs(state_pred, camera_info)
            # rgb, extras = model.render_imgs(state_embeds[args.n_his + i - st_idx][0], camera_info)

        # store the data
        rgb_np = rgb.data.cpu().numpy().clip(0., 1.) * 255
        rgb_np = rgb_np.astype(np.uint8)[..., ::-1][0, 0]

        out.write(rgb_np)
        cv2.imwrite(os.path.join(out_dir, '%d.png' % i), rgb_np)

    out.release()
    exit(0)
    """


    ### MPPI

    # important variables
    # state_embeds: list [n_his], (1, 1, 256)
    # state_embed_goal: (1, 1, 256)
    # actions: list [n_his],

    planner = Planner(args)

    # replicate the action to generate the initial sequence
    for i in range(args.n_look_ahead):
        # actions.append(actions[-1] + (actions_gt[-1] - actions_gt[0]) / args.n_look_ahead)
        actions.append(actions[-1])
    # actions = np.array(actions_gt)
    actions = np.array(actions)

    action_lower_lim = np.min(actions_gt[args.ctrl_init_idx-1:], 0)
    action_upper_lim = np.max(actions_gt[args.ctrl_init_idx-1:], 0)
    action_lim_range = action_upper_lim - action_lower_lim
    action_lower_lim -= action_lim_range * 0.1
    action_upper_lim += action_lim_range * 0.1
    print('action_lower_lim', action_lower_lim)
    print('action_upper_lim', action_upper_lim)

    action_lower_delta_lim = np.min(actions_gt[args.ctrl_init_idx:] - actions_gt[args.ctrl_init_idx-1:-1], 0)
    action_upper_delta_lim = np.max(actions_gt[args.ctrl_init_idx:] - actions_gt[args.ctrl_init_idx-1:-1], 0)
    action_delta_lim_range = action_upper_delta_lim - action_lower_delta_lim
    action_lower_delta_lim -= action_delta_lim_range * 0.1
    action_upper_delta_lim += action_delta_lim_range * 0.1
    print('action_lower_delta_lim', action_lower_delta_lim)
    print('action_upper_delta_lim', action_upper_delta_lim)


    st_idx = args.ctrl_init_idx
    ed_idx = args.ctrl_init_idx + args.n_look_ahead

    for i in range(st_idx, ed_idx):

        print("\n### Step %d/%d" % (i, ed_idx))

        # optimize the action sequence
        embed_cur = torch.cat(state_embeds[-args.n_his:], 1)

        if i == st_idx or i % args.n_update_delta == 0:
            # update the action sequence every n_update_delta iterations
            with torch.set_grad_enabled(False):
                assert len(actions[i - args.n_his:]) == \
                        args.n_look_ahead - (i - args.ctrl_init_idx) + args.n_his - 1

                actions[i-args.n_his:] = planner.trajectory_optimization(
                    state_cur=embed_cur,
                    state_goal=state_embed_goal[0, 0].data.cpu().numpy(),
                    model_dy=model,
                    act_seq=actions[i-args.n_his:],
                    n_sample=args.n_sample,
                    n_look_ahead=args.n_look_ahead - (i - args.ctrl_init_idx),
                    n_update_iter=args.n_update_iter_init if i == st_idx else args.n_update_iter,
                    action_lower_lim=action_lower_lim,
                    action_upper_lim=action_upper_lim,
                    action_lower_delta_lim=action_lower_delta_lim,
                    action_upper_delta_lim=action_upper_delta_lim,
                    use_gpu=use_gpu)

        # execute the action in the simulator
        engine.set_action(actions[i - 1])
        engine.step()

        print(len(actions_gt[i - 1:]), actions_gt[i - 1])
        print(len(actions[i - 1:]), actions[i - 1])



        # new observation embedding
        imgs = []
        poses = []

        # for j in range(args.n_view_enc):
        for j in range(1):
            rad = np.deg2rad(j * 45 + 22.5 + cam_yaw_offset)
            # rad = np.deg2rad(180. - 22.5 + cam_yaw_offset)
            camPos = np.array([np.sin(rad) * cam_dis, cam_height, np.cos(rad) * cam_dis])
            camAngle = np.array([rad, np.deg2rad(cam_pitch_angle), 0.])

            img, viewMatrix, projMatrix = engine.render_img(
                camPos, camAngle, width=180, height=180, BGR2RGB=True)

            imgs.append(img)
            poses.append(np.linalg.inv(np.transpose(viewMatrix)))

        state_embed, _ = calc_state_embed(args, model, imgs, poses)
        state_embeds.append(state_embed)





    ### Generate execution video by replaying the optimized action sequence

    out_dir = os.path.join(args.mpcf, 'mpc_mppi')
    os.system('mkdir -p ' + out_dir)

    video_path = os.path.join(args.mpcf, 'viz_eval_mpc_mppi.avi')
    fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')
    out = cv2.VideoWriter(video_path, fourcc, 30, (360, 360))

    engine.init(scene_params, context)

    # adjust the camera pose
    rad = np.deg2rad(22.5 + cam_yaw_offset)
    # rad = np.deg2rad(180. - 22.5 + cam_yaw_offset)
    camPos = np.array([np.sin(rad) * cam_dis, cam_height, np.cos(rad) * cam_dis])
    camAngle = np.array([rad, np.deg2rad(cam_pitch_angle), 0.])

    for i in range(args.ctrl_init_idx + args.n_look_ahead - 1):

        img, _, _ = engine.render_img(camPos, camAngle)
        out.write(img.astype(np.uint8))

        cv2.imwrite(os.path.join(out_dir, '%d.png' % i), img)

        engine.set_action(actions[i])
        engine.step()

        if i == args.ctrl_init_idx + args.n_look_ahead - 2:
            img, _, _ = engine.render_img(camPos, camAngle)
            out.write(img.astype(np.uint8))

            cv2.imwrite(os.path.join(out_dir, '%d.png' % (i + 1)), img)

            camPos = np.array([
                np.sin(rad_adjust) * cam_dis_adjust,
                cam_height_adjust,
                np.cos(rad_adjust) * cam_dis_adjust])
            camAngle = np.array([
                rad_adjust,
                np.deg2rad(cam_pitch_angle + (cam_pitch_angle_offset if extra_view == 1 else 0.)),
                0.])

            img, _, _ = engine.render_img(camPos, camAngle)
            cv2.imwrite(os.path.join(args.mpcf, 'viz_mpc_result.png'), img)

    out.release()

    particles_cur = engine.get_state()  # particles_cur: n_particles x 3


    print(actions_gt[-1])
    print(actions[-1])

    if args.env == 'FluidManipClip':
        print('position dist:', np.sqrt(np.sum((actions_gt[-1, :3] - actions[-1, :3])**2)))
        print('angle dist:', np.sqrt(np.sum((actions_gt[-1, 3] - actions[-1, 3])**2)))
        print('chamfer dist:', chamfer_loss(
            torch.FloatTensor(particles_cur),
            torch.FloatTensor(particles_goal)).item())

    elif args.env == 'FluidShakeWithIce_1000':
        print('position dist:', np.sqrt(np.sum((actions_gt[-1] - actions[-1])**2)))

        cube_pos_cur = np.mean(particles_cur[-125:], 0)
        cube_pos_goal = np.mean(particles_goal[-125:], 0)
        print('cube dist:', np.sqrt(np.sum((cube_pos_cur - cube_pos_goal)**2)))

        print('chamfer dist:', chamfer_loss(
            torch.FloatTensor(particles_cur[:-125]),
            torch.FloatTensor(particles_goal[:-125])).item())


    visualize_point_cloud(
        particles_cur,
        particles_goal,
        store_path=os.path.join(args.mpcf, 'viz_mpc_result.avi'))








if __name__ == '__main__':
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

    eval()
