from pathlib import Path
from typing import Callable, List

import numpy as np
import torch
import trimesh

from deep_sdf import load_model
from deep_sdf import point_cloud as pcl
from deep_sdf.datasets.shapenet import (MinibatchGenerator, SdfData)
from deep_sdf.experiments.learning_shape_space import (
    DecoderHyperparameters, AdaptiveLatentOptimization, LossFunction,
    MarchingCubes, ModelHyperparameters, TrainingHyperparameters, setup_model,
    LatentOptimizationHyperparameters)

from ..args import DeepSdfArgs
from ..geometry import get_extrinsics


def setup_plot_function(args: DeepSdfArgs, device: str,
                        num_sdf_samples_list: List[int],
                        latent_optimization_iterations: int,
                        latent_optimization_initial_lr: int, grid_size: int):
    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "model.pt"
    else:
        model_path = checkpoint_directory / f"model.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    model_hyperparams = ModelHyperparameters.load_json(args_path)
    decoder_hyperparams = DecoderHyperparameters.load_json(args_path)
    training_hyperparams = TrainingHyperparameters.load_json(args_path)

    model = setup_model(model_hyperparams, decoder_hyperparams)
    load_model(model_path, model)
    model.to(device)
    model.eval()

    minibatch_generator = MinibatchGenerator(device=device)

    loss_function = LossFunction(
        lam=training_hyperparams.loss_lam,
        clamping_distance=training_hyperparams.clamping_distance)

    params = LatentOptimizationHyperparameters(
        num_samples=-1,
        initial_lr=latent_optimization_initial_lr,
        decrease_lr_every=latent_optimization_iterations,
        iterations=latent_optimization_iterations)
    optimize_latent = AdaptiveLatentOptimization(model=model,
                                                 loss_function=loss_function,
                                                 params=params)

    grid_max_value = 1
    grid_min_value = -1
    marching_cubes = MarchingCubes(model=model,
                                   grid_size=grid_size,
                                   grid_max_value=grid_max_value,
                                   grid_min_value=grid_min_value)

    def initial_z_func():
        return torch.normal(mean=0,
                            std=0.01,
                            size=(1, model_hyperparams.z_dim),
                            dtype=torch.float32).to(device)

    plot_func = Plot(minibatch_generator=minibatch_generator,
                     num_sdf_samples_list=num_sdf_samples_list,
                     initial_z_func=initial_z_func,
                     optimize_latent_func=optimize_latent,
                     marching_cubes_func=marching_cubes)

    return plot_func


class Plot:
    def __init__(self, minibatch_generator: MinibatchGenerator,
                 num_sdf_samples_list: List[int],
                 initial_z_func: Callable[[], torch.Tensor],
                 optimize_latent_func: Callable[[torch.Tensor], torch.Tensor],
                 marching_cubes_func: Callable[[torch.Tensor], None]):
        self.minibatch_generator = minibatch_generator
        self.num_sdf_samples_list = num_sdf_samples_list
        self.get_initial_z = initial_z_func
        self.optimize_latent = optimize_latent_func
        self.marching_cubes = marching_cubes_func

    def __call__(self, sdf_data: SdfData, axes, column: int):
        rotation_matrix, translation_vector = get_extrinsics()

        for row, num_sdf_samples in enumerate(self.num_sdf_samples_list):
            print(f"row {row+1} of {len(self.num_sdf_samples_list)}",
                  flush=True)
            random_state = np.random.RandomState(0)
            data = self.minibatch_generator([sdf_data],
                                            num_sdf_samples=num_sdf_samples,
                                            random_state=random_state)
            initial_z = self.get_initial_z()
            z = self.optimize_latent(data, initial_z)
            print("optimize_latent done", flush=True)

            # make prediction
            try:
                mc_vertices, mc_faces = self.marching_cubes(z)
                print("marching_cubes done", flush=True)
                # plot prediction
                # mc_vertices = mc_vertices / pc_data.scale - pc_data.offset
                mc_vertices = (
                    rotation_matrix @ mc_vertices.T).T + translation_vector
                mesh = trimesh.Trimesh(vertices=mc_vertices, faces=mc_faces)
                image = pcl.render_mesh(mesh, camera_mag=1)
                axes[row][column].imshow(image)
            except ValueError:
                pass
            axes[row][column].set_xticks([])
            axes[row][column].set_yticks([])
            if row == 0:
                axes[row][column].set_title("DeepSDF", fontsize=10)
