
# import copy
import open_clip
import pandas as pd
import open3d as o3d
import numpy as np
import json
import torch
import os
from sklearn.cluster import DBSCAN
import cv2
from tqdm import tqdm

# from transformers import AutoModelForCausalLM, AutoTokenizer
# from PIL import Image

# Make device global so functions that expect a global `device` can use it.
device = 'cuda' if torch.cuda.is_available() else 'cpu'


def _load_pose(path, idx):
    path = os.path.join(path, str(idx) + '.txt')
    transformation_matrix = np.loadtxt(path).reshape(4, 4)

    return transformation_matrix


def _load_depth_intrinsics(path):
    intrinsic_depth = np.loadtxt(path)
    scale = 1000.0

    return intrinsic_depth, scale


def dbscan_3d(points, eps, min_samples):
    # Run DBSCAN
    clustering = DBSCAN(eps=eps, min_samples=min_samples).fit(points)
    # Get labels (-1 = noise)
    labels = clustering.labels_

    return labels


def clip_text(text, model_name, model_path):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained=None)
    model.load_state_dict(torch.load(model_path))
    model.eval()
    model.to(device)  # model in train mode by default, impacts some models with BatchNorm or stochastic depth active
    tokenizer = open_clip.get_tokenizer('ViT-H-14')
    text = tokenizer(text).to(device)
    with torch.no_grad():
        embeddings = torch.tensor(model.encode_text(text))

    return embeddings


def points_to_indices(points, params, res):
    points = torch.tensor(points).to(device)
    x_min, y_min, z_min, x_max, y_max, z_max = params
    indices = torch.zeros_like(torch.tensor(points))
    indices[:, 0] = ((points[:, 0] - x_min) / res).long()
    indices[:, 1] = ((points[:, 1] - y_min) / res).long()
    indices[:, 2] = ((points[:, 2] - z_min) / res).long()
    return indices.long()


def indices_to_points(indices, params, res):
    x_min, y_min, z_min, x_max, y_max, z_max = params
    points = indices * res
    points[:, 0] = points[:, 0] + x_min
    points[:, 1] = points[:, 1] + y_min
    points[:, 2] = points[:, 2] + z_min

    return points


def point_cloud(depth, scale, camera_intristics, camera_pose):
    camera_matrix = torch.tensor(camera_intristics).to(device)
    depth = torch.tensor(depth, dtype=torch.float32).to(device)

    camera_pose = torch.tensor(camera_pose).to(device)
    y, x = torch.meshgrid(torch.arange(depth.size()[0]), torch.arange(depth.size()[1]), indexing='ij')
    x = x.to(device)
    y = y.to(device)

    # neglect points with depth = 0
    if (depth.dim() == 3):
        depth = depth[:, :, 0]

    depth = depth.float() / scale

    depth_mask1 = (depth > 0).long()
    depth_mask2 = (depth < 3.0).long()
    mask = depth_mask1
    mask = mask * depth_mask2
    mask = mask > 0

    X = (x - camera_matrix[0, 2]) * depth / camera_matrix[0, 0]
    Y = (y - camera_matrix[1, 2]) * depth / camera_matrix[1, 1]
    Z = depth
    # convert to camera coordinate

    points = torch.stack((X.view(-1), Y.view(-1), Z.view(-1), torch.ones_like(X.view(-1))), dim=-1)

    # camera_pose = torch.transpose(camera_pose,1,2)

    # points = torch.bmm(points,camera_pose)
    points = torch.matmul(camera_pose, points.T)
    points = points.T
    points = points[mask.view(-1)]
    points = points.view(-1, 4)

    return points[:, :3].cpu().numpy()


def grid_indices(points, params, res):
    points = torch.tensor(points).to(device)
    x_min, y_min, z_min, x_max, y_max, z_max = params
    indices = torch.zeros_like(torch.tensor(points))
    indices[:, 0] = ((points[:, 0] - x_min) / res).long()
    indices[:, 1] = ((points[:, 1] - y_min) / res).long()
    indices[:, 2] = ((points[:, 2] - z_min) / res).long()
    return indices.long()


def points_to_grid(points, resolution):
    converted_points = ((points / resolution).astype(int).astype(np.float32)) * resolution
    return converted_points


def points_to_2d(points, camera_intristics, camera_pose, mask_shape):
    points = torch.tensor(points, dtype=torch.float64).to('cuda')
    points = torch.stack((points[:, 0], points[:, 1], points[:, 2], torch.ones_like(points[:, 0])), dim=-1)

    camera_matrix = torch.tensor(camera_intristics).to('cuda')
    # depth = torch.tensor(depth, dtype=torch.float32).to('cuda')
    camera_pose = torch.tensor(camera_pose).to('cuda')
    new_points = torch.matmul(points, torch.linalg.inv(camera_pose.T))

    # project points to image plane
    x = (new_points[:, 0] * camera_matrix[0, 0] / new_points[:, 2] + camera_matrix[0, 2]).long()
    y = (new_points[:, 1] * camera_matrix[1, 1] / new_points[:, 2] + camera_matrix[1, 2]).long()

    # clamp to valid range
    H, W = mask_shape
    x = x.clamp(0, W - 1)
    y = y.clamp(0, H - 1)

    # create mask
    new_mask = torch.zeros(mask_shape, dtype=bool)
    new_mask[y, x] = True

    return new_mask.cpu().numpy()


def similarity_measurement(ids_to_embeddings, target_object, clip_model, clip_path):
    target_embedding = torch.tensor(clip_text(target_object, clip_model, clip_path), dtype=float).to(device)
    similarities = np.zeros(len(ids_to_embeddings.keys()))
    ids = ids_to_embeddings.keys()
    for my_id in ids:
        embedding = torch.tensor(ids_to_embedings[str(my_id)]['embedding']).double().to(device)
        similarities[int(my_id)] = embedding @ target_embedding.T / (
                    torch.norm(embedding) * torch.norm(target_embedding))

    return similarities


def image_with_3d_overlap(the_points, path, prefix_rgb, prefix_depth, pose_path, last_idx, step, resolution, params):
    path += '/'
    image_list = os.listdir(path + prefix_rgb)
    x_min, y_min, z_min, x_max, y_max, z_max = params
    X = int((x_max - x_min) / resolution)
    Y = int((y_max - y_min) / resolution)
    Z = int((z_max - z_min) / resolution)
    voxels = []
    images = []
    boxes = []
    for i in range(len(the_points)):
        if (the_points[i].ndim != 1):
            voxel = torch.zeros((X, Y, Z), device=device, dtype=bool)
            indices = points_to_indices(the_points[i], params, resolution)
            voxel[indices[:, 0], indices[:, 1], indices[:, 2]] = True

            # voxel = voxel.view(-1, 1)
            voxels.append(voxel)

    for i in range(len(voxels)):
        max_ratio = -100.0
        max_idx = 0
        the_box = 0
        # for j in range(0, last_idx, step):
        for j in tqdm(range(0, last_idx, step), total=len(range(0, last_idx, step)), desc="Processing frames"):
            # print('Processing frame {}/{}'.format(j, len(image_list)))
            image_path = image_list[j]
            idx_path = str(image_path[:image_path.index('.')])

            rgb_i = cv2.imread(path + prefix_rgb + idx_path + '.jpg')

            depth_i = cv2.imread(path + prefix_depth + idx_path + '.png', cv2.IMREAD_UNCHANGED).astype(np.float32)

            camera_pose_i = _load_pose(pose_path, int(idx_path))

            # if rgb_i.shape[:2] != depth_i.shape[:2]:
            #     rgb_i = cv2.resize(rgb_i, (depth_i.shape[1], depth_i.shape[0]), interpolation=cv2.INTER_LINEAR)

            points = point_cloud(depth_i, scale, camera_intristics, camera_pose_i)
            # new_points = np.reshape(points[mask], (-1, 3))
            points = points_to_grid(points, resolution)

            voxel_j = torch.zeros((X, Y, Z), device=device, dtype=bool)
            indices = points_to_indices(points, params, resolution)
            indices[:, 0] = indices[:, 0].clamp(0, X - 1)
            indices[:, 1] = indices[:, 1].clamp(0, Y - 1)
            indices[:, 2] = indices[:, 2].clamp(0, Z - 1)

            # voxel_j[indices[:, 0], indices[:, 1], indices[:, 2]] = True

            voxel_j[indices[:, 0], indices[:, 1], indices[:, 2]] = True
            # voxel_j = voxel_j.view(-1, 1)
            overlap = torch.logical_and(voxels[i], voxel_j)

            ratio = overlap.sum().float() / voxels[i].sum().float()
            max_other_ratio = 0
            other_pen_param = 0.7
            for k in range(len(voxels)):
                denom = voxels[k].sum().float()
                if denom.item() > 0:
                    the_ratio = torch.logical_and(voxels[k], voxel_j).sum().float() / denom
                else:
                    the_ratio = torch.tensor(0.0, device=voxels.device)

                if (the_ratio > max_other_ratio and k != i):
                    max_other_ratio = the_ratio

            if (ratio - other_pen_param * max_other_ratio > max_ratio):
                max_ratio = ratio - other_pen_param * max_other_ratio
                max_idx = j
                new_points = indices_to_points(torch.nonzero(overlap, as_tuple=False), params, resolution)
                mask = points_to_2d(new_points, camera_intristics, camera_pose_i, depth_i.shape)

                mask = cv2.resize(mask.astype(int), (rgb_i.shape[1], rgb_i.shape[0]), interpolation=cv2.INTER_NEAREST)
                indices = np.nonzero(mask)
                if (np.shape(indices[0])[0] > 0):
                    y1 = np.min(indices[0])
                    y2 = np.max(indices[0])
                    x1 = np.min(indices[1])
                    x2 = np.max(indices[1])
                    the_box = [x1, x2, y1, y2]

        image_path = image_list[max_idx]

        idx_path = str(image_path[:image_path.index('.')])
        image = cv2.imread(path + prefix_rgb + idx_path + '.jpg')
        images.append(image)
        boxes.append(the_box)

    return images, boxes

def image_views(the_points, path, prefix_rgb, prefix_depth, pose_path, last_idx, step, resolution, params,num_views,param_angle):
    path += '/'
    image_list = os.listdir(path + prefix_rgb)
    x_min, y_min, z_min, x_max, y_max, z_max = params
    X = int((x_max - x_min) / resolution)
    Y = int((y_max - y_min) / resolution)
    Z = int((z_max - z_min) / resolution)
    images = []
    boxes = []
    angles = []


    if (the_points.ndim != 1):
            voxel = torch.zeros((X, Y, Z), device=device, dtype=bool)
            indices = points_to_indices(the_points, params, resolution)
            voxel[indices[:, 0], indices[:, 1], indices[:, 2]] = True

            # voxel = voxel.view(-1, 1)
    for k in range(num_views):
            angle = k/num_views*2*np.pi - np.pi
            max_ratio = -100.0
            max_idx = 0
            the_box = 0
            view_angle = None
            for j in tqdm(range(0, last_idx, step), total=len(range(0, last_idx, step)), desc="Processing frames"):

                # print('Processing frame {}/{}'.format(j, len(image_list)))
                image_path = image_list[j]
                idx_path = str(image_path[:image_path.index('.')])

                rgb_i = cv2.imread(path + prefix_rgb + idx_path + '.jpg')

                depth_i = cv2.imread(path + prefix_depth + idx_path + '.png', cv2.IMREAD_UNCHANGED).astype(np.float32)

                camera_pose_i = _load_pose(pose_path, int(idx_path))
                R = camera_pose_i[:3, :3]

                # Forward vector in camera coordinates (camera looks along -Z)
                f_cam = np.array([0, 0, -1])

                # Forward vector in world coordinates
                f_world = R @ f_cam

                # Project onto XY plane
                f_xy = f_world[:2]

                # Angle from X-axis
                angle_rad = np.arctan2(f_xy[1], f_xy[0])

                f_world = f_world / np.linalg.norm(f_world)

                # Angle with Z axis
                angle_z = np.arccos(f_world[2])  # f_world[2] is the z-component
                # if rgb_i.shape[:2] != depth_i.shape[:2]:
                #     rgb_i = cv2.resize(rgb_i, (depth_i.shape[1], depth_i.shape[0]), interpolation=cv2.INTER_LINEAR)

                points = point_cloud(depth_i, scale, camera_intristics, camera_pose_i)
                # new_points = np.reshape(points[mask], (-1, 3))
                points = points_to_grid(points, resolution)

                voxel_j = torch.zeros((X, Y, Z), device=device, dtype=bool)
                indices = points_to_indices(points, params, resolution)
                indices[:, 0] = indices[:, 0].clamp(0, X - 1)
                indices[:, 1] = indices[:, 1].clamp(0, Y - 1)
                indices[:, 2] = indices[:, 2].clamp(0, Z - 1)

                # voxel_j[indices[:, 0], indices[:, 1], indices[:, 2]] = True

                voxel_j[indices[:, 0], indices[:, 1], indices[:, 2]] = True
                # voxel_j = voxel_j.view(-1, 1)
                overlap = torch.logical_and(voxel, voxel_j)

                ratio = overlap.sum().float() / voxel.sum().float()
                max_other_ratio = 0
                other_pen_param = 0
                # for k in range(len(voxels)):
                #     denom = voxels[k].sum().float()
                #     if denom.item() > 0:
                #         the_ratio = torch.logical_and(voxels[k], voxel_j).sum().float() / denom
                #     else:
                #         the_ratio = torch.tensor(0.0, device=voxels.device)
                #
                #     if (the_ratio > max_other_ratio and k != i):
                #         max_other_ratio = the_ratio

                if (ratio > max_ratio  and np.abs(angle_rad-angle)<param_angle):
                    max_ratio = ratio
                    max_idx = j
                    new_points = indices_to_points(torch.nonzero(overlap, as_tuple=False), params, resolution)
                    mask = points_to_2d(new_points, camera_intristics, camera_pose_i, depth_i.shape)

                    mask = cv2.resize(mask.astype(int), (rgb_i.shape[1], rgb_i.shape[0]), interpolation=cv2.INTER_NEAREST)
                    indices = np.nonzero(mask)
                    if (np.shape(indices[0])[0] > 0):
                        y1 = np.min(indices[0])
                        y2 = np.max(indices[0])
                        x1 = np.min(indices[1])
                        x2 = np.max(indices[1])
                        the_box = [x1, x2, y1, y2]
                        view_angle = angle_rad

            image_path = image_list[max_idx]
            idx_path = str(image_path[:image_path.index('.')])
            image = cv2.imread(path + prefix_rgb + idx_path + '.jpg')
            images.append(image)
            boxes.append(the_box)
            angles.append(view_angle)


        # images.append(my_images)
        # boxes.append(my_boxes)
        # angles.append(my_angles)
    return images, boxes, angles


def overlap_detection(points_i,points_j,resolution):
    x_min_i = np.min(points_i[:, 0])
    y_min_i = np.min(points_i[:, 1])
    z_min_i = np.min(points_i[:, 2])
    x_max_i = np.max(points_i[:, 0])
    y_max_i = np.max(points_i[:, 1])
    z_max_i = np.max(points_i[:, 2])
    x_min_j = np.min(points_j[:, 0])
    y_min_j = np.min(points_j[:, 1])
    z_min_j = np.min(points_j[:, 2])
    x_max_j = np.max(points_j[:, 0])
    y_max_j = np.max(points_j[:, 1])
    z_max_j = np.max(points_j[:, 2])
    my_x_min = np.min([x_min_i, x_min_j]) - 0.2
    my_y_min = np.min([y_min_i, y_min_j]) - 0.2
    my_z_min = np.min([z_min_i, z_min_j]) - 0.2
    my_x_max = np.max([x_max_i, x_max_j]) + 0.2
    my_y_max = np.max([y_max_i, y_max_j]) + 0.2
    my_z_max = np.max([z_max_i, z_max_j]) + 0.2
    X = int((my_x_max - my_x_min) / resolution)
    Y = int((my_y_max - my_y_min) / resolution)
    Z = int((my_z_max - my_z_min) / resolution)
    params = (my_x_min, my_y_min, my_z_min, my_x_max, my_y_max, my_z_max)
    v_i = voxelize_batch(points_i, X, Y, Z, params,resolution)
    v_j = voxelize_batch(points_j, X, Y, Z, params,resolution)
    overlaps = torch.sum(v_i * v_j)
    volume_i = torch.sum(v_i)
    volume_j = torch.sum(v_j)
    overlaps_i = overlaps / volume_i
    overlaps_j = overlaps / volume_j

    return overlaps_i, overlaps_j, volume_i, volume_j

def voxelize_batch(points,X,Y,Z,params,resolution):
    """Convert a batch of point clouds into [B, X*Y*Z] flattened voxel grids."""

    voxel = torch.zeros((X, Y, Z), device=device)
    indices = grid_indices(points, params, resolution)
    voxel[indices[:, 0], indices[:, 1], indices[:, 2]] = 1
    # voxels.append(voxel.view(-1))
     # flatten immediately
    return voxel

def retrieve_object_images(
        dataset_path,  # root path of ScanNet_scenes
        scene_name,  # scene name, e.g., 'scene0011_00'
        object_name,  # target object name, e.g., 'trash can'
        object_id,  # ID for points_to_ids CSV file
        clip_model_name,  # CLIP model name
        clip_state_dict_path,

        num_candidates=6,  # number of top candidates to consider
        dbscan_eps=0.15,  # DBSCAN epsilon
        dbscan_min_samples=200,
        step=10,  # frame step size
        resolution=0.02,  # voxel resolution
        other_ratio_threshold=0.05  # threshold for voxel overlap
):
    # Load point-to-ID mapping
    df_points_to_id = pd.read_csv(os.path.join(dataset_path, scene_name, f'{scene_name}_points_to_ids_{object_id}.csv'))

    # Load embeddings
    embedding_path = os.path.join(dataset_path, scene_name, f'{scene_name}_ids_to_embeddings_{object_id}.json')
    with open(embedding_path, 'r') as file:
        ids_to_embeddings = json.load(file)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    points_ids = np.array(df_points_to_id['Object id'], dtype=int)
    points = df_points_to_id[['x', 'y', 'z']].to_numpy()

    # Compute similarity between embeddings and target object
    # The similarity_measurement function uses a misspelled global `ids_to_embedings`,
    # so expose the loaded embeddings under that name (minimal change).
    globals()['ids_to_embedings'] = ids_to_embeddings
    globals()['device'] = device

    similarities = similarity_measurement(ids_to_embeddings, object_name, clip_model_name, clip_state_dict_path)
    top_indices = list(np.argsort(similarities)[-num_candidates:])

    # Collect points of top candidates
    new_points = []
    for idx in top_indices:
        new_points.append(points[points_ids == idx])
    # new_points = np.vstack(new_points)

    # Cluster using DBSCAN
    # new_labels = dbscan_3d(new_points, dbscan_eps, dbscan_min_samples)
    # clustered_points = [new_points[new_labels == label] for label in np.unique(new_labels) if label != -1]

    # Load camera intrinsics
    camera_intrinsics, scale = _load_depth_intrinsics(
        os.path.join(dataset_path, scene_name, 'intrinsic/intrinsic_depth.txt'))

    # Expose the camera intrinsics and scale under the global names expected by other functions:
    # functions expect `camera_intristics` (typo) and `scale` as globals.
    globals()['camera_intristics'] = camera_intrinsics
    globals()['scale'] = scale

    # Compute bounding box for voxelization
    x_min, y_min, z_min = np.nanmin(points, axis=0) - 1.0
    x_max, y_max, z_max = np.nanmax(points, axis=0) + 1.0
    params = (x_min, y_min, z_min, x_max, y_max, z_max)

    # Path setup
    path = os.path.join(dataset_path, scene_name)
    last_idx = len(os.listdir(os.path.join(path, 'color')))

    prefix_rgb = 'color/'
    prefix_depth = 'depth/'
    pose_path = os.path.join(path, 'pose/')

    # Retrieve images overlapping with object points
    images, boxes = image_with_3d_overlap(
        new_points, path, prefix_rgb, prefix_depth, pose_path,
        last_idx, step, resolution, params
    )
    removes = []
    iou_threshold = 0.4
    for i in range(len(new_points)):
        for j in range(i+1,len(new_points)):
            iou_i, iou_j, volume_i, volume_j = overlap_detection(new_points[i], new_points[j], resolution)
            if(iou_j > iou_threshold and volume_i > volume_j):
               removes.append(j)
            if (iou_i > iou_threshold and volume_j > volume_i):
                removes.append(i)
    the_images = []
    the_boxes = []
    the_points = []
    for i in range(len(images)):
        if(not(i in removes)):
            the_images.append(images[i])
            the_boxes.append(boxes[i])
            the_points.append(new_points[i])
    the_images = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in the_images]
    return the_points, the_images, the_boxes



def get_params(dataset_path, scene_name, object_id =20):
    
    df_points_to_id = pd.read_csv(os.path.join(dataset_path, scene_name, f'{scene_name}_points_to_ids_{object_id}.csv'))
    points = df_points_to_id[['x', 'y', 'z']].to_numpy()
    x_min, y_min, z_min = np.nanmin(points, axis=0) - 1.0
    x_max, y_max, z_max = np.nanmax(points, axis=0) + 1.0
    params = (x_min, y_min, z_min, x_max, y_max, z_max)
    
    return params

def retrieve_object_views(
        scene_name,
        clustered_points ,
        dataset_path = '/home/user01/main_folder/ScanNet',
        param_angle = 0.25,
        num_views = 4,
        step=10,  # frame step size
        resolution=0.02,  # voxel resolution
        other_ratio_threshold=0.05  # threshold for voxel overlap
):
    # Load point-to-ID mapping
    # Load camera intrinsics
    camera_intrinsics, scale = _load_depth_intrinsics(
        os.path.join(dataset_path, scene_name, 'intrinsic/intrinsic_depth.txt'))

    # Expose the camera intrinsics and scale under the global names expected by other functions:
    # functions expect `camera_intristics` (typo) and `scale` as globals.
    globals()['camera_intristics'] = camera_intrinsics
    globals()['scale'] = scale

    # Compute bounding box for voxelization
    # x_min, y_min, z_min = np.nanmin(points, axis=0) - 1.0
    # x_max, y_max, z_max = np.nanmax(points, axis=0) + 1.0
    # params = (x_min, y_min, z_min, x_max, y_max, z_max)
    params = get_params(dataset_path, scene_name)
    # Path setup
    path = os.path.join(dataset_path, scene_name)
    last_idx = len(os.listdir(os.path.join(path, 'color')))
    prefix_rgb = 'color/'
    prefix_depth = 'depth/'
    pose_path = os.path.join(path, 'pose/')

    # Retrieve images overlapping with object points
    images, boxes, angles = image_views(
        clustered_points, path, prefix_rgb, prefix_depth, pose_path,
        last_idx, step, resolution, params,num_views,param_angle
    )

    images = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in images]

    return images, boxes, angles

