import math
from pathlib import Path
from typing import Callable, List

import numpy as np
import torch
import trimesh

from occupancy_networks import load_model
from occupancy_networks import point_cloud as pcl
from occupancy_networks.datasets.surface.uniform_sparse_sampling import (
    MinibatchGenerator)
from occupancy_networks.experiments.encoder import (MarchingCubes, Model,
                                                    ModelHyperparameters,
                                                    TrainingHyperparameters,
                                                    setup_model)

from ..args import OccNetArgs
from ..geometry import get_extrinsics


def setup_plot_function(args: OccNetArgs, device: str,
                        num_point_samples_list: List[int], grid_size: int):
    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "model.pt"
    else:
        model_path = checkpoint_directory / f"model.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    model_hyperparams = ModelHyperparameters.load_json(args_path)
    training_hyperparams = TrainingHyperparameters.load_json(args_path)

    model = setup_model(model_hyperparams)
    load_model(model_path, model)
    model.to(device)
    model.eval()

    minibatch_generator = MinibatchGenerator(
        num_input_points=training_hyperparams.num_input_points,
        num_gt_points=training_hyperparams.num_gt_points,
        noise_stddev=0,
        device=device)

    def kld_weight_func():
        return training_hyperparams.loss_kld_final_weight

    grid_max_value = 1
    grid_min_value = -1
    marching_cubes = MarchingCubes(model=model,
                                   grid_size=grid_size,
                                   grid_max_value=grid_max_value,
                                   grid_min_value=grid_min_value)

    plot_func = Plot(model=model,
                     minibatch_generator=minibatch_generator,
                     num_point_samples_list=num_point_samples_list,
                     marching_cubes_func=marching_cubes)

    return plot_func


class Plot:
    def __init__(self, model: Model, minibatch_generator: MinibatchGenerator,
                 num_point_samples_list: List[int],
                 marching_cubes_func: Callable[[torch.Tensor], None]):
        self.model = model
        self.minibatch_generator = minibatch_generator
        self.num_point_samples_list = num_point_samples_list
        self.marching_cubes = marching_cubes_func

    def __call__(self, pc_data, axes, column: int):
        rotation_matrix, translation_vector = get_extrinsics()

        def cmap_binary(points: np.ndarray):
            x = points[:, 0]
            scale = 1 / np.max(np.abs(x))
            x *= -scale
            intensity = 0.3 * (x + 1) / 2
            rgb = np.repeat(intensity[:, None], 3, axis=1)
            return rgb

        for row, num_point_samples in enumerate(self.num_point_samples_list):
            print(f"row {row+1} of {len(self.num_point_samples_list)}",
                  flush=True)
            random_state = np.random.RandomState(0)
            minibatch = self.minibatch_generator(
                [pc_data],
                num_input_points=num_point_samples,
                num_gt_points=0,
                random_state=random_state)

            # input point cloud
            input_points = minibatch.input_points[0].detach().cpu().numpy()
            input_points = input_points.reshape((-1, 3))
            # input_points = input_points / pc_data.scale - pc_data.offset
            points = (rotation_matrix @ input_points.T).T + translation_vector
            colors = cmap_binary(input_points)
            image = pcl.render_point_cloud(points,
                                           colors,
                                           camera_mag=1,
                                           point_size=10)
            axes[row][0].imshow(image)
            axes[row][0].set_xticks([])
            axes[row][0].set_yticks([])
            axes[row][0].set_ylabel(num_point_samples)
            if row == 0:
                axes[row][0].set_title("Input points", fontsize=10)

            try:
                with torch.no_grad():
                    c = self.model.encode_inputs(minibatch.input_points)
                mc_vertices, mc_faces = self.marching_cubes(c)
                print("marching_cubes done.", flush=True)
                # plot prediction
                # mc_vertices = mc_vertices / pc_data.scale - pc_data.offset
                mc_vertices = (
                    rotation_matrix @ mc_vertices.T).T + translation_vector
                mesh = trimesh.Trimesh(vertices=mc_vertices, faces=mc_faces)
                image = pcl.render_mesh(mesh, camera_mag=1)
                axes[row][column].imshow(image)
            except ValueError:
                pass
            axes[row][column].set_xticks([])
            axes[row][column].set_yticks([])
            if row == 0:
                axes[row][column].set_title("OccNet", fontsize=10)
