import random
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

from implicit_geometric_regularization import click
from implicit_geometric_regularization.datasets.dynamic_faust import (
    MinibatchGenerator, MinibatchIterator, NpzDataset, PlyDataset)


@click.group()
def client():
    pass


@client.command(name="test_ply_dataset")
@click.argument("--dataset-directory", type=str, required=True)
def test_ply_dataset(args):
    scans_directory = Path(args.dataset_directory)
    subject_directory_list = list(scans_directory.iterdir())
    ply_path_list = []
    for subject_directory in subject_directory_list:
        pose_directory_list = list(subject_directory.iterdir())
        for pose_directory in pose_directory_list:
            ply_path_list += list(pose_directory.glob("*.ply"))
    ply_path_list = sorted(ply_path_list)
    dataset_size_train = int(len(ply_path_list) * 0.75)

    # shuffle
    rst = random.Random(0)
    rst.shuffle(ply_path_list)
    ply_path_list_train = ply_path_list[:dataset_size_train]
    ply_path_list_test = ply_path_list[dataset_size_train:]

    dataset_train = PlyDataset(ply_path_list_train)
    dataset_test = PlyDataset(ply_path_list_test)

    minibatch_generator = MinibatchGenerator(num_samples=128 * 128,
                                             device="cuda:0")
    minibatch_iterator_train = MinibatchIterator(
        dataset_train,
        batchsize=1,
        minibatch_generator=minibatch_generator,
        drop_last=True)
    for epoch in range(100):
        for k, data in enumerate(minibatch_iterator_train):
            print(epoch, k, len(minibatch_iterator_train))


@client.command(name="test_npz_dataset")
@click.argument("--dataset-directory", type=str, required=True)
def test_npz_dataset(args):
    scans_directory = Path(args.dataset_directory)
    subject_directory_list = list(scans_directory.iterdir())
    npz_path_list = []
    for subject_directory in subject_directory_list:
        pose_directory_list = list(subject_directory.iterdir())
        for pose_directory in pose_directory_list:
            npz_path_list += list(pose_directory.glob("*.npz"))
    npz_path_list = sorted(npz_path_list)
    dataset_size_train = int(len(npz_path_list) * 0.75)

    # shuffle
    rst = random.Random(0)
    rst.shuffle(npz_path_list)
    ply_path_list_train = npz_path_list[:dataset_size_train]
    ply_path_list_test = npz_path_list[dataset_size_train:]

    dataset_train = NpzDataset(ply_path_list_train)
    dataset_test = NpzDataset(ply_path_list_test)

    minibatch_generator = MinibatchGenerator(num_samples=128 * 128,
                                             device="cuda:0")
    minibatch_iterator_train = MinibatchIterator(
        dataset_train,
        batchsize=1,
        minibatch_generator=minibatch_generator,
        drop_last=True)
    for epoch in range(100):
        for k, data in enumerate(minibatch_iterator_train):
            print(epoch, k, len(minibatch_iterator_train))


@client.command(name="plot_dfaust_quiver")
@click.argument("--dataset-directory", type=str, required=True)
def plot_dfaust_quiver(args):
    scans_directory = Path(args.dataset_directory)
    subject_directory_list = list(scans_directory.iterdir())
    npz_path_list = []
    for subject_directory in subject_directory_list:
        pose_directory_list = list(subject_directory.iterdir())
        for pose_directory in pose_directory_list:
            npz_path_list += list(pose_directory.glob("*.npz"))
    random.shuffle(npz_path_list)
    dataset = NpzDataset(npz_path_list)
    minibatch_generator = MinibatchGenerator(num_point_samples=128 * 128,
                                             with_normal=True,
                                             device="cpu")
    minibatch_iterator_train = MinibatchIterator(
        dataset,
        batchsize=1,
        minibatch_generator=minibatch_generator,
        drop_last=True)
    fig = plt.figure()
    ax = fig.gca(projection="3d")
    for data in minibatch_iterator_train:
        ax.clear()
        x = data.points[0, :, 0]
        y = data.points[0, :, 1]
        z = data.points[0, :, 2]
        u = data.normals[0, :, 0]
        v = data.normals[0, :, 1]
        w = data.normals[0, :, 2]
        # i = np.where((-0.1 < y) * (y < 0.1))
        # x = x[i]
        # y = y[i]
        # z = z[i]
        # u = u[i]
        # v = v[i]
        # w = w[i]
        ax.quiver(x,
                  y,
                  z,
                  u,
                  v,
                  w,
                  length=0.05,
                  linewidths=0.3,
                  normalize=True)
        ax.scatter3D(x, y, z, s=2, c="g")
        ax.set_xlim3d(-1, 1)
        ax.set_ylim3d(-1, 1)
        ax.set_zlim3d(-1, 1)
        plt.show()


@client.command(name="plot_shapenet_quiver")
@click.argument("--dataset-directory", type=str, required=True)
def plot_shapenet_quiver(args):
    scans_directory = Path(args.dataset_directory)
    subject_directory_list = list(scans_directory.iterdir())
    npz_path_list = []
    for category_directory in subject_directory_list:
        npz_path_list += list(category_directory.glob("*.npz"))
    random.shuffle(npz_path_list)
    dataset = NpzDataset(npz_path_list)
    minibatch_generator = MinibatchGenerator(num_point_samples=1000,
                                             with_normal=True,
                                             device="cpu")
    minibatch_iterator_train = MinibatchIterator(
        dataset,
        batchsize=1,
        minibatch_generator=minibatch_generator,
        drop_last=True)
    fig = plt.figure()
    ax = fig.gca(projection="3d")
    for data in minibatch_iterator_train:
        ax.clear()
        x = data.points[0, :, 0]
        y = data.points[0, :, 1]
        z = data.points[0, :, 2]
        u = data.normals[0, :, 0]
        v = data.normals[0, :, 1]
        w = data.normals[0, :, 2]
        # i = np.where((-0.1 < y) * (y < 0.1))
        # x = x[i]
        # y = y[i]
        # z = z[i]
        # u = u[i]
        # v = v[i]
        # w = w[i]
        ax.quiver(x,
                  y,
                  z,
                  u,
                  v,
                  w,
                  length=0.05,
                  linewidths=0.3,
                  normalize=True)
        ax.scatter3D(x, y, z, s=2, c="g")
        ax.set_xlim3d(-1, 1)
        ax.set_ylim3d(-1, 1)
        ax.set_zlim3d(-1, 1)
        plt.show()


if __name__ == "__main__":
    client()
