import concurrent.futures
import json
import math
import os
from os import replace
import random
import subprocess
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import open3d as o3d
import meta_learning_sdf.click as click
from PIL import Image
from scipy.spatial import cKDTree as KDTree


@click.group()
def client():
    pass


def _generate_sdf(surface_points, surface_points_normals, kd_tree, variance):
    sdf_samples = np.random.normal(scale=math.sqrt(variance), size=(250000, 3))
    near_points_indices = np.random.choice(len(surface_points),
                                           size=250000,
                                           replace=True)
    sdf_samples = surface_points[near_points_indices] + sdf_samples

    print("start query")
    distances, locations = kd_tree.query(sdf_samples)
    print("end query")

    nearest_points = surface_points[locations]
    nearest_normals = surface_points_normals[locations]
    direction = sdf_samples - nearest_points
    sign = (nearest_normals * direction).sum(axis=1)
    sign[sign >= 0] = 1
    sign[sign < 0] = -1
    signed_distance = distances * sign

    positive_sdf_indices = np.where(sign > 0)[0]
    negative_sdf_indices = np.where(sign < 0)[0]
    positive_points = sdf_samples[positive_sdf_indices]
    positive_signed_distance = signed_distance[positive_sdf_indices]
    negative_points = sdf_samples[negative_sdf_indices]
    negative_signed_distance = signed_distance[negative_sdf_indices]

    return positive_points, positive_signed_distance, negative_points, negative_signed_distance


@client.command(name="convert_data_to_sdf_npz")
@click.argument("--depth-path", type=str, required=True)
@click.argument("--output-path", type=str, required=True)
def convert_data_to_sdf_npz(args: click.Arguments):
    depth_path = Path(args.depth_path)
    output_path = Path(args.output_path)
    os.makedirs(str(output_path.parent), exist_ok=True)
    depth_data = np.array(Image.open(depth_path))

    print(depth_path, flush=True)

    K = np.array([
        [481.20, 0, 319.5],
        [0, -480, 239.5],
        [0, 0, 1],
    ])
    K_inv = np.linalg.inv(K)
    xx, yy = np.meshgrid(np.arange(depth_data.shape[1]),
                         np.arange(depth_data.shape[0]))
    z = depth_data.reshape((-1, 1))
    x = xx.reshape((-1, 1)) * z
    y = yy.reshape((-1, 1)) * z
    # z = np.ones_like(x)
    points_screen = np.concatenate((x, y, z), axis=1)
    points = (K_inv @ points_screen.T).T

    center = (points.min(0) + points.max(0)) / 2
    points -= center
    max_distance = np.linalg.norm(points, axis=1).max()
    points /= max_distance * 1.03

    direction = points.copy()
    direction[:, 2] += 1

    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(points)
    print("start estimate_normals", flush=True)
    pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(
        radius=0.1, max_nn=30))
    print("end estimate_normals", flush=True)
    normals = np.array(pcd.normals)
    dot = (normals * -direction).sum(axis=1)
    negative = np.where(dot < 0)[0]
    normals[negative] *= -1

    # pcd.normals = o3d.utility.Vector3dVector(normals)
    # pcd.points = o3d.utility.Vector3dVector(points)
    # o3d.visualization.draw_geometries([pcd])
    # exit()

    print("start KDTree", flush=True)
    kd_tree = KDTree(points, compact_nodes=False, balanced_tree=False)
    print("end KDTree", flush=True)
    (positive_points_far, positive_signed_distance_far, negative_points_far,
     negative_signed_distance_far) = _generate_sdf(points, normals, kd_tree,
                                                   0.005)
    (positive_points_near, positive_signed_distance_near, negative_points_near,
     negative_signed_distance_near) = _generate_sdf(points, normals, kd_tree,
                                                    0.0005)

    positive_points = np.concatenate(
        (positive_points_far, positive_points_near), axis=0)
    negative_points = np.concatenate(
        (negative_points_far, negative_points_near), axis=0)
    positive_signed_distance = np.concatenate(
        (positive_signed_distance_far, positive_signed_distance_near), axis=0)
    negative_signed_distance = np.concatenate(
        (negative_signed_distance_far, negative_signed_distance_near), axis=0)

    positive_sdf_samples = np.concatenate(
        (positive_points, positive_signed_distance[:, None]), axis=1)
    negative_sdf_samples = np.concatenate(
        (negative_points, negative_signed_distance[:, None]), axis=1)

    # pcd.points = o3d.utility.Vector3dVector(negative_sdf_samples[:, :3])
    # cmap = plt.get_cmap("RdBu")
    # normalized_sdf = 2 * negative_sdf_samples[:, 3] + 0.5
    # rgba = cmap(normalized_sdf)
    # colors = rgba[:, :3]
    # pcd.colors = o3d.utility.Vector3dVector(colors)

    # # pcd.normals = o3d.utility.Vector3dVector(normals)
    # o3d.visualization.draw_geometries([pcd])
    # exit()

    np.savez(output_path,
             positive_sdf_samples=positive_sdf_samples,
             negative_sdf_samples=negative_sdf_samples,
             offset=-center,
             scale=1 / max_distance)


@client.command(name="convert_data_to_npz")
@click.argument("--depth-path", type=str, required=True)
@click.argument("--output-path", type=str, required=True)
def convert_data_to_npz(args: click.Arguments):
    depth_path = Path(args.depth_path)
    output_path = Path(args.output_path)
    os.makedirs(str(output_path.parent), exist_ok=True)
    depth_data = np.array(Image.open(depth_path))

    K = np.array([
        [481.20, 0, 319.5],
        [0, -480, 239.5],
        [0, 0, 1],
    ])
    K_inv = np.linalg.inv(K)
    xx, yy = np.meshgrid(np.arange(depth_data.shape[1]),
                         np.arange(depth_data.shape[0]))
    z = depth_data.reshape((-1, 1))
    x = xx.reshape((-1, 1)) * z
    y = yy.reshape((-1, 1)) * z
    # z = np.ones_like(x)
    points_screen = np.concatenate((x, y, z), axis=1)
    points = (K_inv @ points_screen.T).T

    center = (points.min(0) + points.max(0)) / 2
    points -= center
    max_distance = np.linalg.norm(points, axis=1).max()
    points /= max_distance * 1.03

    direction = points.copy()
    direction[:, 2] += 1

    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(points)
    pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(
        radius=0.1, max_nn=30))
    normals = np.array(pcd.normals)
    dot = (normals * -direction).sum(axis=1)
    negative = np.where(dot < 0)[0]
    normals[negative] *= -1

    # pcd.normals = o3d.utility.Vector3dVector(normals)
    # o3d.visualization.draw_geometries([pcd])
    # exit()

    # print((points.min(0) + points.max(0)) / 2)
    # exit()

    np.savez(output_path,
             vertices=points,
             vertex_normals=normals,
             offset=-center,
             scale=1 / max_distance)
    print(output_path)


@client.command(name="convert_dataset_to_npz")
@click.argument("--dataset-directory", type=str, required=True)
@click.argument("--output-directory", type=str, required=True)
def convert_dataset_to_npz(args: click.Arguments):
    dataset_directory = Path(args.dataset_directory)
    output_directory = Path(args.output_directory)
    os.makedirs(args.output_directory, exist_ok=True)

    for category_directory in dataset_directory.iterdir():
        print(category_directory)
        os.makedirs(str(output_directory / category_directory.name),
                    exist_ok=True)
        depth_directory = category_directory / "depth"
        for depth_path in depth_directory.glob("*.png"):
            file_id = depth_path.name.replace(".png", "")
            depth_data = np.array(Image.open(depth_path))
            K = np.array([
                [481.20, 0, 319.5],
                [0, -480, 239.5],
                [0, 0, 1],
            ])
            K_inv = np.linalg.inv(K)
            xx, yy = np.meshgrid(np.arange(depth_data.shape[1]),
                                 np.arange(depth_data.shape[0]))
            z = depth_data.reshape((-1, 1))
            x = xx.reshape((-1, 1)) * z
            y = yy.reshape((-1, 1)) * z
            # z = np.ones_like(x)
            points_screen = np.concatenate((x, y, z), axis=1)
            points = (K_inv @ points_screen.T).T

            center = (points.min(0) + points.max(0)) / 2
            points -= center
            max_distance = np.linalg.norm(points, axis=1).max()
            points /= max_distance * 1.03

            direction = points
            direction[:, 2] += 1

            pcd = o3d.geometry.PointCloud()
            pcd.points = o3d.utility.Vector3dVector(points)
            pcd.estimate_normals(
                search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.1,
                                                                  max_nn=30))
            normals = np.array(pcd.normals)
            dot = (normals * -direction).sum(axis=1)
            negative = np.where(dot < 0)[0]
            normals[negative] *= -1

            # pcd.normals = o3d.utility.Vector3dVector(normals)
            # o3d.visualization.draw_geometries([pcd])
            output_path = output_directory / category_directory.name / f"{file_id}.npz"
            np.savez(output_path,
                     vertices=points,
                     vertex_normals=normals,
                     offset=-center,
                     scale=1 / max_distance)


@client.command(name="split_dataset")
@click.argument("--dataset-directory", type=str, required=True)
@click.argument("--output-directory", type=str, required=True)
@click.argument("--split-ratio", type=float, default=0.8)
@click.argument("--split-seed", type=int, default=0)
@click.argument("--name", type=str, default="icl_nuim")
def split_dataset(args: click.Arguments):
    dataset_directory = Path(args.dataset_directory)
    output_directory = Path(args.output_directory)
    os.makedirs(output_directory, exist_ok=True)

    rst = random.Random(args.split_seed)

    data_train = {}
    data_test = {}
    num_data_train = 0
    num_data_test = 0
    category_directory_list = sorted(list(dataset_directory.iterdir()))
    for category_directory in category_directory_list:
        if not category_directory.is_dir():
            continue
        category = category_directory.name
        depth_directory = category_directory / "depth"
        depth_file_list = list(depth_directory.glob("*.png"))
        rst.shuffle(depth_file_list)

        train_size = int(len(depth_file_list) * args.split_ratio)
        depth_list_train = depth_file_list[:train_size]
        depth_list_test = depth_file_list[train_size:]

        data_id_list_train = []
        data_id_list_test = []

        for depth_path in depth_list_train:
            data_id = depth_path.name.replace(".png", "")
            data_id_list_train.append(data_id)

        for depth_path in depth_list_test:
            data_id = depth_path.name.replace(".png", "")
            data_id_list_test.append(data_id)

        data_train[category] = sorted(data_id_list_train)
        data_test[category] = sorted(data_id_list_test)

        num_data_train += len(data_train[category])
        num_data_test += len(data_test[category])

    print(num_data_train)
    print(num_data_test)

    with open(output_directory / f"{args.name}_train.json", "w") as f:
        json.dump(data_train, f, indent=4, sort_keys=True)

    with open(output_directory / f"{args.name}_test.json", "w") as f:
        json.dump(data_test, f, indent=4, sort_keys=True)


@client.command(name="view_pointclout")
@click.argument("--depth-path", type=str, required=True)
def view_pointclout(args: click.Arguments):
    depth_path = Path(args.depth_path)
    depth_data = np.array(Image.open(depth_path))

    K = np.array([
        [481.20, 0, 319.5],
        [0, -480, 239.5],
        [0, 0, 1],
    ])
    K_inv = np.linalg.inv(K)
    xx, yy = np.meshgrid(np.arange(depth_data.shape[1]),
                         np.arange(depth_data.shape[0]))
    z = depth_data.reshape((-1, 1))
    x = xx.reshape((-1, 1)) * z
    y = yy.reshape((-1, 1)) * z
    # z = np.ones_like(x)
    points_screen = np.concatenate((x, y, z), axis=1)
    points = (K_inv @ points_screen.T).T

    center = (points.min(0) + points.max(0)) / 2
    points -= center
    max_distance = np.linalg.norm(points, axis=1).max()
    points /= max_distance * 1.03

    direction = points
    direction[:, 2] += 1

    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(points)
    pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(
        radius=0.1, max_nn=30))
    normals = np.array(pcd.normals)
    dot = (normals * -direction).sum(axis=1)
    negative = np.where(dot < 0)[0]
    normals[negative] *= -1
    pcd.normals = o3d.utility.Vector3dVector(normals)

    o3d.visualization.draw_geometries([pcd])


@client.command(name="view_data")
@click.argument("--data-path", type=str, required=True)
def view_data(args: click.Arguments):
    data = np.load(args.data_path)
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(data["vertices"])
    pcd.normals = o3d.utility.Vector3dVector(data["vertex_normals"])
    o3d.visualization.draw_geometries([pcd])


if __name__ == "__main__":
    client()
