import concurrent.futures
import json
import math
import os
import random
import subprocess
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import open3d as o3d
import implicit_geometric_regularization.click as click


def _split_list(array, segments):
    assert len(array) >= segments
    num_elements_per_segment = math.ceil(len(array) / segments)
    ret = []
    for _ in range(segments - 1):
        ret.append(array[:num_elements_per_segment])
        array = array[num_elements_per_segment:]
    ret.append(array)
    return ret


@click.group()
def client():
    pass


@client.command(name="view_surface_sampling")
@click.argument("--path", type=str, required=True)
def view_surface_sampling(args):
    pcd = o3d.io.read_point_cloud(args.path)
    print(pcd)
    print(pcd.has_normals())
    o3d.visualization.draw_geometries([pcd])


@client.command(name="view_sdf_samples")
@click.argument("--path", type=str, required=True)
def view_sdf_samples(args):
    data = np.load(args.path)
    pos = data["pos"]
    neg = data["neg"]
    sdf = np.vstack((pos, neg))
    points = sdf[:, :3]
    sdf = sdf[:, 3]
    indices = np.where(points[:, 0] <= 0)[0]
    points = points[indices]
    sdf = sdf[indices]
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(points)

    cmap = plt.get_cmap("RdBu")
    normalized_sdf = 2 * sdf + 0.5
    rgba = cmap(normalized_sdf)
    colors = rgba[:, :3]
    pcd.colors = o3d.utility.Vector3dVector(colors)
    print(points.shape, colors.shape)
    print(pcd.has_colors())
    o3d.visualization.draw_geometries([pcd], "sdf")


broken_model_ids = [
    "aaaba1bbe037d3b1e406974af41e8842",
    "ac35b0d3d4b33477e76bc197b3a3ffc0",
    "d6f81af7b34e8da814038d588fd1342f",
    "dcba7668017df61ef51f77a6d7299806",
    "dfc0f60a1b11ab7c388f6c7a9d3e1552",
    "f59e8dac9451d165a68447d7f5a7ee42",
    "12ec19e85b31e274725f67267e31c89",
    "b2a585ba5f0b4a25e76bc197b3a3ffc0",
    "b7eecafd15147c01fabd49ee8315e8b9",
    "b8efc08bc8eab52a330a170e9ceed373",
    "5d636123af31f735e76bc197b3a3ffc0",
    "31b201b7346e6cd15e9e2656aff7dd5b",
    "79ee45aa6b0c86064725f67267e31c89",
    "7de119e0eb48a11e3c8d0fdfb1cc2535",
    "a5abf524f9b08432f51f77a6d7299806",
    "94235090b6d5dbfae76bc197b3a3ffc0",
    "998c2b635922ace93c8d0fdfb1cc2535",
    "c0ac5dea15f961c9e76bc197b3a3ffc0",
    "2d97e5970822beddde03ab2a27ba7531",
    "adcc6534d6db1ff9dff990aa66c50fe0",
    "e4bd6dda8d29106ca16af3198c99de08",
    "c5a02d586ea431a1e76bc197b3a3ffc0",
    "eaaa14d9fab71afa5b903ba10d2ec446",
    "bda8c00b62528346ad8a0ee9b106700e",
    "c715a29db7d888dd23f9e4320fcb3664",
    "2f30f402c00a166368068eb0ef40fbb1",
    "5e515b18ed17a418b056c98b2e5e5e4e",
    "bafb9c9602d00b3e50b42dfb503f5a87",
    "48d58a4f43125e7f34282af0231ccf9c",
    "8facbe9d9f4da233d15a5887ec2183c9",
    "ce82dbe1906e605d9b678eaf6920cd86",
    "15a95cddbc40994beefc4457af135dc1",
    "15bcc664de962b04e76bc197b3a3ffc0",
    "bee929929700e99fad8a0ee9b106700e",
    "eaaa14d9fab71afa5b903ba10d2ec446",
    "c06b5a7aa4557182f51f77a6d7299806",
    "187804891c09be04f1077877e3a9d430",
    "bbddf00c2dda43a2f21cf17406f1f25",
    "d7f1a22419268800e76bc197b3a3ffc0",
    "c487441af8ac37a733718c332b0d7bfd",
    "c535629f9661293dc16ef5c633c71b56",
    "c70c1a6a0e795669f51f77a6d7299806",
    "1a44dd6ee873d443da13974b3533fb59",
    "c833ef6f882a4b2a14038d588fd1342f",
    "dc9a7d116351f2cca16af3198c99de08",
    "21b8b1e51237f4aee76bc197b3a3ffc0",
    "de96be0a27fe1610d40c07d3c15cc681",
    "249e0936ae06383ab056c98b2e5e5e4e",
    "abe557fa1b9d59489c81f0389df0e98a",
    "e2efc1a73017a3d2e76bc197b3a3ffc0",
    "e7580c72525b4bb1cc786970133d7717",
    "5926d3296767ab28543df75232f6ff2b",
    "59e852f315216f95ba9df3ea0397b1a6",
    "5aefcf6b38e180f06df11193c72a632e",
    "31f09d77217fc9e0e76bc197b3a3ffc0",
    "f078a5940280c0a22c6c98851414a9d8",
    "ebf8166bacd6759399513f98ce033426",
    "376c99ec94d34cf4e76bc197b3a3ffc0",
    "8464de18cd5d14e138435fc2a8dffe1b",
    "69240d39dfbc47a0d15a5887ec2183c9",
    "f835366205ba8afd9b678eaf6920cd86",
    "922380f231a342cf388f6c7a9d3e1552",
    "a5d68126acbd43395e9e2656aff7dd5b",
    "94eae2316754482627d265f13671170a",
    "a4910da0271b6f213a7e932df8806f9e",
    "fe4c20766801dc98bc2e5d5fd57660fe",
    "40ee6a47e485cb4d41873672d11706f4",
    "41283ae87a29fbbe76bc197b3a3ffc0",
    "15bc36a3ce59163bce8584f8b28da0ba",
    "fe8b246b47321320c3bd24f986301745",
    "ff3581996365bdddc3bd24f986301745",
    "abf04f17d2c84a160e37b3f76995f8b",
    "fcdbba7127ad58a84155fcb773414092",
    "b976a48c015d6ced5e9e2656aff7dd5b",
    "1196ffab55e431e11b17743c18fb63dc",
    "5026668bb2bcedebccfcde790fc2f661",
    "b13143d5f71e38d24738aee9841818fe",
    "2c1af98d2058a8056588620c25b809f9",
    "922380f231a342cf388f6c7a9d3e1552",
    "2f30f402c00a166368068eb0ef40fbb1",
    "52eaeaf85846d638e76bc197b3a3ffc0",
    "a4910da0271b6f213a7e932df8806f9e",
    "3e52f25b8f66d9a8adf3df9d9e46d0",
    "f97011a0bae2b4062d1c72b9dec4baa1",
    "de45798ef57fe2d131b4f9e586a6d334",
    "581d698b6364116e83e95e8523a2fbf3",
    "2d08a64e4a257e007135fc51795b4038",
    "2dd729a07206d1f5746cec00e236149d",
    "466ea85bb4653ba3a715ae636b111d77",
    "fcdbba7127ad58a84155fcb773414092",
    "6ee01e861f10b1f044175b4dddf5be08",
    "c0d736964ddc268c9fb3e3631a88cdab",
    "a4d779cf204ff052894c6619653a3264",
    "b13143d5f71e38d24738aee9841818fe",
    "809d5384f84d55273a11565e5be9cf53",
    "639d99161524c7dd54e6ead821103617",
    "3c4ed9c8f76c7a5ef51f77a6d7299806",
    "1788d15f49a57570a0402637f097180",
    "9a688545112c2650ca703e831bf56f93",
    "1545a13dc5b12f51f77a6d7299806",
    "1f2a8562a2de13a2c29fde65e51f52cb",
    "3f4f6f64f5ae57a14038d588fd1342f",
    "42db4f765f4e6be414038d588fd1342f",
    "b0952767eeb21b88e2b075a80e28c81b",
    "221e8ea6bdcc1db614038d588fd1342f",
    "22b11483d6a2461814038d588fd1342f",
    "6e77d23b324ddbd65661fcc99c72bf48",
    "453be11e44a230a0f51f77a6d7299806",
    "cdc3762d2846133adc26ec30fe28341a",
    "d4f4b5bf712a96b13679ccb6aaef8b00",
    "d707228baece2270c473585373fc1fd0",
    "539d5ef3031544fb63c6a0e477a59b9f",
    "32859e7af79f494a14038d588fd1342f",
    "34d7a91d639613f6f51f77a6d7299806",
    "d9f0f7cff584b60d826faebfb5cddf3c",
    "3a69f7f6729d8d48f51f77a6d7299806",
    "3eb9e07793635b30f51f77a6d7299806",
    "4385e447533cac72d1c72b9dec4baa1",
    "4c29dcad235ff80df51f77a6d7299806",
    "7f39803c32028449e76bc197b3a3ffc0",
    "5875ca8510373873f51f77a6d7299806",
    "81d84727a6da7ea7bb8dc0cd2a40a9a4",
    "a9957cf39fdd61fc612f7163ca95602",
    "5af850643d64c2621b17743c18fb63dc",
    "5b51df75df88c639f51f77a6d7299806",
    "703c1f85dc01baad9fb3e3631a88cdab",
    "5db74dcfc73a3ea2f2ca754af3aaf35",
    "621dab02dc0ac842e7891ff53b0e70d",
    "64ee5d22281ef431de03ab2a27ba7531",
    "62127325480bec8d2c6c98851414a9d8",
    "6e213a2ecc95c30544175b4dddf5be08",
    "64871dc28a21843ad504e40666187f4e",
    "6498a5053d12addb91a2a5174703986b",
    "64ef0e07129b6bc4c3bd24f986301745",
    "7b1d07d932ca5890f51f77a6d7299806",
    "3d51b8ad6b622c75dd5c7f6e6acea5c1",
    "7e832bc481a3335614038d588fd1342f",
    "904ad336345205dce76bc197b3a3ffc0",
    "8159bdc913cd8a23debd258f4352e626",
    "9377b1b5c83bb05ce76bc197b3a3ffc0",
    "6f194ba6ba254aacf51f77a6d7299806",
    "b9bf493040c8b434f3e39f2e17005efc",
    "97a137cc6688a07c90a9ce3e4b15521e",
    "1d94afb9894bf975e76bc197b3a3ffc0",
    "2af98dcf936576b155f28299c0ff52b7",
    "2e8f1b6cb9b4f568316a315354726289",
    "77dcd07d59503f1014038d588fd1342f",
    "78261b526d28a436cc786970133d7717",
    "790d554d7f9b040299513f98ce033426",
    "a088285efee5f0dbbc6a6acad56465f2",
    "e36cda06eed31d11d816402a0e81d922",
    "a13c36acbc45184de76bc197b3a3ffc0",
    "e5ede813e9f07ee4f3e39f2e17005efc",
    "ba0f7f4b5756e8795ae200efe59d16d0",
    "a333abca08fceb434eec4d2d414b38e0",
    "595e48c492a59d3029404a50338e24e7",
]


@client.command(name="split_dataset")
@click.argument("--dataset-directory", type=str, required=True)
@click.argument("--output-directory", type=str, required=True)
@click.argument("--split-ratio", type=float, default=0.9)
@click.argument("--split-seed", type=int, default=0)
@click.argument("--name", type=str, default="55_categories")
def split_dataset(args: click.Arguments):
    dataset_directory = Path(args.dataset_directory)
    output_directory = Path(args.output_directory)
    os.makedirs(output_directory, exist_ok=True)

    rst = random.Random(args.split_seed)

    include_categories = None
    exclude_categories = None

    # include_categories = [
    #     "02691156", "02933112", "02958343", "03001627", "03636649", "04256520",
    #     "04379243", "04530566"
    # ]
    # include_categories = ["02958343"]
    # exclude_categories = ["02958343"]

    data_train = {}
    data_test = {}
    category_directory_list = sorted(list(dataset_directory.iterdir()))
    for category_directory in category_directory_list:
        if not category_directory.is_dir():
            continue
        category = category_directory.name  # synsetId
        if include_categories is not None:
            if category not in include_categories:
                continue
        if exclude_categories is not None:
            if category in exclude_categories:
                continue
        model_directory_list = list(category_directory.iterdir())
        rst.shuffle(model_directory_list)

        train_size = int(len(model_directory_list) * args.split_ratio)
        model_directory_list_train = model_directory_list[:train_size]
        model_directory_list_test = model_directory_list[train_size:]

        model_id_list_train = []
        model_id_list_test = []

        for model_directory in model_directory_list_train:
            if not model_directory.is_dir():
                continue
            model_id = model_directory.name
            if model_id in broken_model_ids:
                continue
            model_id_list_train.append(model_id)

        for model_directory in model_directory_list_test:
            if not model_directory.is_dir():
                continue
            model_id = model_directory.name
            if model_id in broken_model_ids:
                continue
            model_id_list_test.append(model_id)

        data_train[category] = sorted(model_id_list_train)
        data_test[category] = sorted(model_id_list_train)

    with open(output_directory / f"{args.name}_train.json", "w") as f:
        json.dump(data_train, f, indent=4, sort_keys=True)

    with open(output_directory / f"{args.name}_test.json", "w") as f:
        json.dump(data_test, f, indent=4, sort_keys=True)


@client.command(name="split_model")
@click.argument("--category-directory", type=str, required=True)
@click.argument("--output-directory", type=str, required=True)
@click.argument("--split-ratio", type=float, default=0.9)
@click.argument("--split-seed", type=int, default=0)
@click.argument("--name", type=str, default="chairs")
def split_model(args: click.Arguments):
    category_directory = Path(args.category_directory)
    output_directory = Path(args.output_directory)
    os.makedirs(output_directory, exist_ok=True)

    rst = random.Random(args.split_seed)

    data_train = {}
    data_test = {}
    category = category_directory.name  # synsetId
    model_directory_list = list(category_directory.iterdir())
    rst.shuffle(model_directory_list)

    train_size = int(len(model_directory_list) * args.split_ratio)
    model_directory_list_train = model_directory_list[:train_size]
    model_directory_list_test = model_directory_list[train_size:]

    model_id_list_train = []
    model_id_list_test = []

    for model_directory in model_directory_list_train:
        if not model_directory.is_dir():
            continue
        model_id = model_directory.name
        if model_id in broken_model_ids:
            continue
        obj_path = model_directory / "models" / "model_normalized.obj"
        if not obj_path.exists():
            continue
        model_id_list_train.append(model_id)

    for model_directory in model_directory_list_test:
        if not model_directory.is_dir():
            continue
        model_id = model_directory.name
        if model_id in broken_model_ids:
            continue
        obj_path = model_directory / "models" / "model_normalized.obj"
        if not obj_path.exists():
            continue
        model_id_list_test.append(model_id)

    data_train[category] = sorted(model_id_list_train)
    data_test[category] = sorted(model_id_list_test)

    with open(output_directory / f"{args.name}_train.json", "w") as f:
        json.dump(data_train, f, indent=4, sort_keys=True)

    with open(output_directory / f"{args.name}_test.json", "w") as f:
        json.dump(data_test, f, indent=4, sort_keys=True)


@client.command(name="take_split")
@click.argument("--split-path", type=str, required=True)
@click.argument("--category-id", type=str, required=True)
@click.argument("--output-path", type=str, required=True)
def take_split(args: click.Arguments):
    with open(args.split_path) as f:
        base_split = json.load(f)
    model_id_list = base_split[args.category_id]
    data = {args.category_id: model_id_list}
    with open(args.output_path, "w") as f:
        json.dump(data, f, indent=4, sort_keys=True)


def sample_point_cloud(executable: str, args: dict):
    command = [executable]
    for key in args:
        command.append(key)
        command.append(str(args[key]))
    subproc = subprocess.Popen(command)
    subproc.wait()


@client.command()
@click.argument("--executable",
                type=str,
                required=True,
                help="path to sample_surface_points binary")
@click.argument("--dataset-directory", type=str, required=True)
@click.argument("--split-path", type=str, required=True)
@click.argument("--skip-existing-file", is_flag=True)
@click.argument("--output-directory", type=str, required=True)
@click.argument("--num-point-samples", type=int, default=200000)
@click.argument("--num-threads", type=int, default=1)
def convert(args):
    dataset_directory = Path(args.dataset_directory)
    output_directory = Path(args.output_directory)
    os.makedirs(output_directory, exist_ok=True)
    assert os.path.exists(args.executable)

    split_path = Path(args.split_path)
    assert split_path.is_file()

    with open(split_path) as f:
        split = json.load(f)

    query_list = []
    for category in split:
        model_id_list = split[category]
        for model_id in model_id_list:
            obj_path = dataset_directory / category / model_id / "models" / "model_normalized.obj"
            assert obj_path.exists()

            os.makedirs(output_directory / category / model_id, exist_ok=True)
            npz_path = output_directory / category / model_id / "point_cloud.npz"
            if args.skip_existing_file and npz_path.exists():
                continue

            query_list.append({
                "--input-mesh-path": obj_path,
                "--output-npz-path": npz_path,
                "--num-samples": args.num_point_samples
            })

    with concurrent.futures.ThreadPoolExecutor(
            max_workers=int(args.num_threads)) as executor:
        for query in query_list:
            executor.submit(sample_point_cloud, args.executable, query)
        executor.shutdown()


@client.command(name="view_npz")
@click.argument("--path", type=str, required=True)
def view_npz(args):
    npz_path = Path(args.path)
    assert npz_path.exists()
    print(npz_path)
    data = np.load(npz_path)
    vertices = data["vertices"]
    vertex_normals = data["vertex_normals"]
    # ind = np.where((-0.6 < vertices[:, 2]) * (vertices[:, 2] < -0.4))[0]
    # vertices = vertices[ind]
    # vertex_normals = vertex_normals[ind]
    assert len(vertices) == len(vertex_normals)
    print(vertices.shape)
    print(vertex_normals.shape)
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(vertices)
    pcd.normals = o3d.utility.Vector3dVector(vertex_normals)
    o3d.visualization.draw_geometries([pcd])


@client.command(name="view_obj")
@click.argument("--path", type=str, required=True)
def view_obj(args):
    obj_path = Path(args.path)
    assert obj_path.exists()
    print(obj_path)
    mesh = o3d.io.read_triangle_mesh(str(obj_path))
    vertices = np.array(mesh.vertices)
    min = np.min(vertices, axis=0, keepdims=True)
    max = np.max(vertices, axis=0, keepdims=True)
    center = (min - max) / 2
    vertices -= center
    print(center)
    mesh.vertices = o3d.utility.Vector3dVector(vertices)
    o3d.visualization.draw_geometries([mesh])


@client.command(name="view_npz_dataset")
@click.argument("--dataset-directory", type=str, required=True)
def view_npz_dataset(args):
    dataset_directory = Path(args.dataset_directory)
    category_directory_list = dataset_directory.iterdir()

    category_directory_list = sorted(list(dataset_directory.iterdir()))
    for category_directory in category_directory_list:
        if not category_directory.is_dir():
            continue
        model_directory_list = list(category_directory.iterdir())
        for model_directory in model_directory_list:
            if not model_directory.is_dir():
                continue
            npz_path = model_directory / "point_cloud.npz"
            if not npz_path.exists():
                continue
            print(npz_path)
            data = np.load(npz_path)
            vertices = data["vertices"]
            vertex_normals = data["vertex_normals"]
            # ind = np.where(
            #     (-0.6 < vertices[:, 2]) * (vertices[:, 2] < -0.5))[0]
            # vertices = vertices[ind]
            # vertex_normals = vertex_normals[ind]
            assert len(vertices) == len(vertex_normals)
            print(vertices.shape)
            print(vertex_normals.shape)
            pcd = o3d.geometry.PointCloud()
            pcd.points = o3d.utility.Vector3dVector(vertices)
            pcd.normals = o3d.utility.Vector3dVector(vertex_normals)
            o3d.visualization.draw_geometries([pcd])


@client.command(name="view_obj_dataset")
@click.argument("--dataset-directory", type=str, required=True)
def view_obj_dataset(args):
    dataset_directory = Path(args.dataset_directory)
    category_directory_list = dataset_directory.iterdir()

    category_directory_list = sorted(list(dataset_directory.iterdir()))
    for category_directory in category_directory_list:
        if not category_directory.is_dir():
            continue
        model_directory_list = list(category_directory.iterdir())
        for model_directory in model_directory_list:
            if not model_directory.is_dir():
                continue
            obj_path = model_directory / "models" / "model_normalized.obj"
            if not obj_path.exists():
                continue
            print(obj_path)
            mesh = o3d.io.read_triangle_mesh(str(obj_path))
            print(mesh)
            mesh.compute_triangle_normals()
            mesh.compute_vertex_normals()
            o3d.visualization.draw_geometries([mesh])


if __name__ == "__main__":
    client()
