import math
from dataclasses import dataclass
from pathlib import Path
from typing import List

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.optim as optim
import trimesh
from scipy.spatial import cKDTree as KDTree
from skimage import measure
from torch.optim.lr_scheduler import LambdaLR

from .. import point_cloud as pcl
from ..classes import Hyperparameters
from ..datasets.classes import Minibatch
from ..datasets.shapenet import (MeshData, MinibatchGenerator, SdfData)
from ..models import (Decoder, DecoderHyperparameters, Model,
                      ModelHyperparameters)
from ..optimizers import RAdam
from ..training import Events, LossFunctionInterface
from ..training import LossFunctionOutput as BaseLossFunctionOutput
from ..training import Trainer


@dataclass(frozen=True)
class LossFunctionOutput(BaseLossFunctionOutput):
    reg_term: torch.Tensor = None


@dataclass
class TrainingHyperparameters(Hyperparameters):
    batchsize: int = 32
    num_sdf_samples: int = 128 * 128
    learning_rate: float = 0.00032
    learning_rate_for_latent: float = 0.001
    decrease_lr_every: int = 500
    num_lr_decay: int = 1
    clamping_distance: float = 0.1
    loss_lam: float = 0.0001


@dataclass
class LatentOptimizationHyperparameters(Hyperparameters):
    prefix = "latent_optimization"
    num_samples: int = 30000
    iterations: int = 800
    decrease_lr_every: int = 800
    lr_decay_factor: float = 0.1
    initial_lr: float = 0.005
    num_lr_decay: int = 1
    optimizer: str = "adam"


def _parse_hidden_sizes_str(sizes_str: str) -> List[int]:
    assert sizes_str.startswith(",") is False
    assert sizes_str.endswith(",") is False
    sizes = sizes_str.split(",")
    sizes = [int(size) for size in sizes]
    return sizes


def setup_model(model_hyperparams: ModelHyperparameters,
                decoder_hyperparams: DecoderHyperparameters):
    def activation_func_module():
        return torch.nn.ReLU()

    decoder = Decoder(input_size=3 + model_hyperparams.z_dim,
                      output_size=1,
                      hidden_sizes=_parse_hidden_sizes_str(
                          decoder_hyperparams.hidden_sizes),
                      activation_func_module=activation_func_module,
                      weight_norm=decoder_hyperparams.weight_norm,
                      dropout_prob=decoder_hyperparams.dropout_prob)
    initial_z = torch.normal(mean=0,
                             std=0.01,
                             size=(model_hyperparams.num_data,
                                   model_hyperparams.z_dim),
                             dtype=torch.float32)
    auto_decoder = Model(decoder, initial_z)

    return auto_decoder


def setup_decoder(model_hyperparams: ModelHyperparameters,
                  decoder_hyperparams: DecoderHyperparameters):
    def activation_func_module():
        return torch.nn.ReLU()

    decoder = Decoder(input_size=3 + model_hyperparams.z_dim,
                      output_size=1,
                      hidden_sizes=_parse_hidden_sizes_str(
                          decoder_hyperparams.hidden_sizes),
                      activation_func_module=activation_func_module,
                      weight_norm=decoder_hyperparams.weight_norm,
                      dropout_prob=decoder_hyperparams.dropout_prob)

    return decoder


class LossFunction(LossFunctionInterface):
    def __init__(self, clamping_distance: float, lam: float):
        self.clamping_distance = clamping_distance
        self.lam = lam
        self.loss_l1 = torch.nn.L1Loss(reduction="sum")

    def compute(self, model: Model, points: torch.Tensor,
                distances: torch.Tensor,
                z: torch.Tensor) -> LossFunctionOutput:
        num_sdf_samples = distances.shape[1]
        gt_distances = distances.clamp(-self.clamping_distance,
                                       self.clamping_distance)

        pred_distances = model.decode(points, z).squeeze(dim=2)
        pred_distances = pred_distances.clamp(-self.clamping_distance,
                                              self.clamping_distance)

        loss = self.loss_l1(pred_distances, gt_distances) / num_sdf_samples

        # regularization term
        reg_term = self.lam * torch.norm(z, dim=1)
        reg_term = reg_term.mean()
        loss += reg_term

        return LossFunctionOutput(loss=loss, reg_term=reg_term)

    def __call__(self, model: Model, data: Minibatch):
        z = model.z(data.data_indices)
        output = self.compute(model, data.points, data.distances, z)
        return output


def log_message(run_id: str, decoder: Decoder, trainer: Trainer):
    epoch = trainer.state.epoch
    lr = trainer.optimizer.param_groups[0]["lr"]
    metrics = trainer.state.moving_average.metrics
    progress = epoch / trainer.state.max_epochs * 100
    elapsed_time = int(trainer.state.elapsed_seconds / 60)
    return " - ".join([
        f"[{run_id}] Epoch: {epoch:d} ({progress:.2f}%)",
        f"loss: {metrics['loss']:.4f}",
        f"lr: {lr:.4e}",
        f"reg: {metrics['reg_term']:.4e}",
        f"#grad updates: {trainer.state.num_gradient_updates:d}",
        f"elapsed_time: {elapsed_time} min".format(),
    ])


def log_loss(trainer: Trainer, csv_path: str):
    csv_path = Path(csv_path)
    epoch = trainer.state.epoch
    metrics = trainer.state.moving_average.metrics
    metric_names = sorted(metrics.keys())
    if csv_path.is_file():
        f = open(csv_path, "a")
    else:
        f = open(csv_path, "w")
        f.write(",".join(["epoch"] + metric_names))
        f.write("\n")
    values = [str(epoch)] + [str(metrics[key]) for key in metric_names]
    f.write(",".join(values))
    f.write("\n")
    f.close()


def _split_array(array, segments):
    assert len(array) >= segments
    num_elements_per_segment = math.ceil(len(array) / segments)
    ret = []
    for _ in range(segments - 1):
        ret.append(array[:num_elements_per_segment])
        array = array[num_elements_per_segment:]
    ret.append(array)
    return ret


class LearningRateScheduler:
    def __init__(self, decrease_every: int, factor: float, num_decay: int):
        self.decrease_every = decrease_every
        self.factor = factor
        self.num_decay = num_decay
        self.count = 0
        self.last_gamma = 1

    def __call__(self, iteration):
        if iteration == 0:
            return self.last_gamma
        if self.count >= self.num_decay:
            return self.last_gamma
        if iteration % self.decrease_every == 0:
            self.last_gamma = self.last_gamma * self.factor
            self.count += 1
            return self.last_gamma
        return self.last_gamma


class LatentOptimization:
    def __init__(self, model: Model, loss_function: LossFunction,
                 params: LatentOptimizationHyperparameters):
        self.model = model
        self.loss_function = loss_function
        self.lr = params.initial_lr
        self.max_iterations = params.iterations
        self.decrease_lr_every = params.decrease_lr_every
        self.lr_decay_factor = params.lr_decay_factor
        self.num_lr_decay = params.num_lr_decay
        self.optimizer_name = params.optimizer

    def __call__(self, data: Minibatch, z: torch.Tensor):
        z.requires_grad_(True)
        optimizer = optim.Adam([z], lr=self.lr)
        # optimizer = RAdam([z], lr=self.lr)
        decrease_lr_every = (self.max_iterations
                             if self.decrease_lr_every is None else
                             self.decrease_lr_every)
        lr_lambda = LearningRateScheduler(decrease_every=decrease_lr_every,
                                          factor=self.lr_decay_factor,
                                          num_decay=self.num_lr_decay)
        scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
        z.requires_grad_(True)
        for k in range(self.max_iterations):
            output = self.loss_function.compute(self.model, data.points,
                                                data.distances, z)
            optimizer.zero_grad()
            output.loss.backward()
            optimizer.step()
            scheduler.step()
            # print(k + 1, output.loss.item())
        return z


class MarchingCubes:
    def __init__(self, model: Model, grid_size: int, grid_min_value: float,
                 grid_max_value: float):
        self.model = model
        self.grid_size = grid_size
        assert grid_max_value > grid_min_value
        self.grid_min_value = grid_min_value
        self.grid_max_value = grid_max_value

    def __call__(self, z: torch.Tensor):
        device = self.model.get_device()

        # make prediction
        linspace_size = self.grid_max_value - self.grid_min_value
        voxel_size = linspace_size / self.grid_size
        grid = np.linspace(self.grid_min_value, self.grid_max_value,
                           self.grid_size)
        xv, yv, zv = np.stack(np.meshgrid(grid, grid, grid))
        grid = np.stack((xv, yv, zv)).reshape(
            (3, -1)).transpose().astype(np.float32)
        grid_chunk_list = _split_array(grid, self.grid_size)
        f_list = []
        with torch.no_grad():
            for chunk_index, grid_chunk in enumerate(grid_chunk_list):
                # print(chunk_index + 1, len(grid_chunk_list))
                grid_chunk = torch.from_numpy(grid_chunk).to(
                    device)[None, :, :]
                f = self.model.decode(grid_chunk, z).squeeze(dim=2)[0]
                f = f.cpu().numpy()
                f_list.append(f)
        f = np.concatenate(f_list)
        volume = f.reshape((self.grid_size, self.grid_size, self.grid_size))

        spacing = (voxel_size, -voxel_size, voxel_size)
        vertex_translation = (linspace_size / 2, -linspace_size / 2,
                              linspace_size / 2)
        (vertices, faces, normals,
         values) = measure.marching_cubes_lewiner(volume,
                                                  0.0,
                                                  spacing=spacing,
                                                  gradient_direction="ascent")
        vertices -= vertex_translation
        rotation_matrix = np.array(
            [[math.cos(math.pi / 2), -math.sin(math.pi / 2), 0],
             [math.sin(math.pi / 2),
              math.cos(math.pi / 2), 0], [0, 0, 1]])
        vertices = vertices @ rotation_matrix.T

        return vertices, faces


class AdaptiveLatentOptimization:
    def __init__(self,
                 model: Model,
                 loss_function: LossFunction,
                 params: LatentOptimizationHyperparameters,
                 yield_every: int = None):
        self.model = model
        self.loss_function = loss_function
        self.lr = params.initial_lr
        self.max_iterations = params.iterations
        self.decrease_lr_every = params.decrease_lr_every
        self.lr_decay_factor = params.lr_decay_factor
        self.num_lr_decay = params.num_lr_decay
        self.optimizer_name = params.optimizer
        self.yield_every = yield_every

    def steps(self, data: Minibatch, z: torch.Tensor):
        z.requires_grad_(True)
        yield (0, z)

        context_points = data.points
        with torch.no_grad():
            pred_distance = self.model.decode(x=context_points, z=z)
            error = abs(pred_distance).mean().item()
            lr = self.lr * error
            print("error:", error)
            print("lr:", lr)

        if self.optimizer_name == "adam":
            optimizer = optim.Adam([z], lr=lr)
        elif self.optimizer_name == "sgd":
            optimizer = optim.SGD([z], lr=lr)
        else:
            raise NotImplementedError()
        print("optimizer:", self.optimizer_name)
        print("max_iterations:", self.max_iterations)

        decrease_lr_every = (self.max_iterations
                             if self.decrease_lr_every is None else
                             self.decrease_lr_every)
        lr_lambda = LearningRateScheduler(decrease_every=decrease_lr_every,
                                          factor=self.lr_decay_factor,
                                          num_decay=self.num_lr_decay)
        scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)

        prev_yield_step = 0
        for step in range(1, self.max_iterations + 1):
            output = self.loss_function.compute(self.model,
                                                points=data.points,
                                                distances=data.distances,
                                                z=z)
            optimizer.zero_grad()
            output.loss.backward()
            optimizer.step()
            scheduler.step()

            # lr = optimizer.param_groups[0]["lr"]
            # print(lr)

            if self.yield_every is not None:
                if step % self.yield_every == 0:
                    yield (step, z)
                    prev_yield_step = step
            # print(step + 1, float(lossfunc_output.loss))
        if prev_yield_step != self.max_iterations:
            yield (self.max_iterations, z)

    def __call__(self, data: Minibatch, z: torch.Tensor):
        ret = None
        for step, z in self.steps(data, z):
            ret = z
        return ret


def compute_chamfer_distance(gt_surface_points: np.ndarray,
                             pred_surface_points: np.ndarray):
    gt_points_kd_tree = KDTree(gt_surface_points)
    distances, locations = gt_points_kd_tree.query(pred_surface_points)
    cd_term1 = np.mean(np.square(distances))

    pred_points_kd_tree = KDTree(pred_surface_points)
    distances, locations = pred_points_kd_tree.query(gt_surface_points)
    cd_term2 = np.mean(np.square(distances))

    return cd_term1 + cd_term2


def sample_surface_points(faces: np.ndarray, vertices: np.ndarray,
                          num_sdf_samples: int):
    mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
    samples, face_index = trimesh.sample.sample_surface(mesh, num_sdf_samples)
    return samples


def _normalize(vec: np.ndarray):
    return vec / np.linalg.norm(vec)


def look_at(eye: np.ndarray, center: np.ndarray, up: np.ndarray):
    eye = np.asanyarray(eye)
    center = np.asanyarray(center)
    up = np.asanyarray(up)

    z = _normalize(eye - center)
    x = np.cross(up, z)
    y = np.cross(z, x)

    x = _normalize(x)
    y = _normalize(y)

    rotation_matrix = np.array(
        [
            [x[0], y[0], z[0]],
            [x[1], y[1], z[1]],
            [x[2], y[2], z[2]],
        ],
        dtype=np.float32,
    )
    translation_vector = np.array([-x @ eye, -y @ eye, -z @ eye])

    return rotation_matrix, translation_vector


class ChamferDistance:
    def __init__(self, loss_function: LossFunction,
                 minibatch_generator: MinibatchGenerator,
                 optimize_latent_func: LatentOptimization, model: Model,
                 grid_size: int, latent_optimization_num_samples: int,
                 chamfer_distance_num_samples: int):
        self.loss_function = loss_function
        self.minibatch_generator = minibatch_generator
        self.optimize_latent = optimize_latent_func
        self.model = model
        self.grid_size = grid_size
        self.latent_optimization_num_samples = latent_optimization_num_samples
        self.chamfer_distance_num_samples = chamfer_distance_num_samples

    def __call__(self, pointcloud_data: SdfData, mesh_data: MeshData):
        gt_faces = mesh_data.vertex_indices
        gt_vertices = mesh_data.vertices

        data = self.minibatch_generator(
            [pointcloud_data],
            num_sdf_samples=self.latent_optimization_num_samples)
        device = data.points.get_device()

        initial_z = torch.normal(mean=0,
                                 std=0.01,
                                 size=(1, self.model.z_map.shape[1]),
                                 dtype=torch.float32).to(device)
        z = self.optimize_latent(data, initial_z)

        grid_max_value = 1
        grid_min_value = -1
        marching_cubes = MarchingCubes(model=self.model,
                                       grid_size=self.grid_size,
                                       grid_max_value=grid_max_value,
                                       grid_min_value=grid_min_value)
        try:
            mc_vertices, mc_faces = marching_cubes(z)
            mc_vertices = mc_vertices / pointcloud_data.scale - pointcloud_data.offset
        except ValueError:
            return -1

        gt_surface_points = sample_surface_points(
            gt_faces, gt_vertices, self.chamfer_distance_num_samples)
        pred_surface_points = sample_surface_points(
            mc_faces, mc_vertices, self.chamfer_distance_num_samples)
        chamfer_distance = compute_chamfer_distance(gt_surface_points,
                                                    pred_surface_points)

        if False:
            camera_theta = math.pi / 3
            camera_phi = -math.pi / 4
            camera_r = 1
            eye = [
                camera_r * math.sin(camera_theta) * math.cos(camera_phi),
                camera_r * math.cos(camera_theta),
                camera_r * math.sin(camera_theta) * math.sin(camera_phi),
            ]
            rotation_matrix, translation_vector = look_at(eye=eye,
                                                          center=[0, 0, 0],
                                                          up=[0, 1, 0])
            rotation_matrix = np.linalg.inv(rotation_matrix)

            fig, axes = plt.subplots(1, 2)
            points = (rotation_matrix
                      @ gt_surface_points.T).T + translation_vector[None, :]
            image = pcl.render_point_cloud(points=points,
                                           colors=np.zeros_like(points),
                                           camera_mag=1,
                                           point_size=1)
            axes[0].imshow(image)
            points = (rotation_matrix
                      @ pred_surface_points.T).T + translation_vector[None, :]
            image = pcl.render_point_cloud(points=points,
                                           colors=np.zeros_like(points),
                                           camera_mag=1,
                                           point_size=1)
            axes[1].imshow(image)
            plt.show()
            plt.close(fig)

        return chamfer_distance
