import math
from pathlib import Path
from typing import Callable, List

import numpy as np
import torch
import trimesh

from pcn import load_model
from pcn import point_cloud as pcl
from pcn.datasets.uniform_sparse_sampling import (MinibatchGenerator)
from pcn.experiment import (Model, ModelHyperparameters,
                            TrainingHyperparameters, setup_model)
from pcn.point_cloud import render_point_cloud

from ..args import PcnArgs
from ..geometry import get_extrinsics


def setup_plot_function(args: PcnArgs, device: str,
                        num_point_samples_list: List[int]):
    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "model.pt"
    else:
        model_path = checkpoint_directory / f"model.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    model_hyperparams = ModelHyperparameters.load_json(args_path)
    training_hyperparams = TrainingHyperparameters.load_json(args_path)

    model = setup_model(model_hyperparams)
    load_model(model_path, model)
    model.to(device)
    model.eval()

    minibatch_generator = MinibatchGenerator(device=device)

    plot_func = Plot(model=model,
                     minibatch_generator=minibatch_generator,
                     num_point_samples_list=num_point_samples_list)

    return plot_func


def _cmap_binary(points: np.ndarray):
    points = points.copy()
    x = points[:, 0]
    scale = 1 / np.max(np.abs(x))
    x *= -scale
    intensity = 0.3 * (x + 1) / 2
    rgb = np.repeat(intensity[:, None], 3, axis=1)
    return rgb


class Plot:
    def __init__(self, model: Model, minibatch_generator: MinibatchGenerator,
                 num_point_samples_list: List[int]):
        self.model = model
        self.minibatch_generator = minibatch_generator
        self.num_point_samples_list = num_point_samples_list

    def __call__(self, pc_data, axes, column: int):
        rotation_matrix, translation_vector = get_extrinsics()

        for row, num_point_samples in enumerate(self.num_point_samples_list):
            print(f"row {row+1} of {len(self.num_point_samples_list)}",
                  flush=True)
            random_state = np.random.RandomState(0)
            minibatch = self.minibatch_generator(
                [pc_data],
                num_input_points=num_point_samples,
                random_state=random_state)
            pred_coarse_points, pred_dense_points = self.model(
                minibatch.input_points)

            pred_coarse_points = pred_coarse_points[0].detach().cpu().numpy()
            pred_coarse_points = pred_coarse_points.reshape((-1, 3))
            points = (
                rotation_matrix @ pred_coarse_points.T).T + translation_vector
            colors = _cmap_binary(pred_coarse_points)
            image = render_point_cloud(points,
                                       colors,
                                       camera_mag=1,
                                       point_size=6)

            axes[row][column].imshow(image)
            axes[row][column].set_xticks([])
            axes[row][column].set_yticks([])
            if row == 0:
                axes[row][column].set_title("PCN", fontsize=10)
