from dataclasses import dataclass
from typing import List

import numpy as np
import torch
import math
from skimage import measure

from ..classes import Hyperparameters
from ..datasets.classes import Minibatch
from ..models import Decoder, DecoderHyperparameters
from ..training import (Events, LossFunctionInterface, LossFunctionOutput,
                        Trainer)


@dataclass
class TrainingHyperparameters(Hyperparameters):
    batchsize: int = 1
    num_sdf_samples: int = 128 * 128
    learning_rate: float = 1e-4
    clamping_distance: float = 0.1


def _parse_hidden_sizes_str(sizes_str: str) -> List[int]:
    assert sizes_str.startswith(",") is False
    assert sizes_str.endswith(",") is False
    sizes = sizes_str.split(",")
    sizes = [int(size) for size in sizes]
    return sizes


def define_decoder(hyperparams: DecoderHyperparameters):
    def activation_func_module():
        return torch.nn.ReLU()

    return Decoder(input_size=3,
                   output_size=1,
                   hidden_sizes=_parse_hidden_sizes_str(
                       hyperparams.hidden_sizes),
                   activation_func_module=activation_func_module,
                   weight_norm=hyperparams.weight_norm,
                   dropout_prob=hyperparams.dropout_prob)


class LossFunction(LossFunctionInterface):
    def __init__(self, clamping_distance: float):
        self.clamping_distance = clamping_distance
        self.loss_l1 = torch.nn.L1Loss(reduction="sum")

    def __call__(self, decoder: Decoder, data: Minibatch):
        num_sdf_samples = data.distances.shape[1]
        gt_distances = data.distances

        pred_distances = decoder(data.points).squeeze(dim=2)
        pred_distances = pred_distances.clamp(-self.clamping_distance,
                                              self.clamping_distance)

        loss = self.loss_l1(pred_distances, gt_distances) / num_sdf_samples
        return LossFunctionOutput(loss=loss)


def log_message(run_id: str, decoder: Decoder, trainer: Trainer):
    epoch = trainer.state.epoch
    metrics = trainer.state.moving_average.metrics
    progress = epoch / trainer.state.max_epochs * 100
    elapsed_time = int(trainer.state.elapsed_seconds / 60)
    return " - ".join([
        f"[{run_id}] Epoch: {epoch:d} ({progress:.2f}%)",
        f"loss: {metrics['loss']:.4f}",
        f"#grad updates: {trainer.state.num_gradient_updates:d}",
        f"elapsed_time: {elapsed_time} min".format(),
    ])


def _split_array(array, segments):
    assert len(array) >= segments
    num_elements_per_segment = math.ceil(len(array) / segments)
    ret = []
    for _ in range(segments - 1):
        ret.append(array[:num_elements_per_segment])
        array = array[num_elements_per_segment:]
    ret.append(array)
    return ret


class MarchingCubes:
    def __init__(self, decoder: Decoder, grid_size: int, grid_min_value: float,
                 grid_max_value: float):
        self.decoder = decoder
        self.grid_size = grid_size
        assert grid_max_value > grid_min_value
        self.grid_min_value = grid_min_value
        self.grid_max_value = grid_max_value

    def __call__(self):
        device = self.decoder.get_device()

        # make prediction
        linspace_size = self.grid_max_value - self.grid_min_value
        voxel_size = linspace_size / self.grid_size
        grid = np.linspace(self.grid_min_value, self.grid_max_value,
                           self.grid_size)
        xv, yv, zv = np.stack(np.meshgrid(grid, grid, grid))
        grid = np.stack((xv, yv, zv)).reshape(
            (3, -1)).transpose().astype(np.float32)
        grid_chunk_list = _split_array(grid, self.grid_size)
        f_list = []
        with torch.no_grad():
            for chunk_index, grid_chunk in enumerate(grid_chunk_list):
                print(chunk_index + 1, len(grid_chunk_list))
                grid_chunk = torch.from_numpy(grid_chunk).to(
                    device)[None, :, :]
                f = self.decoder(grid_chunk).squeeze(dim=2)[0]
                f = f.cpu().numpy()
                f_list.append(f)
        f = np.concatenate(f_list)
        volume = f.reshape((self.grid_size, self.grid_size, self.grid_size))

        spacing = (voxel_size, -voxel_size, voxel_size)
        vertex_translation = (linspace_size / 2, -linspace_size / 2,
                              linspace_size / 2)
        (vertices, faces, normals,
         values) = measure.marching_cubes_lewiner(volume,
                                                  0.0,
                                                  spacing=spacing,
                                                  gradient_direction="ascent")
        vertices -= vertex_translation
        rotation_matrix = np.array(
            [[math.cos(math.pi / 2), -math.sin(math.pi / 2), 0],
             [math.sin(math.pi / 2),
              math.cos(math.pi / 2), 0], [0, 0, 1]])
        vertices = vertices @ rotation_matrix.T

        return vertices, faces
