import math
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, List

import numpy as np
import torch
import torch.optim as optim
from skimage import measure
from torch.optim.lr_scheduler import LambdaLR

from .. import models, training
from ..classes import Hyperparameters, ModelInterface
from ..datasets.classes import Minibatch
from ..models import OccupancyNetwork as Model
from ..training import Events, LossFunctionInterface, Trainer


@dataclass(frozen=True)
class LossFunctionOutput(training.LossFunctionOutput):
    kl: torch.Tensor


@dataclass
class TrainingHyperparameters(Hyperparameters):
    batchsize: int = 32
    num_input_points: int = 300
    input_points_noise_stddev: float = 0.005
    num_gt_points: int = 2048
    learning_rate: float = 1e-4


@dataclass
class ModelHyperparameters(Hyperparameters):
    z_dim: int = 256
    decoder_c_dim: int = 0
    decoder_hidden_size: int = 512
    encoder_c_dim: int = 0
    encoder_g_dim: int = 256


def _parse_hidden_sizes_str(sizes_str: str) -> List[int]:
    assert sizes_str.startswith(",") is False
    assert sizes_str.endswith(",") is False
    sizes = sizes_str.split(",")
    sizes = [int(size) for size in sizes]
    return sizes


def setup_model(hyperparams: ModelHyperparameters):
    encoder = models.encoder_latent.Encoder(z_dim=hyperparams.z_dim,
                                            c_dim=hyperparams.encoder_c_dim,
                                            g_dim=hyperparams.encoder_g_dim)
    decoder = models.decoder.DecoderBatchNorm(
        z_dim=hyperparams.z_dim,
        c_dim=hyperparams.decoder_c_dim,
        hidden_size=hyperparams.decoder_hidden_size)

    return Model(encoder_latent=encoder, decoder=decoder)


class LossFunction(LossFunctionInterface):
    def __call__(self, model: Model, data: Minibatch):
        device = model.get_device()

        q_z = model.infer_z(data.input_points, data.input_occupancies, c=None)
        z = q_z.rsample()

        logits = model.decode(data.gt_points, z, c=None).logits
        a = torch.isnan(logits)
        a = torch.where(a == True)
        print("# of nans", a)

        # KL-divergence
        p0 = torch.distributions.Normal(
            torch.tensor([0.0]).to(device),
            torch.tensor([1.0]).to(device))
        kl = torch.distributions.kl_divergence(q_z, p0).sum(dim=-1)
        loss = kl.mean()

        # General points
        loss_i = torch.nn.functional.binary_cross_entropy_with_logits(
            logits, data.gt_occupancies, reduction='none')
        loss = loss + loss_i.sum(-1).mean()

        return LossFunctionOutput(loss=loss, kl=kl.mean())


def log_message(run_id: str, model: Model, trainer: Trainer):
    epoch = trainer.state.epoch
    metrics = trainer.state.moving_average.metrics
    progress = epoch / trainer.state.max_epochs * 100
    elapsed_time = int(trainer.state.elapsed_seconds / 60)
    return " - ".join([
        f"[{run_id}] Epoch: {epoch:d} ({progress:.2f}%)",
        f"loss: {metrics['loss']:.4f}",
        f"#grad updates: {trainer.state.num_gradient_updates:d}",
        f"data_load_time: {trainer.state.data_load_time:.4f}",
        f"forward_time: {trainer.state.forward_time:.4f}",
        f"backward_time: {trainer.state.backward_time:.4f}",
        f"elapsed_time: {elapsed_time} min".format(),
    ])


def log_loss(trainer: Trainer, csv_path: str):
    csv_path = Path(csv_path)
    epoch = trainer.state.epoch
    metrics = trainer.state.moving_average.metrics
    metric_names = sorted(metrics.keys())
    if csv_path.is_file():
        f = open(csv_path, "a")
    else:
        f = open(csv_path, "w")
        f.write(",".join(["epoch"] + metric_names))
        f.write("\n")
    values = [str(epoch)] + [str(metrics[key]) for key in metric_names]
    f.write(",".join(values))
    f.write("\n")
    f.close()


def _split_array(array, segments):
    assert len(array) >= segments
    num_elements_per_segment = math.ceil(len(array) / segments)
    ret = []
    for _ in range(segments - 1):
        ret.append(array[:num_elements_per_segment])
        array = array[num_elements_per_segment:]
    ret.append(array)
    return ret


class MarchingCubes:
    def __init__(self, model: Model, grid_size: int, grid_min_value: float,
                 grid_max_value: float):
        self.model = model
        self.grid_size = grid_size
        assert grid_max_value > grid_min_value
        self.grid_min_value = grid_min_value
        self.grid_max_value = grid_max_value

    def __call__(self, z: torch.Tensor):
        device = self.model.get_device()

        # make prediction
        linspace_size = self.grid_max_value - self.grid_min_value
        voxel_size = linspace_size / self.grid_size
        grid = np.linspace(self.grid_min_value, self.grid_max_value,
                           self.grid_size)
        xv, yv, zv = np.stack(np.meshgrid(grid, grid, grid))
        grid = np.stack((xv, yv, zv)).reshape(
            (3, -1)).transpose().astype(np.float32)
        grid_chunk_list = _split_array(grid, self.grid_size)
        f_list = []
        with torch.no_grad():
            for chunk_index, grid_chunk in enumerate(grid_chunk_list):
                grid_chunk = torch.from_numpy(grid_chunk).to(
                    device)[None, :, :]
                p = self.model.decode(grid_chunk, z, c=None)
                f = -(p.probs - 0.5)
                f = f.cpu().numpy()
                f_list.append(f[0])
        f = np.concatenate(f_list)
        volume = f.reshape((self.grid_size, self.grid_size, self.grid_size))

        spacing = (voxel_size, -voxel_size, voxel_size)
        vertex_translation = (linspace_size / 2, -linspace_size / 2,
                              linspace_size / 2)
        (vertices, faces, normals,
         values) = measure.marching_cubes_lewiner(volume,
                                                  0.0,
                                                  spacing=spacing,
                                                  gradient_direction="ascent")
        vertices -= vertex_translation
        rotation_matrix = np.array(
            [[math.cos(math.pi / 2), -math.sin(math.pi / 2), 0],
             [math.sin(math.pi / 2),
              math.cos(math.pi / 2), 0], [0, 0, 1]])
        vertices = vertices @ rotation_matrix.T

        return vertices, faces
