import math
from dataclasses import dataclass
from pathlib import Path
from typing import List

import numpy as np
import torch
import torch.optim as optim
from skimage import measure
from torch.optim.lr_scheduler import LambdaLR

from ..classes import Hyperparameters
from ..datasets.classes import Minibatch
from ..models import (Decoder, DecoderHyperparameters, Model,
                      ModelHyperparameters)
from ..training import (Events, LossFunctionInterface, LossFunctionOutput,
                        Trainer)


@dataclass
class TrainingHyperparameters(Hyperparameters):
    batchsize: int = 32
    num_surface_samples: int = 128 * 128
    num_eikonal_samples: int = 10 * 1000
    with_normal: bool = False
    learning_rate: float = 0.0001
    learning_rate_for_latent: float = 0.001
    decrease_lr_every: int = 5000
    num_lr_decay: int = 1
    loss_lambda: float = 0.1
    loss_tau: float = 1
    loss_alpha: float = 0.01
    eikonal_term_stddev: float = 0.02


@dataclass
class LatentOptimizationHyperparameters(Hyperparameters):
    prefix = "latent_optimization"
    num_samples: int = 30000
    iterations: int = 800
    decrease_lr_every: int = 800
    lr_decay_factor: float = 0.1
    initial_lr: float = 0.005
    num_lr_decay: int = 1
    optimizer: str = "adam"


def _parse_hidden_sizes_str(sizes_str: str) -> List[int]:
    assert sizes_str.startswith(",") is False
    assert sizes_str.endswith(",") is False
    sizes = sizes_str.split(",")
    sizes = [int(size) for size in sizes]
    return sizes


def setup_model(model_hyperparams: ModelHyperparameters,
                decoder_hyperparams: DecoderHyperparameters):
    def activation_func_module():
        if decoder_hyperparams.activation_func == "relu":
            return torch.nn.ReLU()
        if decoder_hyperparams.activation_func == "softplus":
            return torch.nn.Softplus(beta=100)
        raise NotImplementedError()

    decoder = Decoder(input_size=3 + model_hyperparams.z_dim,
                      output_size=1,
                      hidden_sizes=_parse_hidden_sizes_str(
                          decoder_hyperparams.hidden_sizes),
                      activation_func_module=activation_func_module)
    initial_z = torch.normal(mean=0,
                             std=0.01,
                             size=(model_hyperparams.num_data,
                                   model_hyperparams.z_dim),
                             dtype=torch.float32)
    auto_decoder = Model(decoder, initial_z)

    return auto_decoder


class LossFunction(LossFunctionInterface):
    def __init__(self,
                 tau: float,
                 lam: float,
                 alpha: float,
                 num_eikonal_samples: int,
                 eikonal_term_default_stddev: float = 0.02):
        self.tau = tau
        self.lam = lam
        self.alpha = alpha
        self.num_eikonal_samples = num_eikonal_samples
        self.eikonal_term_default_stddev = eikonal_term_default_stddev

    def compute(self,
                model: Model,
                points: torch.Tensor,
                z: torch.Tensor,
                kth_nn_distances: torch.Tensor = None,
                normals: torch.Tensor = None) -> LossFunctionOutput:
        points.requires_grad_(True)
        f_xi = model.decode(points, z).squeeze(dim=2)
        f_xi_grad = torch.autograd.grad(f_xi.sum(), points,
                                        create_graph=True)[0]
        # encourages f to vanish on X
        loss_f_term = torch.abs(f_xi)
        loss_normals_term = torch.scalar_tensor(0).to(loss_f_term)
        if normals is None:
            pass
        else:
            loss_normals_term = self.tau * torch.norm(f_xi_grad - normals,
                                                      dim=2)

        # mean over the points
        loss = loss_f_term + loss_normals_term
        loss = loss.mean(dim=1)

        # Eikonal term
        num_input_points = points.shape[1]
        random_indices = np.random.choice(num_input_points,
                                          size=self.num_eikonal_samples,
                                          replace=True)
        loc = points[:, random_indices]
        if kth_nn_distances is None:
            stddev = self.eikonal_term_default_stddev
        else:
            stddev = kth_nn_distances.unsqueeze(2).expand(points.shape)
            stddev = stddev[:, random_indices]
        distribution = torch.distributions.Normal(loc=loc, scale=stddev)
        sampled_point = distribution.sample()
        sampled_point.requires_grad_(True)
        f_x = model.decode(sampled_point, z).squeeze(dim=2)
        f_x_grad = torch.autograd.grad(f_x.sum(),
                                       sampled_point,
                                       create_graph=True)[0]
        eikonal_term = (torch.norm(f_x_grad, dim=2) - 1)**2
        loss_eikonal_term = self.lam * eikonal_term.mean(dim=1)

        loss += loss_eikonal_term

        # regularization term
        loss += self.alpha * torch.norm(z, dim=1)

        # mean over the batches
        loss = loss.mean()
        return LossFunctionOutput(loss=loss,
                                  f_term=loss_f_term.mean(),
                                  normals_term=loss_normals_term.mean(),
                                  eikonal_term=loss_eikonal_term.mean())

    def __call__(self, model: Model, minibatch: Minibatch):
        z = model.z(minibatch.data_index)
        output = self.compute(model, minibatch.points, z,
                              minibatch.kth_nn_distances, minibatch.normals)
        return output


def log_message(run_id: str, decoder: Decoder, trainer: Trainer):
    epoch = trainer.state.epoch
    lr = trainer.optimizer.param_groups[0]["lr"]
    metrics = trainer.state.moving_average.metrics
    progress = epoch / trainer.state.max_epochs * 100
    elapsed_time = int(trainer.state.elapsed_seconds / 60)
    return " - ".join([
        f"[{run_id}] Epoch: {epoch:d} ({progress:.2f}%)",
        f"loss: {metrics['loss']:.4f}",
        f"f: {metrics['f_term']:.4f}",
        f"normals: {metrics['normals_term']:.4f}",
        f"eikonal_term: {metrics['eikonal_term']:.4f}",
        f"lr: {lr:.4e}",
        f"#grad updates: {trainer.state.num_gradient_updates:d}",
        f"data_load_time: {trainer.state.data_load_time:.4f}",
        f"elapsed_time: {elapsed_time} min".format(),
    ])


def log_loss(trainer: Trainer, csv_path: str):
    csv_path = Path(csv_path)
    epoch = trainer.state.epoch
    metrics = trainer.state.moving_average.metrics
    metric_names = sorted(metrics.keys())
    if csv_path.is_file():
        f = open(csv_path, "a")
    else:
        f = open(csv_path, "w")
        f.write(",".join(["epoch"] + metric_names))
        f.write("\n")
    values = [str(epoch)] + [str(metrics[key]) for key in metric_names]
    f.write(",".join(values))
    f.write("\n")
    f.close()


def _split_array(array, segments):
    assert len(array) >= segments
    num_elements_per_segment = math.ceil(len(array) / segments)
    ret = []
    for _ in range(segments - 1):
        ret.append(array[:num_elements_per_segment])
        array = array[num_elements_per_segment:]
    ret.append(array)
    return ret


class LearningRateScheduler:
    def __init__(self, decrease_every: int, factor: float, num_decay: int):
        self.decrease_every = decrease_every
        self.factor = factor
        self.num_decay = num_decay
        self.count = 0
        self.last_gamma = 1

    def __call__(self, iteration):
        if iteration == 0:
            return self.last_gamma
        if self.count >= self.num_decay:
            return self.last_gamma
        if iteration % self.decrease_every == 0:
            self.last_gamma = self.last_gamma * self.factor
            self.count += 1
            return self.last_gamma
        return self.last_gamma


class LatentOptimization:
    def __init__(self,
                 model: Model,
                 loss_function: LossFunction,
                 params: LatentOptimizationHyperparameters,
                 num_chunks=1):
        self.model = model
        self.loss_function = loss_function
        self.lr = params.initial_lr
        self.max_iterations = params.iterations
        self.decrease_lr_every = params.decrease_lr_every
        self.lr_decay_factor = params.lr_decay_factor
        self.num_lr_decay = params.num_lr_decay
        self.optimizer_name = params.optimizer
        self.num_chunks = num_chunks

    def __call__(self, minibatch: Minibatch, z: torch.Tensor):
        z.requires_grad_(True)
        optimizer = optim.Adam([z], lr=self.lr)
        lr_lambda = LearningRateScheduler(self.decrease_lr_every,
                                          factor=self.lr_decay_factor,
                                          num_decay=self.num_lr_decay)
        scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
        if self.num_chunks == 1:
            z.requires_grad_(True)
            for k in range(self.max_iterations):
                output = self.loss_function.compute(self.model,
                                                    minibatch.points, z,
                                                    minibatch.kth_nn_distances,
                                                    minibatch.normals)
                optimizer.zero_grad()
                output.loss.backward()
                optimizer.step()
                scheduler.step()
                # print(k + 1, float(output.loss))
        else:
            point_chunks = _split_array(minibatch.points[0], self.num_chunks)
            if minibatch.kth_nn_distances is None:
                kth_nn_distance_chunks = [None] * self.num_chunks
            else:
                kth_nn_distance_chunks = _split_array(
                    minibatch.kth_nn_distances[0], self.num_chunks)
            if minibatch.normals is None:
                normal_chunks = [None] * self.num_chunks
            else:
                normal_chunks = _split_array(minibatch.normals[0],
                                             self.num_chunks)

            for k in range(self.max_iterations):
                mean_loss = 0
                for (point_chunk, kth_nn_distance_chunk,
                     normal_chunk) in zip(point_chunks, kth_nn_distance_chunks,
                                          normal_chunks):
                    point_chunk = point_chunk[None, :, :]
                    if kth_nn_distance_chunk is not None:
                        kth_nn_distance_chunk = kth_nn_distance_chunk[None, :]
                    if normal_chunk is not None:
                        normal_chunk = normal_chunk[None, :, :]
                    output = self.loss_function.compute(
                        self.model, point_chunk, z, kth_nn_distance_chunk,
                        normal_chunk)
                    optimizer.zero_grad()
                    output.loss.backward()
                    mean_loss += float(output.loss)
                optimizer.step()
                scheduler.step()
                mean_loss = mean_loss / self.num_chunks
                # print(k + 1, mean_loss)
        return z


class AdaptiveLatentOptimization:
    def __init__(self,
                 model: Model,
                 loss_function: LossFunction,
                 params: LatentOptimizationHyperparameters,
                 yield_every: int = None):
        self.model = model
        self.loss_function = loss_function
        self.lr = params.initial_lr
        self.max_iterations = params.iterations
        self.decrease_lr_every = params.decrease_lr_every
        self.lr_decay_factor = params.lr_decay_factor
        self.num_lr_decay = params.num_lr_decay
        self.optimizer_name = params.optimizer
        self.yield_every = yield_every

    def steps(self, minibatch: Minibatch, z: torch.Tensor):
        z.requires_grad_(True)
        yield (0, z)

        context_points = minibatch.points
        with torch.no_grad():
            pred_distance = self.model.decode(x=context_points, z=z)
            error = abs(pred_distance).mean().item()
            lr = self.lr * error
            print("error:", error)
            print("lr:", lr)

        if self.optimizer_name == "adam":
            optimizer = optim.Adam([z], lr=lr)
        elif self.optimizer_name == "sgd":
            optimizer = optim.SGD([z], lr=lr)
        else:
            raise NotImplementedError()
        print("optimizer:", self.optimizer_name)
        print("max_iterations:", self.max_iterations)

        decrease_lr_every = (self.max_iterations
                             if self.decrease_lr_every is None else
                             self.decrease_lr_every)
        lr_lambda = LearningRateScheduler(decrease_every=decrease_lr_every,
                                          factor=self.lr_decay_factor,
                                          num_decay=self.num_lr_decay)
        scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)

        prev_yield_step = 0
        for step in range(1, self.max_iterations + 1):
            output = self.loss_function.compute(self.model, minibatch.points,
                                                z, minibatch.kth_nn_distances,
                                                minibatch.normals)
            optimizer.zero_grad()
            output.loss.backward()
            optimizer.step()
            scheduler.step()

            # lr = optimizer.param_groups[0]["lr"]
            # print(lr)

            if self.yield_every is not None:
                if step % self.yield_every == 0:
                    yield (step, z)
                    prev_yield_step = step
            # print(step + 1, float(lossfunc_output.loss))
        if prev_yield_step != self.max_iterations:
            yield (self.max_iterations, z)

    def __call__(self, data: Minibatch, z: torch.Tensor):
        ret = None
        for step, z in self.steps(data, z):
            ret = z
        return ret


class MarchingCubes:
    def __init__(self, model: Model, grid_size: int, grid_min_value: float,
                 grid_max_value: float):
        self.model = model
        self.grid_size = grid_size
        assert grid_max_value > grid_min_value
        self.grid_min_value = grid_min_value
        self.grid_max_value = grid_max_value

    def __call__(self, z: torch.Tensor):
        device = self.model.get_device()

        # make prediction
        linspace_size = self.grid_max_value - self.grid_min_value
        voxel_size = linspace_size / self.grid_size
        grid = np.linspace(self.grid_min_value, self.grid_max_value,
                           self.grid_size)
        xv, yv, zv = np.stack(np.meshgrid(grid, grid, grid))
        grid = np.stack((xv, yv, zv)).reshape(
            (3, -1)).transpose().astype(np.float32)
        grid_chunk_list = _split_array(grid, self.grid_size)
        f_list = []
        with torch.no_grad():
            for chunk_index, grid_chunk in enumerate(grid_chunk_list):
                # print(chunk_index + 1, len(grid_chunk_list))
                grid_chunk = torch.from_numpy(grid_chunk).to(
                    device)[None, :, :]
                f = self.model.decode(grid_chunk, z).squeeze(dim=2)[0]
                f = f.cpu().numpy()
                f_list.append(f)
        f = np.concatenate(f_list)
        volume = f.reshape((self.grid_size, self.grid_size, self.grid_size))

        spacing = (voxel_size, -voxel_size, voxel_size)
        vertex_translation = (linspace_size / 2, -linspace_size / 2,
                              linspace_size / 2)
        (vertices, faces, normals,
         values) = measure.marching_cubes_lewiner(volume,
                                                  0.0,
                                                  spacing=spacing,
                                                  gradient_direction="ascent")
        vertices -= vertex_translation
        rotation_matrix = np.array(
            [[math.cos(math.pi / 2), -math.sin(math.pi / 2), 0],
             [math.sin(math.pi / 2),
              math.cos(math.pi / 2), 0], [0, 0, 1]])
        vertices = vertices @ rotation_matrix.T

        return vertices, faces
