import concurrent.futures
import json
import math
import os
import random
import subprocess
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import open3d as o3d
import deep_sdf.click as click


def _split_list(array, segments):
    assert len(array) >= segments
    num_elements_per_segment = math.ceil(len(array) / segments)
    ret = []
    for _ in range(segments - 1):
        ret.append(array[:num_elements_per_segment])
        array = array[num_elements_per_segment:]
    ret.append(array)
    return ret


@click.group()
def client():
    pass


@client.command(name="split_dataset")
@click.argument("--dataset-directory", type=str, required=True)
@click.argument("--output-directory", type=str, required=True)
@click.argument("--split-ratio", type=float, default=0.9)
@click.argument("--split-seed", type=int, default=0)
@click.argument("--name", type=str, default="55_categories")
def split_dataset(args: click.Arguments):
    dataset_directory = Path(args.dataset_directory)
    output_directory = Path(args.output_directory)
    os.makedirs(output_directory, exist_ok=True)

    rst = random.Random(args.split_seed)

    include_categories = None
    exclude_categories = None

    # include_categories = [
    #     "02691156", "02933112", "02958343", "03001627", "03636649", "04256520",
    #     "04379243", "04530566"
    # ]
    # include_categories = ["02958343"]
    # exclude_categories = ["02958343"]

    data_train = {}
    data_test = {}
    category_directory_list = sorted(list(dataset_directory.iterdir()))
    for category_directory in category_directory_list:
        if not category_directory.is_dir():
            continue
        category = category_directory.name  # synsetId
        if include_categories is not None:
            if category not in include_categories:
                continue
        if exclude_categories is not None:
            if category in exclude_categories:
                continue
        model_directory_list = list(category_directory.iterdir())
        rst.shuffle(model_directory_list)

        train_size = int(len(model_directory_list) * args.split_ratio)
        model_directory_list_train = model_directory_list[:train_size]
        model_directory_list_test = model_directory_list[train_size:]

        model_id_list_train = []
        model_id_list_test = []

        for model_directory in model_directory_list_train:
            if not model_directory.is_dir():
                continue
            model_id = model_directory.name
            model_id_list_train.append(model_id)

        for model_directory in model_directory_list_test:
            if not model_directory.is_dir():
                continue
            model_id = model_directory.name
            model_id_list_test.append(model_id)

        data_train[category] = sorted(model_id_list_train)
        data_test[category] = sorted(model_id_list_train)

    with open(output_directory / f"{args.name}_train.json", "w") as f:
        json.dump(data_train, f, indent=4, sort_keys=True)

    with open(output_directory / f"{args.name}_test.json", "w") as f:
        json.dump(data_test, f, indent=4, sort_keys=True)


@client.command(name="split_model")
@click.argument("--category-directory", type=str, required=True)
@click.argument("--output-directory", type=str, required=True)
@click.argument("--split-ratio", type=float, default=0.9)
@click.argument("--split-seed", type=int, default=0)
@click.argument("--name", type=str, default="chairs")
def split_model(args: click.Arguments):
    category_directory = Path(args.category_directory)
    output_directory = Path(args.output_directory)
    os.makedirs(output_directory, exist_ok=True)

    rst = random.Random(args.split_seed)

    data_train = {}
    data_test = {}
    category = category_directory.name  # synsetId
    model_directory_list = list(category_directory.iterdir())
    rst.shuffle(model_directory_list)

    train_size = int(len(model_directory_list) * args.split_ratio)
    model_directory_list_train = model_directory_list[:train_size]
    model_directory_list_test = model_directory_list[train_size:]

    model_id_list_train = []
    model_id_list_test = []

    for model_directory in model_directory_list_train:
        if not model_directory.is_dir():
            continue
        model_id = model_directory.name
        obj_path = model_directory / "models" / "model_normalized.obj"
        if not obj_path.exists():
            continue
        model_id_list_train.append(model_id)

    for model_directory in model_directory_list_test:
        if not model_directory.is_dir():
            continue
        model_id = model_directory.name
        obj_path = model_directory / "models" / "model_normalized.obj"
        if not obj_path.exists():
            continue
        model_id_list_test.append(model_id)

    data_train[category] = sorted(model_id_list_train)
    data_test[category] = sorted(model_id_list_test)

    with open(output_directory / f"{args.name}_train.json", "w") as f:
        json.dump(data_train, f, indent=4, sort_keys=True)

    with open(output_directory / f"{args.name}_test.json", "w") as f:
        json.dump(data_test, f, indent=4, sort_keys=True)


@client.command(name="take_split")
@click.argument("--split-path", type=str, required=True)
@click.argument("--category-id", type=str, required=True)
@click.argument("--output-path", type=str, required=True)
def take_split(args: click.Arguments):
    with open(args.split_path) as f:
        base_split = json.load(f)
    model_id_list = base_split[args.category_id]
    data = {args.category_id: model_id_list}
    with open(args.output_path, "w") as f:
        json.dump(data, f, indent=4, sort_keys=True)


@client.command(name="find_missing_models")
@click.argument("--dataset-directory", type=str, required=True)
@click.argument("--npz-directory", type=str, required=True)
@click.argument("--output-directory", type=str, required=True)
def find_missing_models(args: click.Arguments):
    dataset_directory = Path(args.dataset_directory)
    npz_directory = Path(args.npz_directory)
    output_directory = Path(args.output_directory)
    os.makedirs(output_directory, exist_ok=True)

    data = {}
    num_data = 0
    category_directory_list = sorted(list(dataset_directory.iterdir()))
    for category_directory in category_directory_list:
        if not category_directory.is_dir():
            continue
        category = category_directory.name  # synsetId
        print(category)
        model_directory_list = list(category_directory.iterdir())
        model_id_list = []
        for model_directory in model_directory_list:
            if not model_directory.is_dir():
                continue
            model_id = model_directory.name
            npz_path = npz_directory / category / model_id / "sdf.npz"
            if not npz_path.exists():
                model_id_list.append(model_id)

        data[category] = sorted(model_id_list)
        num_data += len(data[category])

    print(num_data)

    with open(output_directory / f"missing.json", "w") as f:
        json.dump(data, f, indent=4, sort_keys=True)


def sample_point_cloud(executable: str, args: dict):
    command = [executable]
    for key in args:
        command.append(key)
        command.append(str(args[key]))
    subproc = subprocess.Popen(command)
    subproc.wait()


@client.command()
@click.argument("--executable",
                type=str,
                required=True,
                help="path to sample_sdf binary")
@click.argument("--dataset-directory", type=str, required=True)
@click.argument("--split-path", type=str, required=True)
@click.argument("--skip-existing-file", is_flag=True)
@click.argument("--output-directory", type=str, required=True)
@click.argument("--num-point-samples", type=int, default=500000)
@click.argument("--num-threads", type=int, default=1)
def convert(args):
    dataset_directory = Path(args.dataset_directory)
    output_directory = Path(args.output_directory)
    os.makedirs(output_directory, exist_ok=True)
    assert os.path.exists(args.executable)

    split_path = Path(args.split_path)
    assert split_path.is_file()

    with open(split_path) as f:
        split = json.load(f)

    query_list = []
    for category in split:
        model_id_list = split[category]
        for model_id in model_id_list:
            obj_path = dataset_directory / category / model_id / "models" / "model_normalized.obj"
            assert obj_path.exists()

            os.makedirs(output_directory / category / model_id, exist_ok=True)
            npz_path = output_directory / category / model_id / "sdf.npz"
            if args.skip_existing_file and npz_path.exists():
                continue

            query_list.append({
                "--input-mesh-path": obj_path,
                "--output-npz-path": npz_path,
                "--num-samples": args.num_point_samples
            })

    with concurrent.futures.ThreadPoolExecutor(
            max_workers=int(args.num_threads)) as executor:
        for query in query_list:
            executor.submit(sample_point_cloud, args.executable, query)
        executor.shutdown()


@client.command(name="view_npz")
@click.argument("--path", type=str, required=True)
def view_npz(args):
    npz_path = Path(args.path)
    assert npz_path.exists()
    print(npz_path)
    data = np.load(npz_path)
    pos = data["positive_sdf_samples"]
    neg = data["negative_sdf_samples"]
    print(pos.shape)
    print(neg.shape)
    sdf = np.vstack((pos, neg))
    points = sdf[:, :3]
    sdf = sdf[:, 3]

    indices = np.where(points[:, 0] > 0)[0]
    points = points[indices]
    sdf = sdf[indices]

    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(points)

    cmap = plt.get_cmap("RdBu")
    normalized_sdf = 2 * sdf + 0.5
    rgba = cmap(normalized_sdf)
    colors = rgba[:, :3]
    pcd.colors = o3d.utility.Vector3dVector(colors)
    print(points.shape, colors.shape)
    print(pcd.has_colors())
    o3d.visualization.draw_geometries([pcd], "sdf")


@client.command(name="view_obj")
@click.argument("--path", type=str, required=True)
def view_obj(args):
    obj_path = Path(args.path)
    assert obj_path.exists()
    print(obj_path)
    mesh = o3d.io.read_triangle_mesh(str(obj_path))
    vertices = np.array(mesh.vertices)
    min = np.min(vertices, axis=0, keepdims=True)
    max = np.max(vertices, axis=0, keepdims=True)
    center = (min - max) / 2
    vertices -= center
    print(center)
    mesh.vertices = o3d.utility.Vector3dVector(vertices)
    o3d.visualization.draw_geometries([mesh])


@client.command(name="view_npz_dataset")
@click.argument("--dataset-directory", type=str, required=True)
def view_npz_dataset(args):
    dataset_directory = Path(args.dataset_directory)
    category_directory_list = dataset_directory.iterdir()

    category_directory_list = sorted(list(dataset_directory.iterdir()))
    for category_directory in category_directory_list:
        if not category_directory.is_dir():
            continue
        model_directory_list = list(category_directory.iterdir())
        for model_directory in model_directory_list:
            if not model_directory.is_dir():
                continue
            npz_path = model_directory / "sdf.npz"
            if not npz_path.exists():
                continue
            print(npz_path)
            data = np.load(npz_path)
            pos = data["positive_sdf_samples"]
            neg = data["negative_sdf_samples"]
            print(pos.shape)
            print(neg.shape)
            sdf = np.vstack((pos, neg))
            points = sdf[:, :3]
            sdf = sdf[:, 3]
            indices = np.where(points[:, 0] <= 0)[0]
            points = points[indices]
            sdf = sdf[indices]
            pcd = o3d.geometry.PointCloud()
            pcd.points = o3d.utility.Vector3dVector(points)

            cmap = plt.get_cmap("RdBu")
            normalized_sdf = 2 * sdf + 0.5
            rgba = cmap(normalized_sdf)
            colors = rgba[:, :3]
            pcd.colors = o3d.utility.Vector3dVector(colors)
            print(points.shape, colors.shape)
            print(pcd.has_colors())
            o3d.visualization.draw_geometries([pcd], "sdf")


@client.command(name="check_nan")
@click.argument("--dataset-directory", type=str, required=True)
def check_nan(args):
    dataset_directory = Path(args.dataset_directory)
    category_directory_list = dataset_directory.iterdir()

    category_directory_list = sorted(list(dataset_directory.iterdir()))
    for category_directory in category_directory_list:
        if not category_directory.is_dir():
            continue
        model_directory_list = list(category_directory.iterdir())
        for model_directory in model_directory_list:
            if not model_directory.is_dir():
                continue
            npz_path = model_directory / "sdf.npz"
            if not npz_path.exists():
                continue
            data = np.load(npz_path)
            pos = data["positive_sdf_samples"]
            neg = data["negative_sdf_samples"]
            indices, _ = np.where(np.isnan(pos))
            if len(indices) > 0:
                print(npz_path, len(indices))
            indices, _ = np.where(np.isnan(neg))
            if len(indices) > 0:
                print(npz_path, len(indices))


@client.command(name="view_obj_dataset")
@click.argument("--dataset-directory", type=str, required=True)
def view_obj_dataset(args):
    dataset_directory = Path(args.dataset_directory)
    category_directory_list = dataset_directory.iterdir()

    category_directory_list = sorted(list(dataset_directory.iterdir()))
    for category_directory in category_directory_list:
        if not category_directory.is_dir():
            continue
        model_directory_list = list(category_directory.iterdir())
        for model_directory in model_directory_list:
            if not model_directory.is_dir():
                continue
            obj_path = model_directory / "models" / "model_normalized.obj"
            if not obj_path.exists():
                continue
            print(obj_path)
            mesh = o3d.io.read_triangle_mesh(str(obj_path))
            print(mesh)
            mesh.compute_triangle_normals()
            mesh.compute_vertex_normals()
            o3d.visualization.draw_geometries([mesh])


if __name__ == "__main__":
    client()
