from typing import Optional, Mapping, Type
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import to_hex
import seaborn as sns
import pytorch_lightning as pl
from pytorch_lightning.utilities import grad_norm
from torch_geometric.utils import to_dense_adj, to_dense_batch
from torch_sparse import SparseTensor

from source.utils import to_dense_rect_matrix


class BaseModule(pl.LightningModule):
    """🧱 Base Lightning Module class"""

    def maybe_log_stuff(
        self,
        data_batch,
        batch_idx,
        pooling_output,
        plot_type=[],
        istest=False,
    ):
        if self.plot_preds_at_epoch is not None:
            b_idx = self.plot_preds_at_epoch.get("batch", 0)
            s_idx = self.plot_preds_at_epoch.get("samples", 1)
            s_idx_list = self.plot_preds_at_epoch.get("samples_list", None)
            every = self.plot_preds_at_epoch.get("every", 1)

            if batch_idx == b_idx and (self.current_epoch % every == 0 or istest):
                # Densify original graph
                adj = to_dense_adj(
                    data_batch.edge_index,
                    batch=data_batch.batch,
                )
                x, mask = to_dense_batch(data_batch.x, batch=data_batch.batch)

                # Densify pooled graph
                if isinstance(pooling_output.so.s, SparseTensor):
                    adj_pool = to_dense_adj(
                        pooling_output.edge_index,
                        batch=pooling_output.batch,
                    )
                    x_pool, mask_pool = to_dense_batch(
                        pooling_output.x, batch=pooling_output.batch
                    )
                    sel_matrix = to_dense_rect_matrix(
                        pooling_output.so.s.coo()[:2],
                        data_batch.batch,
                        pooling_output.batch,
                        edge_attr=pooling_output.so.s.coo()[2],
                    )

                # Select the specific graph to log
                mask_i = mask[s_idx]
                adj_i = adj[s_idx][mask_i, :][:, mask_i].detach().cpu().numpy()
                x_i = x[s_idx][mask_i].detach().cpu().numpy()

                if isinstance(pooling_output.so.s, SparseTensor):
                    mask_pool_i = mask_pool[s_idx]
                    sel_matrix_i = (
                        sel_matrix[s_idx][mask_i, :][:, mask_pool_i]
                        .detach()
                        .cpu()
                        .numpy()
                    )
                    adj_pool_i = (
                        adj_pool[s_idx][mask_pool_i, :][:, mask_pool_i]
                        .detach()
                        .cpu()
                        .numpy()
                    )

                else:
                    sel_matrix_i = (
                        pooling_output.so.s[s_idx][mask_i, :].cpu().detach().numpy()
                    )
                    adj_pool_i = pooling_output.edge_index[s_idx].cpu().detach().numpy()

                # Use the original node pos as coords
                if x_i.shape[-1] == 2:
                    pos_i = x_i
                    pos_pool_i = (sel_matrix_i.T @ x_i) / sel_matrix_i.sum(axis=0)[..., None]
                else:
                    pos_i = None
                    pos_pool_i = None

                if "pooled_graph" in plot_type:
                    title = "visuals/pooled_graph"
                    if istest:
                        title += "_test"

                    # assign a label to each node
                    labels_dict = {i: f"{i}" for i in range(sel_matrix_i.shape[1])}

                    self.logger.log_nx_graph(
                        adj=adj_pool_i, pos=pos_pool_i, labels=labels_dict, name=title
                    )

                if "assignments" in plot_type:
                    title = "visuals/assignments"
                    if istest:
                        title += "_test"

                    cmap = plt.cm.viridis
                    signal_i = sel_matrix_i.argmax(axis=-1)
                    signal_i[sel_matrix_i.sum(axis=-1) == 0] = -1
                    node_colors = [
                        to_hex(cmap(i / (sel_matrix_i.shape[1] - 1)))
                        if i != -1
                        else "white"
                        for i in signal_i
                    ]

                    # Remove self loops
                    adj_i[np.arange(adj_i.shape[0]), np.arange(adj_i.shape[0])] = 0

                    self.logger.log_nx_graph(
                        adj=adj_i,
                        pos=pos_i,
                        signal=node_colors,
                        name=title,
                        log_series=False if istest else True,
                    )

                if "s_matrix" in plot_type:
                    title = "visuals/s_matrix"
                    if istest:
                        title += "_test"
                    f, ax = plt.subplots(figsize=(4, 5))
                    ax = sns.heatmap(sel_matrix_i, cbar=True, ax=ax)
                    ax.set_xlabel("Clusters")
                    ax.set_ylabel("Nodes")
                    ax.set_title(r"$\mathbf{S}$")
                    self.logger.experiment[title].append(f)
                    plt.close(f)

                if "sst" in plot_type:
                    title = "visuals/SSt"
                    if istest:
                        title += "_test"
                    rec_adj = sel_matrix_i @ sel_matrix_i.T
                    f, _ = plt.subplots()
                    ax = sns.heatmap(rec_adj)
                    ax.set_title(r"$\mathbf{S}\mathbf{S}^\top$")
                    self.logger.experiment[title].append(f)
                    plt.close(f)

                if "sts" in plot_type:
                    title = "visuals/StS"
                    if istest:
                        title += "_test"
                    clust_matrix = sel_matrix_i.T @ sel_matrix_i
                    f, _ = plt.subplots()
                    ax = sns.heatmap(clust_matrix)
                    ax.set_title(r"$\mathbf{S}^\top\mathbf{S}$")
                    self.logger.experiment[title].append(f)
                    plt.close(f)

                if "assignments_grid" in plot_type:
                    title = "visuals/assignments_grid"
                    if istest:
                        title += "_test"

                    cmap = plt.cm.viridis

                    adj_list = []
                    signal_list = []
                    pos_list = []
                    for idx_j in s_idx_list:
                        # Select the j-th graph to log
                        mask_j = mask[idx_j]
                        adj_j = adj[idx_j][mask_j, :][:, mask_j].detach().cpu().numpy()
                        x_j = x[idx_j][mask_j].detach().cpu().numpy()
                        adj_j[np.arange(adj_j.shape[0]), np.arange(adj_j.shape[0])] = (
                            0  # remove self-loops
                        )

                        if isinstance(pooling_output.so.s, SparseTensor):
                            mask_pool_j = mask_pool[idx_j]
                            sel_matrix_j = (
                                sel_matrix[idx_j][mask_j, :][:, mask_pool_j]
                                .detach()
                                .cpu()
                                .numpy()
                            )
                        else:
                            sel_matrix_j = (
                                pooling_output.so.s[idx_j][mask_j, :]
                                .cpu()
                                .detach()
                                .numpy()
                            )

                        signal_j = sel_matrix_j.argmax(axis=-1)
                        signal_j[sel_matrix_j.sum(axis=-1) == 0] = -1
                        node_colors_j = [
                            to_hex(cmap(i / (sel_matrix_j.shape[1] - 1)))
                            if i != -1
                            else "white"
                            for i in signal_j
                        ]

                        signal_list.append(node_colors_j)
                        pos_list.append(x_j if x_j.shape[-1] == 2 else None)
                        adj_list.append(adj_j)

                    self.logger.log_nx_graph_grid(
                        adj_list=adj_list,
                        signal_list=signal_list,
                        pos_list=pos_list,
                        node_size=25,
                        name=title,
                        log_series=False if istest else True,
                    )

    def __init__(
        self,
        optim_class: Optional[Type] = None,
        optim_kwargs: Optional[Mapping] = None,
        scheduler_class: Optional[Type] = None,
        scheduler_kwargs: Optional[Mapping] = None,
        log_lr: bool = True,
        log_grad_norm: bool = False,
        sync_dist: bool = False,  # if ``True``, reduces the metric across devices. Causes overhead. Use only for multi-gpu train
    ):
        super().__init__()
        self.optim_class = optim_class
        self.optim_kwargs = optim_kwargs or dict()
        self.scheduler_class = scheduler_class
        self.scheduler_kwargs = scheduler_kwargs or dict()
        self.log_lr = log_lr
        self.log_grad_norm = log_grad_norm
        self.sync_dist = sync_dist

    def configure_optimizers(self):
        """
        🛠️ Configure optimizer and scheduler
        """
        cfg = dict()
        optimizer = self.optim_class(self.parameters(), **self.optim_kwargs)
        cfg["optimizer"] = optimizer
        if self.scheduler_class is not None:
            metric = self.scheduler_kwargs.pop("monitor", None)
            scheduler = self.scheduler_class(optimizer, **self.scheduler_kwargs)
            cfg["lr_scheduler"] = scheduler
            if metric is not None:
                cfg["monitor"] = metric
        return cfg

    def on_before_optimizer_step(self, optimizer):
        """
        📏 Log gradients norm
        """
        if self.log_grad_norm:
            self.log_dict(grad_norm(self, norm_type=2))

    def on_train_epoch_start(self) -> None:
        """
        ⌚ Log learning rate at the start of each epoch
        """
        if self.log_lr:
            optimizers = self.optimizers()
            if isinstance(optimizers, list):
                for i, optimizer in enumerate(optimizers):
                    lr = optimizer.optimizer.param_groups[0]["lr"]
                    self.log(
                        f"lr_{i}",
                        lr,
                        on_step=False,
                        on_epoch=True,
                        logger=True,
                        prog_bar=False,
                        batch_size=1,
                        sync_dist=self.sync_dist,
                    )
            else:
                lr = optimizers.optimizer.param_groups[0]["lr"]
                self.log(
                    "lr",
                    lr,
                    on_step=False,
                    on_epoch=True,
                    logger=True,
                    prog_bar=False,
                    batch_size=1,
                    sync_dist=self.sync_dist,
                )
