import os
import jax

# jax.config.update("jax_compilation_cache_dir", "__jaxcache__")

import jax.numpy as jnp
import numpy as np
from functools import partial
import time
import open3d as o3d
import optax
import jax.debug as jdb

import sys
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if BASE_DIR not in sys.path:
    sys.path.insert(0, BASE_DIR)

import util.model_util as mutil
import util.transform_util as tutil
import util.latent_obj_util as loutil
from util.reconstruction_util import create_fps_fcd_from_oriCORNs
from modules import shakey_module
from util.dotenv_util import REP_CKPT
import einops
import matplotlib.pyplot as plt
import util.structs as structs
# import modules.traj_search_module as traj_search_module
from modules.ccd import CuroboCCD, TrajOptCCD, OursCCD, ContinuousCollisionCostBase
import util.broad_phase as broad_phase

def way_points_to_trajectory(waypnts, resolution, cos_transition=True):
    """???"""
    epsilon = 1e-8
    wp_len = jnp.linalg.norm(waypnts[1:] - waypnts[:-1], axis=-1)
    wp_len = wp_len/jnp.sum(wp_len).clip(epsilon)
    wp_len = jnp.where(wp_len<epsilon*10, 0, wp_len)
    wp_len = wp_len/jnp.sum(wp_len)
    wp_len_cumsum = jnp.cumsum(wp_len)
    wp_len_cumsum = jnp.concatenate([jnp.array([0]),wp_len_cumsum], 0)
    wp_len_cumsum = wp_len_cumsum.at[-1].set(1.0)
    indicator = jnp.linspace(0, 1, resolution)
    if cos_transition:
        indicator = (-jnp.cos(indicator*jnp.pi)+1)/2.
    included_idx = jnp.sum(indicator[...,None] > wp_len_cumsum[1:], axis=-1)
    
    upper_residual = (wp_len_cumsum[included_idx+1] - indicator)/wp_len[included_idx].clip(epsilon)
    upper_residual = upper_residual.clip(0.,1.)
    bottom_residual = 1.-upper_residual
    
    traj = waypnts[included_idx] * upper_residual[...,None] + waypnts[included_idx+1] * bottom_residual[...,None]
    traj = jnp.where(wp_len[included_idx][...,None] < 1e-4, waypnts[included_idx], traj)
    traj = traj.at[0].set(waypnts[0])
    traj = traj.at[-1].set(waypnts[-1])
    
    return traj


def evaluate_full_trajectory(control_points, samples_per_segment):
    """
    Evaluate the full trajectory by sampling each segment.
    
    control_points: array of shape (N, 6).
    samples_per_segment: number of sample points per segment.
    Returns: a concatenated array of trajectory values. ((N-1)*samples_per_segment, 6).
    """
    NT = control_points.shape[-2]
    if NT == 2:
        control_points = way_points_to_trajectory(control_points, samples_per_segment, cos_transition=False)
    coeffs = broad_phase.SE3_interpolation_coeffs(control_points)
    return broad_phase.SE3_interpolation_eval(*coeffs, jnp.linspace(0, 1, (NT-1)*samples_per_segment+1, endpoint=True))


def vis_gradient_callback(inputs):
    x, direction, loss_args, self = inputs[:4]
    # x : [NT, 6]
    # direction : [NT, 6]
    # loss_args : struct
    # shakey : shakey_module.Shakey

    # draw 

    fixed_obj = loss_args.fixed_oriCORNs
    interpolated_trajectory, vel, acc, jerk = jax.vmap(partial(evaluate_full_trajectory, samples_per_segment=4))(x)

    # add robot base se2
    interpolated_trajectory = jnp.concatenate([
        jnp.broadcast_to(self.base_se2, (interpolated_trajectory.shape[:-1] + (3,))),
        interpolated_trajectory
    ], axis=-1)

    moving_obj_pqs = self.shakey.FK(interpolated_trajectory, oriCORN_out=False) # [NT, NOB, 7]
    moving_obj = self.shakey.link_canonical_oriCORN # [NOB, ]

    moving_obj_tf = moving_obj.apply_pq_z(moving_obj_pqs, self.models.rot_configs)

    moving_points = moving_obj_tf.fps_tf # [NT, NOB, NFP, 3]
    moving_points_seq = jnp.stack([moving_points[...,1:,:,:,:], moving_points[...,:-1,:,:,:]], axis=-1) # [NT, NOB, NFP, 3, 2] - start and end points
    moving_points_seq = einops.rearrange(moving_points_seq, '... i j k p q -> ... (i j k) p q') # [NT*NOB*NFP, 3, 2]

    for batch_idx in range(moving_points_seq.shape[0]):
        o3d_moving = o3d.geometry.PointCloud()
        o3d_moving.points = o3d.utility.Vector3dVector(moving_points[batch_idx].reshape(-1, 3))
        o3d_moving.paint_uniform_color([0,1,0])

        fixed_pnts = fixed_obj.fps_tf
        fixed_pnts = fixed_pnts.reshape(-1, 3)
        o3d_fixed = o3d.geometry.PointCloud()
        o3d_fixed.points = o3d.utility.Vector3dVector(fixed_pnts)
        o3d_fixed.paint_uniform_color([1,0,0])
        
        # line sequence
        lines_seq = []
        for j in range(moving_points_seq.shape[1]):
            points_seq = moving_points_seq[batch_idx, j]
            points_seq = jnp.moveaxis(points_seq, -1, -2).reshape(-1, 3)
            line_idx = np.arange(points_seq.shape[0]//2)
            line_idx = np.array([[0,1]]).astype(np.int32)
            lines = o3d.geometry.LineSet(points=o3d.utility.Vector3dVector(points_seq), lines=o3d.utility.Vector2iVector(line_idx)).paint_uniform_color([0,0,1])
            lines_seq.append(lines)
        
        o3d.visualization.draw_geometries([o3d_moving, o3d_fixed, *lines_seq])




def print_callback(inputs):
    inputs[0](*inputs[1], visualize=True)

    print(inputs)

    plt.figure()
    for i in range(6):
        plt.subplot(6,1,i+1)
        plt.plot(np.linspace(0,1,inputs[0].shape[0]), inputs[0][:,i])
        plt.plot(np.linspace(0,1,inputs[1][0].shape[0]), inputs[1][0][:,i])
        plt.plot(np.linspace(0,1,inputs[2][0].shape[0]), inputs[2][0][:,i])
    plt.show()


def smooth_col_cost(col_logits, mu=0.1):
    col_cost = jnp.where(col_logits > 0, col_logits + 0.5*mu, col_logits)
    col_cost = jnp.where(jnp.logical_and(0>col_logits, col_logits>-mu), 0.5/mu*(col_logits+mu)**2, col_cost)
    col_cost = jnp.where(col_logits<-mu, 0, col_cost)
    return col_cost

def joint_limit_cost_func(q_t, q_l, q_u, eta):
    return jnp.where(
        q_t < q_l, q_l - q_t + 0.5 * eta,
        jnp.where(
            (q_l <= q_t) & (q_t < q_l + eta),
            0.5 / eta * (q_l - q_t + eta)**2,
            jnp.where(
                q_t > q_u, q_t - q_u + 0.5 * eta,
                jnp.where(
                    (q_u - eta < q_t) & (q_t <= q_u),
                    0.5 / eta * (q_t - q_u + eta)**2,
                    0.0
                )
            )
        )
    )


class CostModules(object):

    def __init__(self, models:mutil.Models, shakey:shakey_module.Shakey, 
                 robot_base_pqc=None, 
                 ccd_type='ours',
                 broadphase_type='naive',
                 collision_threshold=4.0,
                 col_coef=1.0,
                 reduce_k=16,
                 acc_coef_factor=5, 
                 jerk_coef_factor=1,
                 vel_coef=0.02,
                 curobo_activation_distance=0.040,
                 se2_bounds=None,
                 ):
        
        self.se2_bounds = se2_bounds
        optimizer = optax.lbfgs(linesearch=None)
        self.optimizer = optimizer

        self.models = models
        self.col_coef = col_coef
        self.collision_threshold = collision_threshold

        if robot_base_pqc is not None:
            base_se2, robot_height = tutil.pq2SE2h(robot_base_pqc)
            self.base_se2 = base_se2
            shakey:shakey_module.Shakey = shakey.replace(robot_fixed_pqc=robot_base_pqc)
        else:
            self.base_se2 = None

        self.shakey = shakey

        self.ccd_type = ccd_type
        self.broadphase_type = broadphase_type

        self.self_collision_threshold = 0.5
        self.self_collision_coef = self.col_coef
        
        self.joint_limit_coef = 200.0

        self.vel_coef = vel_coef
        
        self.acc_coef = self.vel_coef*1e-3*acc_coef_factor
        self.jerk_coef = self.vel_coef*1e-7*jerk_coef_factor

        self.ccd_cls: ContinuousCollisionCostBase = None

        if self.ccd_type == 'curobo':
            self.ccd_cls = CuroboCCD(col_coef, activation_distance=curobo_activation_distance, rot_configs=None)
        elif self.ccd_type == 'trajopt':
            self.ccd_cls = TrajOptCCD(col_coef)
        elif self.ccd_type == 'ours' or self.ccd_type == 'stamp':
            self.ccd_cls = OursCCD(
                models,
                collision_threshold,
                col_coef,
                reduce_k,
                is_continuous=(self.ccd_type == 'ours'),
                broad_phase_cls=broad_phase.BroadPhaseWarp(),
                broadphase_type=self.broadphase_type
            )
        else:
            raise NotImplementedError(f"{self.ccd_type} is not implemented")

    def traj_opt_cost(self, trajectory: jnp.ndarray, 
                  loss_args:structs.LossArgs, jkey: jnp.ndarray, interpolation_num:int, 
                  ccd=True,
                  visualize=False):
        '''
        NI, number_of_interpolation_for_closest_point_search
        NBP, number_of_broadphase_pair_result
        fixed_obj: [,]
        moving_obj: [NOB, ] NOB: number of moving object
        trajectory: [NT, 6] NT: number of trajectory points, 6: 6 joint angles
        '''

        fixed_obj = loss_args.fixed_oriCORNs
        if ccd:
            original_outer_shape = trajectory.shape[:-2]
            trajectory = trajectory.reshape(-1, trajectory.shape[-2], trajectory.shape[-1])
            interpolated_trajectory, vel, acc, jerk = jax.vmap(partial(evaluate_full_trajectory, samples_per_segment=interpolation_num))(trajectory)

            vel, acc, jerk = vel[...,1:-1,:], acc[...,1:-1,:], jerk[...,1:-1,:]

            vel_loss = self.vel_coef*jnp.mean(jnp.mean(jnp.square(vel), axis=-1), axis=(-1,))
            acc_loss = self.acc_coef*jnp.mean(jnp.mean(jnp.square(acc), axis=-1), axis=(-1,))
            jerk_loss = self.jerk_coef*jnp.mean(jnp.mean(jnp.square(jerk), axis=-1), axis=(-1,))
        else:
            original_outer_shape = trajectory.shape[:-1]
            interpolated_trajectory = trajectory[...,None,:].repeat(2, axis=-2)
            vel_loss = jnp.zeros(original_outer_shape, dtype=jnp.float32)
            acc_loss = jnp.zeros(original_outer_shape, dtype=jnp.float32)
            jerk_loss = jnp.zeros(original_outer_shape, dtype=jnp.float32)

        # add robot base se2
        if self.base_se2 is not None and interpolated_trajectory.shape[-1] == self.shakey.num_act_joints:
            interpolated_trajectory = jnp.concatenate([
                jnp.broadcast_to(self.base_se2, (interpolated_trajectory.shape[:-1] + (3,))),
                interpolated_trajectory
            ], axis=-1)

        if loss_args.moving_oriCORNs is not None:
            moving_obj_link = self.shakey.link_canonical_oriCORN
            moving_obj_in_hand = loss_args.moving_oriCORNs
            if moving_obj_in_hand.ndim == 0:
                moving_obj_in_hand = moving_obj_in_hand[None]
            moving_obj = moving_obj_link.concat(moving_obj_in_hand, axis=0)

            moving_obj_link_pqs = self.shakey.FK(interpolated_trajectory, oriCORN_out=False)
            ee_indices = self.shakey.ee_idx
            ee_indices = np.array(ee_indices)
            if ee_indices.ndim == 0:
                ee_indices = ee_indices[None]
            moving_obj_pqs = jnp.concatenate([
                moving_obj_link_pqs,
                tutil.pq_multi(
                    moving_obj_link_pqs[..., ee_indices, :],
                    loss_args.ee_to_obj_pq,
                )
            ], axis=-2)
        else:
            moving_obj_pqs = self.shakey.FK(interpolated_trajectory, oriCORN_out=False) # [NT, NOB, 7]
            moving_obj = self.shakey.link_canonical_oriCORN # [NOB, ]
        

        def relative_ccd(fixed_idx, moving_idx):
            fixed_obj_bimanual = moving_obj[fixed_idx]
            moving_obj_pqs_tmp = moving_obj_pqs[...,moving_idx,:]
            fixed_obj_pqs = moving_obj_pqs[...,fixed_idx,:]
            
            moving_obj_pqs_ccd = tutil.pq_multi(tutil.pq_inv(fixed_obj_pqs[...,None,:]), moving_obj_pqs_tmp[...,None,:,:]) # (NT, NOB1, NOB2, 7)
            moving_obj_ccd = moving_obj[moving_idx]
            
            args = ()
            vmap_axes = ()
            if isinstance(self.ccd_cls, OursCCD):
                args = (jkey, interpolation_num, visualize)
                vmap_axes = (None, None, None)
                ccd_cls_vmap = jax.vmap(self.ccd_cls, (None,-3,0,*vmap_axes))
            elif isinstance(self.ccd_cls, CuroboCCD):
                def ccd_cls_vmap(moving_obj_ccd, moving_obj_pqs_ccd, fixed_obj_bimanual):
                    ccd_res_list = []
                    for i in range(fixed_obj_bimanual.shape[0]):
                        ccd_res = self.ccd_cls(
                            moving_obj_ccd,
                            moving_obj_pqs_ccd[..., i, :, :],
                            fixed_obj_bimanual[i],
                            moving_idx,
                            fixed_idx[i],
                            loss_args.moving_spheres,
                            loss_args.mesh_ids,
                            visualize=visualize,
                        )
                        ccd_res_list.append(ccd_res)
                    ccd_res = jax.tree_util.tree_map(lambda *x: jnp.stack(x, axis=0), *ccd_res_list)
                    return ccd_res

            ccd_res = ccd_cls_vmap(
                moving_obj_ccd,
                moving_obj_pqs_ccd,
                fixed_obj_bimanual[:,None],
                *args,
            )

            ccd_aux = jax.tree_util.tree_map(lambda x: jnp.max(x, axis=0), ccd_res[1])
            ccd_res = (jnp.sum(ccd_res[0], axis=0), ccd_aux)
            return ccd_res
        
        ccd_res_list = []
        if fixed_obj is not None:
            args = ()
            if isinstance(self.ccd_cls, OursCCD):
                args = (jkey, interpolation_num, visualize)
            elif isinstance(self.ccd_cls, CuroboCCD):
                args = (None, None, loss_args.moving_spheres, loss_args.mesh_ids, visualize)

            ccd_res = self.ccd_cls(
                moving_obj,
                moving_obj_pqs,
                fixed_obj,
                *args,
            )
            ccd_res_list.append(ccd_res)
            if visualize:
                return ccd_res
        
        if loss_args.fixed_moving_idx_pair is not None:
            for fixed_idx, moving_idx in loss_args.fixed_moving_idx_pair:
                ccd_res = relative_ccd(fixed_idx, moving_idx)
                ccd_res_list.append(ccd_res)

        ccd_res = jax.tree_util.tree_map(lambda *x: jnp.stack(x, axis=0), *ccd_res_list)
        ccd_aux = jax.tree_util.tree_map(lambda x: jnp.max(x, axis=0), ccd_res[1])
        ccd_res = (jnp.sum(ccd_res[0], axis=0), ccd_aux)
        
        collision_loss, collision_aux_info = ccd_res
        collision_binary = collision_aux_info['collision_binary']
        aggregate_axes = -(np.arange(collision_binary.ndim - collision_loss.ndim) + 1).astype(int)
        collision_binary = jnp.any(collision_binary, axis=aggregate_axes.tolist()).astype(jnp.bool)

        ## add self collision cost
        if self.models.self_collision_checker_model is not None:
            self_collision_logit = self.models.apply('self_collision_checker', interpolated_trajectory[...,-self.shakey.num_act_joints:]).squeeze(-1)
            self_collision_loss = self.self_collision_coef*jnp.sum(smooth_col_cost(self_collision_logit, mu=self.self_collision_threshold), axis=(-1,))
            self_collision_logit = mutil.aggregate_cost(self_collision_logit, axes=(-1,))
        else:
            self_collision_logit = -jnp.ones_like(collision_loss)
            self_collision_loss = jnp.zeros_like(collision_loss)

        joint_limit_loss = joint_limit_cost_func(interpolated_trajectory[...,-self.shakey.num_act_joints:], 
                                                 self.shakey.q_lower_bound, self.shakey.q_upper_bound, 0.01)
        joint_limit_logit = self.shakey.q_lower_bound - interpolated_trajectory[...,-self.shakey.num_act_joints:]
        joint_limit_logit = jnp.maximum(joint_limit_logit, interpolated_trajectory[...,-self.shakey.num_act_joints:] - self.shakey.q_upper_bound)
        joint_limit_logit = jnp.max(joint_limit_logit, axis=(-1,-2))
        joint_limit_loss = self.joint_limit_coef*jnp.sum(joint_limit_loss, axis=(-1,-2))

        # add wall collision
        if loss_args.plane_params is not None:
            moving_obj_radius = moving_obj.mean_fps_dist.at[0].set(0.01) # set base radius
            moving_obj_fps_tf = tutil.pq_action(moving_obj_pqs[...,None,:], moving_obj.fps_tf)
            plane_col_logit = loss_args.plane_params[...,3]-jnp.einsum('...i,...ji->...j', moving_obj_fps_tf, loss_args.plane_params[...,:3]) # [NB, NT, NOB, NFPS, NPLN]
            plane_col_logit = plane_col_logit+moving_obj_radius[...,None,None] # [NB, NT, NOB, NFPS, NPLN]
            plane_col_binary = plane_col_logit > -0.015
            plane_col_loss = self.joint_limit_coef*jnp.sum(smooth_col_cost(plane_col_logit, mu=0.07), axis=(-1,-2, -3, -4))
            plane_col_loss += 1*jnp.sum(plane_col_binary, axis=(-1,-2,-3,-4))
            plane_col_logit = mutil.aggregate_cost(plane_col_logit, axes=(-1,-2, -3, -4))
            plane_col_binary = jnp.any(plane_col_binary, axis=(-1,-2, -3, -4)).astype(jnp.bool)
        else:
            plane_col_logit = -jnp.ones_like(collision_loss)
            plane_col_loss = jnp.zeros_like(collision_loss)

        valid_key = []
        for key in collision_aux_info:
            if collision_aux_info[key].shape == collision_loss.shape:
                valid_key.append(key)
        collision_aux_info = {key: collision_aux_info[key] for key in valid_key}

        if 'collision_logits' not in collision_aux_info:
            collision_aux_info['collision_logits'] = -1.0
        # collision_aux_info['collision_logits'] = jnp.maximum(collision_aux_info['collision_logits'], self_collision_logit)
        collision_aux_info['collision_logits'] = jnp.maximum(collision_aux_info['collision_logits'], joint_limit_logit)
        
        invalid_mask = collision_binary
        # invalid_mask = jnp.logical_or(collision_binary, self_collision_logit > -1.0)
        invalid_mask = jnp.logical_or(invalid_mask, joint_limit_logit > -0.01)

        if loss_args.plane_params is not None:
            collision_aux_info['collision_logits'] = jnp.maximum(collision_aux_info['collision_logits'], plane_col_logit)
            invalid_mask = jnp.logical_or(invalid_mask, plane_col_binary)
        else:
            plane_col_binary = jnp.zeros_like(collision_binary)

        loss = collision_loss + vel_loss + self_collision_loss + acc_loss + jerk_loss + joint_limit_loss + plane_col_loss
        loss_aux = {
            'loss': loss,
            'collision_loss': collision_loss,
            **collision_aux_info,
            'vel_loss': vel_loss,
            'self_collision_loss': self_collision_loss,
            'acc_loss': acc_loss,
            'jerk_loss': jerk_loss,
            'plane_col_loss': plane_col_loss,
            'joint_limit_loss': joint_limit_loss,
            'invalid_mask': invalid_mask,
            'collision_binary': collision_binary,
            'plane_col_binary': plane_col_binary,
        }

        loss, loss_aux = jax.tree_util.tree_map(lambda x: x.reshape(original_outer_shape), (loss, loss_aux))

        return loss, loss_aux