"""
Script for running VisMPC and analytic policies in sim.
"""
import subprocess
import pkg_resources
import numpy as np
import argparse
import os
from os.path import join
import sys
import time
import logging
import pickle
import datetime
import cv2
import json
import math
import copy

# from gym_cloth.envs import ClothEnv
from collections import defaultdict
from GNS.fabric_vsf.vismpc.mpc import VISMPC, VISMPC_MASK_ONE_STEP
# from GNS.fabric_vsf.vismpc.cost_functions import coverage, L2, SSIM
from GNS.fabric_vsf.vismpc.visualize import Viz
from softgym.registered_env import env_arg_dict
from softgym.registered_env import SOFTGYM_ENVS
# from GNS.fabric_vsf.scripts.collect_data_softgym import denormalize_action
# from GNS.contrastive_forward_model.cfm.evaluate_planning_softgym import add_arrow, set_picker_underground
import pyflex
from GNS.camera_utils import get_world_coor_from_image, get_matrix_world_to_camera, project_to_image
from chester import logger
from softgym.envs.bimanual_env import BimanualEnv
from softgym.envs.bimanual_tshirt import BimanualTshirtEnv

np.set_printoptions(edgeitems=10, linewidth=180, suppress=True)

#Adi: Now adding the 'oracle_reveal' demonstrator policy which in reveals occluded corners.
POLICIES = ['oracle','harris','wrinkle','highest','random', 'oracle_reveal']
RAD_TO_DEG = 180. / np.pi
DEG_TO_RAD = np.pi / 180.
BLUE  = (255,0,0)
GREEN = (0,255,0)
RED   = (0,0,255)

class Policy(object):

    def __init__(self):
        pass

    def get_action(self, obs, t):
        raise NotImplementedError()

    def set_env_cfg(self, env, cfg):
        self.env = env
        self.cfg = cfg

    def _data_delta(self, pt, targx, targy, shrink=True):
        """Given pt and target locations, return info needed for action.

        Assumes DELTA actions. Returns x, y of the current point (which should
        be the target) but also the cx, and cy, which should be used if we are
        'clipping' it into [-1,1], but for the 80th time, this really means
        _expanding_ the x,y.
        """
        x, y = pt.x, pt.y
        cx = (x - 0.5) * 2.0
        cy = (y - 0.5) * 2.0
        dx = targx - x
        dy = targy - y
        dist = np.sqrt( (x-targx)**2 + (y-targy)**2 )
        # ----------------------------------------------------------------------
        # Sometimes we grab the top, and can 'over-pull' toward a background
        # corner. Thus we might as well try and reduce it a bit. Experiment!  I
        # did 0.95 for true corners, but if we're pulling one corner 'inwards'
        # then we might want to try a smaller value, like 0.9.
        # ----------------------------------------------------------------------
        if shrink:
            dx *= 0.90
            dy *= 0.90
        return (x, y, cx, cy, dx, dy, dist)


class VisualMPCPolicy(Policy):
    def __init__(self, args, num_elites, num_iters, population_size, planning_method='random_shooting'):
        super().__init__()
        self.args = args
        self.num_elites = num_elites
        self.num_iters = num_iters
        self.population_size = population_size
        self.planning_method = planning_method

    def set_env_cfg(self, env, cfg, model_name, data_name, cost_fn):
        self.env = env
        self.cfg = cfg
        viz = None

        cost_fn = lambda x: x
        self.mpc = VISMPC_MASK_ONE_STEP(cost_fn, '{}'.format(data_name), '{}'.format(model_name), viz=viz, 
            num_elites=self.num_elites, num_iters=self.num_iters, population_size=self.population_size,
            log_dir=logger.get_dir(), adim=self.args.adim)
        # self.mpc = VISMPC(cost_fn, '{}'.format(data_name), '{}'.format(model_name), viz=viz, 
        #     num_elites=self.num_elites, num_iters=self.num_iters, population_size=self.population_size,
        #     log_dir=logger.get_dir())
        self.env.mpc = self.mpc

    def get_action(self, obs, mask, matrix_world_to_camera, pull_distance_max, x_threshold, 
        z_threshold, t, ep=0, adim=None):
        if t == 0:
            self.mpc.reset()
        
        if self.planning_method == 'random_shooting':
            return self.mpc.get_next_action(obs, mask, matrix_world_to_camera, pull_distance_max, 
                x_threshold, z_threshold, timestep=t, ep=ep)
        elif self.planning_method == 'cem_original':
            # print("using cem to get action")
            return self.mpc.get_next_action_cem_original(obs, mask, matrix_world_to_camera, pull_distance_max, 
                x_threshold, z_threshold, timestep=t, adim=adim)

def denormalize_action_old(action, imsize, pull_distance_max):
    action = np.clip(action, a_min=-1, a_max=1)
    u, v, delta_x, delta_z = action
    # print("normalized uv: ", u, v)
    u = (u + 1.) * (imsize - 1) / 2.
    v = (v + 1.) * (imsize - 1) / 2.
    u, v = int(u), int(v)
    u = min(max(0, u), imsize - 1)
    v = min(max(0, v), imsize - 1)
    # print("denormalized u v: ", u, v)
    delta_x = delta_x * pull_distance_max
    delta_z = delta_z * pull_distance_max 
    return [u, v, delta_x, delta_z]

def denormalize_action(action, imsize, action_range):
    if len(action) == 4:
        action = np.clip(action, a_min=-1, a_max=1)
        pick_uv = (action[:2] + 1.) * (imsize - 1) / 2.
        pick_uv = np.rint(pick_uv)
        pick_uv = np.clip(pick_uv, 0, imsize - 1).astype(np.int)
        movement = action[2:] * action_range[1]
        place_uv = pick_uv + movement
        place_uv = np.rint(place_uv)
        return pick_uv, place_uv
    elif len(action) == 8:
        action = np.clip(action, a_min=-1, a_max=1)
        pick_uv = (action[:2] + 1.) * (imsize - 1) / 2.
        pick_uv = np.rint(pick_uv)
        pick_uv = np.clip(pick_uv, 0, imsize - 1).astype(np.int)
        movement = action[2:4] * action_range[1]
        place_uv = pick_uv + movement
        place_uv = np.rint(place_uv)

        pick_uv_2 = (action[4:6] + 1.) * (imsize - 1) / 2.
        pick_uv_2 = np.rint(pick_uv_2)
        pick_uv_2 = np.clip(pick_uv_2, 0, imsize - 1).astype(np.int)
        movement = action[6:] * action_range[1]
        place_uv_2 = pick_uv_2 + movement
        place_uv_2 = np.rint(place_uv_2)

        return pick_uv, place_uv, pick_uv_2, place_uv_2

def add_arrow(image, start_uv, after_uv):
    startr, startc = start_uv[1], start_uv[0]
    endr, endc = after_uv[1], after_uv[0]
    startr, startc, endr, endc = int(startr), int(startc), int(endr), int(endc)
    cv2.arrowedLine(image, (startc, startr), (endc, endr), (255, 0, 0), 1)
    image[startr-1:startr+1, startc-1:startc+1, :] = (0, 0, 0)

def set_picker_underground():
    shape_states = pyflex.get_shape_states().reshape(-1, 14)
    shape_states[1, :3] = -1
    shape_states[1, 7:10] = -1

    shape_states[0, :3] = -1
    shape_states[0, 7:10] = -1
    pyflex.set_shape_states(shape_states.flatten())
    pyflex.step()

def L2(traj, goal_img):
    """
    average L2 difference in the last image
    """
    # print("traj shape:", traj.shape)
    # print("goal_img shape:", goal_img.shape, flush=True)
    return np.linalg.norm(traj - goal_img)

def get_rgbd(obs, camera_height, camera_width, depth_max):
    rgbd = pyflex.render_sensor()
    rgbd = np.array(rgbd).reshape(camera_height, camera_width, 4)
    rgbd = rgbd[::-1, :, :]
    depth_ori = rgbd[:, :, 3]
    depth = (np.clip(depth_ori[:, :, None] / depth_max * 255, a_min=0, a_max=255)).astype(np.uint8)
    rgbd = np.dstack([obs, depth])
    return rgbd.astype(np.float32), depth_ori

def get_rgbd_goal(rgb, depth_ori, depth_max):
    rgb = rgb[:, :, ::-1]
    depth = (np.clip(depth_ori[:, :, None] / depth_max * 255, a_min=0, a_max=255)).astype(np.uint8)
    rgbd = np.dstack([rgb, depth])
    return rgbd.astype(np.float32)

def run(args, policy, model_name, data_name, cost_fn='L2'):
    """Run an analytic policy, using similar setups as baselines-fork.

    If we have a random seed in the args, we use that instead of the config
    file. That way we can run several instances of the policy in parallel for
    faster data collection.

    model_name and cost_fn only have semantic meaning for vismpc
    """
    # Should seed env this way, following gym conventions.  NOTE: we pass in
    # args.cfg_file here, but then it's immediately loaded by ClothEnv. When
    # env.reset() is called, it uses the ALREADY loaded parameters, and does
    # NOT re-query the file again for parameters (that'd be bad!).
    

    # TODO: change this to the bimanual env
    assert args.camera_param_height == 0.45
    assert args.depth_max == 0.45   
    if args.small_action:
        print("using small action!")
        action_range = [0, 23]
    else: # use large action
        print("using large action!")
        action_range = [0, 45]
    pixel_low, pixel_high = 5, 51

    if args.cloth == 'square' or args.cloth == 'rectangular':
        env = BimanualEnv(use_depth=True,
                        use_cached_states=False,
                        horizon=2,
                        use_desc=False,
                        action_repeat=1,
                        headless=True,
                        render=True,
                        camera_width=56,
                        camera_height=56,
                        camera_param_height=args.camera_param_height,
                        rect=(args.cloth=='rectangular'))
    elif args.cloth == 'tshirt':
        env = BimanualTshirtEnv(use_depth=True,
                    use_cached_states=False,
                    use_desc=False,
                    horizon=2,
                    action_repeat=1,
                    headless=True,
                    camera_height=56,
                    camera_width=56)


    print("softgym env built!")

    policy.set_env_cfg(env, None, model_name, data_name, cost_fn)
    print("policy env config set!")

    # Book-keeping.
    # num_episodes = 0
    # stats_all = []
    # coverage = []
    # variance_inv = []
    # nb_steps = []

    rewards = dict()
    final_rewards = []
    cam_pos, cam_angle = env.get_camera_params()
    matrix_world_to_camera = get_matrix_world_to_camera(cam_pos, cam_angle)
    eval_folder = logger.get_dir()

    if args.cloth == 'square':
        goal_depth_path = './data/local/0.45goals_correct/mult_step_cam_0.45/'
        goal_rgb_path = './data/local/0.45goals_correct/mult_step_cam_0.45/'
        coord_path = './data/local/0.45goals_correct/mult_step_cam_0.45/particles/'

        goal_names = {
            'vsf_all_corn_in_': 4,
            'vsf_double_rect_': 2,
            'vsf_double_tri_': 2, 
            'vsf_opp_corn_in_': 2,
            'vsf_two_side_horz_': 2,
            'vsf_two_side_vert_': 2
        }
    elif args.cloth == 'rectangular':
        goal_depth_path = './data/local/0.45goals_correct/rect_cam_0.45/'
        goal_rgb_path = './data/local/0.45goals_correct/rect_cam_0.45/'
        coord_path = './data/local/0.45goals_correct/rect_cam_0.45/particles/'

        goal_names = {
            'vsf_horz_fold_': 1,
            'vsf_one_corn_in_': 1,
            'vsf_two_side_horz_': 2,
            'vsf_two_side_vert_': 2,
            'vsf_vert_fold_': 1
        }
    elif args.cloth == 'tshirt':
        goal_depth_path = './data/local/0.45goals_correct/tsh_cam_0.45/'
        goal_rgb_path = './data/local/0.45goals_correct/tsh_cam_0.45/'
        coord_path = './data/local/0.45goals_correct/tsh_cam_0.45/particles/'

        goal_names = {
            'vsf_across_horz_': 1,
            'vsf_across_vert_': 1,
            'vsf_three_step_': 3
        }

    for goal_name, steps in goal_names.items():
        print("running episode {}...".format(goal_name), flush=True)
        images = []
        rewards[goal_name] = []
        
        env.reset()
        set_picker_underground()
        pyflex.step()
        obs = env.get_image(env.camera_width, env.camera_height)

        for t in range(steps):
            # NOTE: set the cost fn adpatively
            # TODO: change the goal image
            goal_depth = cv2.imread(os.path.join(goal_depth_path, '{}{}_depth.png'.format(goal_name, t)))
            goal_depth = goal_depth[:, :, 0] / 255.
            goal_depth = cv2.resize(goal_depth, (56, 56))
            goal_rgb = cv2.imread(os.path.join(goal_rgb_path, "{}{}.png".format(goal_name, t)))
            goal_rgb = cv2.resize(goal_rgb, (56, 56))
            goal_img = get_rgbd_goal(goal_rgb, goal_depth, args.depth_max)

            particle_pos_goal_path = os.path.join(coord_path, "{}{}.npy".format(goal_name, t))
            particle_pos_goal = np.load(particle_pos_goal_path)
            particle_pos_goal = particle_pos_goal[:, :3]

            print("goal image loaded!", flush=True)
            cv2.imwrite(os.path.join(eval_folder, "{}_goal_rgb.png".format(goal_name)), goal_img[:, :, :3])
            cv2.imwrite(os.path.join(eval_folder, "{}_goal_depth.png".format(goal_name)), goal_img[:, :, 3])

            cost_fn = lambda traj: L2(traj, goal_img)
            policy.mpc.set_cost_function(cost_fn)

            rgbd, depth = get_rgbd(obs, env.camera_height, env.camera_width, args.depth_max)

            action, uv, predicted_img = policy.get_action(rgbd, depth, matrix_world_to_camera, action_range,
                pixel_low, pixel_high, t=t, adim=args.adim)

            old_obs = obs.copy()

            action = denormalize_action(action, env.camera_width, action_range)
            if args.single_arm: # just add two elements, they will not be used
                action = [action[0], action[1], action[0], action[1]]

            _, _, _, _ = env.step(action, pickplace=True, on_table=False, 
                single_arm=args.single_arm)
            reward = env.compute_reward(goal_pos=particle_pos_goal)
            rewards[goal_name].append(reward)

            set_picker_underground()
            pyflex.step()
            obs = env.get_image(env.camera_width, env.camera_height)

            cv2.imwrite(os.path.join(eval_folder, 'compare_{}_{}.png'.format(goal_name, t)), 
                np.concatenate([predicted_img[:, :, :3], obs], axis=0)[:, :, ::-1])

            if args.single_arm:
                pick_uv, place_uv = action[:2]
                pick_u, pick_v = pick_uv
                place_u, place_v = place_uv
                # add_arrow(old_obs, (pick_u, pick_v), (place_u, place_v))
                add_arrow(old_obs, (pick_v, pick_u), (place_v, place_u))
                images.append(old_obs)

            else:
                pick_uv, place_uv = action[:2]
                pick_u, pick_v = pick_uv
                place_u, place_v = place_uv
                # add_arrow(old_obs, (pick_u, pick_v), (place_u, place_v))
                add_arrow(old_obs, (pick_v, pick_u), (place_v, place_u))
                pick_uv, place_uv = action[2:]
                pick_u, pick_v = pick_uv
                place_u, place_v = place_uv
                # add_arrow(old_obs, (pick_u, pick_v), (place_u, place_v))
                add_arrow(old_obs, (pick_v, pick_u), (place_v, place_u))
                images.append(old_obs)
            
            images.append(obs)
            images.append(goal_img[:, :, :3])


        # save the trajecotry figure
        final_rewards.append(reward)
        image_traj = np.concatenate(images, axis=1)
        cv2.imwrite(os.path.join(eval_folder, "traj_{}.png".format(goal_name)), image_traj[:, :, ::-1])

        print("traj {} final reward {}".format(goal_name, reward), flush=True)
        logger.record_tabular("final reward", reward)
        logger.dump_tabular()

    results = dict()
    results['reward'] = [float(x) for x in final_rewards]
    results['mean_reward'] = np.mean(final_rewards).item()
    results['std_reward'] = np.std(final_rewards).item()


    # Just dump here to keep saving and overwriting.
    with open(os.path.join(eval_folder, "eval_result.json"), 'w') as fh:
        json.dump(results, fh)

    with open(os.path.join(eval_folder, "all_goal_all_reward.json"), 'w') as fh:
        json.dump(rewards, fh)

def get_default_args():
    pp = argparse.ArgumentParser()
    pp.add_argument("--max_episodes", type=int, default=10)
    pp.add_argument("--seed", type=int)
    pp.add_argument("--model_path", type=str, default="/data/pure_random")
    pp.add_argument("--data_path", type=str, default="/data/pure_random")
    pp.add_argument("--pick_and_place_num", type=int, default=10)
    args = pp.parse_args([])
    return args


def run_task(vv, log_dir, exp_name):
    args = get_default_args()
    args.__dict__.update(**vv)

    logger.configure(dir=log_dir, exp_name=exp_name)
    logdir = logger.get_dir()
    assert logdir is not None
    os.makedirs(logdir, exist_ok=True)
    
    with open(os.path.join(logger.get_dir(), 'variant.json'), 'w') as f:
        json.dump(args.__dict__, f, indent=2, sort_keys=True)

    policy = VisualMPCPolicy(args, num_elites=args.num_elites, population_size=args.population_size, 
        num_iters=args.num_iters, planning_method=args.planning_method)

    np.random.seed(vv['seed'])

    run(args, policy, args.model_path, args.data_path)
