from pathlib import Path
from typing import Callable, List

import numpy as np
import torch
import trimesh

from meta_learning_sdf import load_model
from meta_learning_sdf import point_cloud as pcl
from meta_learning_sdf.datasets.shapenet.uniform_sparse_sampling import (
    MeshData, MinibatchGenerator, UniformPointCloudData)
from meta_learning_sdf.experiments.baseline import (
    AdaptiveLatentOptimization, LossFunction, MarchingCubes, Model,
    ModelHyperparameters, TrainingHyperparameters, setup_model,
    LatentOptimizationHyperparameters)

from ..args import ProposedMethodArgs
from ..geometry import get_extrinsics


def setup_plot_function(args: ProposedMethodArgs, device: str,
                        num_point_samples_list: List[int],
                        latent_optimization_iterations: int,
                        latent_optimization_initial_lr: int, grid_size: int):
    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "model.pt"
    else:
        model_path = checkpoint_directory / f"model.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    model_hyperparams = ModelHyperparameters.load_json(args_path)
    training_hyperparams = TrainingHyperparameters.load_json(args_path)

    model = setup_model(model_hyperparams)
    load_model(model_path, model)
    model.to(device)
    model.eval()

    minibatch_generator = MinibatchGenerator(device=device)

    def kld_weight_func():
        return training_hyperparams.loss_kld_final_weight

    loss_function = LossFunction(
        tau=training_hyperparams.loss_tau,
        lam=training_hyperparams.loss_lambda,
        kld_weight_func=kld_weight_func,
        num_eikonal_samples=training_hyperparams.num_eikonal_samples,
        eikonal_term_stddev=training_hyperparams.eikonal_term_stddev)

    params = LatentOptimizationHyperparameters(
        num_samples=-1,
        initial_lr=latent_optimization_initial_lr,
        decrease_lr_every=latent_optimization_iterations,
        iterations=latent_optimization_iterations)
    optimize_latent = AdaptiveLatentOptimization(model=model,
                                                 loss_function=loss_function,
                                                 params=params)

    grid_max_value = 1
    grid_min_value = -1
    marching_cubes = MarchingCubes(model=model,
                                   grid_size=grid_size,
                                   grid_max_value=grid_max_value,
                                   grid_min_value=grid_min_value)

    def initial_z_func():
        return torch.normal(mean=0,
                            std=0.01,
                            size=(1, model_hyperparams.z_dim),
                            dtype=torch.float32).to(device)

    plot_func = Plot(model=model,
                     minibatch_generator=minibatch_generator,
                     num_point_samples_list=num_point_samples_list,
                     initial_z_func=initial_z_func,
                     optimize_latent_func=optimize_latent,
                     marching_cubes_func=marching_cubes)

    return plot_func


class Plot:
    def __init__(self,
                 model: Model,
                 minibatch_generator: MinibatchGenerator,
                 num_point_samples_list: List[int],
                 initial_z_func: Callable[[], torch.Tensor],
                 optimize_latent_func: Callable[[torch.Tensor], torch.Tensor],
                 marching_cubes_func: Callable[[torch.Tensor], None],
                 label: str = "Proposed Method"):
        self.model = model
        self.minibatch_generator = minibatch_generator
        self.num_point_samples_list = num_point_samples_list
        self.get_initial_z = initial_z_func
        self.optimize_latent = optimize_latent_func
        self.marching_cubes = marching_cubes_func
        self.label = label

    def __call__(self, pc_data: UniformPointCloudData, axes, column: int):
        rotation_matrix, translation_vector = get_extrinsics()

        def cmap_binary(points: np.ndarray):
            x = points[:, 0]
            scale = 1 / np.max(np.abs(x))
            x *= -scale
            intensity = 0.3 * (x + 1) / 2
            rgb = np.repeat(intensity[:, None], 3, axis=1)
            return rgb

        for row, num_point_samples in enumerate(self.num_point_samples_list):
            print(f"row {row+1} of {len(self.num_point_samples_list)}",
                  flush=True)
            random_state = np.random.RandomState(0)
            minibatch = self.minibatch_generator(
                [pc_data],
                min_num_context=num_point_samples,
                max_num_context=num_point_samples,
                min_num_target=0,
                max_num_target=0,
                tuple_target=False,
                random_state=random_state)
            minibatch.context_normals_list = []

            # input point cloud
            input_points = minibatch.context_points_list[0][0].detach().cpu(
            ).numpy()
            input_points = input_points.reshape((-1, 3))
            # input_points = input_points / pc_data.scale - pc_data.offset
            points = (rotation_matrix @ input_points.T).T + translation_vector
            colors = cmap_binary(input_points)
            image = pcl.render_point_cloud(points,
                                           colors,
                                           camera_mag=1,
                                           point_size=10)
            axes[row][0].imshow(image)
            axes[row][0].set_xticks([])
            axes[row][0].set_yticks([])
            # axes[row][0].set_ylabel(num_point_samples)
            axes[row][0].text(0, 320, str(num_point_samples), rotation=90)

            if row == 0:
                axes[row][0].set_title("Input points", fontsize=10)

            # make prediction
            with torch.no_grad():
                h, h_dist = self.model.encoder(
                    minibatch.context_points_list[0])
            h = self.optimize_latent(minibatch, h)
            print("optimize_latent done.", flush=True)

            try:
                mc_vertices, mc_faces = self.marching_cubes(h)
                print("marching_cubes done.", flush=True)
                # plot prediction
                # mc_vertices = mc_vertices / pc_data.scale - pc_data.offset
                mc_vertices = (
                    rotation_matrix @ mc_vertices.T).T + translation_vector
                mesh = trimesh.Trimesh(vertices=mc_vertices, faces=mc_faces)
                image = pcl.render_mesh(mesh, camera_mag=1)
                axes[row][column].imshow(image)
            except ValueError:
                pass
            axes[row][column].set_xticks([])
            axes[row][column].set_yticks([])
            if row == 0:
                axes[row][column].set_title(self.label, fontsize=10)
