import math
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, List

import numpy as np
import torch
import torch.optim as optim
from skimage import measure
from torch.optim.lr_scheduler import LambdaLR

from .. import training
from ..classes import Hyperparameters, ModelInterface
from ..datasets.classes import Minibatch
from ..models.baseline import (Decoder, DeterministicEncoder, LatentEncoder,
                               LatentEncoderSharedHiddenLayers, Model,
                               ModelHyperparameters, ModelOutputDescription)
from ..training import (Events, LossFunctionInterface, LossFunctionOutput,
                        Trainer)


@dataclass
class TrainingHyperparameters(Hyperparameters):
    batchsize: int = 32
    min_num_context: int = 1
    max_num_context: int = 5000
    min_num_target: int = -1
    max_num_target: int = 16384
    learning_rate: float = 1e-4
    decrease_lr_every: int = 50000
    num_lr_decay: int = 0
    eikonal_term_stddev: float = 0.2
    num_eikonal_samples: int = 5000
    anneal_kld_weight: bool = False
    kld_weight_annealing_epochs: int = 1000
    loss_kld_initial_weight: float = 1
    loss_kld_final_weight: float = 1
    loss_lambda: float = 1
    loss_tau: float = 1


@dataclass
class LatentOptimizationHyperparameters(Hyperparameters):
    prefix = "latent_optimization"
    num_samples: int = 30000
    iterations: int = 800
    decrease_lr_every: int = 800
    lr_decay_factor: float = 0.1
    initial_lr: float = 0.005
    num_lr_decay: int = 1
    optimizer: str = "adam"


def _parse_hidden_sizes_str(sizes_str: str) -> List[int]:
    assert sizes_str.startswith(",") is False
    assert sizes_str.endswith(",") is False
    sizes = sizes_str.split(",")
    sizes = [int(size) for size in sizes]
    return sizes


def setup_model(model_hyperparams: ModelHyperparameters,
                model_class: ModelInterface = None):
    X_dim = 3  # (x, y, z)
    Y_dim = 1  # distance

    model_class = Model if model_class is None else model_class

    # Encoder
    if model_hyperparams.deterministic_encoder:
        encoder = DeterministicEncoder(
            input_size=X_dim,
            output_size=model_hyperparams.h_dim,
            hidden_sizes=_parse_hidden_sizes_str(
                model_hyperparams.encoder_hidden_sizes))
    else:
        if model_hyperparams.encoder_share_hidden_layers:
            encoder = LatentEncoderSharedHiddenLayers(
                input_size=X_dim,
                output_size=model_hyperparams.h_dim,
                shared_layers_output_size=model_hyperparams.
                encoder_shared_layers_output_size,
                shared_layers_hidden_sizes=_parse_hidden_sizes_str(
                    model_hyperparams.encoder_shared_layers_hidden_sizes),
                gaussian_params_hidden_sizes=_parse_hidden_sizes_str(
                    model_hyperparams.encoder_gaussian_params_hidden_sizes),
                min_g=model_hyperparams.min_g,
                prior_mu=model_hyperparams.f0,
                prior_sigma=model_hyperparams.g0,
            )
        else:
            encoder = LatentEncoder(
                input_size=X_dim,
                output_size=model_hyperparams.h_dim,
                hidden_sizes=_parse_hidden_sizes_str(
                    model_hyperparams.encoder_hidden_sizes),
                min_g=model_hyperparams.min_g,
                prior_mu=model_hyperparams.f0,
                prior_sigma=model_hyperparams.g0,
            )

    # Decoder
    def activation_func_module():
        return torch.nn.Softplus(beta=model_hyperparams.decoder_softplus_beta)

    decoder = Decoder(input_size=X_dim + model_hyperparams.h_dim,
                      output_size=Y_dim,
                      hidden_sizes=_parse_hidden_sizes_str(
                          model_hyperparams.decoder_hidden_sizes),
                      activation_func_module=activation_func_module)

    return model_class(encoder=encoder, decoder=decoder)


class LossFunction(LossFunctionInterface):
    def __init__(self,
                 tau: float,
                 lam: float,
                 kld_weight_func: Callable[[], float],
                 num_eikonal_samples: int,
                 eikonal_term_stddev: float = 0.02):
        self.tau = tau
        self.lam = lam
        self.kld_weight_func = kld_weight_func
        self.num_eikonal_samples = num_eikonal_samples
        self.eikonal_term_stddev = eikonal_term_stddev

    def forward(self, model: Model,
                minibatch: Minibatch) -> ModelOutputDescription:
        context_points = minibatch.context_points_list[0]  # viewpoint 1
        return model(context_points, minibatch.target_points)

    def compute(self,
                minibatch: Minibatch,
                model: Model,
                output: ModelOutputDescription,
                num_eikonal_samples: int = None) -> LossFunctionOutput:
        context_points = minibatch.context_points_list[0]  # viewpoint 1
        if len(minibatch.context_normals_list) == 0:
            context_normals = None
        else:
            context_normals = minibatch.context_normals_list[0]  # viewpoint 1

        batchsize = context_points.shape[0]
        X_dim = model.X_dim
        num_eikonal_samples = (self.num_eikonal_samples
                               if num_eikonal_samples is None else
                               num_eikonal_samples)

        # sum over the points
        likelihood = -torch.abs(output.pred_distance).squeeze(2).sum(dim=1)

        if output.h_prior_dist is None or output.h_posterior_dist is None:
            kld = torch.scalar_tensor(0).to(likelihood)
        else:
            kld = torch.distributions.kl_divergence(
                output.h_posterior_dist, output.h_prior_dist).sum(dim=1)
            assert kld.shape == likelihood.shape
        kld = self.kld_weight_func() * kld

        # regularization
        if minibatch.target_normals is None:
            normals = context_normals
            print("normals not available")
        else:
            normals = torch.cat((context_normals, minibatch.target_normals),
                                dim=1)

        decoder_input_points = output.decoder_input_points

        if normals is None:
            normals_term = torch.full((1, 1), 0,
                                      dtype=torch.float32).to(likelihood)
        else:
            normals = normals.view((batchsize, -1, X_dim))
            f_xi = output.pred_distance
            f_xi_grad = torch.autograd.grad(f_xi.sum(),
                                            decoder_input_points,
                                            create_graph=True)[0]
            normals_term = torch.norm(f_xi_grad - normals, dim=2)

        # Eikonal term
        if num_eikonal_samples == 0:
            eikonal_term = torch.scalar_tensor(0).to(normals_term).repeat(
                batchsize)
            eikonal_term = eikonal_term.view((batchsize, 1))
        else:
            num_input_points = decoder_input_points.shape[1]
            random_indices = np.random.choice(num_input_points,
                                              size=num_eikonal_samples,
                                              replace=True)
            loc = decoder_input_points[:, random_indices]
            stddev = self.eikonal_term_stddev
            distribution = torch.distributions.Normal(loc=loc, scale=stddev)
            sampled_point = distribution.sample()
            sampled_point.requires_grad_(True)

            h = output.decoder_input_h
            h = h[:, None, :].expand(
                (h.shape[0], num_eikonal_samples, h.shape[1]))

            f_x = model.decoder(X=sampled_point, h=h)
            f_x_grad = torch.autograd.grad(f_x.sum(),
                                           sampled_point,
                                           create_graph=True)[0]
            eikonal_term = (torch.norm(f_x_grad, dim=2) - 1)**2

        loss_normals_term = self.tau * normals_term.sum(dim=1)
        loss_eikonal_term = self.lam * eikonal_term.sum(dim=1)
        loss = torch.mean(-(likelihood - kld) + loss_normals_term +
                          loss_eikonal_term)

        return LossFunctionOutput(loss=loss,
                                  likelihood=likelihood.mean(),
                                  kld=kld.mean(),
                                  normals_term=loss_normals_term.mean(),
                                  eikonal_term=loss_eikonal_term.mean())

    def __call__(self, model: Model, minibatch: Minibatch):
        output = self.forward(model, minibatch)
        return self.compute(minibatch, model, output)


def log_message(run_id: str, model: Model, trainer: Trainer,
                kld_weight: float):
    epoch = trainer.state.epoch
    lr = trainer.optimizer.param_groups[0]["lr"]
    metrics = trainer.state.moving_average.metrics
    progress = epoch / trainer.state.max_epochs * 100
    elapsed_time = int(trainer.state.elapsed_seconds / 60)
    return " - ".join([
        f"[{run_id}] Epoch: {epoch:d} ({progress:.2f}%)",
        f"loss: {metrics['loss']:.4f}",
        f"f: {metrics['likelihood']:.4f}",
        f"lr: {lr:.4e}",
        f"kld: {metrics['kld']:.4f} (weight: {kld_weight:.4f})",
        f"normals: {metrics['normals_term']:.4f}",
        f"eikonal: {metrics['eikonal_term']:.4f}",
        f"#grad updates: {trainer.state.num_gradient_updates:d}",
        # f"data_load_time: {trainer.state.data_load_time:.4f}",
        # f"forward_time: {trainer.state.forward_time:.4f}",
        # f"backward_time: {trainer.state.backward_time:.4f}",
        f"elapsed_time: {elapsed_time} min".format(),
    ])


def log_loss(trainer: Trainer, csv_path: str):
    csv_path = Path(csv_path)
    epoch = trainer.state.epoch
    metrics = trainer.state.moving_average.metrics
    metric_names = sorted(metrics.keys())
    if csv_path.is_file():
        f = open(csv_path, "a")
    else:
        f = open(csv_path, "w")
        f.write(",".join(["epoch"] + metric_names))
        f.write("\n")
    values = [str(epoch)] + [str(metrics[key]) for key in metric_names]
    f.write(",".join(values))
    f.write("\n")
    f.close()


def _split_array(array, segments):
    assert len(array) >= segments
    num_elements_per_segment = math.ceil(len(array) / segments)
    ret = []
    for _ in range(segments - 1):
        ret.append(array[:num_elements_per_segment])
        array = array[num_elements_per_segment:]
    ret.append(array)
    return ret


class LearningRateScheduler:
    def __init__(self, decrease_every: int, factor: float, num_decay: int):
        self.decrease_every = decrease_every
        self.factor = factor
        self.num_decay = num_decay
        self.count = 0
        self.last_gamma = 1

    def __call__(self, iteration):
        if iteration == 0:
            return self.last_gamma
        if self.count >= self.num_decay:
            return self.last_gamma
        if iteration % self.decrease_every == 0:
            self.last_gamma = self.last_gamma * self.factor
            self.count += 1
            return self.last_gamma
        return self.last_gamma


class LatentOptimization:
    def __init__(self,
                 model: Model,
                 loss_function: LossFunction,
                 params: LatentOptimizationHyperparameters,
                 yield_every: int = None):
        self.model = model
        self.loss_function = loss_function
        self.lr = params.initial_lr
        self.max_iterations = params.iterations
        self.decrease_lr_every = params.decrease_lr_every
        self.lr_decay_factor = params.lr_decay_factor
        self.num_lr_decay = params.num_lr_decay
        self.yield_every = yield_every

    def steps(self, data: Minibatch, h: torch.Tensor):
        h.requires_grad_(True)
        optimizer = optim.Adam([h], lr=self.lr)
        decrease_lr_every = (self.max_iterations
                             if self.decrease_lr_every is None else
                             self.decrease_lr_every)
        lr_lambda = LearningRateScheduler(decrease_every=decrease_lr_every,
                                          factor=self.lr_decay_factor,
                                          num_decay=self.num_lr_decay)
        scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)

        context_points = minibatch.context_points_list[0]  # viewpoint 1
        batchsize, num_context_tuples = context_points.shape[:2]
        X_dim = self.model.X_dim
        X = context_points.view((batchsize, -1, X_dim))
        X.requires_grad_(True)
        num_repeats = X.shape[1]
        decoder_input_h_shape = (h.shape[0], num_repeats, h.shape[1])

        yield (0, h)
        prev_yield_step = 0
        for step in range(1, self.max_iterations + 1):
            decoder_input_h = h[:, None, :].expand(decoder_input_h_shape)
            pred_distance = self.model.decoder(X=X, h=decoder_input_h)
            model_output = ModelOutputDescription(pred_distance=pred_distance,
                                                  decoder_input_points=X,
                                                  decoder_input_h=h,
                                                  h_prior_dist=None,
                                                  h_posterior_dist=None)
            lossfunc_output = self.loss_function.compute(
                minibatch, self.model, model_output)
            optimizer.zero_grad()
            lossfunc_output.loss.backward()
            optimizer.step()
            scheduler.step()

            # lr = optimizer.param_groups[0]["lr"]
            # print(lr)

            if self.yield_every is not None:
                if step % self.yield_every == 0:
                    yield (step, h)
                    prev_yield_step = step
            # print(step + 1, float(lossfunc_output.loss))
        if prev_yield_step != self.max_iterations:
            yield (self.max_iterations, h)

    def __call__(self, minibatch: Minibatch, h: torch.Tensor):
        ret = None
        for step, h in self.steps(minibatch, h):
            ret = h
        return ret


class AdaptiveLatentOptimization:
    def __init__(self,
                 model: Model,
                 loss_function: LossFunction,
                 params: LatentOptimizationHyperparameters,
                 yield_every: int = None):
        self.model = model
        self.loss_function = loss_function
        self.lr = params.initial_lr
        self.max_iterations = params.iterations
        self.decrease_lr_every = params.decrease_lr_every
        self.lr_decay_factor = params.lr_decay_factor
        self.num_lr_decay = params.num_lr_decay
        self.optimizer_name = params.optimizer
        self.yield_every = yield_every

    def steps(self, minibatch: Minibatch, h: torch.Tensor):
        h.requires_grad_(True)
        yield (0, h)

        with torch.no_grad():
            context_points = minibatch.context_points_list[0]  # viewpoint 1
            batchsize, num_context_tuples = context_points.shape[:2]
            X_dim = self.model.X_dim
            X = context_points.view((batchsize, -1, X_dim))
            X.requires_grad_(True)
            num_repeats = X.shape[1]
            decoder_input_h_shape = (h.shape[0], num_repeats, h.shape[1])
            decoder_input_h = h[:, None, :].expand(decoder_input_h_shape)
            pred_distance = self.model.decoder(X=X, h=decoder_input_h)
            error = abs(pred_distance).mean().item()
            lr = self.lr * error
            print("error:", error)
            print("lr:", lr)

        if self.optimizer_name == "adam":
            optimizer = optim.Adam([h], lr=lr)
        elif self.optimizer_name == "sgd":
            optimizer = optim.SGD([h], lr=lr)
        else:
            raise NotImplementedError()
        print("optimizer:", self.optimizer_name)

        decrease_lr_every = (self.max_iterations
                             if self.decrease_lr_every is None else
                             self.decrease_lr_every)
        lr_lambda = LearningRateScheduler(decrease_every=decrease_lr_every,
                                          factor=self.lr_decay_factor,
                                          num_decay=self.num_lr_decay)
        scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)

        prev_yield_step = 0
        for step in range(1, self.max_iterations + 1):
            decoder_input_h = h[:, None, :].expand(decoder_input_h_shape)
            pred_distance = self.model.decoder(X=X, h=decoder_input_h)
            model_output = ModelOutputDescription(pred_distance=pred_distance,
                                                  decoder_input_points=X,
                                                  decoder_input_h=h,
                                                  h_prior_dist=None,
                                                  h_posterior_dist=None)
            lossfunc_output = self.loss_function.compute(
                minibatch, self.model, model_output)
            optimizer.zero_grad()
            lossfunc_output.loss.backward()
            optimizer.step()
            scheduler.step()

            # lr = optimizer.param_groups[0]["lr"]
            # print(lr)

            if self.yield_every is not None:
                if step % self.yield_every == 0:
                    yield (step, h)
                    prev_yield_step = step
            # print(step + 1, float(lossfunc_output.loss))
        if prev_yield_step != self.max_iterations:
            yield (self.max_iterations, h)

    def __call__(self, minibatch: Minibatch, h: torch.Tensor):
        ret = None
        for step, h in self.steps(minibatch, h):
            ret = h
        return ret


class MarchingCubes:
    def __init__(self, model: Model, grid_size: int, grid_min_value: float,
                 grid_max_value: float):
        self.model = model
        self.grid_size = grid_size
        assert grid_max_value > grid_min_value
        self.grid_min_value = grid_min_value
        self.grid_max_value = grid_max_value

    def __call__(self, h: torch.Tensor):
        device = self.model.get_device()

        # make prediction
        linspace_size = self.grid_max_value - self.grid_min_value
        voxel_size = linspace_size / self.grid_size
        grid = np.linspace(self.grid_min_value, self.grid_max_value,
                           self.grid_size)
        xv, yv, zv = np.stack(np.meshgrid(grid, grid, grid))
        grid = np.stack((xv, yv, zv)).reshape(
            (3, -1)).transpose().astype(np.float32)
        grid_chunk_list = _split_array(grid, self.grid_size)
        f_list = []
        with torch.no_grad():
            for chunk_index, grid_chunk in enumerate(grid_chunk_list):
                # print(chunk_index + 1, len(grid_chunk_list))
                grid_chunk = torch.from_numpy(grid_chunk).to(
                    device)[None, :, :]
                repeats = grid_chunk.shape[1]
                _h = h[:, None, :].expand((h.shape[0], repeats, h.shape[1]))
                f = self.model.decoder(X=grid_chunk, h=_h).squeeze(dim=2)[0]
                f = f.cpu().numpy()
                f_list.append(f)
        f = np.concatenate(f_list)
        volume = f.reshape((self.grid_size, self.grid_size, self.grid_size))

        spacing = (voxel_size, -voxel_size, voxel_size)
        vertex_translation = (linspace_size / 2, -linspace_size / 2,
                              linspace_size / 2)
        (vertices, faces, normals,
         values) = measure.marching_cubes_lewiner(volume,
                                                  0.0,
                                                  spacing=spacing,
                                                  gradient_direction="ascent")
        vertices -= vertex_translation
        rotation_matrix = np.array(
            [[math.cos(math.pi / 2), -math.sin(math.pi / 2), 0],
             [math.sin(math.pi / 2),
              math.cos(math.pi / 2), 0], [0, 0, 1]])
        vertices = vertices @ rotation_matrix.T

        return vertices, faces


@dataclass
class TrainerState(training.TrainerState):
    kld_weight: float = 1

    def state_dict(self):
        state_dict = super().state_dict()
        state_dict["kld_weight"] = self.kld_weight
        return state_dict

    def load_state_dict(self, state_dict: dict):
        super().load_state_dict(state_dict)
        self.kld_weight = state_dict["kld_weight"]
