import jax.numpy as jnp
import pickle
import numpy as np
import os
import glob
import open3d as o3d
import jax
import matplotlib.pyplot as plt
import einops
from functools import partial
from tqdm import tqdm
from dataclasses import replace
from pathlib import Path
# import kdtree from scipy
from scipy.spatial import cKDTree
from scipy.spatial.transform import Rotation as sciR
import datetime

# Setup import path
import sys
BASEDIR = Path(__file__).parent.parent
if str(BASEDIR) not in sys.path:
    sys.path.insert(0, str(BASEDIR))

# from util.model_util import SDFDecoder
from util.model_util import Models
import util.transform_util as tutil
import util.latent_obj_util as loutil

DS_BASE_DIR = 'col_data_patch5'
DS_TYPE = 'train'
START_IDX = 0
NTREE_WIDTH = 1000
NCUMPAIR = 20
DATE_ID = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

target_sdf_list = None

with open(f'dataset/sdf_dirs.txt', 'r') as f:
    sdf_dirs = f.readlines()
sdf_dirs = [sd.strip() for sd in sdf_dirs]

target_sdf_idx = []
if target_sdf_list is not None:
    for tsdf in target_sdf_list:
        assert tsdf in sdf_dirs
        target_sdf_idx.append(sdf_dirs.index(tsdf))

surface_pnts_list = []
query_pnts_list = []
signed_distance_list = []
for sdd in tqdm(sdf_dirs):
    with open(sdd, 'rb') as f:
        sdf_data = pickle.load(f)
    query_pnts_list.append(sdf_data['query_points'])
    signed_distance_list.append(sdf_data['signed_distance'])
    surface_pnts_list.append(np.random.permutation(sdf_data['surface_points']))
sp_cnt = np.array([sp.shape[0] for sp in surface_pnts_list]).min()
surface_pnts_list = np.stack([sp[:sp_cnt] for sp in surface_pnts_list], axis=0)
query_pnts_list = np.stack(query_pnts_list, axis=0)
signed_distance_list = np.stack(signed_distance_list, axis=0)

jkey = jax.random.PRNGKey(0)

def build_tree(pcd):
    npcd = pcd.shape[0]
    tree = cKDTree(pcd)
    pcd_sample = pcd[:NTREE_WIDTH]
    dist_to_sample, idx_to_sample = tree.query(pcd_sample, k=int(4.0*(npcd//NTREE_WIDTH)))
    mean_dist = np.mean(dist_to_sample[...,1:5])
    return idx_to_sample[...,1:], mean_dist

tree_list = []
mean_dist_list = []
for spp in tqdm(surface_pnts_list):
    tree_, md_ = build_tree(spp)
    tree_list.append(tree_)
    mean_dist_list.append(md_)

# convert to jax array
surface_pnts_list = jnp.array(surface_pnts_list)
tree_list = jnp.array(tree_list)
mean_dist_list = jnp.array(mean_dist_list)
query_pnts_list = jnp.array(query_pnts_list)
signed_distance_list = jnp.array(signed_distance_list)


def query_sdf(obj_idx, qpnt_canonical):
    sdf = signed_distance_list[obj_idx]
    query_pnts = query_pnts_list[obj_idx]
    sdf_val = jnp.linalg.norm(query_pnts - qpnt_canonical, axis=-1)
    return sdf[jnp.argmin(sdf_val, axis=-1)].squeeze(-1)


def invtransform_pnt(pnt, scale, pos, quat):
    # pnt = tutil.qaction(tutil.qinv(quat), pnt)
    # pnt = pnt - pos[...,None,:]
    pnt = tutil.pq_action(*tutil.pq_inv(pos[...,None,:], quat[...,None,:]), pnt)
    pnt = pnt/scale
    return pnt

def transform_pcd(pcd, scale, pos, quat):
    pcd = pcd*scale
    # pcd = pcd + pos[...,None,:]
    # pcd = tutil.qaction(quat[...,None,:], pcd)
    pcd = tutil.pq_action(pos[...,None,:], quat[...,None,:], pcd)
    return pcd

def brute_force_closest_pair(pcd1, pcd2):
    dist_spAB=jnp.linalg.norm(pcd1[...,:,None,:] - pcd2[...,None,:,:], axis=-1)
    # min direction
    dist_spAB_flat = dist_spAB.reshape(-1)
    min_idx = jnp.argmin(dist_spAB_flat, axis=-1)
    min_dist = dist_spAB_flat[min_idx]
    Aidx = min_idx // dist_spAB.shape[-1]
    Bidx = min_idx % dist_spAB.shape[-1]
    return min_dist, Aidx, Bidx, pcd1[Aidx], pcd2[Bidx], dist_spAB_flat

# @jax.jit
def point_based_collision(obj_idx, scale, pos, quat, jkey):
    pcd12 = surface_pnts_list[obj_idx]
    pcd12_l1 = pcd12[:,:NTREE_WIDTH]
    pcd12_l1_tf = transform_pcd(pcd12_l1, scale, pos, quat)
    tree12 = tree_list[obj_idx]

    # bruth force closest pair
    min_dist, Aidx, Bidx, _, _, dist_spAB_flat_l1 = brute_force_closest_pair(pcd12_l1_tf[0], pcd12_l1_tf[1])
    pcd1_idx_l2 = tree12[0,Aidx]
    pcd2_idx_l2 = tree12[1,Bidx]

    pcd1_l2 = pcd12[0,pcd1_idx_l2]
    pcd2_l2 = pcd12[1,pcd2_idx_l2]

    pcd12_l2 = jnp.stack([pcd1_l2, pcd2_l2], axis=0)
    pcd12_l2_tf = transform_pcd(pcd12_l2, scale, pos, quat)

    min_dist, Aidx_l2, Bidx_l2, cpntA, cpntB, dist_spAB_flat = brute_force_closest_pair(pcd12_l2_tf[0], pcd12_l2_tf[1])

    Aidx = tree12[0, Aidx, Aidx_l2]
    Bidx = tree12[1, Bidx, Bidx_l2]
    
    dist_threshold = jnp.max(mean_dist_list[obj_idx]*scale.squeeze(-1).squeeze(-1))

    in_contact_mask = dist_spAB_flat < dist_threshold
    # jax random select index within mask
    perm_idx = jax.random.permutation(jkey, jnp.arange(in_contact_mask.shape[0]))
    in_contact_idx = jnp.where(in_contact_mask[perm_idx], size=8, fill_value=-1)[0]
    in_contact_idx = jnp.where(in_contact_idx!=-1, perm_idx[in_contact_idx], -1)
    in_contact_idx1 = in_contact_idx//pcd12_l2_tf.shape[-1]
    in_contact_idx2 = in_contact_idx%pcd12_l2_tf.shape[-1]
    in_contact_pnt1 = pcd12_l2_tf[0,in_contact_idx1]
    in_contact_pnt2 = pcd12_l2_tf[1,in_contact_idx2]
    # tree12[0, Aidx, in_contact_idx1]
    # tree12[1, Bidx, in_contact_idx2]
    in_contact_mask_selected = in_contact_idx!=-1
    in_contact_pnt1 = jnp.where(in_contact_idx[...,None]!=-1, in_contact_pnt1, 0)
    in_contact_pnt2 = jnp.where(in_contact_idx[...,None]!=-1, in_contact_pnt2, 0)

    contact_no = jnp.sum(in_contact_mask)

    # sdf query
    qpnt_canonical1 = invtransform_pnt(cpntA, scale[1], pos[1], quat[1]) # A points -> transform to canonical pose of B
    qpnt_canonical2 = invtransform_pnt(cpntB, scale[0], pos[0], quat[0])
    sdf1 = query_sdf(obj_idx[1], qpnt_canonical1) # obj B frame
    sdf2 = query_sdf(obj_idx[0], qpnt_canonical2)

    col_res = min_dist<dist_threshold # collision if true
    sdf_col = jnp.logical_or(sdf1<0, sdf2<0)
    col_res = jnp.where(col_res, col_res, sdf_col)

    return col_res, min_dist, cpntB - cpntA, cpntA, cpntB, contact_no, in_contact_pnt1, in_contact_pnt2, in_contact_mask_selected



@jax.jit
def point_based_collision_patch(obj_idx, scale, pos, quat, jkey):
    pcd12 = surface_pnts_list[obj_idx]
    pcd12_l1 = pcd12[:,:NTREE_WIDTH]
    pcd12_l1_tf = transform_pcd(pcd12_l1, scale, pos, quat)
    tree12 = tree_list[obj_idx]

    n_cull = 256
    # bruth force closest pair
    min_dist, Aidx, Bidx, _, _, dist_spAB_flat_l1 = brute_force_closest_pair(pcd12_l1_tf[0], pcd12_l1_tf[1])
    in_contact_idx_l1 = jnp.argsort(dist_spAB_flat_l1)[:n_cull]

    def body_func(carry):
        itr, col_res_pad, non_col_cnt = carry
        Aidx = in_contact_idx_l1[itr] // pcd12_l1_tf.shape[-2]
        Bidx = in_contact_idx_l1[itr] % pcd12_l1_tf.shape[-2]
        pcd1_idx_l2 = tree12[0,Aidx]
        pcd2_idx_l2 = tree12[1,Bidx]

        pcd1_l2 = pcd12[0,pcd1_idx_l2]
        pcd2_l2 = pcd12[1,pcd2_idx_l2]

        pcd12_l2 = jnp.stack([pcd1_l2, pcd2_l2], axis=0)
        pcd12_l2_tf = transform_pcd(pcd12_l2, scale, pos, quat)

        min_dist, Aidx_l2, Bidx_l2, cpntA, cpntB, dist_spAB_flat = brute_force_closest_pair(pcd12_l2_tf[0], pcd12_l2_tf[1])

        Aidx = tree12[0, Aidx, Aidx_l2]
        Bidx = tree12[1, Bidx, Bidx_l2]
        
        dist_threshold = jnp.max(mean_dist_list[obj_idx]*scale.squeeze(-1).squeeze(-1))

        in_contact_mask = dist_spAB_flat < dist_threshold
        contact_no = jnp.sum(in_contact_mask)

        # sdf query
        qpnt_canonical1 = invtransform_pnt(cpntA, scale[1], pos[1], quat[1]) # A points -> transform to canonical pose of B
        qpnt_canonical2 = invtransform_pnt(cpntB, scale[0], pos[0], quat[0])
        sdf1 = query_sdf(obj_idx[1], qpnt_canonical1) # obj B frame
        sdf2 = query_sdf(obj_idx[0], qpnt_canonical2)

        col_res = min_dist<dist_threshold # collision if true
        sdf_col = jnp.logical_or(sdf1<0, sdf2<0)
        col_res = jnp.where(col_res, col_res, sdf_col)

        # random selection cpnt for in_contact_pnt
        perm_idx = jax.random.permutation(jkey, jnp.arange(in_contact_mask.shape[0]))
        in_contact_idx = jnp.where(in_contact_mask[perm_idx], size=1, fill_value=-1)[0]
        in_contact_idx = jnp.where(in_contact_idx!=-1, perm_idx[in_contact_idx], -1)
        in_contact_idx1 = in_contact_idx//pcd12_l2_tf.shape[-2]
        in_contact_idx2 = in_contact_idx%pcd12_l2_tf.shape[-2]
        in_contact_pnt1 = pcd12_l2_tf[0,in_contact_idx1]
        in_contact_pnt2 = pcd12_l2_tf[1,in_contact_idx2]
        # in_contact_mask_selected = in_contact_idx!=-1
        in_contact_pnt1 = jnp.where(in_contact_idx[...,None]!=-1, in_contact_pnt1, 0).squeeze(0)
        in_contact_pnt2 = jnp.where(in_contact_idx[...,None]!=-1, in_contact_pnt2, 0).squeeze(0)
        
        cpntA = jnp.where(contact_no>0, in_contact_pnt1, cpntA)
        cpntB = jnp.where(contact_no>0, in_contact_pnt2, cpntB)

        col_res_pad = (col_res_pad[0].at[itr].set(col_res), col_res_pad[1].at[itr].set(min_dist), 
                       col_res_pad[2].at[itr].set(cpntB - cpntA), col_res_pad[3].at[itr].set(cpntA), 
                       col_res_pad[4].at[itr].set(cpntB), col_res_pad[5].at[itr].set(contact_no))

        non_col_cnt = jnp.where(col_res, 0, non_col_cnt+1)

        return itr+1, col_res_pad, non_col_cnt

    def cond_func(carry):
        itr, col_res, non_col_cnt = carry
        return jnp.logical_and(itr < in_contact_idx_l1.shape[0], non_col_cnt<2)

    col_res_pad = (jnp.zeros((n_cull,), dtype=jnp.bool), jnp.zeros((n_cull,), dtype=jnp.float32), 
                   jnp.zeros((n_cull, 3), dtype=jnp.float32), jnp.zeros((n_cull, 3), dtype=jnp.float32), 
                   jnp.zeros((n_cull, 3), dtype=jnp.float32), jnp.zeros((n_cull, ), dtype=jnp.int32))

    itr, col_res_pad, _ = jax.lax.while_loop(cond_func, body_func, (0, col_res_pad, 0))

    # sort by idx
    sorted_idx = jnp.argsort(jnp.where(jnp.arange(n_cull) < itr, col_res_pad[1], jnp.inf))
    col_res_pad = jax.tree_util.tree_map(lambda x: x[sorted_idx], col_res_pad)

    return col_res_pad

transform_pcd_jit = jax.jit(transform_pcd)

os.makedirs(DS_BASE_DIR, exist_ok=True)
obj_AB_idx = []
obj_AB_scale = []
obj_AB_pos = []
obj_AB_quat = []
col_res_list = []
positive_cnt = 0
total_cnt = 0
same_pair_cnt = 0
for i in (pbar:=tqdm(range(10000000))):
    jkey, subkey = jax.random.split(jkey)
    if same_pair_cnt > NCUMPAIR:
        same_pair_cnt = 0
    
    if same_pair_cnt == 0:
        obj_idx = np.random.randint(0, len(sdf_dirs), size=(2,))
        if target_sdf_list is not None:
            obj_idx[0] = np.random.choice(target_sdf_idx)
        random_scale = np.random.uniform(0.02, 1.2, size=(2,1,1))

    rand_quat = tutil.qrand_np((2,))
    pcd12 = transform_pcd_jit(surface_pnts_list[obj_idx,:100], random_scale, np.zeros((2,3)), rand_quat)
    aabb_max = pcd12.max(-2)
    aabb_min = pcd12.min(-2)
    # saft_bound_calculation
    aabb_size = aabb_max - aabb_min
    aabb_center = (aabb_max + aabb_min)/2
    aabb_size_mean = np.mean(aabb_size, axis=0)*1.2
    random_pos2 = np.random.uniform(-aabb_size_mean, aabb_size_mean, size=(3,))
    rand_pos = np.stack([np.zeros_like(random_pos2), random_pos2], axis=0)
    rand_pos = rand_pos - aabb_center

    col_res = point_based_collision_patch(obj_idx, random_scale, rand_pos, rand_quat, subkey)
    if total_cnt>NCUMPAIR//10 and not jnp.any(col_res[0]):
        if positive_cnt/total_cnt < 0.40:
            if col_res[1][0] > 0.2*np.max(random_scale):
                continue
            # else:
    # # visualize in open3d
    # print(col_res)
    # pcd12 = transform_pcd_jit(surface_pnts_list[obj_idx,:1000], random_scale, rand_pos, rand_quat)
    # pcd1 = o3d.geometry.PointCloud()
    # pcd1.points = o3d.utility.Vector3dVector(pcd12[0])
    # pcd2 = o3d.geometry.PointCloud()
    # pcd2.points = o3d.utility.Vector3dVector(pcd12[1])
    # # colorize
    # pcd1.paint_uniform_color([1, 0.706, 0])
    # pcd2.paint_uniform_color([0, 0.651, 0.929])

    # pnts_o3d_list = []
    # for cp in range(col_res[0].shape[0]):
    #     if col_res[0][cp]:
    #         closest_points1 = o3d.geometry.TriangleMesh.create_sphere(radius=0.01)
    #         closest_points1.compute_vertex_normals()
    #         closest_points1.translate(col_res[3][cp])
    #         closest_points1.paint_uniform_color([1, 0.706, 0])
    #         closest_points2 = o3d.geometry.TriangleMesh.create_sphere(radius=0.01)
    #         closest_points2.compute_vertex_normals()
    #         closest_points2.translate(col_res[4][cp])
    #         closest_points2.paint_uniform_color([0, 0.651, 0.929])
    #         pnts_o3d_list.append(closest_points1)
    #         pnts_o3d_list.append(closest_points2)

    # o3d.visualization.draw_geometries([pcd1, pcd2, *pnts_o3d_list])
    # # visualize in open3d
    
    same_pair_cnt += 1
    positive_cnt += jnp.any(col_res[0])
    total_cnt += 1
    obj_AB_idx.append(np.array(obj_idx))
    obj_AB_scale.append(np.array(random_scale))
    obj_AB_pos.append(np.array(rand_pos))
    obj_AB_quat.append(np.array(rand_quat))
    col_res_save = jax.tree_util.tree_map(lambda x: np.array(x), col_res)
    col_res_list.append(col_res_save)
    if total_cnt%2000 == 0 and i!=0:
        with open(f'{DS_BASE_DIR}/{DATE_ID}_{START_IDX+total_cnt}.pkl', 'wb') as f:
            pickle.dump((obj_AB_idx, obj_AB_scale, obj_AB_pos, obj_AB_quat, col_res_list), f)
        obj_AB_idx = []
        obj_AB_scale = []
        obj_AB_pos = []
        obj_AB_quat = []
        col_res_list = []

    pbar.set_description(f"PR: {positive_cnt/total_cnt:.3f}")


    

