from dataclasses import dataclass
from typing import List

import numpy as np
import torch
import math
from skimage import measure

from ..classes import Hyperparameters
from ..datasets.classes import Minibatch
from ..models import Decoder, DecoderHyperparameters
from ..training import (Events, LossFunctionInterface, LossFunctionOutput,
                        Trainer)


@dataclass
class TrainingHyperparameters(Hyperparameters):
    batchsize: int = 1
    num_surface_samples: int = 128 * 128
    num_eikonal_samples: int = 500 * 1000
    with_normal: bool = False
    learning_rate: float = 1e-4
    loss_lambda: float = 0.1
    loss_tau: float = 1


def _parse_hidden_sizes_str(sizes_str: str) -> List[int]:
    assert sizes_str.startswith(",") is False
    assert sizes_str.endswith(",") is False
    sizes = sizes_str.split(",")
    sizes = [int(size) for size in sizes]
    return sizes


def setup_decoder(hyperparams: DecoderHyperparameters):
    def activation_func_module():
        return torch.nn.Softplus(beta=hyperparams.softplus_beta)

    return Decoder(input_size=3,
                   output_size=1,
                   hidden_sizes=_parse_hidden_sizes_str(
                       hyperparams.hidden_sizes),
                   activation_func_module=activation_func_module)


class LossFunction(LossFunctionInterface):
    def __init__(self,
                 tau: float,
                 lam: float,
                 num_eikonal_samples: int,
                 eikonal_term_default_stddev: float = 0.02):
        self.tau = tau
        self.lam = lam
        self.num_eikonal_samples = num_eikonal_samples
        self.eikonal_term_default_stddev = eikonal_term_default_stddev

    def __call__(self, decoder: Decoder, minibatch: Minibatch):
        minibatch.points.requires_grad_(True)
        f_xi = decoder(minibatch.points).squeeze(dim=2)
        f_xi_grad = torch.autograd.grad(f_xi.sum(),
                                        minibatch.points,
                                        create_graph=True)[0]
        # encourages f to vanish on X
        loss_f_term = torch.abs(f_xi)
        loss_normals_term = torch.scalar_tensor(0).to(loss_f_term)
        if minibatch.normals is not None:
            loss_normals_term = self.tau * torch.norm(
                f_xi_grad - minibatch.normals, dim=2)
            # sign = torch.sum(f_xi_grad * data.normals, dim=2, keepdim=True)
            # loss_normals_term = self.tau * torch.norm(
            #     f_xi_grad - sign * data.normals, dim=2)
            # loss_normals_term += (torch.norm(f_xi_grad, dim=2) - 1)**2

        # mean over the points
        loss = loss_f_term + loss_normals_term
        loss = loss.mean(dim=1)

        # Eikonal term
        num_input_points = minibatch.points.shape[1]
        random_indices = np.random.choice(num_input_points,
                                          size=self.num_eikonal_samples,
                                          replace=True)
        loc = minibatch.points[:, random_indices]
        if minibatch.kth_nn_distances is None:
            stddev = self.eikonal_term_default_stddev
        else:
            stddev = minibatch.kth_nn_distances.unsqueeze(2).expand(
                minibatch.points.shape)
            stddev = stddev[:, random_indices]
        distribution = torch.distributions.Normal(loc=loc, scale=stddev)
        sampled_point = distribution.sample()
        sampled_point.requires_grad_(True)
        f_x = decoder(sampled_point).squeeze(dim=2)
        f_x_grad = torch.autograd.grad(f_x.sum(),
                                       sampled_point,
                                       create_graph=True)[0]
        eikonal_term = (torch.norm(f_x_grad, dim=2) - 1)**2
        loss_eikonal_term = self.lam * eikonal_term.mean(dim=1)

        loss += loss_eikonal_term
        loss = loss.mean()
        return LossFunctionOutput(loss=loss,
                                  f_term=loss_f_term.mean(),
                                  normals_term=loss_normals_term.mean(),
                                  eikonal_term=loss_eikonal_term.mean())


def log_message(run_id: str, decoder: Decoder, trainer: Trainer):
    epoch = trainer.state.epoch
    metrics = trainer.state.moving_average.metrics
    progress = epoch / trainer.state.max_epochs * 100
    elapsed_time = int(trainer.state.elapsed_seconds / 60)
    return " - ".join([
        f"[{run_id}] Epoch: {epoch:d} ({progress:.2f}%)",
        f"loss: {metrics['loss']:.4f}",
        f"f: {metrics['f_term']:.4f}",
        f"normals: {metrics['normals_term']:.4f}",
        f"eikonal_term: {metrics['eikonal_term']:.4f}",
        f"#grad updates: {trainer.state.num_gradient_updates:d}",
        f"elapsed_time: {elapsed_time} min".format(),
    ])


def _split_array(array, segments):
    assert len(array) >= segments
    num_elements_per_segment = math.ceil(len(array) / segments)
    ret = []
    for _ in range(segments - 1):
        ret.append(array[:num_elements_per_segment])
        array = array[num_elements_per_segment:]
    ret.append(array)
    return ret


class MarchingCubes:
    def __init__(self, decoder: Decoder, grid_size: int, grid_min_value: float,
                 grid_max_value: float):
        self.decoder = decoder
        self.grid_size = grid_size
        assert grid_max_value > grid_min_value
        self.grid_min_value = grid_min_value
        self.grid_max_value = grid_max_value

    def __call__(self):
        device = self.decoder.get_device()

        # make prediction
        linspace_size = self.grid_max_value - self.grid_min_value
        voxel_size = linspace_size / self.grid_size
        grid = np.linspace(self.grid_min_value, self.grid_max_value,
                           self.grid_size)
        xv, yv, zv = np.stack(np.meshgrid(grid, grid, grid))
        grid = np.stack((xv, yv, zv)).reshape(
            (3, -1)).transpose().astype(np.float32)
        grid_chunk_list = _split_array(grid, self.grid_size)
        f_list = []
        with torch.no_grad():
            for chunk_index, grid_chunk in enumerate(grid_chunk_list):
                print(chunk_index + 1, len(grid_chunk_list))
                grid_chunk = torch.from_numpy(grid_chunk).to(
                    device)[None, :, :]
                f = self.decoder(grid_chunk).squeeze(dim=2)[0]
                f = f.cpu().numpy()
                f_list.append(f)
        f = np.concatenate(f_list)
        volume = f.reshape((self.grid_size, self.grid_size, self.grid_size))

        spacing = (voxel_size, -voxel_size, voxel_size)
        vertex_translation = (linspace_size / 2, -linspace_size / 2,
                              linspace_size / 2)
        (vertices, faces, normals,
         values) = measure.marching_cubes_lewiner(volume,
                                                  0.0,
                                                  spacing=spacing,
                                                  gradient_direction="ascent")
        vertices -= vertex_translation
        rotation_matrix = np.array(
            [[math.cos(math.pi / 2), -math.sin(math.pi / 2), 0],
             [math.sin(math.pi / 2),
              math.cos(math.pi / 2), 0], [0, 0, 1]])
        vertices = vertices @ rotation_matrix.T

        return vertices, faces
