import math
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, List

import numpy as np
import torch
import torch.optim as optim
from skimage import measure
from torch.optim.lr_scheduler import LambdaLR

from .. import models, training
from ..classes import Hyperparameters, ModelInterface
from ..datasets.classes import Minibatch
from ..models import OccupancyNetwork as Model
from ..training import (Events, LossFunctionInterface, LossFunctionOutput,
                        Trainer)


@dataclass
class TrainingHyperparameters(Hyperparameters):
    batchsize: int = 32
    num_input_points: int = 300
    input_points_noise_stddev: float = 0.005
    num_gt_points: int = 2048
    learning_rate: float = 1e-4


@dataclass
class ModelHyperparameters(Hyperparameters):
    decoder_c_dim: int = 256
    decoder_hidden_size: int = 256
    encoder_c_dim: int = 256
    encoder_g_dim: int = 256
    encoder_hidden_dim: int = 256


def _parse_hidden_sizes_str(sizes_str: str) -> List[int]:
    assert sizes_str.startswith(",") is False
    assert sizes_str.endswith(",") is False
    sizes = sizes_str.split(",")
    sizes = [int(size) for size in sizes]
    return sizes


def setup_model(hyperparams: ModelHyperparameters):
    encoder = models.pointnet.ResnetPointnet(
        c_dim=hyperparams.encoder_c_dim,
        g_dim=hyperparams.encoder_g_dim,
        hidden_dim=hyperparams.encoder_hidden_dim)
    decoder = models.decoder.DecoderBatchNorm(
        z_dim=0,
        c_dim=hyperparams.decoder_c_dim,
        hidden_size=hyperparams.decoder_hidden_size)

    return Model(encoder=encoder, decoder=decoder)


class LossFunction(LossFunctionInterface):
    def __call__(self, model: Model, data: Minibatch):
        c = model.encode_inputs(data.input_points)
        logits = model.decode(data.gt_points, z=None, c=c).logits

        # General points
        loss = torch.nn.functional.binary_cross_entropy_with_logits(
            logits, data.gt_occupancies, reduction='none')
        loss = loss.sum(-1).mean()

        return LossFunctionOutput(loss=loss)


def log_message(run_id: str, model: Model, trainer: Trainer):
    epoch = trainer.state.epoch
    metrics = trainer.state.moving_average.metrics
    progress = epoch / trainer.state.max_epochs * 100
    elapsed_time = int(trainer.state.elapsed_seconds / 60)
    return " - ".join([
        f"[{run_id}] Epoch: {epoch:d} ({progress:.2f}%)",
        f"loss: {metrics['loss']:.4f}",
        f"#grad updates: {trainer.state.num_gradient_updates:d}",
        f"data_load_time: {trainer.state.data_load_time:.4f}",
        f"forward_time: {trainer.state.forward_time:.4f}",
        f"backward_time: {trainer.state.backward_time:.4f}",
        f"elapsed_time: {elapsed_time} min".format(),
    ])


def log_loss(trainer: Trainer, csv_path: str):
    csv_path = Path(csv_path)
    epoch = trainer.state.epoch
    metrics = trainer.state.moving_average.metrics
    metric_names = sorted(metrics.keys())
    if csv_path.is_file():
        f = open(csv_path, "a")
    else:
        f = open(csv_path, "w")
        f.write(",".join(["epoch"] + metric_names))
        f.write("\n")
    values = [str(epoch)] + [str(metrics[key]) for key in metric_names]
    f.write(",".join(values))
    f.write("\n")
    f.close()


def _split_array(array, segments):
    assert len(array) >= segments
    num_elements_per_segment = math.ceil(len(array) / segments)
    ret = []
    for _ in range(segments - 1):
        ret.append(array[:num_elements_per_segment])
        array = array[num_elements_per_segment:]
    ret.append(array)
    return ret


class MarchingCubes:
    def __init__(self, model: Model, grid_size: int, grid_min_value: float,
                 grid_max_value: float):
        self.model = model
        self.grid_size = grid_size
        assert grid_max_value > grid_min_value
        self.grid_min_value = grid_min_value
        self.grid_max_value = grid_max_value

    def __call__(self, c: torch.Tensor):
        device = self.model.get_device()

        # make prediction
        linspace_size = self.grid_max_value - self.grid_min_value
        voxel_size = linspace_size / self.grid_size
        grid = np.linspace(self.grid_min_value, self.grid_max_value,
                           self.grid_size)
        xv, yv, zv = np.stack(np.meshgrid(grid, grid, grid))
        grid = np.stack((xv, yv, zv)).reshape(
            (3, -1)).transpose().astype(np.float32)
        grid_chunk_list = _split_array(grid, self.grid_size)
        f_list = []
        with torch.no_grad():
            for chunk_index, grid_chunk in enumerate(grid_chunk_list):
                grid_chunk = torch.from_numpy(grid_chunk).to(
                    device)[None, :, :]
                p = self.model.decode(grid_chunk, z=None, c=c)
                f = -(p.probs - 0.5)
                f = f.cpu().numpy()
                f_list.append(f[0])
        f = np.concatenate(f_list)
        volume = f.reshape((self.grid_size, self.grid_size, self.grid_size))

        spacing = (voxel_size, -voxel_size, voxel_size)
        vertex_translation = (linspace_size / 2, -linspace_size / 2,
                              linspace_size / 2)
        (vertices, faces, normals,
         values) = measure.marching_cubes_lewiner(volume,
                                                  0.0,
                                                  spacing=spacing,
                                                  gradient_direction="ascent")
        vertices -= vertex_translation
        rotation_matrix = np.array(
            [[math.cos(math.pi / 2), -math.sin(math.pi / 2), 0],
             [math.sin(math.pi / 2),
              math.cos(math.pi / 2), 0], [0, 0, 1]])
        vertices = vertices @ rotation_matrix.T

        return vertices, faces
