# %%
import jax.numpy as jnp
import jax
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.collections import PatchCollection
from matplotlib.patches import Rectangle
import copy
import os
from scipy.spatial import KDTree, cKDTree
import astar_module
import time
from scipy.stats.qmc import Halton
import einops

def nn(qpnt_, pnts_list, csq_, k=None):
    npd = pnts_list.shape[-2]
    dist_ = jnp.where(jnp.arange(npd) < csq_, jnp.linalg.norm(pnts_list - qpnt_, axis=-1), 1e5)
    if k is None:
        return jnp.argmin(dist_, axis=-1)
    else:
        idx = jnp.argsort(dist_, axis=-1)[...,:k]
        return idx, dist_[idx]

def draw_plots(qpt_, pnts_list, pnts_list2=None, figname=None):
    # draw tree
    plt.figure(figsize=[5,5])
    ax = plt.gca()
    ax.add_collection(copy.deepcopy(pc))
    plt.plot(init_pnt[0], init_pnt[1], 'ro')
    plt.plot(gpt[0], gpt[1], 'bo')
    # plt.plot(qpt_[0], qpt_[1], 'yo')
    if qpt_ is not None:
        plt.plot(qpt_[:,0], qpt_[:,1], 'yo')
    for i in range(0,len(pnts_list)-1):
        cpt = pnts_list[i]
        ppt = pnts_list[i+1]
        dpnts = np.stack([cpt, ppt], 0)
        plt.plot(dpnts[:,0], dpnts[:,1], 'g')
        plt.plot(cpt[0], cpt[1], 'go')
    if pnts_list2 is not None:
        for i in range(0,len(pnts_list2)-1):
            cpt = pnts_list2[i]
            ppt = pnts_list2[i+1]
            dpnts = np.stack([cpt, ppt], 0)
            plt.plot(dpnts[:,0], dpnts[:,1], 'r')
            plt.plot(cpt[0], cpt[1], 'ro')
    ax.set(xlim=(-1, 1), ylim=(-1, 1))
    # ax.axis('equal')
    if figname is None:
        plt.show()
    else:
        plt.savefig(figname)
    plt.close()


def Halton_uniform_sample_with_bound(nsample, lower_bound, upper_bound, halton_sampler=None):
    ndim = lower_bound.shape[-1]
    if halton_sampler is None:
        halton_sampler = Halton(d=ndim, scramble=False)
    bound_range = upper_bound - lower_bound
    samples = halton_sampler.random(n=int(nsample*np.max(bound_range)/np.min(bound_range)*1.5))
    samples = (samples-0.5) * np.max(bound_range) + (lower_bound + upper_bound)/2
    valid_mask = np.all(samples > lower_bound, axis=-1) & np.all(samples < upper_bound, axis=-1)
    samples = samples[valid_mask][:nsample]
    return samples

# %%

def sampler_heuristic(jkey, start, goal, one_batch_size, upper_bound, lower_bound):
    ndim = start.shape[-1]
    sg_dir = start-goal
    sg_dir_norm = jnp.linalg.norm(sg_dir, axis=-1, keepdims=True)
    sg_dir_normalized = sg_dir/sg_dir_norm
    traj = jax.random.normal(jkey, shape=(one_batch_size,ndim))
    _, jkey = jax.random.split(jkey)
    r = jax.random.uniform(jkey, shape=(one_batch_size,1))**(1/ndim)
    _, jkey = jax.random.split(jkey)
    traj = traj/jnp.linalg.norm(traj, axis=-1, keepdims=True)
    traj = traj * r
    
    traj_sg = jnp.sum(sg_dir_normalized * traj, axis=-1, keepdims=True) * sg_dir_normalized
    traj_perp = traj - traj_sg

    # traj = traj_sg * sg_dir_norm*0.8 + traj_perp*sg_dir_norm*0.5 + (start+goal)*0.5
    traj = traj_sg * sg_dir_norm*1.0 + traj_perp*sg_dir_norm*1.0 + (start+goal)*0.5
    traj = traj.clip(lower_bound, upper_bound)

    return traj


def PRM_node_only(jkey, start, goal, node_size, k, upper_bound, lower_bound, col_check, path_check, 
                  state_scale=None, col_args=None, node_one_batch_size=1000, path_one_batch_size=1000, node_visualize_func=None):

    prm_start_t = time.time()
    ndim = start.shape[-1]
    # halton_sampler = Halton(d=ndim, scramble=False)

    _, jkey = jax.random.split(jkey)
    col_check_start_t = time.time()
    node_one_batch_size = np.minimum(node_one_batch_size, node_size)
    itr_no = node_size//node_one_batch_size
    col_res = []
    col_cost = []
    nodes = []
    nnode = 0
    # for col_itr in range(itr_no):
    for col_itr in range(10000):
        print(f'col itr num {col_itr}')
        jkey, subkey1, subkey2 = jax.random.split(jkey, 3)
        # nodes_ = sampler(subkey1, one_batch_size)
        nodes_ = jax.random.uniform(subkey1, shape=(node_one_batch_size, ndim), minval=lower_bound, maxval=upper_bound)
        # nodes_ = sampler_heuristic(jkey, start, goal, one_batch_size, upper_bound, lower_bound)
        col_res_, col_cost_ = col_check(nodes_, col_args, subkey2)[:2]
        col_res.append(col_res_)
        col_cost.append(col_cost_)
        nodes.append(nodes_)
        nnode += jnp.sum(jnp.logical_not(col_res_))
        if nnode > node_size:
            break
    nodes = jnp.concatenate(nodes, axis=0)
    col_res = jnp.concatenate(col_res, axis=0)
    col_cost = jnp.concatenate(col_cost, axis=0)
    if col_res.ndim==2:
        col_res = col_res.squeeze(-1)
    if col_cost.ndim==2:
        col_cost = col_cost.squeeze(-1)
    nodes = jax.block_until_ready(nodes)
    col_res = jax.block_until_ready(col_res)
    col_cost = jax.block_until_ready(col_cost)
    col_check_end_t = time.time()

    # estimated nn dist
    # nn_dist_entire = jnp.sort(jnp.linalg.norm(nodes[...,None,:] - nodes[...,None,:,:], axis=-1), axis=-1)[...,1]
    
    create_graph_start_t = time.time()
    # valid filter
    conversion_start_t = time.time()
    nodes_jp = nodes[jnp.where(jnp.logical_not(col_res))]
    col_cost = col_cost[jnp.where(jnp.logical_not(col_res))]
    col_res = col_res[jnp.where(jnp.logical_not(col_res))]

    nodes_jp = jnp.concatenate([nodes_jp, start[None], goal[None]], 0)
    col_cost = jnp.concatenate([col_cost, jnp.array([0]), jnp.array([0])], 0)
    col_res = jnp.concatenate([col_res, jnp.array([False]), jnp.array([False])], 0)

    if node_visualize_func is not None:
        for ii in range(nodes_jp.shape[0]):
            node_visualize_func(nodes_jp[ii])

    print(f'valid number {col_res.shape[0]}')
    heuristic = jnp.linalg.norm(nodes_jp - goal, axis=-1)*5 # (N, )
    heuristic = jax.block_until_ready(heuristic)
    conversion_end_t = time.time()

    # reduce for path validity check
    path_one_batch_per_node = path_one_batch_size//k
    path_itr_no = (nodes_jp.shape[0]//path_one_batch_per_node)
    new_node_no = path_itr_no*path_one_batch_per_node
    nodes_jp, col_res, col_cost, heuristic = jax.tree_util.tree_map(lambda x: x[-new_node_no:], (nodes_jp, col_res, col_cost, heuristic))

    if state_scale is not None:
        nodes_scaled_np = np.array(nodes_jp*state_scale)
    else:
        nodes_scaled_np = np.array(nodes_jp)
    col_cost = np.array(col_cost)
    col_res = np.array(col_res)
    heuristic = np.array(heuristic)
    nn_start_t = time.time()
    kdtree = cKDTree(nodes_scaled_np)
    # start_idx = kdtree.query(start, k=1)[1]
    # goal_idx = kdtree.query(goal, k=1)[1]
    start_idx = len(nodes_scaled_np) - 2
    goal_idx = len(nodes_scaled_np) - 1
    graph_dist_origin, graph_idx_origin = kdtree.query(nodes_scaled_np, k=k+1)
    nn_dist_estimate = np.mean(graph_dist_origin[:,1:ndim+2])
    # add goal for every node
    graph_idx_origin = np.concat([graph_idx_origin, np.array(goal_idx)[None,None].repeat(graph_idx_origin.shape[0], axis=0)], axis=-1)
    graph_dist_origin = np.concat([graph_dist_origin, np.linalg.norm(nodes_scaled_np - nodes_scaled_np[goal_idx], axis=-1, keepdims=True)], axis=-1)
    nn_end_t = time.time()

    # axis_ranges = upper_bound - lower_bound  # (ndim,)
    # volume = np.prod(axis_ranges)            # Scalar
    # nn_dist_estimate_analytical = (np.log(node_size) / node_size) ** (1 / ndim) * volume ** (1 / ndim)

    # print(f'nn_dist_estimate {nn_dist_estimate} / nn_dist_estimate_analytical {nn_dist_estimate_analytical}')
    print(f'nn_dist_estimate {nn_dist_estimate}')

    path_validity_check_start_t = time.time()
    ## path validity check
    path_validity_res = []
    path_cost = []
    for pitr in range(path_itr_no):
        print(f'path itr num {pitr}')
        nodes_start = nodes_jp[pitr*path_one_batch_per_node:(pitr+1)*path_one_batch_per_node]
        nodes_start = einops.repeat(nodes_start, '... n f -> ... n r f', r=graph_idx_origin.shape[-1]-1)
        nodes_pair = jnp.stack([nodes_start, nodes_jp[graph_idx_origin[pitr*path_one_batch_per_node:(pitr+1)*path_one_batch_per_node,1:]]], axis=-2)
        path_validity_res_, path_cost_ = path_check(nodes_pair, col_args, jkey)[:2]
        path_validity_res.append(path_validity_res_)
        path_cost.append(path_cost_)
    path_validity_res = np.concatenate(path_validity_res, axis=0)
    path_cost = np.concatenate(path_cost, axis=0)

    # path_validity_res = np.zeros(graph_idx_origin[...,1:].shape, dtype=np.float32) # test
    # path_cost = -np.ones(graph_idx_origin[...,1:].shape, dtype=np.float32) # test
    
    path_validity_end_t = time.time()

    def build_graph(nn_dist):
        graph_dist = graph_dist_origin.copy()
        graph_idx = graph_idx_origin.copy()
        graph_idx[graph_dist > nn_dist] = -1

        # Replace indices equal to kdtree.n (no neighbor found within the distance limit) with -1
        graph_dist[graph_idx == -1] = 1e5  # Assign infinite distance where no neighbor is found
        
        graph_dist = graph_dist[:,1:]
        graph_idx = graph_idx[:,1:]
        col_res_edge = np.where(graph_idx!=-1, path_validity_res, True)
        col_cost_edge = np.where(graph_idx!=-1, path_cost, 1e5)

        # col_res_edge = col_res[graph_idx]
        # col_res_edge = np.logical_or(col_res[:,None], col_res_edge)
        # col_cost_edge = col_cost[graph_idx]
        # col_cost_edge = col_cost[:,None] + col_cost_edge
        return graph_dist, graph_idx, col_res_edge, col_cost_edge

    create_graph_end_t = time.time()

    # for nn_factor in [1.4, 1.5, 1.8, 2.0, 2.2, 2.5, 2.8, 3.1, 4.0, 8.0, 100.0,]:
    # for nn_factor in [1.2, 1.4, 1.5, 1.8, 2.0, 2.2, 2.5, 2.8, 3.1, 4.0, 8.0, 100.0,]:
    for nn_factor in [2.0, 4.0, 8.0, 16.0, 32.0]:
        graph_dist, graph_idx, invalid_path_mask, col_cost_edge = build_graph(nn_dist_estimate*nn_factor)

        # ## astar algorithm
        edge_cost = col_cost_edge*0.02 + graph_dist
        
        astar_start_t = time.time()
        idx_path = astar_module.astar(
            # len(nodes_scaled_np)-2,
            # len(nodes_scaled_np)-1,
            start_idx,
            goal_idx,
            nodes_scaled_np,
            graph_idx,
            invalid_path_mask,
            edge_cost,
            heuristic
        )
        if idx_path is None:
            print(f'no path found run again / current nn_factor {nn_factor}')
            continue
        else:
            print(f'solution found / current nn_factor {nn_factor}')
            break
    if idx_path is None:
        print('no path found')
        return None, None
    path_positions = [nodes_jp[idx] for idx in idx_path]
    # path_positions = [start] + path_positions + [goal]
    astar_end_t = time.time()
    prm_end_t = time.time()

    # # test path validity
    # node_pair_res = jnp.stack(path_positions)
    # node_pair_res = jnp.stack([node_pair_res[:-1], node_pair_res[1:]], axis=-2)
    # path_validity_res_, path_cost_ = path_check(node_pair_res, jkey, col_args)[:2]

    print(f'prm time {prm_end_t - prm_start_t:.4f} / col check time {col_check_end_t - col_check_start_t:.4f} / create graph time {create_graph_end_t - create_graph_start_t:.4f} / nn_time {nn_end_t-nn_start_t:.4f} / astar time {astar_end_t - astar_start_t:.4f} / conversion time {conversion_end_t - conversion_start_t:.4f} / path validity check time {path_validity_end_t - path_validity_check_start_t:.4f}')
    
    # return path_positions, (nodes_jp, graph_idx, invalid_path_mask, graph_dist, col_cost_edge, heuristic)
    return path_positions, None


# @partial(jax.jit, static_argnums=[1,2])
def way_points_to_trajectory(waypnts, resolution, cos_transition=True):
    """???"""
    epsilon = 1e-8
    wp_len = jnp.linalg.norm(waypnts[1:] - waypnts[:-1], axis=-1)
    wp_len = wp_len/jnp.sum(wp_len).clip(epsilon)
    wp_len = jnp.where(wp_len<epsilon*10, 0, wp_len)
    wp_len = wp_len/jnp.sum(wp_len)
    wp_len_cumsum = jnp.cumsum(wp_len)
    wp_len_cumsum = jnp.concatenate([jnp.array([0]),wp_len_cumsum], 0)
    wp_len_cumsum = wp_len_cumsum.at[-1].set(1.0)
    indicator = jnp.linspace(0, 1, resolution)
    if cos_transition:
        indicator = (-jnp.cos(indicator*jnp.pi)+1)/2.
    included_idx = jnp.sum(indicator[...,None] > wp_len_cumsum[1:], axis=-1)
    
    upper_residual = (wp_len_cumsum[included_idx+1] - indicator)/wp_len[included_idx].clip(epsilon)
    upper_residual = upper_residual.clip(0.,1.)
    bottom_residual = 1.-upper_residual
    
    traj = waypnts[included_idx] * upper_residual[...,None] + waypnts[included_idx+1] * bottom_residual[...,None]
    traj = jnp.where(wp_len[included_idx][...,None] < 1e-4, waypnts[included_idx], traj)
    traj = traj.at[0].set(waypnts[0])
    traj = traj.at[-1].set(waypnts[-1])
    
    return traj



def mppi_trajectory_optimization(initial_trajectory, num_itr, num_samples, cost_func, sigma=1.0, lambda_=1.0, rng_key=None):
    """
    MPPI-based trajectory optimization using JAX.

    Args:
        initial_trajectory (jnp.ndarray): Initial trajectory of shape (M, N).
        num_itr (int): Number of iterations.
        num_samples (int): Number of samples per iteration.
        cost_func (function): Function that computes the cost given a trajectory.
        sigma (float, optional): Standard deviation of the noise. Defaults to 1.0.
        lambda_ (float, optional): Temperature parameter for weighting. Defaults to 1.0.
        rng_key (jax.random.PRNGKey, optional): Random key for reproducibility. Defaults to None.

    Returns:
        jnp.ndarray: Optimized trajectory of shape (M, N).
    """
    M, N = initial_trajectory.shape

    # Ensure rng_key is provided
    if rng_key is None:
        rng_key = jax.random.PRNGKey(0)

    # Vectorize the cost function over batches
    batched_cost_func = cost_func

    # @jax.jit
    def mppi_iteration(trajectory, rng_key):
        # Generate perturbations
        rng_key, subkey = jax.random.split(rng_key)
        noise = jax.random.normal(subkey, shape=(num_samples, M, N)) * sigma  # Shape: (num_samples, M, N)
        # noise = noise.at[0].set(jnp.zeros((M, N)))  # Ensure noise-free trajectory is included

        # print('originnal cost', cost_func(trajectory))

        # Generate sampled trajectories
        sampled_trajectories = trajectory + noise  # Shape: (num_samples, M, N)

        # Compute costs for sampled trajectories
        costs = batched_cost_func(sampled_trajectories)  # Shape: (num_samples,)

        # Compute weights
        min_cost = jnp.min(costs)
        exp_costs = jnp.exp(- (costs - min_cost) / lambda_)  # Shape: (num_samples,)
        weights = exp_costs / jnp.sum(exp_costs + 1e-10)  # Avoid division by zero

        # Update trajectory
        # weighted_noise = jnp.tensordot(weights, noise, axes=1)  # Shape: (M, N)
        weighted_noise = jnp.einsum('i,ijk->jk', weights, noise)  # Shape: (M, N)
        updated_trajectory = trajectory + weighted_noise

        print(f'originnal cost: {cost_func(trajectory)} / best cost: {min_cost} / updated cost: {cost_func(updated_trajectory)}')

        return updated_trajectory, rng_key

    # Compile the update step
    # @jax.jit
    def optimization_loop(val):
        trajectory, rng_key = val

        def body_fun(i, val):
            trajectory, rng_key = val
            trajectory, rng_key = mppi_iteration(trajectory, rng_key)
            return trajectory, rng_key

        trajectory, rng_key = jax.lax.fori_loop(0, num_itr, body_fun, (trajectory, rng_key))
        # # debug for loop
        # for i in range(num_itr):
        #     trajectory, rng_key = body_fun(i, (trajectory, rng_key))
        #     print(cost_func(trajectory))

        return trajectory

    # Run the optimization
    refined_trajectory = optimization_loop((initial_trajectory, rng_key))

    print(f'original cost: {cost_func(initial_trajectory)} / final cost: {cost_func(refined_trajectory)}')

    return refined_trajectory



# %%
if __name__ == '__main__':
    from functools import partial

    jkey = jax.random.PRNGKey(0)
    np.random.seed(1)
    FPS = 200
    npd = 400
    nbox = 40

    box_pos = np.random.uniform(low=-1, high=1, size=(nbox, 2))
    box_hscale = np.random.uniform(low=0.01, high=0.15, size=(nbox, 2))
    box_ep = jnp.concatenate([box_pos - box_hscale, box_pos + box_hscale], axis=-1)
    box_vis = [Rectangle(bx[:2], *list(bx[2:] - bx[:2])) for bx in box_ep]
    pc = PatchCollection(box_vis, facecolor=np.array([0.5,0.5,0.5]), alpha=1.0,
                            edgecolor=np.array([0,0,0]))

    def collision_checker(qpnts, jkey):
        box_epcd = box_ep
        if len(qpnts.shape) >= 2:
            # box_epcd = box_ep[None]
            for _ in range(len(qpnts.shape) - len(box_epcd.shape)):
                box_epcd = box_epcd[:,None]
            box_epcd = box_epcd[:,None]
            # for _ in range(len(box_epcd.shape) - len(qpnts.shape)):
            #     qpnts = qpnts[:,None]
        boxes_cd = jnp.all(jnp.concatenate([qpnts < box_epcd[...,2:], qpnts > box_epcd[...,:2]], axis=-1), axis=-1)
        col_cost = jnp.minimum(jnp.min(box_epcd[...,2:] - qpnts, axis=-1), jnp.min(qpnts - box_epcd[...,:2], axis=-1))
        col_cost = jnp.max(col_cost, axis=0)
        col_res = jnp.any(boxes_cd, axis=0)
        return col_res, col_cost
        

    # qps = np.random.uniform(low=-1, high=1, size=(1000,2))
    # res = collision_checker(qps)

    # init_pnt = np.random.uniform(-1, 1, size=(2,))
    # while collision_checker(init_pnt):
    #     init_pnt = np.random.uniform(-1, 1, size=(2,))

    # gpt = np.random.uniform(-1, 1, size=(2,))
    # while collision_checker(gpt):
    #     gpt = np.random.uniform(-1, 1, size=(2,))

    def init_nodes():
        pnts_list = jnp.zeros((npd, 2))
        parent_id = -2*jnp.ones((npd,), dtype=jnp.int32)

        # init_pnt = np.random.uniform(-1, 1, size=(2,))
        init_pnt = np.array([-1,-1])
        while collision_checker(init_pnt, jkey)[0]:
            init_pnt = np.random.uniform(-1, 1, size=(2,))

        # gpt = np.random.uniform(-1, 1, size=(2,))
        gpt = np.array([1,1])
        while collision_checker(gpt, jkey)[0]:
            gpt = np.random.uniform(-1, 1, size=(2,))
        
        pnts_list = pnts_list.at[0].set(init_pnt)
        parent_id = parent_id.at[0].set(-1)

        return pnts_list, parent_id, init_pnt, gpt

    def path_check(jkey, qpt_, node_pnts_, col_res_no=100):
        # path collision check
        interporation_factor = np.linspace(0, 1, num=col_res_no)
        for _ in range(len(node_pnts_.shape) - len(interporation_factor.shape)):
            interporation_factor = interporation_factor[...,None]
        interporation_factor = interporation_factor[...,None]
        cdqpnts_ = qpt_ + interporation_factor * (node_pnts_ - qpt_)
        cd_res_, col_cost = collision_checker(cdqpnts_, jkey)
        cd_res_ = jnp.any(cd_res_, axis=0)
        col_cost = jnp.sum(col_cost, axis=0)
        return cd_res_, col_cost

    def sampler(jkey, n):
        return jax.random.uniform(jkey, shape=(n,2,), dtype=jnp.float32, minval=-1, maxval=1)

    lower_bound = jnp.array([-1,-1])
    upper_bound = jnp.array([1,1])

    # init
    # tree expansions
    pnts_list, parent_id, init_pnt, gpt = init_nodes()

    # path_res, nodes = PRM(jkey, init_pnt, gpt, 1000, 100, sampler, path_check)
    # path_res, aux = PRM_node_only(jkey, init_pnt, gpt, 20000, 10, sampler, jax.jit(collision_checker), col_resolution=0.1)
    # path_res, aux = PRM_node_only(jkey, init_pnt, gpt, 100000, 10, sampler, jax.jit(collision_checker), one_batch_size=10000)

    path_res, aux = PRM_node_only(jkey, init_pnt, gpt, 20000, 10, upper_bound, lower_bound, jax.jit(collision_checker), one_batch_size=2000)
    # path_res, aux = PRM_node_only(jkey, init_pnt, gpt, 5000, 100, sampler, jax.jit(collision_checker), col_resolution=0.1)
    nodes = aux[0]

    draw_plots(nodes, path_res)

    def traj_cost(traj):
        # initial and goal cost
        # init_cost = 1*jnp.linalg.norm(traj[0] - jax.lax.stop_gradient(traj[0]))
        # goal_cost = 1*jnp.linalg.norm(traj[-1] - jax.lax.stop_gradient(traj[-1]))
        col_res, col_cost = collision_checker(traj, jkey)

        # avoid collision cost with log barrier function constraint with zero
        # col_cost_avoide = -jnp.log(-col_cost)


        # velocity loss
        vel = traj[...,1:,:] - traj[...,:-1,:]
        vel_loss = jnp.sum(jnp.linalg.norm(vel, axis=-1)**2, axis=-1)

        # acceleration loss
        acc = vel[...,1:,:] - vel[...,:-1,:]
        acc_loss = jnp.sum(jnp.linalg.norm(acc, axis=-1)**2, axis=-1)

        # return jnp.sum(vel_loss) + jnp.sum(acc_loss)
        return vel_loss + jnp.sum(col_cost) * 0.01

    path_rf = np.array(path_res)
    # col_res, col_cost = collision_checker(path_rf, jkey)

    grad_func = jax.jit(jax.grad(traj_cost))
    import optax

    optimizer = optax.adam(1e-3)
    opt_state = optimizer.init(path_rf)
    for i in range(10):
        grad = grad_func(path_rf)
        updates, opt_state = optimizer.update(grad, opt_state)
        path_rf = optax.apply_updates(path_rf, updates)
        path_rf = path_rf.at[0].set(path_res[0])
        path_rf = path_rf.at[-1].set(path_res[-1])
        path_rf = way_points_to_trajectory(path_rf, resolution=path_rf.shape[0], cos_transition=False)
        print(f'cost {traj_cost(path_rf)}')

    # path_rf = mppi_trajectory_optimization(path_rf, num_itr=10000, num_samples=5000, cost_func=traj_cost, sigma=0.0005, lambda_=0.0001, rng_key=jkey)

    # draw_plots(None, path_res, path_rf, figname='tmp/PRM_figure.png')
    draw_plots(None, path_res, path_rf)
    # traj_cost(path_rf)
    print(1)
