import os
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, List

import numpy as np
import torch
import torch.optim as optim
from skimage import measure
from torch.optim.lr_scheduler import LambdaLR

from . import models, training
from .classes import Hyperparameters, ModelInterface
from .datasets.classes import MinibatchDescription
from .models import MLP
from .models import PointCompletionNetwork as Model
from .training import (Events, LossFunctionInterface, LossFunctionOutput,
                       Trainer)

if "SKIP_CPP_EXTENSION" not in os.environ:
    from torch_chamfer_distance import (ChamferDistance as
                                        ChamferDistanceFunction)


@dataclass
class TrainerState(training.TrainerState):
    alpha: float = 1

    def state_dict(self):
        state_dict = super().state_dict()
        state_dict["alpha"] = self.alpha
        return state_dict

    def load_state_dict(self, state_dict: dict):
        super().load_state_dict(state_dict)
        self.alpha = state_dict["alpha"]


@dataclass
class TrainingHyperparameters(Hyperparameters):
    batchsize: int = 32
    num_input_points: int = 3000
    learning_rate: float = 1e-4
    lr_decay_every: int = 50000
    lr_decay_factor: float = 0.7
    lr_clip: float = 1e-6


@dataclass
class ModelHyperparameters(Hyperparameters):
    encoder_1_hidden_sizes: str = "128"
    encoder_1_output_size: int = 256
    encoder_2_hidden_sizes: str = "512"
    encoder_2_output_size: int = 1024
    decoder_coarse_hidden_sizes: str = "1024,1024"
    decoder_coarse_num_output_points: int = 1024
    decoder_folding_hidden_sizes: str = "512,512"
    decoder_folding_num_output_points: int = 16384
    num_dense_gt_points: int = 16384
    num_coarse_gt_points: int = 1024


def _parse_hidden_sizes_str(sizes_str: str) -> List[int]:
    assert sizes_str.startswith(",") is False
    assert sizes_str.endswith(",") is False
    sizes = sizes_str.split(",")
    sizes = [int(size) for size in sizes]
    return sizes


def setup_model(hyperparams: ModelHyperparameters):
    local_feature_size = hyperparams.encoder_1_output_size
    global_feature_size = hyperparams.encoder_1_output_size
    point_feature_size = hyperparams.encoder_2_output_size
    encoder_1 = MLP(input_size=3,
                    output_size=local_feature_size,
                    hidden_sizes=_parse_hidden_sizes_str(
                        hyperparams.encoder_1_hidden_sizes))
    encoder_2 = MLP(input_size=local_feature_size + global_feature_size,
                    output_size=point_feature_size,
                    hidden_sizes=_parse_hidden_sizes_str(
                        hyperparams.encoder_2_hidden_sizes))

    decoder_coarse = MLP(
        input_size=point_feature_size,
        output_size=hyperparams.decoder_coarse_num_output_points * 3,
        hidden_sizes=_parse_hidden_sizes_str(
            hyperparams.decoder_coarse_hidden_sizes))

    decoder_folding = MLP(input_size=point_feature_size + 2 + 3,
                          output_size=3,
                          hidden_sizes=_parse_hidden_sizes_str(
                              hyperparams.decoder_folding_hidden_sizes))

    return Model(encoder_1=encoder_1,
                 encoder_2=encoder_2,
                 decoder_coarse=decoder_coarse,
                 decoder_folding=decoder_folding,
                 num_coarse_points=hyperparams.num_coarse_gt_points,
                 num_dense_points=hyperparams.num_dense_gt_points)


class LossFunction(LossFunctionInterface):
    def __init__(self):
        self.compute_chamfer_distance = ChamferDistanceFunction()

    def __call__(self, model: Model, data: MinibatchDescription, alpha: float):
        pred_coarse_points, pred_dense_points = model(data.input_points)

        dist_1, dist_2 = self.compute_chamfer_distance(data.gt_coarse_points,
                                                       pred_coarse_points)
        loss_coarse = torch.sqrt(dist_1).mean() + torch.sqrt(dist_2).mean()

        dist_1, dist_2 = self.compute_chamfer_distance(data.gt_dense_points,
                                                       pred_dense_points)
        loss_dense = torch.sqrt(dist_1).mean() + torch.sqrt(dist_2).mean()

        loss = loss_coarse + alpha * loss_dense

        return LossFunctionOutput(loss=loss,
                                  loss_coarse=loss_coarse.mean(),
                                  loss_dense=loss_dense.mean())


def log_message(run_id: str, model: Model, trainer: Trainer):
    epoch = trainer.state.epoch
    metrics = trainer.state.moving_average.metrics
    progress = epoch / trainer.state.max_epochs * 100
    elapsed_time = int(trainer.state.elapsed_seconds / 60)
    lr = trainer.optimizer.param_groups[0]["lr"]
    return " - ".join([
        f"[{run_id}] Epoch: {epoch:d} ({progress:.2f}%)",
        f"loss: {metrics['loss']:.4e}",
        f"loss_coars: {metrics['loss_coarse']:.4e}",
        f"loss_dense: {metrics['loss_dense']:.4e}",
        f"alpha: {trainer.state.alpha:.4f}",
        f"lr: {lr:.4e}",
        f"#grad updates: {trainer.state.num_gradient_updates:d}",
        f"elapsed_time: {elapsed_time} min".format(),
    ])


def log_loss(trainer: Trainer, csv_path: str):
    csv_path = Path(csv_path)
    epoch = trainer.state.epoch
    metrics = trainer.state.moving_average.metrics
    metric_names = sorted(metrics.keys())
    if csv_path.is_file():
        f = open(csv_path, "a")
    else:
        f = open(csv_path, "w")
        f.write(",".join(["epoch"] + metric_names))
        f.write("\n")
    values = [str(epoch)] + [str(metrics[key]) for key in metric_names]
    f.write(",".join(values))
    f.write("\n")
    f.close()


def _split_array(array, segments):
    assert len(array) >= segments
    num_elements_per_segment = math.ceil(len(array) / segments)
    ret = []
    for _ in range(segments - 1):
        ret.append(array[:num_elements_per_segment])
        array = array[num_elements_per_segment:]
    ret.append(array)
    return ret


class MarchingCubes:
    def __init__(self, model: Model, grid_size: int, grid_min_value: float,
                 grid_max_value: float):
        self.model = model
        self.grid_size = grid_size
        assert grid_max_value > grid_min_value
        self.grid_min_value = grid_min_value
        self.grid_max_value = grid_max_value

    def __call__(self, c: torch.Tensor):
        device = self.model.get_device()

        # make prediction
        linspace_size = self.grid_max_value - self.grid_min_value
        voxel_size = linspace_size / self.grid_size
        grid = np.linspace(self.grid_min_value, self.grid_max_value,
                           self.grid_size)
        xv, yv, zv = np.stack(np.meshgrid(grid, grid, grid))
        grid = np.stack((xv, yv, zv)).reshape(
            (3, -1)).transpose().astype(np.float32)
        grid_chunk_list = _split_array(grid, self.grid_size)
        f_list = []
        with torch.no_grad():
            for chunk_index, grid_chunk in enumerate(grid_chunk_list):
                grid_chunk = torch.from_numpy(grid_chunk).to(
                    device)[None, :, :]
                p = self.model.decode(grid_chunk, z=None, c=c)
                f = -(p.probs - 0.5)
                f = f.cpu().numpy()
                f_list.append(f[0])
        f = np.concatenate(f_list)
        volume = f.reshape((self.grid_size, self.grid_size, self.grid_size))

        spacing = (voxel_size, -voxel_size, voxel_size)
        vertex_translation = (linspace_size / 2, -linspace_size / 2,
                              linspace_size / 2)
        (vertices, faces, normals,
         values) = measure.marching_cubes_lewiner(volume,
                                                  0.0,
                                                  spacing=spacing,
                                                  gradient_direction="ascent")
        vertices -= vertex_translation
        rotation_matrix = np.array(
            [[math.cos(math.pi / 2), -math.sin(math.pi / 2), 0],
             [math.sin(math.pi / 2),
              math.cos(math.pi / 2), 0], [0, 0, 1]])
        vertices = vertices @ rotation_matrix.T

        return vertices, faces
