import os
import pickle
from tqdm import tqdm
import open3d as o3d
import cv2
import numpy as np
from shutil import copy
import torch
from pytorch3d.ops.knn import knn_points

test_scenes = {'7-scenes-redkitchen':100, 'sun3d-home_at-home_at_scan1_2013_jan_1':100, 'sun3d-home_md-home_md_scan9_2012_sep_30':100,
'sun3d-hotel_uc-scan3':100, 'sun3d-hotel_umd-maryland_hotel1':100, 'sun3d-hotel_umd-maryland_hotel3':50, 'sun3d-mit_76_studyroom-76-1studyroom2':50,
'sun3d-mit_lab_hj-lab_hj_tea_nov_2_2012_scan1_erika':50}
valid_scenes = set(['sun3d-brown_cs_3-brown_cs3', 'sun3d-harvard_c3-hv_c3_1', 'sun3d-mit_32_d507-d507_2', '7-scenes-fire',
'rgbd-scenes-v2-scene_01', 'bundlefusion-apt0', 'bundlefusion-office0'])
ingore_scenes = set(['analysis-by-synthesis-apt1-kitchen', 'analysis-by-synthesis-apt1-living', 'analysis-by-synthesis-apt2-bed',
'analysis-by-synthesis-apt2-kitchen', 'analysis-by-synthesis-apt2-living', 'analysis-by-synthesis-apt2-luke',
'analysis-by-synthesis-office2-5a', 'analysis-by-synthesis-office2-5b'])

path_dataset = '/mnt/e/Dataset/3DMatch'
path_output = '/mnt/e/Dataset/RGBD_3DMatch'

def load_3DMatch(path):
    depth_file = path
    image_file = os.path.dirname(path) + '/' + os.path.basename(path).replace("depth", "color")
    pose_file =  os.path.dirname(path) + '/' + os.path.basename(path).replace("depth.png", "pose.txt")
    intrinsics_file = os.path.dirname(os.path.dirname(path)) + "/camera-intrinsics.txt"

    image = cv2.imread(image_file)
    depth = o3d.io.read_image(depth_file)
    pose = np.loadtxt(pose_file, dtype = np.float32)
    intrinsics = np.loadtxt(intrinsics_file, dtype = np.float32)

    # Generate PointClouds
    camera_intrinsic = o3d.camera.PinholeCameraIntrinsic()
    camera_intrinsic.set_intrinsics(width = image.shape[0], height = image.shape[1], fx = intrinsics[0,0], fy = intrinsics[1,1], cx = intrinsics[0,2], cy = intrinsics[1,2])
    # Invalid depth with 65535 and 0 will be removed
    pc = o3d.geometry.PointCloud.create_from_depth_image(depth, camera_intrinsic, depth_scale = 1000, depth_trunc = 65.535)

    return pc, pose, intrinsics

def process_point_cloud(pcd, threshold):
    pcd_f = o3d.geometry.PointCloud(pcd)
    pcd_f = pcd_f.voxel_down_sample(0.025)
    pcd_f = torch.tensor(np.asarray(pcd_f.points), dtype = torch.float32, device = 'cuda')
    pcd_c = o3d.geometry.PointCloud(pcd)
    pcd_c = pcd_c.voxel_down_sample(0.2)
    pcd_c = torch.tensor(np.asarray(pcd_c.points), dtype = torch.float32, device = 'cuda')
    
    if pcd_c.shape[0] > 400:  # Clip the point clouds that are too large
        center = pcd_c.mean(0)
        _, matches, _ = knn_points(center[None, None, :], pcd_c[None, :, :], K = 400)
        matches = matches[0, 0, :]
        pcd_c = pcd_c[matches, :]
        diff, _, _ = knn_points(pcd_f[None, :, :], pcd_c[None, :, :], K = 1)
        diff = diff[0, :, 0]
        pcd_f = pcd_f[diff <= threshold ** 2, :]
    
    pcd_f = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(pcd_f.cpu().numpy()))

    return pcd_f

def get_overlap_ratio(source,target,threshold):
    """
    We compute overlap ratio from source point cloud to target point cloud
    """
    pcd_tree = o3d.geometry.KDTreeFlann(target)
    
    match_count=0
    for i, point in enumerate(source.points):
        [count, _, _] = pcd_tree.search_radius_vector_3d(point, threshold)
        if(count!=0):
            match_count += 1

    overlap_ratio = match_count / len(source.points)
    return overlap_ratio

def get_overlap_ratio_cuda(source, target, threshold):
    source = torch.tensor(np.asarray(source.points), dtype = torch.float32, device = 'cuda')
    target = torch.tensor(np.asarray(target.points), dtype = torch.float32, device = 'cuda')
    
    diff, _, _ = knn_points(source[None, :, :], target[None, :, :], K = 1)
    diff = diff[0, :, 0]
    diff = (diff <= threshold ** 2).to(torch.int32)

    return diff.sum(-1).cpu().item() / source.shape[0]

def search_file(folder_name, res, stride, max_num = None):
    scene = []

    count = 0
    for seq_name in os.listdir(path_dataset + '/' + folder_name):
        cur = 0
        seq_folder = path_dataset + '/' + folder_name + '/' + seq_name
        if os.path.isdir(seq_folder):
            while os.path.exists(seq_folder + '/' + f'frame-' + str(cur).zfill(6) + '.depth.png'):
                if count == max_num:
                    break

                tmp = {}
                tmp['idx'] = count
                tmp['path'] = folder_name + '/' + seq_name + '/' + f'frame-' + str(cur).zfill(6) + '.depth.png'
                scene.append(tmp)
                count += 1
                cur += stride

    res[folder_name] = scene

def save_idx(split):
    res = {}
    for folder_name in tqdm(os.listdir(path_dataset)):
        if os.path.isdir(path_dataset + '/' + folder_name):
            if split == 'train':
                if not folder_name in test_scenes and not folder_name in valid_scenes and not folder_name in ingore_scenes:
                    search_file(folder_name, res, 50)
            elif split == 'val':
                if folder_name in valid_scenes:
                    search_file(folder_name, res, 100)
            elif split == 'test':
                if folder_name in test_scenes:
                    search_file(folder_name, res, test_scenes[folder_name], 100)

    with open(os.path.join(split + '_idx.pkl'), 'wb') as f:
        pickle.dump(res, f)

def process_data(split):
    with open(os.path.join(split + '_idx.pkl'), 'rb') as f:
        idx_list = pickle.load(f)

    metadata = []
    if not os.path.exists(path_output + '/' + split):
        os.mkdir(path_output + '/' + split)
    for key in idx_list:
        print(key)
        cur = idx_list[key]
        cur_path = path_output + '/' + split + '/' + key
        if not os.path.exists(cur_path):
            os.mkdir(cur_path)

        pc, pos, intrins = [], [], []
        for i in tqdm(range(len(cur))):
            pcd, pose, intrinsics = load_3DMatch(path_dataset + '/' + cur[i]['path'])
            pcd = process_point_cloud(pcd, 0.2)

            torch.save(torch.tensor(np.asarray(pcd.points), dtype = torch.float32), cur_path + f"/cloud_bin_{cur[i]['idx']}.pth")
            img_path = path_dataset + '/' + os.path.dirname(cur[i]['path']) + '/' + os.path.basename(cur[i]['path']).replace("depth", "color")
            copy(img_path, cur_path + f"/cloud_bin_{cur[i]['idx']}.png")

            pcd.transform(pose)
            pc.append(pcd)
            pos.append(pose)
            intrins.append(intrinsics)

        interval = [0]
        while interval[-1] < len(cur):
            interval.append(interval[-1] + 60)
        interval[-1] = len(cur)

        for k in range(len(interval) - 1):
            for i in tqdm(range(interval[k], interval[k + 1] - 1)):
                for j in range(i + 1, interval[k + 1]):
                    # c_overlap = get_overlap_ratio(pc[i], pc[j], 0.05)
                    c_overlap = get_overlap_ratio_cuda(pc[i], pc[j], 0.05)
                    if c_overlap >= 0.05:
                        tmp = {}
                        tmp['scene_name'] = key
                        tmp['frag_id0'] = cur[i]['idx']
                        tmp['frag_id1'] = cur[j]['idx']
                        tmp['overlap'] = c_overlap
                        tmp['pcd0'] = split + '/' + key + f"/cloud_bin_{cur[i]['idx']}.pth"
                        tmp['pcd1'] = split + '/' + key + f"/cloud_bin_{cur[j]['idx']}.pth"
                        tmp['img0'] = split + '/' + key + f"/cloud_bin_{cur[i]['idx']}.png"
                        tmp['img1'] = split + '/' + key + f"/cloud_bin_{cur[j]['idx']}.png"
                        tmp['intrinsics'] = intrins[i]
                        trans = np.matmul(np.linalg.inv(pos[i]), pos[j])
                        tmp['rotation'] = trans[0:3, :][:, 0:3]
                        tmp['translation'] = trans[0:3, 3]
                        metadata.append(tmp)

    with open(os.path.join(split + '.pkl'), 'wb') as f:
        pickle.dump(metadata, f)

def split_overlap():
    with open(os.path.join("test.pkl"), 'rb') as f:
        metadata_list = pickle.load(f)
    high, low = [], []
    for meta in metadata_list:
        if meta['overlap'] > 0.3 and meta['overlap'] <= 0.7:
            high.append(meta)
        elif meta['overlap'] > 0.1 and meta['overlap'] <= 0.3:
            low.append(meta)
    with open(os.path.join('RGBD_3DMatch.pkl'), 'wb') as f:
        pickle.dump(high, f)
    with open(os.path.join('RGBD_3DLoMatch.pkl'), 'wb') as f:
        pickle.dump(low, f)

if __name__ == "__main__":
    save_idx('train')
    process_data('train')
    save_idx('val')
    process_data('val')
    save_idx('test')
    process_data('test')
    split_overlap()

    '''with open(os.path.join("RGBD_3DLoMatch.pkl"), 'rb') as f:
        metadata_list = pickle.load(f)
    count = {}
    ov = [0] * 11
    for meta in metadata_list:
        if meta['scene_name'] in count:
            count[meta['scene_name']] += 1
        else:
            count[meta['scene_name']] = 1
        ov[int(meta['overlap'] * 10)] += 1
    print(len(metadata_list))
    print(count)
    print(ov)'''