"""
Machinery for running MPC with CEM, based on saved-visual/dmbrl/controllers/VIS_MPC.py
"""
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import
import os
import tensorflow as tf
import numpy as np
import scipy.stats as stats
from GNS.fabric_vsf.vismpc.SV2P import SV2P
from dotmap import DotMap
import sys
from GNS.camera_utils import get_world_coor_from_image
import cv2
import scipy

# TODO: add visualizations of the planned action sequence

class VISMPC():
    def __init__(self, cost_fn, data_dir='vismpc/sv2p_data_cloth', model_dir='vismpc/sv2p_model_cloth', horizon=5, batch=False, viz=None, 
        num_elites=50, population_size=200, num_iters=5, log_dir=None):
        """
        cost_fn: higher-order function 
        """
        params = DotMap()
        params.name = 'cloth'
        params.model_dir = model_dir
        params.data_dir = data_dir
        params.popsize = population_size  # must match _run_cem's popsize!
        params.nparts = 1
        params.plan_hor = horizon
        params.adim = 8
        params.stochastic_model = True
        sys.argv = sys.argv[:1]
        self.model = SV2P(params)
        self.cost_fn = cost_fn
        # TUNE CEM VARIANCE:
        # -0.4/0.4 work better for smoothing, -0.7/0.7 better for folding
        # self.ac_lb = np.array([-1., -1., -0.7, -0.7])
        # self.ac_ub = np.array([1., 1., 0.7, 0.7])
        self.ac_lb = np.array([-1., -1., -1., -1.])
        self.ac_ub = np.array([1., 1., 1., 1.])
        self.act_dim = params.adim
        self.plan_hor = horizon
        self.batch = batch
        self.prev_sol = np.tile((self.ac_lb + self.ac_ub) / 2, [self.plan_hor])
        # / 16 works better for smoothing, /8 better for folding
        # self.init_var = np.tile(np.square(self.ac_ub - self.ac_lb) / 8, [self.plan_hor])
        self.init_var = np.tile(np.square(self.ac_ub - self.ac_lb) / 16, [self.plan_hor])
        self.viz = viz

        self.num_iters = num_iters
        self.num_elites = num_elites
        self.popsize = population_size
        self.log_dir = log_dir

    def reset(self):
        self.prev_sol = np.tile((self.ac_lb + self.ac_ub) / 2, [self.plan_hor])

    def set_cost_function(self, cost_fn):
        self.cost_fn = cost_fn

    def get_next_action(self, obs, timestep=None, ep=None):
        soln = self._run_cem(obs, mean=self.prev_sol, var=self.init_var, timestep=timestep, num_iters=self.num_iters, ep=ep)
        self.prev_sol = np.concatenate([np.copy(soln)[self.act_dim:], np.zeros(self.act_dim)])
        return soln[:self.act_dim]

    def _run_cem(self, obs, mean, var, num_iters=10, timestep=None, ep=None):
        # TODO: change this to be arguments instead of hard-coding
        num_elites, alpha, popsize = self.num_elites, 0.1, self.popsize
        print('running cem with num iters {}, popsize {}'.format(num_iters, popsize))
        lb = np.tile(self.ac_lb, [self.plan_hor])
        ub = np.tile(self.ac_ub, [self.plan_hor])
        X = stats.truncnorm(-2, 2, loc=np.zeros_like(mean), scale=np.ones_like(mean))
        if self.viz:
            self.viz.set_context(obs)
        for i in range(num_iters):
            lb_dist, ub_dist = mean - lb, ub - mean
            constrained_var = var
            #constrained_var = np.minimum(np.minimum(np.square(lb_dist / 2), np.square(ub_dist / 2)), var)
            samples = X.rvs(size=[popsize, self.plan_hor * self.act_dim]) * np.sqrt(constrained_var) + mean
            print("samples.shape: ", samples.shape)
            # exit()
            costs, pred_trajs = self._predict_and_eval(obs, samples, timestep=timestep)
            print("CEM Iteration: ", i, "Cost (mean, std): ", np.mean(costs), ",", np.std(costs))
            elites = samples[np.argsort(costs)][:num_elites]

            elite_trajs = pred_trajs[np.argsort(costs)][:num_elites]
            save_img = elite_trajs[0].reshape((self.plan_hor * 56,56,obs.shape[2]))
            cv2.imwrite(os.path.join(self.log_dir, 'cem_{}_{}_{}.png'.format(ep, timestep, i)), save_img)
            # if self.viz:
            #     elite_trajs = pred_trajs[np.argsort(costs)][:num_elites]
            #     self.viz.set_grid(elite_trajs[:10].reshape((10,self.plan_hor,56,56,obs.shape[2])), elites[:10].reshape((10,self.plan_hor,self.act_dim)), np.sort(costs)[:10])
            #     self.viz.render_image('logs/debug/t={}i={}.jpg'.format(timestep, i))
            new_mean = np.mean(elites, axis=0)
            new_var = np.var(elites, axis=0)
            # refit mean/var
            mean, var = alpha * mean + (1 - alpha) * new_mean, alpha * var + (1 - alpha) * new_var
        return mean

    def _predict_and_eval(self, obs, ac_seqs, timestep=None):
        ac_seqs = np.reshape(ac_seqs, [-1, self.plan_hor, self.act_dim])
        pred_trajs = self.model.predict(obs, ac_seqs)
        # since feed_dict in SV2P is going to require np arrays
        if self.batch:
            costs = self.cost_fn(pred_trajs[:,0])
        else:
            costs = []
            for traj in pred_trajs:
                traj = traj[0]
                costs.append(self.cost_fn(traj))
        return np.array(costs), pred_trajs[:,0]



# NOTE: do only 1 step prediction, sample points from the cloth mask
class VISMPC_MASK_ONE_STEP():
    def __init__(self, cost_fn, data_dir='vismpc/sv2p_data_cloth', model_dir='vismpc/sv2p_model_cloth', horizon=1, batch=False, viz=None, 
        num_elites=50, population_size=200, num_iters=5, log_dir=None, adim=4):
        """
        cost_fn: higher-order function 
        """
        params = DotMap()
        params.name = 'cloth'
        params.model_dir = model_dir
        params.data_dir = data_dir
        params.popsize = population_size  # must match _run_cem's popsize!
        params.nparts = 1
        params.plan_hor = horizon
        params.adim = adim
        params.stochastic_model = True
        sys.argv = sys.argv[:1]
        self.model = SV2P(params)
        self.cost_fn = cost_fn
        # TUNE CEM VARIANCE:
        # -0.4/0.4 work better for smoothing, -0.7/0.7 better for folding
        # self.ac_lb = np.array([-1., -1., -0.7, -0.7])
        # self.ac_ub = np.array([1., 1., 0.7, 0.7])
        self.ac_lb = np.array([-1., -1., -1., -1.])
        self.ac_ub = np.array([1., 1., 1., 1.])
        self.act_dim = params.adim
        self.plan_hor = horizon
        self.batch = batch
        self.prev_sol = np.tile((self.ac_lb + self.ac_ub) / 2, [self.plan_hor])
        # / 16 works better for smoothing, /8 better for folding
        # self.init_var = np.tile(np.square(self.ac_ub - self.ac_lb) / 8, [self.plan_hor])
        self.init_var = np.tile(np.square(self.ac_ub - self.ac_lb) / 16, [self.plan_hor])
        self.viz = viz

        self.num_iters = num_iters
        self.num_elites = num_elites
        self.popsize = population_size
        self.log_dir = log_dir

    def reset(self):
        self.prev_sol = np.tile((self.ac_lb + self.ac_ub) / 2, [self.plan_hor])

    def set_cost_function(self, cost_fn):
        self.cost_fn = cost_fn

    def get_next_action(self, obs, mask, matrix_world_to_camera, pull_distance_max, x_threshold, z_threshold,
        timestep=None, ep=None):
     
        # soln = self._run_cem(obs, mean=self.prev_sol, var=self.init_var, timestep=timestep, num_iters=self.num_iters, ep=ep)
        # self.prev_sol = np.concatenate([np.copy(soln)[self.act_dim:], np.zeros(self.act_dim)])
        # normalized_delta_move = np.random.uniform(-1, 1, size=(self.popsize, 2))
        
        ### new way of sampling uv
        # uv_mean = np.zeros(2)
        # uv_std = np.ones(2)
        # sampled_uv = np.clip(np.random.normal(uv_mean, uv_std, size=(self.popsize, 2)), -1, 1)
        # normalized_uv, uv = self.convert_uv(sampled_uv, mask)
        # pull_vector_mean = np.zeros(2)
        # pull_vector_std = np.ones(2)
        # normalized_delta_move = np.clip(np.random.normal(pull_vector_mean, pull_vector_std, size=(self.popsize, 2)), -1, 1)


        ### old way of sampling uv
        normalized_uv, uv = self.sample_uv(mask)
        normalized_delta_move = self.sample_delta_move(matrix_world_to_camera, uv, mask, pull_distance_max, 
            x_threshold, z_threshold)

        normalized_action = np.concatenate([normalized_uv, normalized_delta_move], axis=1)
        costs, pred_trajs = self._predict_and_eval(obs, normalized_action, timestep=timestep)
        sort_idx = np.argsort(costs)
        best_action = normalized_action[sort_idx[0]]
        best_prediction = pred_trajs[sort_idx[0]]

        save_img = best_prediction.reshape((56,56,obs.shape[2]))
        # cv2.imwrite(os.path.join(self.log_dir, 'cem_{}_{}.png'.format(ep, timestep)), save_img)

        return best_action, uv[sort_idx[0]], save_img

   
    
    def clip_sampled_move(self, sampled_move, unormalized_pick_uv, action_range, pixel_low, pixel_high):
        longest_move = action_range[1]
        # denormalize
        unormalized_move = (sampled_move * longest_move).astype("int")
        # get place point
        after_move_uv = unormalized_pick_uv + unormalized_move
        # clip place point
        after_move_uv = np.clip(after_move_uv, pixel_low, pixel_high)
        # recompute movement
        movement = after_move_uv - unormalized_pick_uv
        # normalize agains
        normalized_move = movement / longest_move
        return normalized_move

    def convert_sampled_theta_and_distance_to_normalized_movement(
        self, sampled_theta, sampled_distance, unormalized_pick_uv, action_range, pixel_low, pixel_high
    ):
        # suppose unormalized move is within range [0, 1], the denormalized val is positive
        longest_move = action_range[1]
        unormalized_move = longest_move * sampled_distance

        # compute place location; theta is in [0, 2\pi] so the move can be negative
        sampled_theta = sampled_theta.flatten()
        unormalized_place_u = np.rint(
            np.clip(
                unormalized_pick_uv[:, 0] + unormalized_move * np.sin(sampled_theta), 
                pixel_low, pixel_high
            )
        )
        unormalized_place_v = np.rint(
            np.clip(
                unormalized_pick_uv[:, 1] + unormalized_move * np.cos(sampled_theta), 
                pixel_low, pixel_high
            )
        )

        # print("unormalized_place_v.shape: ", unormalized_place_v.shape)
        # print("unormalized_place_u.shape: ", unormalized_place_u.shape)
        # print("unormalized_pick_uv[:, 0].shape: ", unormalized_pick_uv[:, 0].shape)
        # print("unormalized_pick_uv[:, 1].shape: ", unormalized_pick_uv[:, 1].shape)
        

        # normalize movement action again
        move_u = unormalized_place_u - unormalized_pick_uv[:, 0]
        move_v = unormalized_place_v - unormalized_pick_uv[:, 1]
        move_u = move_u / longest_move
        move_v = move_v / longest_move
        # print("move_u.shape: ", move_u.shape)
        # print("move_v.shape: ", move_v.shape)
        assert np.alltrue(-1 <= move_u) and np.alltrue(move_u <= 1)
        assert np.alltrue(-1 <= move_v) and np.alltrue(move_v <= 1)
        return np.stack([move_u, move_v], axis=1)

    def get_next_action_cem_original(self, obs, mask, matrix_world_to_camera, action_range, pixel_low, 
        pixel_high, timestep=None, adim=4):


        if True:
            uv_mean = np.zeros(adim // 2)
            uv_std = np.ones(adim // 2)
            move_mean = np.zeros(adim // 2)
            move_std = np.ones(adim // 2)

            print("pop size is: ", self.popsize, flush=True)
            for cem_iter in range(self.num_iters):
                print("cem iter {}".format(cem_iter), flush=True)
                sampled_uv = np.clip(np.random.normal(uv_mean, uv_std, size=(self.popsize, adim // 2)), -1, 1)
                sampled_move = np.clip(np.random.normal(
                    move_mean, move_std, size=(self.popsize, adim // 2)), -1, 1
                )

                # project sampled uv to cloth mask
                pick_1_uv, unormalized_pick_1_uv = self.convert_uv(sampled_uv[:, :2], mask, col_row=False)
                # clip sampled movement to be within range
                sampled_move[:, :2] = self.clip_sampled_move(
                    sampled_move[:, :2], unormalized_pick_1_uv, action_range, pixel_low, pixel_high
                )

                if adim == 8:
                    # project sampled uv to cloth mask
                    pick_2_uv, unormalized_pick_2_uv = self.convert_uv(sampled_uv[:, 2:], mask, col_row=False)
                    # clip samplied movement to be within range
                    sampled_move[:, 2:] = self.clip_sampled_move(
                        sampled_move[:, 2:], unormalized_pick_2_uv, action_range, pixel_low, pixel_high
                    )

                normalized_action = np.concatenate([pick_1_uv, sampled_move[:, :2]], axis=1)
                if adim == 8:
                    normalized_action = np.concatenate(
                        [normalized_action, pick_2_uv, sampled_move[:, 2:]], axis=1
                    )

                costs, pred_trajs = self._predict_and_eval(obs, normalized_action, timestep=timestep)

                sort_idx = np.argsort(costs)
                elite_actions = normalized_action[sort_idx[:self.num_elites]]
                if adim == 4:
                    uv_mean = np.mean(elite_actions[:, :2], axis=0)          
                    uv_std = np.std(elite_actions[:, :2], axis=0)
                    move_mean = np.mean(elite_actions[:, 2:], axis=0)          
                    move_std = np.std(elite_actions[:, 2:], axis=0)
                elif adim == 8:
                    uv_mean = np.mean(elite_actions[:, [0, 1, 4, 5]], axis=0)          
                    uv_std = np.std(elite_actions[:, [0, 1, 4, 5]], axis=0)
                    move_mean = np.mean(elite_actions[:, [2, 3, 6, 7]], axis=0)          
                    move_std = np.std(elite_actions[:, [2, 3, 6, 7]], axis=0)

            best_action = normalized_action[sort_idx[0]]
            best_prediction = pred_trajs[sort_idx[0]]

            save_img = best_prediction.reshape((56,56,obs.shape[2]))
            return best_action, None, save_img
        elif False:
            uv_mean = np.zeros(4)
            uv_std = np.ones(4)
            theta_mean = np.zeros(1)
            theta_std = np.ones(1)
            distance_mean = np.zeros(2) 
            distance_std = np.ones(2)

            print("pop size is: ", self.popsize, flush=True)
            for cem_iter in range(self.num_iters):
                print("cem iter {}".format(cem_iter), flush=True)
                sampled_uv = np.clip(np.random.normal(uv_mean, uv_std, size=(self.popsize, 4)), -1, 1)
                sampled_theta = (np.clip(np.random.normal(theta_mean, theta_std, size=(self.popsize, 1)), -1, 1) + 1) * np.pi
                sampled_distance = (np.clip(np.random.normal(distance_mean, distance_std, size=(self.popsize, 2)), -1, 1) + 1) / 2.

                # print("sampled theta is: ", sampled_theta / np.pi, flush=True)

                # project sampled uv to cloth mask
                pick_1_uv, unormalized_pick_1_uv = self.convert_uv(sampled_uv[:, :2], mask, col_row=False)
                # clip sampled movement to be within range
                sampled_move_1 = self.convert_sampled_theta_and_distance_to_normalized_movement(
                    sampled_theta, sampled_distance[:, 0], unormalized_pick_1_uv, action_range, pixel_low, pixel_high
                )

                # project sampled uv to cloth mask
                pick_2_uv, unormalized_pick_2_uv = self.convert_uv(sampled_uv[:, 2:], mask, col_row=False)
                # clip samplied movement to be within range
                sampled_move_2 = self.convert_sampled_theta_and_distance_to_normalized_movement(
                    sampled_theta, sampled_distance[:, 1], unormalized_pick_2_uv, action_range, pixel_low, pixel_high
                )

                # print("pick_1_uv shape: ", pick_1_uv.shape)
                # print("pick_2_uv shape: ", pick_2_uv.shape)
                # print("sampled_move_1 shape: ", sampled_move_1.shape)
                # print("sampled_move_2 shape: ", sampled_move_2.shape, flush=True)
                normalized_action = np.concatenate([pick_1_uv, sampled_move_1, pick_2_uv, sampled_move_2], axis=1)
                costs, pred_trajs = self._predict_and_eval(obs, normalized_action, timestep=timestep)

                sort_idx = np.argsort(costs)
                elite_idx = sort_idx[:self.num_elites]
                uv_mean = np.mean(sampled_uv[elite_idx], axis=0)          
                uv_std = np.std(sampled_uv[elite_idx], axis=0)
                theta_mean = np.mean(sampled_theta[elite_idx], axis=0)
                theta_std = np.std(sampled_theta[elite_idx], axis=0)
                distance_mean = np.mean(sampled_distance[elite_idx], axis=0)
                distance_std = np.std(sampled_distance[elite_idx], axis=0)

            best_action = normalized_action[sort_idx[0]]
            best_prediction = pred_trajs[sort_idx[0]]

            save_img = best_prediction.reshape((56,56,obs.shape[2]))
            return best_action, None, save_img


    def sample_delta_move(self, matrix_world_to_camera, uvs, depth, pull_distance_max, x_threshold, z_threshold):
        # print("in sample delta move, pull_distance_max {} x_threshold {} z_threshold {}".format(
        #     pull_distance_max, x_threshold, z_threshold
        # ))
        delta_xs = []
        delta_zs = []
        for uv in uvs:
            imsize = depth.shape[0]
            u, v = uv[0], uv[1]
            d = depth[v][u]
            picker_pos = get_world_coor_from_image(u, v, d, imsize, imsize, matrix_world_to_camera)
            picker_x, picker_z = picker_pos[0], picker_pos[2]

            while True:
                
                delta_x = np.random.uniform(-pull_distance_max, pull_distance_max)
                delta_z = np.random.uniform(-pull_distance_max, pull_distance_max)

                # break
                new_x = picker_x + delta_x
                new_z = picker_z + delta_z
                if new_x < x_threshold and new_x > - x_threshold and new_z < z_threshold and new_z > - z_threshold:
                    break

            delta_xs.append(delta_x)
            delta_zs.append(delta_z)
        
        delta_x = np.vstack(delta_xs)
        delta_z = np.vstack(delta_zs)
        delta_move = np.hstack([delta_x, delta_z])
    
        delta_move = delta_move / pull_distance_max
        return delta_move
        
    def sample_uv(self, mask):

        # imsize, _ = mask.shape
        # cloth_idx = np.argwhere(mask)
        # idxes = np.random.randint(0, len(cloth_idx), size=self.popsize)
        # rowcols = cloth_idx[idxes]

        # uvs = rowcols.copy()
        # uvs[:, 0] = rowcols[:, 1]
        # uvs[:, 1] = rowcols[:, 0]
        # normalized_uvs = 2 * (uvs).astype(np.float64) / (imsize - 1)  - 1

        imsize = mask.shape[0]
        idxs = np.argwhere(mask)
        uvs = idxs.copy()
        uvs[:, 0] = idxs[:, 1]
        uvs[:, 1] = idxs[:, 0]

        bb_margin = 6
        lb_u, ub_u = int(np.min(uvs[:, 0])), int(np.max(uvs[:, 0]))
        lb_v, ub_v = int(np.min(uvs[:, 1])), int(np.max(uvs[:, 1]))
        u = np.random.randint(max(lb_u - bb_margin, 0), min(ub_u + bb_margin, imsize - 1), size=self.popsize)
        v = np.random.randint(max(lb_v - bb_margin, 0), min(ub_v + bb_margin, imsize - 1), size=self.popsize)
        # u = np.random.randint(0, imsize - 1, size=self.popsize)
        # v = np.random.randint(0, imsize - 1, size=self.popsize)
        

        uv = np.vstack([u, v]).T # popsize x 2
        distances = scipy.spatial.distance.cdist(uv, uvs)
        argmin_idxes = np.argmin(distances, axis=1)
        uvs = uvs[argmin_idxes]
        normalized_uvs = 2 * (uvs).astype(np.float64) / (imsize - 1)  - 1

        return normalized_uvs, uvs

    def convert_uv(self, sampled_uv, mask, col_row=True):
        imsize = mask.shape[0]
        idxs = np.argwhere(mask)
        uvs = idxs.copy()
        if col_row:
            uvs[:, 0] = idxs[:, 1]
            uvs[:, 1] = idxs[:, 0]

        unnormalized_sampled_uv = ((sampled_uv + 1) * (imsize - 1) / 2).astype('int')

        uv = unnormalized_sampled_uv
        distances = scipy.spatial.distance.cdist(uv, uvs)
        argmin_idxes = np.argmin(distances, axis=1)
        uvs = uvs[argmin_idxes]
        normalized_uvs = 2 * (uvs).astype(np.float64) / (imsize - 1)  - 1

        return normalized_uvs, uvs

    def _predict_and_eval(self, obs, ac_seqs, timestep=None):
        ac_seqs = np.reshape(ac_seqs, [-1, self.plan_hor, self.act_dim])
        pred_trajs = self.model.predict(obs, ac_seqs)
        # since feed_dict in SV2P is going to require np arrays
        if self.batch:
            costs = self.cost_fn(pred_trajs[:,0])
        else:
            costs = []
            for traj in pred_trajs:
                traj = traj[0]
                costs.append(self.cost_fn(traj))
        return np.array(costs), pred_trajs[:,0]

