import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import os, sys
from functools import partial

BASEDIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if BASEDIR not in sys.path:
    sys.path.insert(0, BASEDIR)
import util.transform_util as tutil


def min_max_normalize(points):
    """
    Scale each dimension of 'points' to [0, 1].
    points: shape (M, d)
    """
    min_vals = jnp.min(points, axis=-2, keepdims=True)
    max_vals = jnp.max(points, axis=-2, keepdims=True)
    # Add small epsilon to avoid div by zero
    return (points - min_vals) / (max_vals - min_vals + 1e-9)

def morton_code_1D(x, bits=16):
    """
    Convert a single coordinate x in [0,1] to an integer in [0, 2^bits).
    Example: bits=16 => we map x in [0,1] to an integer in [0,65535].
    """
    x_int = (x * (2**bits - 1) + 0.5).astype(int)
    # clip to valid range
    x_int = jnp.clip(x_int, 0, 2**bits - 1)
    return x_int

def z_order(points, bits=11):
    """
    Convert dD points in [0,1]^d into a single integer 'Z-order' code
    by interleaving bits. This is a simplistic approach suitable for demonstration.
    
    points: shape (M, d) in [0,1]
    returns: shape (M,) integer code
    """
    outer_shape = points.shape[:-1]
    d = points.shape[-1]
    points = jnp.reshape(points, (-1, d))
    # Discretize each dimension
    coords = jax.vmap(lambda row: jnp.array([morton_code_1D(coord, bits) for coord in row]))(points)
    # coords: shape (M, d) of integers in [0, 2^bits)

    def interleave_bits(coord_row):
        # coord_row shape = (d,)
        # We'll store the result in a 32-bit integer
        code = jnp.uint32(0)
        for bit_i in range(bits):
            for dim_i in range(d):
                # Extract the bit_i-th bit from coords[dim_i]
                bit_val = (coord_row[dim_i] >> bit_i) & 1
                # Place bit_i of dim_i -> bit_i * d + dim_i
                code = code | jnp.uint32(bit_val << (bit_i * d + dim_i))
        return code

    zcodes = jax.vmap(interleave_bits)(coords)
    zcodes = jnp.reshape(zcodes, outer_shape)
    return zcodes


# -----------------------------------------------
# 3. Approximate search using Z-order "distance"
# -----------------------------------------------
def nearest_search_in_1D(X_zcodes_sorted, query_zcodes, sorted_indices=None, A_points_sorted=None, B_points_sorted=None):
    """
    For each query_zcode, find the closest index in X_zcodes_sorted
    by comparing with the left or right neighbor in the sorted array.
    
    Returns:
      nearest_indices: shape (len(query_zcodes),)
      dist: shape (len(query_zcodes),) [the absolute difference in 1D Z-order]
    """
    # insertion_idx_left: shape (len(query_zcodes),)
    insertion_idx_left = jax.vmap(jnp.searchsorted, (None, 0))(
        X_zcodes_sorted, query_zcodes
    )
    insertion_idx_right = jnp.clip(insertion_idx_left, 0, X_zcodes_sorted.shape[0] - 1)
    
    # Compare distance to left insertion index vs right insertion index (+1)
    # but first ensure we won't go out of bounds
    insertion_idx_right_plus1 = jnp.clip(insertion_idx_right + 1, 0, X_zcodes_sorted.shape[0] - 1)
    
    # Distances to possible neighbors
    dist_left = jnp.abs(X_zcodes_sorted[insertion_idx_right] - query_zcodes)
    dist_right = jnp.abs(X_zcodes_sorted[insertion_idx_right_plus1] - query_zcodes)
    
    # If left is smaller => pick left index, else pick right index
    left_closer = dist_left <= dist_right
    nearest_indices = jnp.where(left_closer, insertion_idx_right, insertion_idx_right_plus1)
    dist = jnp.where(left_closer, dist_left, dist_right)

    if A_points_sorted is not None and B_points_sorted is not None:
        dist = jnp.linalg.norm(A_points_sorted[nearest_indices] - B_points_sorted, axis=-1)

    if sorted_indices is not None:
        # Unsort the indices to match the original order
        unsorted_indices = jnp.empty_like(sorted_indices)
        unsorted_indices = unsorted_indices.at[sorted_indices].set(jnp.arange(sorted_indices.shape[0]))
        return nearest_indices[unsorted_indices], dist[unsorted_indices]
    else:
        return nearest_indices, dist



def approximate_neighbors(key, A, B, m=4, k=50, bits=8):
    """
    Return indices of ~k points in A that are 'closest' to B in Z-order sense,
    and indices of ~k points in B that are 'closest' to A in Z-order sense.
    """
    # 1) Compute a random rotation
    key, _ = jax.random.split(key)
    R = tutil.q2R(tutil.qrand((m,), key))
    d = A.shape[-1]

    # 2) Rotate the points
    A_rot = jnp.einsum('mij,nj->mni', R, A)  # shape (m, N, d)
    B_rot = jnp.einsum('mij,nj->mni', R, B)  # shape (m, N, d)

    # 3) Min-max normalize BOTH sets together
    #    so they share the same bounding box
    AB_rot = jnp.concatenate([A_rot, B_rot], axis=-2)   # shape (2N, d)
    # outer_shape = AB_rot.shape[:-1]
    # AB_norm = min_max_normalize(AB_rot.reshape(-1, d))  # shape (2N, d)
    # AB_norm = AB_norm.reshape(outer_shape + (d,))
    AB_norm = min_max_normalize(AB_rot)
    A_norm, B_norm = jnp.split(AB_norm, 2, axis=-2)    # each (N, d)

    # 4) Convert to Z-order
    A_z = z_order(A_norm, bits)  # shape (N,)
    B_z = z_order(B_norm, bits)  # shape (N,)   
    
    # 5) Sort each set by its Z-order
    A_sorted_indices = jnp.argsort(A_z, axis=-1)    # shape (N,)
    B_sorted_indices = jnp.argsort(B_z, axis=-1)    # shape (N,)
    A_sorted = jnp.take_along_axis(A[None], A_sorted_indices[...,None], axis=-2)  # shape (N, d)
    B_sorted = jnp.take_along_axis(B[None], B_sorted_indices[...,None], axis=-2)  # shape (N, d)
    A_z_sorted = jnp.take_along_axis(A_z, A_sorted_indices, axis=-1)     # shape (N,)
    B_z_sorted = jnp.take_along_axis(B_z, B_sorted_indices, axis=-1)     # shape (N,)


    # 6) For each B_z_sorted, find the best match in A_z_sorted => pick top k
    #    *strictly speaking, we do a single nearest match, then pick the top k*
    #    If you want more robust results, do a window search or multiple transformations
    nearest_idx_in_A, dist_from_B = jax.vmap(nearest_search_in_1D)(A_z_sorted, B_z_sorted, B_sorted_indices, A_sorted, B_sorted)
    # dist_from_B: how far each B_zcode is from its nearest A_zcode
    # Sort all B points by how close they are to A in 1D Z-order
    # and pick the top k
    min_in_batch = jnp.argmin(dist_from_B, axis=0)
    dist_from_B = jnp.take_along_axis(dist_from_B, min_in_batch[None,:], axis=0).squeeze(0)
    top_k_B = jnp.argsort(dist_from_B)[:k]
    B_chosen_indices = top_k_B

    # 7) For each A_z_sorted, find best match in B_z_sorted => pick top k
    nearest_idx_in_B, dist_from_A = jax.vmap(nearest_search_in_1D)(B_z_sorted, A_z_sorted, A_sorted_indices, B_sorted, A_sorted)
    min_in_batch = jnp.argmin(dist_from_A, axis=0)
    dist_from_A = jnp.take_along_axis(dist_from_A, min_in_batch[None,:], axis=0).squeeze(0)
    top_k_A = jnp.argsort(dist_from_A)[:k]
    A_chosen_indices = top_k_A

    return A_chosen_indices, B_chosen_indices

def approximate_neighbors_batch(key, A, B, m=4, k=50):
    """
    Batch version of approximate_neighbors.
    A, B: shape (batch, N, d)
    """
    # 1) Compute a random
    outer_shape = jnp.broadcast_shapes(A.shape[:-2], B.shape[:-2])
    A = jnp.broadcast_to(A, outer_shape + A.shape[-2:])
    B = jnp.broadcast_to(B, outer_shape + B.shape[-2:])
    A = A.reshape(-1, A.shape[-2], A.shape[-1])
    B = B.reshape(-1, B.shape[-2], B.shape[-1])
    A_chosen_indices, B_chosen_indices = jax.vmap(partial(approximate_neighbors, m=m, k=k))(jax.random.split(key, A.shape[0]), A, B)
    # recover original shape
    A_chosen_indices = A_chosen_indices.reshape(outer_shape + (k,))
    B_chosen_indices = B_chosen_indices.reshape(outer_shape + (k,))
    return A_chosen_indices, B_chosen_indices


def brute_force_neighbors(A, B, k=50):

    pairwise_dist_sq = jnp.sum((A[...,:,None,:] - B[...,None,:,:])**2, axis=-1)
    _, Aidx = jax.lax.top_k(-jnp.min(pairwise_dist_sq, axis=-1), k)
    _, Bidx = jax.lax.top_k(-jnp.min(pairwise_dist_sq, axis=-2), k)
    return Aidx, Bidx

def kdtree_nn(A,B,k=50):
    import util.kdtree as kdtree
    batch_shape = jnp.broadcast_shapes(A.shape[:-2], B.shape[:-2])
    fps_tf_A_flat = A.reshape(-1, *A.shape[-2:])
    fps_tf_B_flat = B.reshape(-1, *B.shape[-2:])
    _, min_distB = jax.vmap(kdtree.batch_nearest_neighbor, (0,0,None))(fps_tf_A_flat, fps_tf_B_flat, 100)
    _, min_distA = jax.vmap(kdtree.batch_nearest_neighbor, (0,0,None))(fps_tf_B_flat, fps_tf_A_flat, 100)
    _, Aidx = jax.lax.top_k(-min_distA, k)
    _, Bidx = jax.lax.top_k(-min_distB, k)
    Aidx = Aidx.reshape(*batch_shape, -1)
    Bidx = Bidx.reshape(*batch_shape, -1)
    return Aidx, Bidx



if __name__ == "__main__":
    import time
    
    # ---------------------------
    # 1. Synthetic data creation
    # ---------------------------
    key = jax.random.PRNGKey(2)

    N = 2000  # number of points per set
    d = 3     # dimension

    # Create two sets A and B
    A = jax.random.normal(key, (N, d))  # shape (1000, 3)
    A = A / jnp.linalg.norm(A, axis=1, keepdims=True)  # normalize to unit length
    key, _ = jax.random.split(key)
    B = jax.random.normal(key, (N, d))  # shape (1000, 3)
    B = B / jnp.linalg.norm(B, axis=1, keepdims=True)  # normalize to unit length
    B = B + jnp.array([1.2, 0.0, 0.0])  # shift B by a small amount

    # ----------------------------------------------------
    # 4. Run the approximate selection & see what we get
    # ----------------------------------------------------
    approximate_neighbors_batch_jit = jax.jit(approximate_neighbors_batch)
    brute_force_neighbors_jit = jax.jit(brute_force_neighbors)
    kdtree_jit = jax.jit(kdtree_nn)

    for i in range(1001):
        if i==1:
            start_ann = time.time()
        res = approximate_neighbors_batch_jit(key, A, B)
        res = jax.block_until_ready(res)
    end_ann = time.time()
    A_idx, B_idx = res

    for i in range(1001):
        if i==1:
            start_brute = time.time()
        res = brute_force_neighbors_jit(A, B)
        res = jax.block_until_ready(res)
    end_brute = time.time()

    for i in range(1001):
        if i==1:
            start_kd = time.time()
        res = kdtree_jit(A, B)
        res = jax.block_until_ready(res)
    end_kd = time.time()

    print("Time taken for ANN:", end_ann - start_ann)
    print("Time taken for brute force:", end_brute - start_brute)
    print("Time taken for kdtree:", end_kd - start_kd)

    # print("Indices of 50 chosen points in A:", A_idx)
    # print("Indices of 50 chosen points in B:", B_idx)

    # ---------------------------------------
    # 5. Visualization in 3D
    # ---------------------------------------
    def visualize_3d(A, B, A_idx, B_idx, title='Approximate Neighbors'):
        """
        3D scatter plot of the original points in A and B, 
        highlighting the ~50 chosen indices in each set.
        """
        # Convert JAX arrays to NumPy arrays for plotting
        A_np = np.array(A)
        B_np = np.array(B)
        # Extract the chosen subsets
        A_sel = A_np[A_idx]
        B_sel = B_np[B_idx]

        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')

        # Plot all points from A (blue) and the chosen subset (red)
        ax.scatter(A_np[:,0], A_np[:,1], A_np[:,2],
                c='blue', alpha=0.4, label='A all')
        ax.scatter(A_sel[:,0], A_sel[:,1], A_sel[:,2],
                c='red', s=50, label='A selected')

        # Plot all points from B (green) and the chosen subset (magenta)
        ax.scatter(B_np[:,0], B_np[:,1], B_np[:,2],
                c='green', alpha=0.4, label='B all')
        ax.scatter(B_sel[:,0], B_sel[:,1], B_sel[:,2],
                c='magenta', s=50, label='B selected')

        ax.set_title(title)
        ax.legend()
        plt.show()

    # Finally, call the visualization
    visualize_3d(A, B, A_idx, B_idx, title="Example: 3D Approximate Neighbors")