# -*- coding: utf-8 -*-

import os
import sys
import h5py
import math
import numpy as np
import torch as to
import torch.distributed as dist
import matplotlib.pyplot as plt
# from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import mean_squared_error as mse
from matplotlib.ticker import MaxNLocator
from typing import Union, Tuple, Iterable, Dict, Callable, Any, List
import tvem
from tvem.utils import H5Logger
from tvem.utils.parallel import broadcast, init_processes, pprint
from tvem.utils.param_init import (
    init_W_data_mean,
    init_sigma_default,
    init_pies_default,
)
from mloutil.viz import matrix_as_image, matrix_columns_as_image_subplots
from mloutil.prepost.overlapping_patches import OverlappingPatches, MultiDimOverlappingPatches
from params import defaults


def init_processes_and_get_rank() -> int:
    comm_rank = 0
    if tvem.get_run_policy() == "mpi":
        init_processes()
        comm_rank = dist.get_rank()
    return comm_rank


def get_comm_rank() -> int:
    return dist.get_rank() if dist.is_initialized() else 0


def get_comm_world_size() -> int:
    return dist.get_world_size() if dist.is_initialized() else 1


def get_dataset_file_name(directory: str) -> str:
    return f"{directory}/train_dataset.h5"


def _read_image_h5py(file: str, h5key: str, precision: to.dtype) -> to.Tensor:
    with h5py.File(file, "r") as f:
        img_np = f[h5key][...]
        if img_np.dtype == np.uint16:
            img_np = img_np.astype(np.int16)
        img = to.from_numpy(img_np).type(precision)
    return img


def _read_image_imageio(file: str, precision: to.dtype) -> to.Tensor:
    from imageio import imread

    return to.from_numpy(imread(file)).type(precision)


def store_as_h5(to_store_dict: Dict[str, to.Tensor], output_file: str, verbose: bool = True):
    rank = get_comm_rank()
    if rank == 0:
        os.makedirs(os.path.split(output_file)[0], exist_ok=True)
        with h5py.File(output_file, "w") as f:
            for key, val in to_store_dict.items():
                f.create_dataset(key, data=val if isinstance(val, float) else val.detach().cpu())
        if verbose:
            print(f"Wrote {output_file}")


def get_no_subplots_yx_gfs(H: int) -> Tuple[int, int]:
    if H <= 25:
        return (1, H)
    elif H % 20 == 0:
        return (H // 20, 20)
    elif H % 25 == 0:
        return (H // 25, 25)
    else:
        return (math.ceil(H / 25), 25)


def set_pixels_to_nan(image: to.Tensor, percentage: int) -> to.Tensor:
    set_to_nan = to.rand_like(image) < percentage / 100.0
    image_with_nans = image.clone().detach()
    image_with_nans[set_to_nan] = float("nan")
    print(f"Randomly set {percentage} % of pixels to nan")
    return image_with_nans


def read_image_set_to_nan_get_patches_and_write_out(
    image_file: str,
    patch_height: int,
    patch_width: int,
    patch_shift: int,
    precision: to.dtype,
    dataset_file: str,
    incomplete_percentage: int,
) -> Tuple[Union[OverlappingPatches, MultiDimOverlappingPatches], Callable, str, str, H5Logger]:
    """1. Read image from file, 2. add white Gaussian noise, 3. segment into overlapping
    patches using the OverlappingPatches routine from mloutil.prepost and 4. write tensor
    containing the image patches to an H5 file.

    :param image_file: Full path to image file
    :param patch_height: Patch height in pixels
    :param patch_width: Patch width in pixels
    :param patch_shift: Shift between adjacent patches in pixels
    :param remove_percentage: Percentage of pixels to remove (will be set to nan)
    :param precision: torch.dtype of obtained tensors
    :param dataset_file: Full path to H5 file to be written containing the image patches (will be
                         used as the training dataset)
    :param incomplete_percentage: Percentage of image pixels that will be set to nan to simulate
                                  incomplete data

    :return: image patches as instance of mloutil.prepost.OverlappingPatches or
             mloutil.prepost.MultiDimOverlappingPatches, function to compute evaluation metric for
             reconstructed image given original image
    """
    output_directory = os.path.split(dataset_file)[0]
    ext = os.path.split(image_file)[1].split(".")[1]
    if ext == "h5":
        original = to.flip(_read_image_h5py(image_file, "data", precision), [0,])
    else:
        try:
            original = _read_image_imageio(image_file, precision)
        except Exception:
            raise RuntimeError(f"cannot read {image_file}")
    print(f"Opened {image_file}")
    ndim = original.dim()
    assert ndim == 2 or ndim == 3, "image must be 2- or 3-dim."

    training_image = set_pixels_to_nan(original, incomplete_percentage)

    matrix_as_image(
        cdata=training_image,
        dpi=defaults.dpi,
        colormap="jet", #"gray" if ndim == 2 else "jet",
        figure_name=f"train_image_{incomplete_percentage}missing",
        output_directory=output_directory,
    )
    plt.close()

    OVP = OverlappingPatches if ndim == 2 else MultiDimOverlappingPatches
    patches = OVP(
        training_image,
        patch_height=patch_height,
        patch_width=patch_width,
        patch_shift=patch_shift,
    )
    store_as_h5({"data": patches.get().t().detach().cpu()}, output_file=dataset_file)
    if ndim == 3:
        _kwargs = {"concatenate": False}
        assert original.shape[2] == 3, "Expect RGB image"
    else:
        _kwargs = {}
    nonconcatenated_patches = patches.get(**_kwargs)
    no_subplots_yx_datapoints = defaults.no_subplots_yx_datapoints
    no_samples_to_show = math.prod(no_subplots_yx_datapoints)
    no_subplots_yx_datapoints = get_no_subplots_yx_gfs(no_samples_to_show)
    matrix_columns_as_image_subplots(
        cdata=nonconcatenated_patches[
            :, to.randperm(nonconcatenated_patches.shape[1])[:no_samples_to_show]
        ],
        subplots_yx=no_subplots_yx_datapoints,
        height_width=(patch_height, patch_width),
        figsize=(
            defaults.figsize_single_gf[0] * no_subplots_yx_datapoints[1],
            defaults.figsize_single_gf[1] * no_subplots_yx_datapoints[0],
        ),
        global_clim=False,
        cmap="jet", #if ndim == 3 else "gray",
        sym_clim=False,
        figure_name="samples",
        output_directory=output_directory,
    )
    plt.close()

    def eval_metric_fn(reconstructed: to.Tensor):
        assert reconstructed.shape == original.shape
        # return psnr(
        #     original.detach().cpu().numpy(), reconstructed.detach().cpu().numpy(), data_range=255
        # )
        return mse(original.detach().cpu().numpy(), reconstructed.detach().cpu().numpy())

    # eval_metric_name = "psnr"
    # eval_metric_label = r"$\mathrm{PSNR}\,/\,dB$"
    eval_metric_name = "mse"
    eval_metric_label = r"$\mathrm{MSE}$"

    reco_file = f"{output_directory}/reco.h5"
    reco_logger = H5Logger(reco_file)
    reco_logger.set(
        incomplete_image=training_image,
        incomplete_percentage=to.tensor([incomplete_percentage]),
    )
    reco_logger.write()
    print(f"Appended incomplete_image and incomplete_percentage to {reco_file}")

    return patches, eval_metric_fn, eval_metric_name, eval_metric_label, reco_logger


def get_merge_epochs(reco_every: int, no_epochs: int) -> to.Tensor:
    return to.unique(
        to.cat(
            (
                to.arange(start=0, end=no_epochs, step=reco_every),
                to.tensor([no_epochs - 1]),
            )
        )
    )


def get_log_blacklist(blacklist: Iterable[str], theta: Dict[str, to.Tensor]):
    b = tuple(x for x in blacklist if x != "THETA")
    b += tuple(theta.keys()) if "THETA" in blacklist else ()
    return b


def get_theta_init_for_bsc(
    data_file: str,
    no_gen_fields: int,
    device: to.device,
    precision: to.dtype,
) -> Tuple[to.Tensor, to.Tensor, to.Tensor]:
    """Get parameters to initialize BSC model. Columns of W will be set to noisy data mean,
    sigma2 will be set to data variance and priors will equally be set to 2./H.
    :param data_file: Name of H5 file containing data (assumed to be located at node `data`)
    :param no_gen_fields: Number of latent dimensions of BSC model (number of generative
    fields/number of columns of BSC's W parameter)
    :param device: torch device to store parameters
    :param precision: torch dtype of parameters
    """
    comm_rank = get_comm_rank()
    if comm_rank == 0:
        with h5py.File(data_file, "r") as f:
            Y = to.tensor(f["data"][...], dtype=precision, device=device)
        D = to.tensor(Y.shape[1], dtype=precision, device=device)
    else:
        D = to.zeros((1,), dtype=precision, device=device)
    broadcast(D)

    if comm_rank == 0:
        Y_np = Y.numpy()
        Y_mean = to.from_numpy(np.nanmean(Y_np, axis=0))
        Y_var = to.from_numpy(np.nanvar(Y_np, axis=0))
        W_init = init_W_data_mean(
            Y_mean,
            Y_var,
            H=no_gen_fields,
            std_factor=0.01,
            dtype=precision,
            device=device,
        ).contiguous()
        sigma2_init = init_sigma_default(Y_var, dtype=precision, device=device) ** 2
        pies_init = init_pies_default(
            no_gen_fields, crowdedness=2.0, dtype=precision, device=device
        )
    else:
        W_init = to.zeros((int(D.item()), no_gen_fields), dtype=precision, device=device)
        sigma2_init = to.zeros((1,), dtype=precision, device=device)
        pies_init = to.zeros((no_gen_fields,), dtype=precision, device=device)
    broadcast(W_init)
    broadcast(sigma2_init)
    broadcast(pies_init)
    return W_init, sigma2_init, pies_init


class stdout_logger(object):
    def __init__(
        self,
        txt_file: str,
    ):
        self.terminal = sys.stdout
        os.makedirs(os.path.split(txt_file)[0], exist_ok=True)
        self.log = open(txt_file, "a")

    def write(self, message):
        self.terminal.write(message)
        self.terminal.flush()
        self.log.write(message)
        self.log.flush()

    def flush(self):
        # this flush method is needed for python 3 compatibility.
        # this handles the flush command by doing nothing.
        # you might want to specify some extra behavior here.
        pass


def get_cycliclr_half_step_size(
    no_data_points: int,
    batch_size: int,
    epochs_per_half_cycle: int,
    no_epochs: int,
) -> float:
    cycliclr_half_step_size = np.ceil(no_data_points / batch_size) * epochs_per_half_cycle
    return cycliclr_half_step_size


class StoppingCriterion:
    """Monitor PSNR values and stop iterating if difference between two subsequent PSNRs is\
    negative in X out oy Y steps.

    :param x: Tolerate at most x subsequent steps with negative PSNR difference.
    :param y: Tolerate within y steps.
    """

    def __init__(
        self,
        x: int,
        y: int = None,
    ):
        if y is None:
            self.y = x
        else:
            assert not y < x, "y must be greater/equal x"
            self.y = y
        self.x = x
        self.x_counter, self.y_counter = 0, 0
        self.old_value = float("nan")
        self.break_now = False

    def check(self, new_value: Union[None, float]) -> bool:
        """Check a single epoch

        :param new_value: Current new_value
        """
        if new_value is None:
            return False

        value_has_changed = new_value != self.old_value
        self.old_value = new_value

        self.x_counter += 1
        if new_value < 0:
            pprint("Negative new_value")
            if self.y_counter == 0 and value_has_changed:
                self.x_counter = 1
            self.y_counter += 1
            self.break_now = False if self.y_counter < self.x else True
        if not self.break_now and (self.x_counter == self.y):
            self.x_counter, self.y_counter = 0, 0

        return self.break_now


def _get_rcparams(fontsize: int, ticksize: int, use_tex_fonts: bool):
    params: Dict[str, Any] = {
        "axes.labelsize": fontsize,
        "axes.titlesize": fontsize,
        "xtick.labelsize": ticksize,
    }
    if use_tex_fonts:
        params = {
            "text.usetex": True,
            "font.family": "sans-serif",
        }
    return params


def _get_ylim_of_last_x(ydata: Union[np.ndarray, to.Tensor], focus_last_x_factor: float):
    fig, ax = plt.subplots()
    ax.plot(ydata[-int(focus_last_x_factor * len(ydata)) :])
    ylim = ax.get_ylim()
    plt.close(fig)
    del fig
    return ylim


class free_energy_vs_eval_metric_lineplot:
    def __init__(
        self,
        xdata_free_energy: Union[List[int], to.Tensor, np.ndarray],
        ydata_free_energy: Union[List[int], to.Tensor, np.ndarray],
        xdata_eval_metric: Union[List[int], to.Tensor, np.ndarray],
        ydata_eval_metric: Union[List[int], to.Tensor, np.ndarray],
        figure_name: str,
        eval_metric_ylabel: str,
        linestyle_free_energy: str,
        linestyle_eval_metric: str,
        marker_free_energy: str,
        marker_eval_metric: str,
        markersize: int,
        focus_free_energy_last_x: float = 1.0,
        focus_eval_metric_last_x: float = 1.0,
        fontsize: int = 18,
        ticksize: int = 10,
        xlim: Tuple[float, float] = None,
        use_tex_fonts: bool = False,
        output_directory: str = None,
        output_file_type: str = "png",
        figsize: Tuple[int, int] = None,
        dpi: int = 96,
    ):
        """
        Visualization of free energy alongside evaluation metric

        :param xdata_free_energy: Free energy iteration numbers to plot. Defaults to
                                  np.arange(len(xdata_free_energy)).
        :param ydata_free_energy: Free energy values to plot (1-dim tensor of length no_iterations).
        :param xdata_eval_metric: PSNR iteration numbers to plot. Defaults to
                                  np.arange(len(xdata_eval_metric)).
        :param ydata_eval_metric: PSNR values to plot (1-dim tensor of length no_iterations).
        :param figure_name: Figure name.
        :param focus_free_energy_last_x: Optimize ylim of free energy curve based on last x percent
                                         of epochs
        :param focus_free_energy_last_x: Optimize ylim of PSNR curve based on last x percent of
                                         epochs
        :param eval_metric_ylabel: label of PSNR (y-)axis
        :param fontsize: Fontsize of axis labels.
        :param ticksize: Fontsize of x- and y-axis ticks.
        :param linestyle_free_energy: Line style for free energy curve
        :param linestyle_eval_metric: Line style for PSNR curve
        :param marker_free_energy: Marker style for free energy curve
        :param marker_eval_metric: Marker style for PSNR curve
        :param markersize: Size of markers illustrating free energy and eval_metric values
        :param xlim: Limits for y-axis.
        :param use_tex_fonts: Set rcParams to use LaTeX fonts
        :param output_directory: Directory to save the figure, e.g. as .png file.
        :param output_file_type: Figure extension (jpg, png, ...)
        :param figsize: Figure size (if not specified matplotlib's default is used)
        :param dpi: Resolution in DPI passed to plt.subplots
        """
        plt.rcParams.update(_get_rcparams(fontsize, ticksize, use_tex_fonts))

        self.fig, self.ax1 = plt.subplots(
            figsize=figsize, dpi=dpi, facecolor="white", edgecolor="none"
        )

        (self.l1,) = self.ax1.plot(
            xdata_free_energy,
            ydata_free_energy,
            color="k",
            linestyle=linestyle_free_energy,
            marker=marker_free_energy,
            markersize=markersize,
        )
        if focus_free_energy_last_x is not None:
            self.ax1.set_ylim(_get_ylim_of_last_x(ydata_free_energy, focus_free_energy_last_x))

        self.fig.canvas.set_window_title(figure_name)

        self.ax1.xaxis.set_major_locator(MaxNLocator(integer=True))
        if xlim is not None:
            self.ax1.set_xlim(xlim)
        self.ax1.set_ylabel(r"$\mathcal{F}(\mathcal{K},\Theta)\,/\,N$", fontsize=fontsize)
        self.ax1.set_xlabel(r"$\mathrm{Iteration}$", fontsize=fontsize)
        self.ax1.tick_params(axis="y", labelcolor="k")
        if ticksize is not None:
            plt.tick_params(labelsize=ticksize)

        self.ax1.set_frame_on(True)

        self.ax2 = self.ax1.twinx()
        (self.l2,) = self.ax2.plot(
            xdata_eval_metric,
            ydata_eval_metric,
            color="b",
            linestyle=linestyle_eval_metric,
            marker=marker_eval_metric,
            markersize=markersize,
        )
        if focus_eval_metric_last_x is not None:
            self.ax2.set_ylim(_get_ylim_of_last_x(ydata_eval_metric, focus_eval_metric_last_x))

        self.ax2.set_ylabel(eval_metric_ylabel, color="b")
        self.ax2.tick_params(axis="y", labelcolor="b")

        if output_directory is not None:
            if not os.path.exists(output_directory):
                os.makedirs(output_directory)

            self.fig.savefig(
                f"{output_directory}/{figure_name}.{output_file_type}",
                bbox_inches="tight",
            )
            print(f"Wrote {output_directory}/{figure_name}.{output_file_type}")

    def update(
        self,
        xdata_free_energy: Union[List[int], to.Tensor, np.ndarray],
        ydata_free_energy: Union[List[int], to.Tensor, np.ndarray],
        xdata_eval_metric: Union[List[int], to.Tensor, np.ndarray],
        ydata_eval_metric: Union[List[int], to.Tensor, np.ndarray],
        figure_name: str,
        xlim: Tuple[float, float] = None,
        focus_free_energy_last_x: float = 1.0,
        focus_eval_metric_last_x: float = 1.0,
        output_directory: str = None,
        output_file_type: str = "png",
    ):
        assert hasattr(self, "l1")
        self.l1.set_xdata(xdata_free_energy)
        self.l1.set_ydata(ydata_free_energy)

        assert hasattr(self, "l2")
        self.l2.set_xdata(xdata_eval_metric)
        self.l2.set_ydata(ydata_eval_metric)

        self.fig.canvas.set_window_title(figure_name)

        if focus_free_energy_last_x is not None:
            self.ax1.set_ylim(_get_ylim_of_last_x(ydata_free_energy, focus_free_energy_last_x))
        if focus_eval_metric_last_x is not None:
            self.ax2.set_ylim(_get_ylim_of_last_x(ydata_eval_metric, focus_eval_metric_last_x))

        _xlim = (xdata_free_energy[0], xdata_free_energy[-1])
        self.ax1.set_xlim(_xlim if xlim is None and _xlim[0] != _xlim[1] else xlim)

        self.fig.tight_layout()
        self.fig.canvas.set_window_title(figure_name)

        if output_directory is not None:
            self.fig.savefig(
                f"{output_directory}/{figure_name}.{output_file_type}",
                bbox_inches="tight",
            )
            print(f"Wrote {output_directory}/{figure_name}.{output_file_type}")
