import os
# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"   
# os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"

import jax

# jax.config.update("jax_compilation_cache_dir", "__jaxcache__")
# jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
# jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
# jax.config.update("jax_persistent_cache_enable_xla_caches", "xla_gpu_per_fusion_autotune_cache_dir")

import jax.numpy as jnp
import numpy as np
import pickle
import glob
from tqdm import tqdm
from functools import partial
import time
import open3d as o3d
import optax
import jax.debug as jdb

import sys
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if BASE_DIR not in sys.path:
    sys.path.insert(0, BASE_DIR)

import util.model_util as mutil
from train_pointfeat import ColDataset
import util.transform_util as tutil
import util.latent_obj_util as loutil
from util.reconstruction_util import create_fps_fcd_from_oriCORNs
import pybullet as pb
from modules import shakey_module
from util.dotenv_util import REP_CKPT
import einops
import matplotlib.pyplot as plt
import util.structs as structs
# import modules.traj_search_module as traj_search_module
import modules.motion_planner as graph_motion_planner
from modules.ccd import CuroboCCD, TrajOptCCD, OursCCD, ContinuousCollisionCostBase
import util.broad_phase as broad_phase
import modules.cost_module as cost_module

def way_points_to_trajectory(waypnts, resolution, cos_transition=True, indicator=None):
    """???"""
    epsilon = 1e-8
    wp_len = jnp.linalg.norm(waypnts[1:] - waypnts[:-1], axis=-1)
    wp_len = wp_len/jnp.sum(wp_len).clip(epsilon)
    wp_len = jnp.where(wp_len<epsilon*10, 0, wp_len)
    wp_len = wp_len/jnp.sum(wp_len)
    wp_len_cumsum = jnp.cumsum(wp_len)
    wp_len_cumsum = jnp.concatenate([jnp.array([0]),wp_len_cumsum], 0)
    wp_len_cumsum = wp_len_cumsum.at[-1].set(1.0)
    if indicator is None:
        indicator = jnp.linspace(0, 1, resolution)
    if cos_transition:
        indicator = (-jnp.cos(indicator*jnp.pi)+1)/2.
    included_idx = jnp.sum(indicator[...,None] > wp_len_cumsum[1:], axis=-1)
    
    upper_residual = (wp_len_cumsum[included_idx+1] - indicator)/wp_len[included_idx].clip(epsilon)
    upper_residual = upper_residual.clip(0.,1.)
    bottom_residual = 1.-upper_residual
    
    traj = waypnts[included_idx] * upper_residual[...,None] + waypnts[included_idx+1] * bottom_residual[...,None]
    traj = jnp.where(wp_len[included_idx][...,None] < 1e-4, waypnts[included_idx], traj)
    traj = traj.at[0].set(waypnts[0])
    traj = traj.at[-1].set(waypnts[-1])
    
    return traj


def evaluate_full_trajectory(control_points, samples_per_segment):
    """
    Evaluate the full trajectory by sampling each segment.
    
    control_points: array of shape (N, 6).
    samples_per_segment: number of sample points per segment.
    Returns: a concatenated array of trajectory values. ((N-1)*samples_per_segment, 6).
    """
    NT = control_points.shape[-2]
    if NT == 2:
        control_points = way_points_to_trajectory(control_points, samples_per_segment, cos_transition=False)
    coeffs = broad_phase.SE3_interpolation_coeffs(control_points)
    return broad_phase.SE3_interpolation_eval(*coeffs, jnp.linspace(0, 1, (NT-1)*samples_per_segment+1, endpoint=True))


def vis_gradient_callback(inputs):
    x, direction, loss_args, self = inputs[:4]
    # x : [NT, 6]
    # direction : [NT, 6]
    # loss_args : struct
    # shakey : shakey_module.Shakey

    # draw 

    fixed_obj = loss_args.fixed_oriCORNs
    interpolated_trajectory, vel, acc, jerk = jax.vmap(partial(evaluate_full_trajectory, samples_per_segment=4))(x)

    # add robot base se2
    interpolated_trajectory = jnp.concatenate([
        jnp.broadcast_to(self.base_se2, (interpolated_trajectory.shape[:-1] + (3,))),
        interpolated_trajectory
    ], axis=-1)

    moving_obj_pqs = self.shakey.FK(interpolated_trajectory, oriCORN_out=False) # [NT, NOB, 7]
    moving_obj = self.shakey.link_canonical_oriCORN # [NOB, ]

    moving_obj_tf = moving_obj.apply_pq_z(moving_obj_pqs, self.models.rot_configs)

    moving_points = moving_obj_tf.fps_tf # [NT, NOB, NFP, 3]
    moving_points_seq = jnp.stack([moving_points[...,1:,:,:,:], moving_points[...,:-1,:,:,:]], axis=-1) # [NT, NOB, NFP, 3, 2] - start and end points
    moving_points_seq = einops.rearrange(moving_points_seq, '... i j k p q -> ... (i j k) p q') # [NT*NOB*NFP, 3, 2]

    for batch_idx in range(moving_points_seq.shape[0]):
        o3d_moving = o3d.geometry.PointCloud()
        o3d_moving.points = o3d.utility.Vector3dVector(moving_points[batch_idx].reshape(-1, 3))
        o3d_moving.paint_uniform_color([0,1,0])

        fixed_pnts = fixed_obj.fps_tf
        fixed_pnts = fixed_pnts.reshape(-1, 3)
        o3d_fixed = o3d.geometry.PointCloud()
        o3d_fixed.points = o3d.utility.Vector3dVector(fixed_pnts)
        o3d_fixed.paint_uniform_color([1,0,0])
        
        # line sequence
        lines_seq = []
        for j in range(moving_points_seq.shape[1]):
            points_seq = moving_points_seq[batch_idx, j]
            points_seq = jnp.moveaxis(points_seq, -1, -2).reshape(-1, 3)
            line_idx = np.arange(points_seq.shape[0]//2)
            line_idx = np.array([[0,1]]).astype(np.int32)
            lines = o3d.geometry.LineSet(points=o3d.utility.Vector3dVector(points_seq), lines=o3d.utility.Vector2iVector(line_idx)).paint_uniform_color([0,0,1])
            lines_seq.append(lines)
        
        o3d.visualization.draw_geometries([o3d_moving, o3d_fixed, *lines_seq])




def print_callback(inputs):
    # inputs[0](*inputs[1], visualize=True)

    # visualization
    # for i in range(10):
    #     o3d_fixed_gt = inputs[3].get_fps_o3d(color=np.array([1,0,0]))
    #     o3d_fixed = inputs[0][i].get_fps_o3d(color=np.array([0,1,0]))
    #     o3d_moving_gt = inputs[-1][i].get_fps_o3d(color=np.array([0.1,0.4,0.2]))
    #     o3d_moving = inputs[1][i].get_fps_o3d()
    #     # draw lines
    #     points_seq = np.array(np.moveaxis(inputs[2][i], -1 ,-2).reshape(-1, 3))
    #     line_idx = np.arange(points_seq.shape[0]//2)
    #     line_idx = np.stack([2*line_idx, 2*line_idx+1], axis=-1).astype(np.int32)
    #     lines = o3d.geometry.LineSet(points=o3d.utility.Vector3dVector(points_seq), lines=o3d.utility.Vector2iVector(line_idx)).paint_uniform_color([0,0,1])
    #     o3d.visualization.draw_geometries([o3d_fixed_gt, o3d_fixed, o3d_moving_gt, o3d_moving, lines])

    print(inputs)

    # plt.figure()
    # for i in range(6):
    #     plt.subplot(6,1,i+1)
    #     plt.plot(np.linspace(0,1,inputs[0].shape[0]), inputs[0][:,i])
    #     plt.plot(np.linspace(0,1,inputs[1][0].shape[0]), inputs[1][0][:,i])
    #     plt.plot(np.linspace(0,1,inputs[2][0].shape[0]), inputs[2][0][:,i])
    # plt.show()

def smooth_col_cost(col_logits, mu=0.1):
    col_cost = jnp.where(col_logits > 0, col_logits + 0.5*mu, col_logits)
    col_cost = jnp.where(jnp.logical_and(0>col_logits, col_logits>-mu), 0.5/mu*(col_logits+mu)**2, col_cost)
    col_cost = jnp.where(col_logits<-mu, 0, col_cost)
    return col_cost

def joint_limit_cost_func(q_t, q_l, q_u, eta):
    return jnp.where(
        q_t < q_l, q_l - q_t + 0.5 * eta,
        jnp.where(
            (q_l <= q_t) & (q_t < q_l + eta),
            0.5 / eta * (q_l - q_t + eta)**2,
            jnp.where(
                q_t > q_u, q_t - q_u + 0.5 * eta,
                jnp.where(
                    (q_u - eta < q_t) & (q_t <= q_u),
                    0.5 / eta * (q_t - q_u + eta)**2,
                    0.0
                )
            )
        )
    )


def zeroth_order_bundled_gradient(cost_fn, q, key, num_samples=32, sigma=0.1):
    """
    Computes the zero-order bundled gradient of cost_fn at q.
    
    Args:
      q: jnp.array, the joint trajectory (can be any shape).
      cost_fn: function mapping q -> scalar cost.
      num_samples: number of Monte Carlo samples.
      sigma: standard deviation for the Gaussian perturbation.
      key: a jax.random.PRNGKey for randomness.
    
    Returns:
      grad_est: an array of the same shape as q, representing the estimated gradient.
    """

    # Compute the cost at the nominal trajectory.
    f_q, loss_aux = cost_fn(q)
    
    # Sample num_samples perturbations from a Gaussian distribution.
    # The shape of each perturbation matches q.
    perturb_shape = (num_samples,) + q.shape
    perturbations = jax.random.normal(key, shape=perturb_shape) * sigma
    perturbations = perturbations.at[...,0,:].set(0.0)
    perturbations = perturbations.at[...,-1,:].set(0.0)
    
    # Define a function to evaluate the cost at a perturbed q.
    def eval_cost(perturb):
        return cost_fn(q + perturb)
    
    # Use vmap to evaluate the cost for all perturbations.
    f_q_perturbed, loss_aux_ptb = jax.vmap(eval_cost)(perturbations)
    
    # Compute the finite-difference quotient for each sample.
    # Note: This division is done elementwise.
    # Each sample yields an array of the same shape as q.
    grad_samples = (f_q_perturbed - f_q)[...,None,None] / perturbations
    
    # Average over all samples to get the bundled gradient.
    grad_est = jnp.mean(grad_samples, axis=0)

    loss_aux_ptb = jax.tree_util.tree_map(lambda x: jnp.mean(x, axis=0), loss_aux_ptb)
    return grad_est, loss_aux_ptb

class TrajectoryOptimizer(object):

    def __init__(self, models:mutil.Models, shakey:shakey_module.Shakey, robot_base_pqc=None, 
                 ccd_type='ours',
                 broadphase_type='naive',
                 col_coef=1.0,
                 particle_itr_no=10, gradient_itr_no=5,
                 num_trajectory_points_particle=5, num_trajectory_points_gradient=16,
                 interpolation_num_particle=3,
                 interpolation_num_gradient=3,
                collision_threshold=6.0, reduce_k=16,
                 num_mppi_samples=50,
                 num_seeds=4,
                acc_coef_factor=5, jerk_coef_factor=1,
                 vel_coef=0.02,
                 bundled_samples_num=16,
                 bundled_order=0,
                 linesearch_batch_num=20,
                 se2_bounds=None,
                 curobo_activation_distance=0.040,
                 ):
        optimizer = optax.lbfgs(linesearch=None)
        self.optimizer = optimizer

        self.models = models
        # self.col_coef = col_coef
        self.perform_gradient_jit = jax.jit(self.perform_gradient)
        self.bundled_samples_num = bundled_samples_num
        self.bundled_order = bundled_order
        self.linesearch_batch_num = linesearch_batch_num

        if robot_base_pqc is not None:
            base_se2, robot_height = tutil.pq2SE2h(robot_base_pqc)
            self.base_se2 = base_se2
            # self.shakey:shakey_module.Shakey = self.shakey.replace(robot_height=robot_height)
            shakey:shakey_module.Shakey = shakey.replace(robot_fixed_pqc=robot_base_pqc)
        else:
            self.base_se2 = None

        self.shakey = shakey

        self.cost_module_cls = cost_module.CostModules(
            models, shakey,
            robot_base_pqc=robot_base_pqc, 
            ccd_type=ccd_type,
            broadphase_type=broadphase_type,
            collision_threshold=collision_threshold,
            col_coef=col_coef,
            reduce_k=reduce_k,
            acc_coef_factor=acc_coef_factor, 
            jerk_coef_factor=jerk_coef_factor,
            vel_coef=vel_coef,
            se2_bounds=se2_bounds,
            curobo_activation_distance=curobo_activation_distance,
        )

        self.ccd_cls = self.cost_module_cls.ccd_cls

        self.graph_motion_planner = graph_motion_planner.MotionPlanner(self.cost_module_cls)

        self.num_seed = num_seeds
        # self.ccd_type = ccd_type
        # self.broadphase_type = broadphase_type
        self.particle_itr_no = particle_itr_no
        self.gradient_itr_no = gradient_itr_no
        self.num_trajectory_points_particle = num_trajectory_points_particle
        self.num_trajectory_points_gradient = num_trajectory_points_gradient
        self.interpolation_num_particle = interpolation_num_particle
        self.interpolation_num_gradient = interpolation_num_gradient
        # self.collision_threshold = collision_threshold
        # self.reduce_k = reduce_k

        # self.self_collision_threshold = 0.5
        # self.self_collision_coef = self.col_coef
        
        # self.joint_limit_coef = 200.0

        # self.vel_coef = vel_coef
        
        # self.acc_coef = self.vel_coef*1e-3*acc_coef_factor
        # self.jerk_coef = self.vel_coef*1e-7*jerk_coef_factor

        # MPPI params
        self.num_mppi_samples = num_mppi_samples
        self.k_mu = 0.1
        # self.beta = 0.003 # temperature
        # self.k_sigma = 0.2
        # self.init_cov_scale = 1.5
        self.beta = 0.005    # temperature
        self.k_sigma = 0.15
        self.init_cov_scale = 2.0


    def line_search(self, x:jnp.ndarray, direction, loss_aux, loss_args:structs.LossArgs, jkey):

        scale_min = jnp.array(1e-6)
        scale_max = jnp.array(1.0)
        zoom_ratio = [0.1, 0.02]

        # jdb.callback(vis_gradient_callback, (x, direction, loss_args, self))

        for i in range(2):
            scales = (scale_max[...,None] - scale_min[...,None])*jnp.linspace(0,1,self.linesearch_batch_num) + scale_min[...,None]
            direction_query = direction[...,None,:,:] * scales[...,None,None]
            x_query = x[...,None,:,:] + direction_query
            loss_batch, loss_aux = self.cost_module_cls.traj_opt_cost(x_query, loss_args, jkey, self.interpolation_num_gradient)
            min_idx = jnp.argmin(loss_batch, axis=-1, keepdims=True)
            min_scale = jnp.take_along_axis(scales + jnp.zeros(min_idx.shape), min_idx, axis=-1).squeeze(-1)
            scale_min = jnp.maximum(min_scale-zoom_ratio[i], scale_min)
            scale_max = jnp.minimum(min_scale+zoom_ratio[i], scale_max)
        loss_aux_selected = jax.tree_util.tree_map(lambda x: jnp.take_along_axis(x, min_idx, axis=-1).squeeze(-1), loss_aux)
        return jnp.take_along_axis(direction_query, min_idx[...,None,None], axis=-3).squeeze(-3), min_idx, loss_aux_selected

    def one_step_lbfgs(self, x, opt_state, loss_args:structs.LossArgs, jkey):
        jkey, subkey1, subkey2 = jax.random.split(jkey, 3)
        if self.bundled_order == 0:
            grad, loss_aux = zeroth_order_bundled_gradient(partial(self.cost_module_cls.traj_opt_cost, loss_args=loss_args, jkey=subkey1, 
                                                                   interpolation_num=self.interpolation_num_gradient), x, subkey2, 
                                                           num_samples=self.bundled_samples_num, sigma=0.05)
        else:
            def cost_f(*x):
                costs = self.cost_module_cls.traj_opt_cost(*x, self.interpolation_num_gradient)
                return jnp.sum(costs[0]), costs[1]
            if self.bundled_samples_num == 1:
                grad, loss_aux = jax.grad(cost_f, has_aux=True)(x, loss_args, subkey1)
            else:
                ptb = jax.random.normal(subkey1, shape=(self.bundled_samples_num, *x.shape))*0.05
                x_ptb = x + ptb
                original_outer_shape = x_ptb.shape[:-2]
                x_ptb = x_ptb.reshape(-1, x_ptb.shape[-2], x_ptb.shape[-1])
                grad, loss_aux = jax.grad(cost_f, has_aux=True)(x_ptb, loss_args, subkey2)
                grad, loss_aux = jax.tree_map(lambda x: jnp.mean(x.reshape(original_outer_shape + x.shape[1:]), axis=0), (grad, loss_aux))
                # grad, loss_aux = jax.tree_map(lambda x: jnp.mean(x, axis=0), (grad, loss_aux))
        grad = grad.at[...,0,:].set(0)
        grad = grad.at[...,-1,:].set(0)
        grad = jnp.where(jnp.isfinite(grad), grad, 0)
        updates, opt_state = jax.vmap(self.optimizer.update)(grad, opt_state, x)
        updates, min_idx, loss_aux = self.line_search(x, updates, loss_aux, loss_args, jkey)
        x = optax.apply_updates(x, updates)
        return x, opt_state, loss_aux, min_idx

    def perform_gradient(self, init_q, goal_q, loss_args: structs.LossArgs, jkey, warm_start_traj=None):
        if warm_start_traj is not None:
            initial_trajectory = warm_start_traj
        else:
            # Create an initial trajectory (shape: [NT, 6])
            initial_trajectory = jnp.linspace(init_q, goal_q, self.num_trajectory_points_gradient)
        
        # cond_fun: continue while we haven't hit the maximum iterations
        # and the early stopping condition is not met.
        def cond_fun(state):
            i = state["i"]
            early_stop = (state["prev_loss"] == state["min_loss"]) & (state["col_loss"] < 0.0) & (i>3)
            early_stop = jnp.all(early_stop)
            # return jnp.logical_and(i < self.gradient_itr_no, jnp.logical_not(early_stop))
            return i < self.gradient_itr_no
        
        # body_fun: perform one update step and update the state.
        def body_fun(state):
            i = state["i"]
            jkey = state["jkey"]
            x = state["x"]
            opt_state = state["opt_state"]
            
            # Update the random key
            jkey, subkey = jax.random.split(jkey)
            
            # Perform one optimization step
            x_new, opt_state_new, loss_aux, min_idx = self.one_step_lbfgs(x, opt_state, loss_args, subkey)
            # loss, col_loss, vel_loss, self_col_loss = loss_aux[:4]
            loss = loss_aux['loss']
            col_loss = loss_aux['collision_logits']
            invalid_mask = loss_aux['invalid_mask']

            # jdb.callback(print_callback, loss_aux)
            
            # Update the best solution if current loss is lower than the best so far.
            new_min_loss = jnp.where(loss < state["min_loss"], loss, state["min_loss"])
            new_min_x = jnp.where((loss < state["min_loss"])[...,None,None], x_new, state["min_x"])
            min_col_loss = jnp.where((loss < state["min_loss"]), col_loss, state["min_col_loss"])
            min_invalid_mask = jnp.where((loss < state["min_loss"]), invalid_mask, state["min_invalid_mask"]).astype(jnp.bool)
            
            # Package the updated state.
            new_state = {
                "i": i + 1,
                "jkey": jkey,
                "x": x_new,
                "opt_state": opt_state_new,
                "prev_loss": state["min_loss"],    # for the next iteration's early-stopping check
                "min_loss": new_min_loss,
                "min_x": new_min_x,
                "loss": loss,
                "min_col_loss": min_col_loss,
                "col_loss": col_loss,
                "min_invalid_mask": min_invalid_mask,
            }
            return new_state
        
        # Initial loop state.

        original_outer_shape = initial_trajectory.shape[:-2]
        initial_trajectory = initial_trajectory.reshape(-1, initial_trajectory.shape[-2], initial_trajectory.shape[-1])
        outer_shape = initial_trajectory.shape[:-2]
        dummy_inf_array = jnp.full(outer_shape, jnp.inf)

        init_state = {
            "i": 0,
            "jkey": jkey,
            "x": initial_trajectory,
            "opt_state": jax.vmap(self.optimizer.init)(initial_trajectory),
            "prev_loss": dummy_inf_array,
            "min_loss": dummy_inf_array,
            "min_x": initial_trajectory,
            "loss": dummy_inf_array,     # initial loss (set high so the first step will update it)
            "col_loss": dummy_inf_array, # initial collision loss
            "min_col_loss": dummy_inf_array,
            "min_invalid_mask": jnp.full(outer_shape, True).astype(jnp.bool),
        }
        
        # Run the while loop.
        final_state = jax.lax.while_loop(cond_fun, body_fun, init_state)

        final_state.pop('jkey')
        final_state.pop('opt_state')
        final_state.pop('i')

        final_state = jax.tree_util.tree_map(lambda x: x.reshape(original_outer_shape + x.shape[1:]), final_state)
        
        # Return the best solution found.
        return final_state["min_x"], final_state



    def one_step_particle(self, Theta_mu, Theta_sigma, loss_args, jkey):
        """
        A single particle-based update step.
        
        Args:
          Theta_mu: Current mean trajectory, shape [T, dim].
          Theta_sigma: Current elementwise covariance, shape [T, dim] (variance per element).
          loss_args: Additional cost arguments.
          jkey: JAX PRNGKey.
          
        Returns:
          Theta_mu_new, Theta_sigma_new, costs, candidate, weights.
        """
        num_samples = self.num_mppi_samples
        T, dim = Theta_mu.shape[-2:]
        outer_shape = Theta_mu.shape[:-2]
        
        # Sample noise from standard normal.
        noise_key, subkey = jax.random.split(jkey)
        noise = jax.random.normal(noise_key, shape=outer_shape + (num_samples, T, dim))
        noise = noise.at[...,(0,-1),:].set(0)
        # Generate particles: each candidate = Theta_mu + sqrt(Theta_sigma) * noise.
        candidate = Theta_mu[...,None, :, :] + jnp.sqrt(Theta_sigma)[...,None, :, :] * noise
        if candidate.shape[-1] == self.shakey.num_act_joints:
            candidate = jnp.clip(candidate, self.shakey.q_lower_bound, self.shakey.q_upper_bound)
        else:
            lower_bound = jnp.concat([self.cost_module_cls.se2_bounds[0], self.shakey.q_lower_bound])
            upper_bound = jnp.concat([self.cost_module_cls.se2_bounds[1], self.shakey.q_upper_bound])
            candidate = candidate.clip(lower_bound, upper_bound)
        
        original_outer_shape = candidate.shape[:-2]
        candidate = candidate.reshape(-1, candidate.shape[-2], candidate.shape[-1])
        candidate = jax.vmap(partial(evaluate_full_trajectory, samples_per_segment=4))(candidate)[0]
        candidate = jax.vmap(partial(way_points_to_trajectory, resolution=self.num_trajectory_points_particle, cos_transition=False))(candidate)
        candidate = candidate.reshape(original_outer_shape + candidate.shape[-2:])

        # Evaluate cost for each particle.
        keys = jax.random.split(subkey, num_samples)
        costs, loss_auxs = self.cost_module_cls.traj_opt_cost(candidate, loss_args, keys[0], self.interpolation_num_particle)

        # jdb.callback(print_callback, (self.traj_opt_cost, (candidate, loss_args, keys[0], self.interpolation_num_particle)))
        
        # self normalization
        temperature = self.beta
        c = -costs
        c_max = jnp.max(c, axis=-1, keepdims=True)
        c_scale = jnp.abs(jnp.median(c - c_max, axis=-1, keepdims=True))
        exp_c = jnp.exp((c - c_max)/c_scale/temperature)

        weights = exp_c / jnp.sum(exp_c, axis=-1, keepdims=True)

        Theta_mu_updated = jnp.sum(weights[..., None, None] * candidate, axis=-3)
        
        # Update mean trajectory.
        # Theta_mu_new = (1 - k_mu)*Theta_mu + k_mu * sum_i (w_i * candidate_i)
        Theta_mu_new = (1 - self.k_mu) * Theta_mu + self.k_mu * Theta_mu_updated
        
        # Update covariance.
        # Here we perform an elementwise update:
        # Theta_sigma_new = (1 - k_sigma)*Theta_sigma + sum_i (w_i * (candidate^2 - Theta_mu))
        # (Note: candidate**2 is elementwise square.)
        Theta_sigma_new = (1 - self.k_sigma) * Theta_sigma + jnp.sum(weights[..., None, None] * (candidate - Theta_mu_updated[...,None,:,:])**2, axis=-3)

        # min cost in candidates
        min_idx = jnp.argmin(costs, axis=-1)
        min_candidate = jnp.take_along_axis(candidate, min_idx[...,None,None,None], axis=-3).squeeze(-3)
        loss_aux = jax.tree_util.tree_map(lambda x: jnp.take_along_axis(x, min_idx[...,None], axis=-1).squeeze(-1), loss_auxs)
        # min_candidate = candidate[min_idx]
        # loss_aux = jax.tree_util.tree_map(lambda x: x[min_idx], loss_auxs)
        # min_cost = costs[min_idx]

        # jdb.callback(print_callback, (Theta_mu, Theta_sigma, candidate, costs, weights, exp_c, Theta_mu_new, Theta_sigma_new))
        
        return Theta_mu_new, Theta_sigma_new, min_candidate, loss_aux


    def perform_particle(self, init_q, goal_q, loss_args, jkey):
        """
        Given an initial and goal state, generate an initial trajectory and then refine it using MPPI.
        The loop continues for a maximum of `itr_no` iterations or until an early stopping condition is met.
        """
        # Create an initial trajectory (shape: [NT, dim])
        initial_trajectory = (goal_q - init_q)[...,None,:]*jnp.linspace(0, 1, self.num_trajectory_points_particle)[...,None] + init_q[...,None,:]

        outer_shape = initial_trajectory.shape[:-2]

        # The loop state contains iteration counter, current nominal trajectory, and best-so-far values.
        # fill jnp.inf with shape outer_shape
        dummy_inf_array = jnp.full(outer_shape, jnp.inf)
        init_state = {
            "i": 0,
            "jkey": jkey,
            "Theta_mu": initial_trajectory,
            "Theta_sigma":self.init_cov_scale * jnp.ones_like(initial_trajectory),
            "prev_loss": dummy_inf_array,
            "min_loss": dummy_inf_array,
            "min_x": initial_trajectory,
            "loss": dummy_inf_array,     # initial loss (set high so the first step will update it)
            "col_loss": dummy_inf_array, # initial collision loss
        }

        # Continue iterating while maximum iterations not reached and not meeting an early stopping condition.
        def cond_fun(state):
            i = state["i"]
            # early_stop = (state["prev_loss"] <= state["loss"]) & (state["col_loss"] < 1e-4) & (i > 10)
            early_stop = state["Theta_sigma"] < 1e-2
            early_stop = jnp.all(early_stop)
            return jnp.logical_and(i < self.particle_itr_no, jnp.logical_not(early_stop))

        def body_fun(state):
            i = state["i"]
            jkey = state["jkey"]
            Theta_mu = state["Theta_mu"]
            Theta_sigma = state["Theta_sigma"]

            # Update random key.
            jkey, subkey = jax.random.split(jkey)

            # Perform one MPPI update step.

            Theta_mu, Theta_sigma, best_theta, loss_aux = self.one_step_particle(Theta_mu, Theta_sigma, loss_args, subkey)
            loss = loss_aux['loss']
            col_loss = loss_aux['collision_loss']

            # jdb.callback(print_callback, (Theta_mu, Theta_sigma, best_theta, loss_aux))
            # jdb.callback(print_callback, (Theta_sigma,))

            # Update the best solution if the current loss is lower than the best so far.
            new_min_loss = jnp.where(loss < state["min_loss"], loss, state["min_loss"])
            new_min_x = jnp.where((loss < state["min_loss"])[...,None,None], best_theta, state["min_x"])

            new_state = {
                "i": i + 1,
                "jkey": jkey,
                "Theta_mu": Theta_mu,
                "Theta_sigma": Theta_sigma,
                "prev_loss": loss,    # for the next iteration's early-stopping check
                "min_loss": new_min_loss,
                "min_x": new_min_x,
                "loss": loss,
                "col_loss": col_loss,
            }
            return new_state

        # init_state = body_fun(init_state)
        final_state = jax.lax.while_loop(cond_fun, body_fun, init_state)
        # Return the best (lowest cost) trajectory found.
        final_state['x_particle'] = final_state['min_x']
        final_state.pop('i')
        # final_state.pop('jkey')
        return final_state["min_x"], final_state
    
    def perform_both(self, init_q, goal_q, loss_args:structs.LossArgs, jkey):
        # particle
        x_particle, final_state = self.perform_particle(init_q, goal_q, loss_args, jkey)

        original_outer_shape = x_particle.shape[:-2]
        x_particle = x_particle.reshape(-1, x_particle.shape[-2], x_particle.shape[-1])
        x_particle = jax.vmap(partial(evaluate_full_trajectory, samples_per_segment=int(self.num_trajectory_points_gradient/self.num_trajectory_points_particle)+1))(x_particle)[0]
        x_particle = jax.vmap(partial(way_points_to_trajectory, resolution=self.num_trajectory_points_gradient, cos_transition=False))(x_particle)
        x_particle = x_particle.reshape(original_outer_shape + x_particle.shape[-2:])
        # x = x_particle

        # gradient
        jkey, subkey = jax.random.split(final_state['jkey'])
        x, final_state = self.perform_gradient(init_q, goal_q, loss_args, subkey, warm_start_traj=x_particle)
        
        final_state['x_particle'] = x_particle

        return x, final_state
    
    def perform_search(self, init_q, goal_q, loss_args:structs.LossArgs, jkey, node_visualize_func=None):
        
        for node_size in [40000, 80000, 160000]:
            x, _ = self.graph_motion_planner.plan(jkey, init_q, goal_q, loss_args, 
                                                  node_size=node_size, 
                                                  num_neighbors=12, 
                                                  node_visualize_func=node_visualize_func)
            if x is not None:
                break

        # gradient
        x = way_points_to_trajectory(x, 32, False)

        # x, final_state = self.perform_gradient_jit(init_q, goal_q, loss_args, jkey, warm_start_traj=x)
        return x, None
    
    def perform_multiple_seed(self, init_q:jnp.ndarray, goal_q:jnp.ndarray, loss_args:structs.LossArgs, jkey:jnp.ndarray):

        def cond_func(state):
            return jnp.logical_and(state['i'] < 7, state['opt_aux_info']['min_invalid_mask'])
        

        def body_func(state):
            i = state['i']
            jkey = state['jkey']
            min_loss = state['min_loss']
            x, final_state = self.perform_both(init_q[None].repeat(self.num_seed, 0), goal_q, loss_args, jkey)
            x = x[jnp.argmin(final_state['min_loss'])]
            final_state = jax.tree_util.tree_map(lambda x_: x_[jnp.argmin(final_state['min_loss'])], final_state)
            cur_loss = final_state['min_loss']

            min_x = jnp.where(cur_loss < min_loss, x, state['min_x'])
            min_final_state = jax.tree_util.tree_map(lambda x_, y_: jnp.where(cur_loss < min_loss, x_, y_), final_state, state['opt_aux_info'])

            # x_particle = min_final_state['x_particle']

            # min_final_state.pop('x_particle')
            # min_col_loss = jnp.where(cur_loss < min_loss, final_state['min_col_loss'], state['col_loss'])
            # invalid_mask = jnp.where(cur_loss < min_loss, final_state['min_invalid_mask'], state['invalid_mask']).astype(jnp.bool_)
            # x_particle = jnp.where(cur_loss < min_loss, final_state['x_particle'], state['x_particle'])
            min_loss = jnp.where(cur_loss < min_loss, cur_loss, min_loss)
            jkey, _ = jax.random.split(jkey)
            return {'i': i+1, 'jkey': jkey, 'min_loss': min_loss, 'min_x': min_x, 
                    # 'x_particle': x_particle,
                    # 'col_loss': min_col_loss, 'invalid_mask': invalid_mask, 'x_particle': x_particle, 
                    'opt_aux_info': min_final_state}
        
        # Use jax.eval_shape to capture the structure (shapes and dtypes) of the body output.
        dummy_out_shape = jax.eval_shape(lambda *x: partial(self.perform_both, loss_args=loss_args, jkey=jkey)(*x)[1], init_q, goal_q)

        # Create dummy initial values (e.g. zeros) matching the output structure.
        dummy_opt_aux_info = jax.tree_map(lambda s: jnp.ones(s.shape, s.dtype), dummy_out_shape)

        initial_state = {'i': 0, 'jkey': jkey, 'min_loss': jnp.inf, 
                         'min_x': jnp.zeros((self.num_trajectory_points_gradient, init_q.shape[-1])), 
                        #  'x_particle': jnp.zeros((self.num_trajectory_points_gradient, init_q.shape[-1])),
                         'opt_aux_info': dummy_opt_aux_info
                         }
        

        final_out = jax.lax.while_loop(cond_func, body_func, initial_state)


        return final_out['min_x'], final_out


if __name__ == "__main__":

    for _ in range(10):
        pos_random = np.random.uniform(-1, 1, (8, 3,))
        quat_random = tutil.qrand((8,))

        # pos_random = np.linspace(-1,1,4)
        # pos_random = jnp.stack([pos_random, jnp.zeros_like(pos_random), jnp.zeros_like(pos_random)], axis=-1)
        # quat_random = np.random.normal(size=(4, 3))*0.0
        # quat_random = tutil.qExp(quat_random)

        control_pnts = np.concatenate([pos_random, quat_random], axis=-1)[None]

        eval_t = np.linspace(0, 1, 100)

        eval_pqc = SE3_interpolation(control_pnts, eval_t)

        eval_pqc = eval_pqc.squeeze(0)
        control_pnts = control_pnts.squeeze(0)

        import matplotlib.pyplot as plt
        plt.figure()
        plt.plot(eval_pqc[...,0])
        plt.show()

        # draw coordinates
        intp_frames = []
        for i in range(eval_pqc.shape[-2]):
            intp_frames.append(o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1).transform(tutil.pq2H(eval_pqc[i])))
        for i in range(control_pnts.shape[-2]):
            intp_frames.append(o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.3).transform(tutil.pq2H(control_pnts[i])))
        o3d.visualization.draw_geometries(intp_frames)




    import util.scene_util as scene_util
    # add logging
    import logging

    # add log text file
    logging.basicConfig(filename='traj_opt.log', level=logging.INFO)

    # load oriCORN model
    models = mutil.Models().load_pretrained_models()
    models = models.load_self_collision_model('shakey')
    # robot
    shakey = shakey_module.load_urdf_kinematics(
        urdf_dirs="assets/ur5/urdf/shakey_open.urdf",
        models=models,
    )

    base_se2 = np.array([0, -0.15, 3 * np.pi/2])
    robot_pb_uid = shakey.create_pb(se2=base_se2)
    
    base_pqc = tutil.SE2h2pq(base_se2, np.array(shakey.robot_height))
    traj_optimizer = TrajectoryOptimizer(models, shakey, robot_base_pqc=jnp.concat(base_pqc, axis=-1))
    # traj_opt_jit = jax.jit(traj_optimizer.perform_particle)
    # traj_opt_jit = jax.jit(traj_optimizer.perform_both)
    traj_opt_jit = jax.jit(traj_optimizer.perform_multiple_seed)

    ik_jit = shakey.get_IK_jit_func(base_pqc, grasp_center_coordinate=True)
    q_zero = np.zeros(6)

    def sample_valid_q(z_range, jkey):
        # turn on self collision

        while True:
            jkey, _ = jax.random.split(jkey)
            random_gripper_pos = jax.random.uniform(jkey, (3,), minval=jnp.array([-1, 0.2, z_range[0]]), maxval=jnp.array([1, 1, z_range[1]]))
            random_gripper_quat = tutil.qrand((), jkey)
            goal_pq = np.concatenate([random_gripper_pos, random_gripper_quat], axis=0)
            ik_q = ik_jit(q_zero, goal_pq)
            shakey.set_q_pb(robot_pb_uid, ik_q[3:])
            pb.performCollisionDetection()
            col_res = pb.getContactPoints(robot_pb_uid)
            if len(col_res) == 0:
                break
        ik_q = ik_q.at[3:].set(ik_q[3:] % (2 * jnp.pi))
        return ik_q[3:]

    # hyperparameter
    success_list = []
    for seed in range(100):
        jkey = jax.random.PRNGKey(seed)

        random_obj, environment_obj, pybullet_scene  = scene_util.create_table_sampled_scene(models=models, num_objects=0, seed=seed)
        if random_obj is None:
            fixed_obj = environment_obj
        else:
            fixed_obj = environment_obj.concat(random_obj, axis=0)
        moving_obj = shakey.link_canonical_oriCORN

        pybullet_scene.reconstruct_scene_in_pybullet(False)

        pb.resetDebugVisualizerCamera(
            cameraDistance=1.79, 
            cameraYaw=-443.20,
            cameraPitch=0.26,
            cameraTargetPosition=[-0.05, 0.07, 0.30]
        )
        robot_pb_uid = shakey.create_pb(se2=base_se2)

        # init, goal sampling    
        jkey, subkey1, subkey2, subkey3 = jax.random.split(jkey, 4)
        init_q = sample_valid_q([0.4, 1.0], subkey1)
        goal_q = sample_valid_q([0, 0.4], subkey2)

        distance_larger_than_pi = jnp.abs(init_q - goal_q) > jnp.pi
        goal_q = jnp.where(distance_larger_than_pi, goal_q - 2*jnp.pi*jnp.sign(goal_q - init_q), goal_q)

        loss_args = structs.LossArgs(fixed_obj)

        print('start traj opt')
        start_time = time.time()
        x, _ = traj_opt_jit(init_q, goal_q, loss_args, subkey3)
        # x = traj_optimizer.perform_search(init_q, goal_q, loss_args, jkey)
        print(f'elapsed time: {time.time()-start_time:.2f}s')

        # helper functions for visualization
        success = simulate_in_pb(x, robot_pb_uid, shakey, sleep_time=0.002, evaluate=True)
        success_list.append(success)
        # plot_logs(logs)

        pybullet_scene.clear_scene()

        print(f'success rate: {np.mean(success_list)} / {len(success_list)}')

