import json
import os
from pathlib import Path

import h5py
import matplotlib.pyplot as plt
import numpy as np
import open3d as o3d
from tensorpack import dataflow

from pcn import click


def resample_pcd(pcd, n):
    """Drop or duplicate points so that pcd has exactly n points"""
    idx = np.random.permutation(pcd.shape[0])
    if idx.shape[0] < n:
        idx = np.concatenate(
            [idx, np.random.randint(pcd.shape[0], size=n - pcd.shape[0])])
    return pcd[idx[:n]]


class BatchData(dataflow.ProxyDataFlow):
    def __init__(self,
                 ds,
                 batchsize,
                 input_size,
                 gt_size,
                 remainder=False,
                 use_list=False):
        super(BatchData, self).__init__(ds)
        self.batchsize = batchsize
        self.input_size = input_size
        self.gt_size = gt_size
        self.remainder = remainder
        self.use_list = use_list

    def __len__(self):
        ds_size = len(self.ds)
        div = ds_size // self.batchsize
        rem = ds_size % self.batchsize
        if rem == 0:
            return div
        return div + int(self.remainder)

    def __iter__(self):
        holder = []
        for data in self.ds:
            holder.append(data)
            if len(holder) == self.batchsize:
                yield self._aggregate_batch(holder, self.use_list)
                del holder[:]
        if self.remainder and len(holder) > 0:
            yield self._aggregate_batch(holder, self.use_list)

    def _aggregate_batch(self, data_holder, use_list=False):
        ''' Concatenate input points along the 0-th dimension
            Stack all other data along the 0-th dimension
        '''
        ids = np.stack([x[0] for x in data_holder])
        inputs = [
            resample_pcd(x[1], self.input_size)
            if x[1].shape[0] > self.input_size else x[1] for x in data_holder
        ]
        inputs = np.expand_dims(np.concatenate([x for x in inputs]),
                                0).astype(np.float32)
        npts = np.stack([
            x[1].shape[0]
            if x[1].shape[0] < self.input_size else self.input_size
            for x in data_holder
        ]).astype(np.int32)
        gts = np.stack([resample_pcd(x[2], self.gt_size)
                        for x in data_holder]).astype(np.float32)
        return ids, inputs, npts, gts


@click.group()
def client():
    pass


@client.command()
@click.argument("--dataset-file", type=str, required=True)
def read(args):
    batchsize = 32
    input_size = 3000
    output_size = 16384
    df = dataflow.LMDBSerializer.load(args.dataset_file, shuffle=False)
    size = df.size()
    df = dataflow.LocallyShuffleData(df, buffer_size=2000)
    # df = dataflow.PrefetchData(df, nr_prefetch=500, nr_proc=1)
    df = BatchData(df, batchsize, input_size, output_size)
    # df = dataflow.PrefetchDataZMQ(df, nr_proc=8)
    df = dataflow.RepeatedData(df, -1)
    df.reset_state()

    train_gen = df.get_data()

    def cmap_func(point3d):
        cmap = plt.get_cmap("winter")
        z = point3d[:, 0]
        rgba = cmap(z)
        rgb = rgba[:, :3]
        return rgb

    for p in train_gen:
        inds, inputs, npts, gt = p
        print(inds[0])
        print(inputs.shape)
        inputs = np.split(inputs[0], npts, axis=0)
        print(npts, np.sum(npts))
        print(gt.shape)
        colors = cmap_func(gt[0])
        print(np.min(gt), np.max(gt))
        pcl.Viewer(gt[0], colors, bg_color=(255, 255, 255))


def _convert_points_to_polar_coordinates(points: np.ndarray):
    r = np.sqrt(np.sum(points**2, axis=1))
    theta = np.arccos(points[:, 2] / r)
    phi = np.arctan2(points[:, 1], points[:, 0])

    ret = np.array([r, theta, phi]).T
    ret = np.ascontiguousarray(ret)

    return ret

    # x = r * np.sin(theta) * np.cos(phi)
    # y = r * np.sin(theta) * np.sin(phi)
    # z = r * np.cos(theta)
    # error = np.sum(abs(x - points[:, 0]))
    # print(error)
    # error = np.sum(abs(y - points[:, 1]))
    # print(error)
    # error = np.sum(abs(z - points[:, 2]))
    # print(error)
    # print(r[0], theta[0], phi[0], points[0])
    # exit()


@client.command()
@click.argument("--dataset-file", type=str, required=True)
@click.argument("--output-file", type=str, required=True)
def convert(args):
    df = dataflow.LMDBSerializer.load(args.dataset_file, shuffle=False)
    train_gen = df.get_data()
    dataset = {}
    for data_index, data in enumerate(train_gen):
        print(data_index, df.size())
        file_id, subset, gt = data
        file_id_components = file_id.split("_")
        obj_index = file_id_components[0] + "_" + file_id_components[1]
        if obj_index not in dataset:
            dataset[obj_index] = {}
            dataset[obj_index]["gt"] = _convert_points_to_polar_coordinates(gt)
            dataset[obj_index]["partial_inputs"] = []
        dataset[obj_index]["partial_inputs"].append(
            _convert_points_to_polar_coordinates(subset))

    with h5py.File(args.output_file, "w") as f:
        for data_index, obj_index in enumerate(dataset):
            print(data_index, len(dataset))
            data = dataset[obj_index]
            group = f.create_group(obj_index)
            group.create_dataset("gt", data=data["gt"])
            subgroup = group.create_group("partial_inputs")
            for subset_index, subset in enumerate(data["partial_inputs"]):
                subgroup.create_dataset(str(subset_index), data=subset)


@client.command(name="read_h5")
@click.argument("--path", type=str, required=True)
def read_h5(args):
    with h5py.File(args.path) as f:
        keys = list(f.keys())


@client.command(name="convert_lmdb_to_npz")
@click.argument("--path", type=str, required=True)
@click.argument("--output-root-directory", type=str, required=True)
def convert_lmdb_to_npz(args):
    output_root_directory = Path(args.output_root_directory)

    df = dataflow.LMDBSerializer.load(args.path, shuffle=False)
    train_gen = df.get_data()

    prev_model_id = None
    output_data = {}
    for data_index, data in enumerate(train_gen):
        file_id, partial_points, gt = data
        parts = file_id.split("_")
        category_id = parts[0]
        model_id = parts[1]
        partial_input_index = data_index % 8
        if partial_input_index == 0:
            prev_model_id = model_id
        else:
            assert model_id == prev_model_id
        output_data[f"partial_input_{partial_input_index}"] = partial_points
        if partial_input_index == 7:
            output_data["gt"] = gt
            cateogory_directory = output_root_directory / category_id
            os.makedirs(cateogory_directory, exist_ok=True)
            npz_path = cateogory_directory / f"{model_id}.npz"
            np.savez(npz_path, **output_data)
            print(npz_path)
            output_data = {}


@client.command(name="check_npz_data")
@click.argument("--path", type=str, required=True)
def check_npz_data(args):
    data = np.load(args.path)
    geometries = []
    print(data.files)
    gt = data["gt"]
    print(np.min(gt, axis=0), np.max(gt, axis=0))
    print(gt.shape)
    print(np.min(gt, axis=0), np.max(gt, axis=0))
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(gt)
    pcd.colors = o3d.utility.Vector3dVector(
        np.array([[1, 0, 0]]).repeat(len(gt), axis=0))
    geometries.append(pcd)

    for k in range(1):
        partial_input = data[f"partial_input_{k}"]
        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(partial_input)
        pcd.colors = o3d.utility.Vector3dVector(
            np.array([[0, 1, 0]]).repeat(len(partial_input), axis=0))
        geometries.append(pcd)

    o3d.visualization.draw_geometries(geometries)


def _read_pcd(path: str):
    with open(path, "r") as f:
        lines = [line.strip().split(" ") for line in f.readlines()]
    point3d_list = []
    is_data = False
    for line in lines:
        if line[0] == "DATA":
            is_data = True
            continue
        if is_data:
            point3d_list.append(
                [float(line[0]),
                 float(line[1]),
                 float(line[2])])
    return np.array(point3d_list)


@client.command(name="check_pcd_dataset")
@click.argument("--pcd-path", type=str, required=True)
def check_pcd_dataset(args):
    def cmap_func(point3d, colormap):
        cmap = plt.get_cmap(colormap)
        z = point3d[:, 0]
        rgba = cmap(z)
        rgb = rgba[:, :3]
        return rgb

    with open(args.pcd_path, "r") as f:
        lines = [line.strip().split(" ") for line in f.readlines()]
    point3d_list = []
    is_data = False
    for line in lines:
        if line[0] == "DATA":
            is_data = True
            continue
        if is_data:
            point3d_list.append(
                [float(line[0]),
                 float(line[1]),
                 float(line[2])])
    point_cloud = np.array(point3d_list)
    print(point_cloud)
    print(np.min(point_cloud, axis=0), np.max(point_cloud, axis=0))
    print(point_cloud.shape)

    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(point_cloud)
    o3d.visualization.draw_geometries([pcd])


@client.command(name="check_scale")
@click.argument("--complete-pcd-path", type=str, required=True)
@click.argument("--partial-pcd-path", type=str, required=True)
@click.argument("--npz-path", type=str, required=True)
def check_scale(args):
    def cmap_func(point3d, colormap):
        cmap = plt.get_cmap(colormap)
        z = point3d[:, 0]
        rgba = cmap(z)
        rgb = rgba[:, :3]
        return rgb

    data = np.load(args.npz_path)
    vertices = data["vertices"]
    offset = data["offset"]
    scale = data["scale"]

    npz_pcd = o3d.geometry.PointCloud()
    npz_pcd.points = o3d.utility.Vector3dVector(vertices)
    npz_pcd.colors = o3d.utility.Vector3dVector(
        np.repeat(np.array([[1, 0, 0]]), len(vertices), axis=0))

    complete_point_cloud = _read_pcd(args.complete_pcd_path)
    print(complete_point_cloud)
    print(complete_point_cloud.shape)

    x = complete_point_cloud[:, 0, None]
    y = complete_point_cloud[:, 1, None]
    z = complete_point_cloud[:, 2, None]
    complete_point_cloud = np.hstack((z, y, -x))
    complete_point_cloud = complete_point_cloud * scale  # ignore offset

    complete_pcn_pcd = o3d.geometry.PointCloud()
    complete_pcn_pcd.points = o3d.utility.Vector3dVector(complete_point_cloud)
    complete_pcn_pcd.colors = o3d.utility.Vector3dVector(
        np.repeat(np.array([[0, 1, 0]]), len(complete_point_cloud), axis=0))

    partial_point_cloud = _read_pcd(args.partial_pcd_path)
    partial_point_cloud = partial_point_cloud * scale  # ignore offset
    x = partial_point_cloud[:, 0, None]
    y = partial_point_cloud[:, 1, None]
    z = partial_point_cloud[:, 2, None]
    partial_point_cloud = np.hstack((z, y, -x))

    partial_pcn_pcd = o3d.geometry.PointCloud()
    partial_pcn_pcd.points = o3d.utility.Vector3dVector(partial_point_cloud)
    partial_pcn_pcd.colors = o3d.utility.Vector3dVector(
        np.repeat(np.array([[0, 0, 1]]), len(partial_point_cloud), axis=0))

    o3d.visualization.draw_geometries(
        [complete_pcn_pcd, partial_pcn_pcd, npz_pcd])


@client.command(name="generate_test_split")
@click.argument("--root-directory", type=str, required=True)
@click.argument("--output-directory", type=str, required=True)
def generate_test_split(args):
    data = {}
    root_directory = Path(args.root_directory)
    output_directory = Path(args.output_directory)
    for category_directory in sorted(root_directory.iterdir()):
        category_id = category_directory.name
        print(category_id)
        data[category_id] = []
        for model_path in sorted(category_directory.glob("*.pcd")):
            model_id = model_path.name.replace(".pcd", "")
            print(category_id, model_id)
            data[category_id].append(model_id)

    with open(output_directory / "8_categories_test.json", "w") as f:
        json.dump(data, f, indent=4, sort_keys=True)


if __name__ == "__main__":
    client()
