import jax.numpy as jnp
import pickle
import numpy as np
import os
import glob
import open3d as o3d
import jax
import matplotlib.pyplot as plt
import einops
from functools import partial
from tqdm import tqdm
from dataclasses import replace
from pathlib import Path
from scipy.spatial import cKDTree
from scipy.spatial.transform import Rotation as sciR
import datetime
import copy

# Setup import path
import sys
BASEDIR = Path(__file__).parent.parent
if str(BASEDIR) not in sys.path:
    sys.path.insert(0, str(BASEDIR))

from util.model_util import Models
import util.transform_util as tutil
import util.latent_obj_util as loutil
import util.structs as structs

def invtransform_pnt(pnt, scale, pos, quat):
    # Inverse transform point: first inverse pose then scale down.
    pnt = tutil.pq_action(*tutil.pq_inv(pos[..., None, :], quat[..., None, :]), pnt)
    pnt = pnt / scale
    return pnt

def transform_pcd(pcd, scale, pos, quat):
    # Scale point cloud then transform by pose.
    pcd = pcd * scale
    pcd = tutil.pq_action(pos[..., None, :], quat[..., None, :], pcd)
    return pcd

def brute_force_closest_pair(pcd1, pcd2):
    # Compute pairwise distances between points in pcd1 and pcd2.
    dist_spAB = jnp.linalg.norm(pcd1[..., :, None, :] - pcd2[..., None, :, :], axis=-1)
    dist_spAB_flat = dist_spAB.reshape(-1)
    min_idx = jnp.argmin(dist_spAB_flat)
    min_dist = dist_spAB_flat[min_idx]
    Aidx = min_idx // dist_spAB.shape[-1]
    Bidx = min_idx % dist_spAB.shape[-1]
    return min_dist, Aidx, Bidx, pcd1[Aidx], pcd2[Bidx], dist_spAB_flat

class ColDataGen(object):
    def __init__(self, ntree_width=1000, ncumpair=8):
        self.ntree_width = ntree_width
        self.ncumpair = ncumpair

        self.surface_pnts_num = 128000
        self.query_num = 80000

        # This JIT-compiled transform for point clouds.
        self.transform_pcd_jit = jax.jit(transform_pcd)

    def build_data(self, sdf_dir, gen_data_func=None):

        try:
            with open(sdf_dir, 'rb') as f:
                sdf_data = pickle.load(f)
        except:
            # generate dataset
            sdf_data = gen_data_func(sdf_dir)
        surface_pnts = np.random.permutation(sdf_data['surface_points'])[:self.surface_pnts_num]
        query_pnts = sdf_data['query_points'][:self.query_num]
        signed_distance = sdf_data['signed_distance'][:self.query_num]

        # Build KD-trees for surface points.
        def build_tree(pcd):
            npcd = pcd.shape[0]
            tree = cKDTree(pcd)
            pcd_sample = pcd[:self.ntree_width]
            # Query more neighbors than the number of sampled points.
            dist_to_sample, idx_to_sample = tree.query(pcd_sample, k=int(4.0 * (npcd // self.ntree_width)))
            mean_dist = np.mean(dist_to_sample[..., 1:5])
            return idx_to_sample[..., 1:], mean_dist

        tree_idx, mean_dist = build_tree(surface_pnts)

        return surface_pnts, query_pnts, signed_distance, tree_idx, mean_dist


    def init_data(self, sdf_dirs, gen_data_func=None):
        surface_pnts_list = []
        query_pnts_list = []
        signed_distance_list = []
        tree_list = []
        mean_dist_list = []        
        for sdf_dir in tqdm(sdf_dirs):
            surface_pnts, query_pnts, signed_distance, tree_idx, mean_dist = self.build_data(sdf_dir, gen_data_func)
            surface_pnts_list.append(surface_pnts)
            query_pnts_list.append(query_pnts)
            signed_distance_list.append(signed_distance)
            tree_list.append(tree_idx)
            mean_dist_list.append(mean_dist)

        self.surface_pnts_list = jnp.array(surface_pnts_list)
        self.query_pnts_list = jnp.array(query_pnts_list)
        self.signed_distance_list = jnp.array(signed_distance_list)
        self.tree_list = jnp.array(tree_list)
        self.mean_dist_list = jnp.array(mean_dist_list)

        self.point_based_collision_jit = jax.jit(self.point_based_collision_patch)

    def query_sdf(self, obj_idx, qpnt_canonical):
        # Query the signed distance field for a given object index and canonical query point.
        sdf = self.signed_distance_list[obj_idx]
        query_pnts = self.query_pnts_list[obj_idx]
        sdf_val = jnp.linalg.norm(query_pnts - qpnt_canonical, axis=-1)
        return sdf[jnp.argmin(sdf_val)].squeeze(-1)

    def point_based_collision_patch(self, obj_idx, scale, pos, quat, jkey):
        # Collision detection based on point cloud proximity.
        # Select the surface point cloud for the given object indices.
        pcd12 = self.surface_pnts_list[obj_idx]
        # Use only the first ntree_width points for level 1.
        pcd12_l1 = pcd12[:, :self.ntree_width]
        pcd12_l1_tf = transform_pcd(pcd12_l1, scale, pos, quat)
        tree12 = self.tree_list[obj_idx]

        n_cull = 1
        # Compute the closest pair (brute force) for level 1.
        min_dist, Aidx, Bidx, _, _, dist_spAB_flat_l1 = brute_force_closest_pair(pcd12_l1_tf[0], pcd12_l1_tf[1])
        in_contact_idx_l1 = jnp.argsort(dist_spAB_flat_l1)[:n_cull]

        def body_func(carry):
            itr, col_res_pad, non_col_cnt = carry
            # Determine indices for point pairs.
            Aidx_ = in_contact_idx_l1[itr] // pcd12_l1_tf.shape[-2]
            Bidx_ = in_contact_idx_l1[itr] % pcd12_l1_tf.shape[-2]
            pcd1_idx_l2 = tree12[0, Aidx_]
            pcd2_idx_l2 = tree12[1, Bidx_]

            pcd1_l2 = pcd12[0, pcd1_idx_l2]
            pcd2_l2 = pcd12[1, pcd2_idx_l2]

            pcd12_l2 = jnp.stack([pcd1_l2, pcd2_l2], axis=0)
            pcd12_l2_tf = transform_pcd(pcd12_l2, scale, pos, quat)

            min_dist_l2, Aidx_l2, Bidx_l2, cpntA, cpntB, dist_spAB_flat = brute_force_closest_pair(
                pcd12_l2_tf[0], pcd12_l2_tf[1]
            )

            # Refine indices using the level-2 KD-tree.
            # Aidx_final = tree12[0, Aidx_, Aidx_l2]
            # Bidx_final = tree12[1, Bidx_, Bidx_l2]
            
            dist_threshold = jnp.max(self.mean_dist_list[obj_idx] * scale.squeeze(-1).squeeze(-1))
            in_contact_mask = dist_spAB_flat < dist_threshold
            contact_no = jnp.sum(in_contact_mask)

            # SDF query: transform points to canonical frames.
            qpnt_canonical1 = invtransform_pnt(cpntA, scale[1], pos[1], quat[1])
            qpnt_canonical2 = invtransform_pnt(cpntB, scale[0], pos[0], quat[0])
            sdf1 = self.query_sdf(obj_idx[1], qpnt_canonical1)
            sdf2 = self.query_sdf(obj_idx[0], qpnt_canonical2)

            col_res = min_dist_l2 < dist_threshold
            sdf_col = jnp.logical_or(sdf1 < 0, sdf2 < 0)
            col_res = jnp.where(col_res, col_res, sdf_col)

            # Randomly select an in-contact point.
            perm_idx = jax.random.permutation(jkey, jnp.arange(in_contact_mask.shape[0]))
            in_contact_idx = jnp.where(in_contact_mask[perm_idx], size=1, fill_value=-1)[0]
            in_contact_idx = jnp.where(in_contact_idx != -1, perm_idx[in_contact_idx], -1)
            in_contact_idx1 = in_contact_idx // pcd12_l2_tf.shape[-2]
            in_contact_idx2 = in_contact_idx % pcd12_l2_tf.shape[-2]
            in_contact_pnt1 = pcd12_l2_tf[0, in_contact_idx1]
            in_contact_pnt2 = pcd12_l2_tf[1, in_contact_idx2]
            in_contact_pnt1 = jnp.where(in_contact_idx[..., None] != -1, in_contact_pnt1, 0).squeeze(0)
            in_contact_pnt2 = jnp.where(in_contact_idx[..., None] != -1, in_contact_pnt2, 0).squeeze(0)
            
            # Replace contact points if in-contact points are found.
            cpntA = jnp.where(contact_no > 0, in_contact_pnt1, cpntA)
            cpntB = jnp.where(contact_no > 0, in_contact_pnt2, cpntB)

            # Record results into the padded arrays.
            col_res_pad = (
                col_res_pad[0].at[itr].set(col_res),
                col_res_pad[1].at[itr].set(min_dist_l2),
                col_res_pad[2].at[itr].set(cpntB - cpntA),
                col_res_pad[3].at[itr].set(cpntA),
                col_res_pad[4].at[itr].set(cpntB),
                col_res_pad[5].at[itr].set(contact_no)
            )

            non_col_cnt = jnp.where(col_res, 0, non_col_cnt + 1)
            return itr + 1, col_res_pad, non_col_cnt

        def cond_func(carry):
            itr, col_res, non_col_cnt = carry
            return jnp.logical_and(itr < in_contact_idx_l1.shape[0], non_col_cnt < 2)

        # Initialize a padded result tuple.
        col_res_pad = (
            jnp.zeros((n_cull,), dtype=jnp.bool_),
            jnp.zeros((n_cull,), dtype=jnp.float32),
            jnp.zeros((n_cull, 3), dtype=jnp.float32),
            jnp.zeros((n_cull, 3), dtype=jnp.float32),
            jnp.zeros((n_cull, 3), dtype=jnp.float32),
            jnp.zeros((n_cull,), dtype=jnp.int32)
        )
        itr, col_res_pad, _ = jax.lax.while_loop(cond_func, body_func, (0, col_res_pad, 0))
        # Sort the results.
        sorted_idx = jnp.argsort(jnp.where(jnp.arange(n_cull) < itr, col_res_pad[1], jnp.inf))
        col_res_pad = jax.tree_util.tree_map(lambda x: x[sorted_idx], col_res_pad)
        return col_res_pad

    def generate_data(self, jkey, ndata=10000, target_indices=None, visualize=False)->structs.ColDataset:
        obj_AB_idx = []
        obj_AB_scale = []
        obj_AB_pos = []
        obj_AB_quat = []
        col_res_list = []

        # Run data generation. Assumes that init_data() has been called.
        pbar = tqdm(range(10000000), desc="Generating data")
        same_pair_cnt = 0
        positive_cnt = 0
        total_cnt = 0
        for i in pbar:
            jkey, subkey = jax.random.split(jkey)
            if same_pair_cnt > self.ncumpair:
                same_pair_cnt = 0

            # Select new objects if necessary.
            if same_pair_cnt == 0:
                obj_idx = np.random.randint(0, self.nshape, size=(2,))
                if target_indices is not None:
                    obj_idx[1] = np.random.choice(target_indices)

                random_scale = np.random.uniform(0.02, 1.2, size=(2, 1, 1))

            # Transform a subset of the surface point cloud.
            rand_quat = tutil.qrand_np((2,))
            pcd12 = self.transform_pcd_jit(self.surface_pnts_list[obj_idx, :100],
                                           random_scale,
                                           np.zeros((2, 3)),
                                           rand_quat)
            aabb_max = pcd12.max(axis=-2)
            aabb_min = pcd12.min(axis=-2)
            aabb_size = aabb_max - aabb_min
            aabb_center = (aabb_max + aabb_min) / 2
            aabb_size_mean = np.mean(aabb_size, axis=0) * 1.2
            random_pos2 = np.random.uniform(-aabb_size_mean, aabb_size_mean, size=(3,))
            rand_pos = np.stack([np.zeros_like(random_pos2), random_pos2], axis=0)
            rand_pos = rand_pos - aabb_center

            # Compute collision result.
            col_res = self.point_based_collision_jit(obj_idx, random_scale, rand_pos, rand_quat, subkey)
            # Basic filtering based on collision and scale.
            if total_cnt > self.ncumpair // 10 and not jnp.any(col_res[0]):
                if positive_cnt / total_cnt < 0.40:
                    if col_res[1][0] > 0.2 * np.max(random_scale):
                        continue
            
            if visualize:
                # visualize in open3d
                print(col_res)
                print(random_scale)
                pcd12 = self.transform_pcd_jit(self.surface_pnts_list[obj_idx,:1000], random_scale, rand_pos, rand_quat)
                pcd1 = o3d.geometry.PointCloud()
                pcd1.points = o3d.utility.Vector3dVector(pcd12[0])
                pcd2 = o3d.geometry.PointCloud()
                pcd2.points = o3d.utility.Vector3dVector(pcd12[1])
                # colorize
                pcd1.paint_uniform_color([1, 0.706, 0] if col_res[0][0] else [1, 0, 0])
                pcd2.paint_uniform_color([0, 0.651, 0.929])

                objB_o3d = o3d.io.read_triangle_mesh(self.sdf_dirs_for_col_dataset[-1])
                objB_o3d.compute_vertex_normals()
                objB_o3d.scale(random_scale[1,0,0], np.zeros(3))
                objB_o3d.transform(tutil.pq2H(rand_pos[1], rand_quat[1]))

                pnts_o3d_list = []
                for cp in range(col_res[0].shape[0]):
                    if col_res[0][cp]:
                        closest_points1 = o3d.geometry.TriangleMesh.create_sphere(radius=0.01)
                        closest_points1.compute_vertex_normals()
                        closest_points1.translate(col_res[3][cp])
                        closest_points1.paint_uniform_color([1, 0.706, 0])
                        closest_points2 = o3d.geometry.TriangleMesh.create_sphere(radius=0.01)
                        closest_points2.compute_vertex_normals()
                        closest_points2.translate(col_res[4][cp])
                        closest_points2.paint_uniform_color([0, 0.651, 0.929])
                        pnts_o3d_list.append(closest_points1)
                        pnts_o3d_list.append(closest_points2)

                o3d.visualization.draw_geometries([objB_o3d, pcd1, pcd2, *pnts_o3d_list])
                # visualize in open3d

            same_pair_cnt += 1
            if jnp.any(col_res[0]):
                positive_cnt += 1
            total_cnt += 1

            # Save current results.
            obj_AB_idx.append(copy.deepcopy(np.array(obj_idx)))
            obj_AB_scale.append(copy.deepcopy(np.array(random_scale)))
            obj_AB_pos.append(copy.deepcopy(np.array(rand_pos)))
            obj_AB_quat.append(copy.deepcopy(np.array(rand_quat)))
            col_res_save = jax.tree_util.tree_map(lambda x: np.array(x), col_res)
            col_res_list.append(col_res_save)

            # Save data to disk every 2000 iterations.
            if total_cnt == ndata:
                obj_AB_idx = np.stack(obj_AB_idx, axis=0)
                obj_AB_scale = np.stack(obj_AB_scale, axis=0)
                obj_AB_pos = np.stack(obj_AB_pos, axis=0)
                obj_AB_quat = np.stack(obj_AB_quat, axis=0)
                col_res_list = jax.tree_util.tree_map(lambda *x: np.stack(x, axis=0), *col_res_list)

                return structs.ColDataset(
                    obj_idx=obj_AB_idx,
                    obj_scale=obj_AB_scale,
                    obj_pos=obj_AB_pos,
                    obj_quat=obj_AB_quat,
                    col_gt=col_res_list[0],
                    distance=col_res_list[1],
                    min_direction=col_res_list[2]
                )

            pbar.set_description(f"PR: {positive_cnt / total_cnt:.3f}, Total: {total_cnt}")


    def gen_col_data_with_mesh(self, models: Models, mesh_path, create_obj_dataset_func, ndata=1000, visualize=False):

        source_ds_names = ["ur5", "RobotBimanualV4"]
        sdf_dirs_for_col_dataset = []
        for sdfdir in models.asset_path_util.rel_sdf_paths:
            for sdn in source_ds_names:
                if sdn in sdfdir:
                    models.asset_path_util.get_obj_id(sdfdir)
                    sdf_dirs_for_col_dataset.append(models.asset_path_util.sdf_path_by_idx(models.asset_path_util.get_obj_id(sdfdir)))

        sdf_dirs_for_col_dataset.append(mesh_path)


        self.sdf_dirs_for_col_dataset = sdf_dirs_for_col_dataset

        self.init_data(sdf_dirs_for_col_dataset, create_obj_dataset_func)

        return self.generate_data(jax.random.PRNGKey(0), ndata=ndata, target_indices=[self.nshape - 1], visualize=visualize), sdf_dirs_for_col_dataset


    @property
    def nshape(self):
        return len(self.surface_pnts_list)

if __name__ == '__main__':

    from modules.encoder import create_obj_dataset
    import util.model_util as mutil

    models = mutil.Models().load_pretrained_models()

    col_data_gen = ColDataGen()

    mesh_path = 'assets/room/raw/room_no_floor_v2.obj'

    datasets = col_data_gen.gen_col_data_with_mesh(models, mesh_path, create_obj_dataset)

    print(1)
