import json
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
import trimesh

from meta_learning_sdf import click, load_model, mkdir
from meta_learning_sdf import point_cloud as pcl
from meta_learning_sdf.datasets.shapenet import (MeshDataDescription,
                                                 Minibatch, MinibatchGenerator,
                                                 NpzDataset,
                                                 SurfacePointDataDescription)
from meta_learning_sdf.experiments.baseline import (
    LatentOptimization, LossFunction, MarchingCubes, Model,
    ModelHyperparameters, TrainingHyperparameters, setup_model)


def _normalize(vec: np.ndarray):
    return vec / np.linalg.norm(vec)


def _look_at(eye: np.ndarray, center: np.ndarray, up: np.ndarray):
    eye = np.asanyarray(eye)
    center = np.asanyarray(center)
    up = np.asanyarray(up)

    z = _normalize(eye - center)
    x = np.cross(up, z)
    y = np.cross(z, x)

    x = _normalize(x)
    y = _normalize(y)

    rotation_matrix = np.array(
        [
            [x[0], y[0], z[0]],
            [x[1], y[1], z[1]],
            [x[2], y[2], z[2]],
        ],
        dtype=np.float32,
    )
    translation_vector = np.array([-x @ eye, -y @ eye, -z @ eye])

    return rotation_matrix, translation_vector


@click.command()
@click.argument("--checkpoint-directory", type=str, required=True)
@click.argument("--checkpoint-epoch", type=int, default=None)
@click.argument("--dataset-directory", type=str, required=True)
@click.argument("--split-path", type=str, required=True)
def main(args):
    device = torch.device("cuda", 0)
    dataset_directory = Path(args.dataset_directory)

    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "model.pt"
    else:
        model_path = checkpoint_directory / f"model.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    model_hyperparams = ModelHyperparameters.load_json(args_path)
    model = setup_model(model_hyperparams)
    load_model(model_path, model)
    model.to(device)

    split_path = Path(args.split_path)
    assert split_path.is_file()

    with open(split_path) as f:
        split = json.load(f)

    npz_path_list_train = []
    for category in split:
        model_id_list = split[category]
        for model_id in model_id_list:
            npz_path = dataset_directory / category / model_id / "point_cloud.npz"
            if not npz_path.exists():
                continue
            npz_path_list_train.append(npz_path)
    print(len(npz_path_list_train))

    dataset = NpzDataset(npz_path_list_train)
    minibatch_generator = MinibatchGenerator(device=device)

    num_point_samples_list = [1, 10, 100, 1000, 10000, 30000]

    for data in dataset.shuffle():
        for num_point_samples in num_point_samples_list:
            minibatch = minibatch_generator([data],
                                            min_num_context=num_point_samples,
                                            max_num_context=num_point_samples,
                                            min_num_target=0,
                                            max_num_target=0,
                                            tuple_target=False)
            with torch.no_grad():
                f_n, g_n = model.encoder.compute_f_and_g(
                    minibatch.context_points)
                mu, variance = model.encoder.compute_mean_and_variance(
                    f_n, g_n)
                scale = torch.sqrt(variance)
                print(scale)
                print(torch.mean(scale))
        exit()


if __name__ == "__main__":
    main()
