import math
import os
from typing import List, Mapping, Optional, Union
import numpy as np
import networkx as nx
from networkx.drawing.nx_agraph import graphviz_layout
import pandas as pd
import matplotlib.pyplot as plt
from pytorch_lightning.loggers import NeptuneLogger as LightningNeptuneLogger
from torch import Tensor
from neptune.types import File


def save_figure(fig, filename: str, as_html=False, as_pickle=False):
    if filename.endswith(".html"):
        as_html = True
        filename = filename[:-5]
    elif filename.endswith(".pkl"):
        as_pickle = True
        filename = filename[:-4]
    if not (as_html or as_pickle):
        as_html = False  # save as html if nothing is specified
    if as_html:
        import mpld3

        with open(filename + ".html", "w") as fp:
            mpld3.save_html(fig, fp)
    if as_pickle:
        import pickle

        with open(filename + ".pkl", "wb") as fp:
            pickle.dump(fig, fp)


def ensure_list(value):
    # if isinstance(value, Sequence) and not isinstance(value, str):
    if hasattr(value, "__iter__") and not isinstance(value, str):
        return list(value)
    else:
        return [value]


class NeptuneLogger(LightningNeptuneLogger):
    """Extensions of PyTorch Lightning
    :class:`~pytorch_lightning.loggers.NeptuneLogger` with useful logging
    functionalities.

    Args:
        api_key (str, optional): Neptune API token, found on https://neptune.ai
            upon registration. Read: `how to find and set Neptune API token
            <https://docs.neptune.ai/administration/security-and-privacy/how-to-find-and-set-neptune-api-token>`_.
            It is recommended to keep it in the :obj:`NEPTUNE_API_TOKEN`
            environment variable, then you can drop :attr:`api_key=None`.
            (default: :obj:`None`)
        project_name (str, optional): Name of a project in a form of
            "my_workspace/my_project". If :obj:`None`, the value of
            `NEPTUNE_PROJECT` environment variable will be taken.
            You need to create the project in https://neptune.ai first.
            (default: :obj:`None`)
        experiment_name (str, optional): Editable name of the run.
            Run name appears in the "all metadata/sys" section in Neptune UI.
            (default: :obj:`None`)
        tags (list, optional): List of tags of the run.
            (default: :obj:`None`)
        params (Mapping, optional): Mapping of the run's parameters (are logged
            as :obj:`"parameters"` on Neptune).
            (default: :obj:`None`)
        save_dir (str, optional): Save directory of the experiment, used to
            temporarily log artifacts before upload. If :obj:`None`, then
            defaults to ``.neptune``.
            (default: :obj:`None`)
        debug (bool): If :obj:`True`, then do not log online (i.e., log in
            :obj:`"debug"` mode). Otherwise log online in :obj:`"async"` mode.
            (default: :obj:`False`)
        prefix (str, optional): Root namespace for all metadata logging.
            (default: :obj:`"logs"`)
        upload_stdout (bool): If :obj:`True`, then log also :obj:`stdout` on
            Neptune.
            (default: :obj:`False`)
        **kwargs: Additional parameters for
            :class:`~pytorch_lightning.loggers.NeptuneLogger`.
    """

    def __init__(
        self,
        api_key: Optional[str] = None,
        project_name: Optional[str] = None,
        experiment_name: Optional[str] = None,
        tags: Optional[Union[str, List]] = None,
        params: Optional[Mapping] = None,
        save_dir: Optional[str] = None,
        debug: bool = False,
        prefix: Optional[str] = "logs",
        upload_stdout: bool = False,
        **kwargs,
    ):
        prefix = prefix or ""
        if tags is not None:
            kwargs["tags"] = ensure_list(tags)
        mode = "debug" if debug else "async"
        super(NeptuneLogger, self).__init__(
            api_key=api_key,
            project=project_name,
            name=experiment_name,
            log_model_checkpoints=False,
            prefix=prefix,
            capture_stdout=upload_stdout,
            mode=mode,
            **kwargs,
        )
        self.save_dir = save_dir
        if params is not None:
            self.run["parameters"] = params

    @property
    def save_dir(self) -> Optional[str]:
        """Gets the save directory of the experiment.

        Returns:
            the root directory where experiment logs get saved
        """
        return self._save_dir

    @save_dir.setter
    def save_dir(self, value):
        if value is not None:
            self._save_dir = os.path.abspath(value)
        else:
            self._save_dir = os.path.join(os.getcwd(), ".neptune")

    def log_metric(
        self,
        metric_name: str,
        metric_value: Union[Tensor, float, str],
    ):
        self.run[f"logs/{metric_name}"].log(metric_value)

    def _artifact_storage_path(self, name, extension: str = None):
        # add extension to name
        if extension is not None:
            if not extension.startswith("."):
                extension = "." + extension
            if not name.endswith(extension):
                name += extension
        else:
            _, extension = os.path.splitext(name)
        # save artifact with temporary random id
        from random import choice
        from string import ascii_letters

        rnd = "".join([choice(ascii_letters) for _ in range(16)]) + extension
        # create artifact path
        os.makedirs(self.save_dir, exist_ok=True)
        id_path = os.path.join(self.save_dir, rnd)
        return id_path, name

    def log_artifact(
        self,
        filename: str,
        artifact_name: Optional[str] = None,
        delete_after: bool = False,
    ):
        if artifact_name is None:
            # './dir/file.ext' -> 'file.ext'
            artifact_name = os.path.basename(filename)
        if delete_after:
            self.run[f"artifacts/{artifact_name}"].upload(filename, wait=True)
            os.unlink(filename)
        else:
            self.run[f"artifacts/{artifact_name}"].upload(filename)

    def log_numpy(self, array, name: str = "array"):
        """Log a numpy array object.

        Args:
            array (array_like): The array to be logged.
            name (str): The name of the file. (default: :obj:`'array'`)
        """
        path, name = self._artifact_storage_path(name, extension=".npy")
        np.save(path, array)
        self.log_artifact(path, artifact_name=name, delete_after=True)

    def log_dataframe(self, df: pd.DataFrame, name: str = "dataframe"):
        """Log a dataframe as csv.

        Args:
            df (DataFrame): The dataframe to be logged.
            name (str): The name of the file. (default: :obj:`'dataframe'`)
        """
        path, name = self._artifact_storage_path(name, extension=".csv")
        df.to_csv(path, index=True, index_label="index")
        self.log_artifact(path, artifact_name=name, delete_after=True)

    def log_figure(self, fig, name: str = "figure"):
        """Log a matplotlib figure as html.

        Args:
            fig (Figure): The matplotlib figure to be logged.
            name (str): The name of the file. (default: :obj:`'figure'`)
        """
        path, name = self._artifact_storage_path(name, extension=".html")
        save_figure(fig, path)
        self.log_artifact(path, artifact_name=name, delete_after=True)

    def log_tensor_img(self, tensor: Union[Tensor, np.array], name: str = "tensor"):
        """
        Log an image represented as a torch Tensor or a Numpy Array
        """
        path, name = self._artifact_storage_path(name, extension=".png")
        if tensor.is_cuda:
            tensor = tensor.cpu()
        File.as_image(tensor)._save(path)
        self.log_artifact(path, artifact_name=name, delete_after=True)

    def draw_nx_graph(
        self, adj, pos, ax, signal, cmap, labels=None, node_size=40, font_size=8
    ):
        graph = nx.from_numpy_array(adj)

        if pos is None:
            # pos = graphviz_layout(graph, prog="neato") 
            pos = nx.spring_layout(graph)

        if signal is not None:
            if hasattr(signal, "dtype") and signal.dtype == int:  # categorical signal
                nx.draw_networkx_nodes(
                    graph,
                    node_color=signal,
                    ax=ax,
                    node_size=node_size,
                    pos=pos,
                    alpha=0.7,
                    cmap=cmap,
                )
            elif hasattr(signal, "dtype"):  # continuous signal
                vmax = np.max(signal)
                vmin = np.min(signal)
                nx.draw_networkx_nodes(
                    graph,
                    node_color=signal,
                    ax=ax,
                    node_size=node_size,
                    pos=pos,
                    alpha=0.7,
                    cmap=cmap,
                    vmax=vmax,
                    vmin=vmin,
                )
            elif isinstance(signal, list):  # list of colors
                nx.draw_networkx_nodes(
                    graph,
                    node_color=signal,
                    ax=ax,
                    node_size=node_size,
                    pos=pos,
                    alpha=0.7,
                    edgecolors="gray",
                )
            else:
                raise ValueError(
                    "Signal must be a numpy array, list, or 'white' string."
                )

        if labels is not None:
            nx.draw_networkx_labels(
                graph, ax=ax, pos=pos, labels=labels, font_size=font_size
            )

        nx.draw_networkx_edges(
            graph, ax=ax, pos=pos, alpha=0.5, edge_color="lightgray", width=1
        )

    def log_nx_graph(
        self,
        adj,
        signal: Optional[Union[np.ndarray, List]] = None,
        node_size=25,
        font_size=12,
        name: str = "graph",
        pos: Optional[np.ndarray] = None,
        labels: Optional[Mapping] = None,
        cmap=None,
        log_series=True,
    ):
        cmap_obj = plt.cm.get_cmap(cmap) if cmap is not None else plt.cm.viridis

        fig, ax = plt.subplots(figsize=(8, 6))
        self.draw_nx_graph(
            adj,
            pos,
            ax,
            signal,
            cmap_obj,
            labels=labels,
            node_size=node_size,
            font_size=font_size,
        )
        plt.axis("off")

        if log_series:
            self.experiment[name].append(fig)
        else:
            self.experiment[name].upload(fig)
        plt.close()

    def log_nx_graph_grid(
        self,
        adj_list: list,
        signal_list: Optional[List[Union[np.ndarray, list]]] = None,
        pos_list: Optional[List[Optional[np.ndarray]]] = None,
        labels_list: Optional[List[Optional[Mapping]]] = None,
        node_size: int = 25,
        font_size: int = 12,
        name: str = "graph_grid",
        cmap=None,
        log_series: bool = True,
    ):
        num_graphs = len(adj_list)
        grid_size = int(math.ceil(math.sqrt(num_graphs)))

        cmap_obj = plt.cm.get_cmap(cmap) if cmap is not None else plt.cm.viridis

        # Create figure and axes for the grid
        fig, axes = plt.subplots(
            grid_size, grid_size, figsize=(4 * grid_size, 4 * grid_size)
        )
        axes = axes.flatten()

        for i in range(num_graphs):
            adj_i = adj_list[i]
            sig_i = signal_list[i] if signal_list is not None else None
            lab_i = labels_list[i] if labels_list is not None else None
            pos_i = pos_list[i] if pos_list is not None else None
            ax = axes[i]

            self.draw_nx_graph(
                adj_i,
                pos_i,
                ax,
                sig_i,
                cmap_obj,
                labels=lab_i,
                node_size=node_size,
                font_size=font_size,
            )

            ax.set_axis_off()
            ax.set_title(f"Graph {i}")

        # Turn off any unused subplots
        for j in range(num_graphs, len(axes)):
            axes[j].set_axis_off()

        # Log the figure
        if log_series:
            self.experiment[name].append(fig)
        else:
            self.experiment[name].upload(fig)

        plt.close(fig)
