from typing import Optional, Mapping, Type
import torch
import torchmetrics
from torchmetrics.clustering import (
    NormalizedMutualInfoScore,
    HomogeneityScore,
    CompletenessScore,
)

from .base_module import BaseModule
from source.utils.metrics import cluster_acc, FuzzyMutualInformation, FuzzyClusterCosine


class ClusterModule(BaseModule):
    """
    Lightning module to perform clustering with graph pooling 🎱
    """

    def __init__(
        self,
        model: Optional[torch.nn.Module] = None,
        num_classes: int = None,
        num_clusters: int = None,
        optim_class: Optional[Type] = None,
        optim_kwargs: Optional[Mapping] = None,
        scheduler_class: Optional[Type] = None,
        scheduler_kwargs: Optional[Mapping] = None,
        log_lr: bool = True,
        log_grad_norm: bool = False,
        sync_dist: bool = False,  # if ``True``, reduces the metric across devices. Causes overhead. Use only for multi-gpu train
        plot_dict: Optional[Mapping] = None,
    ):
        super().__init__(
            optim_class,
            optim_kwargs,
            scheduler_class,
            scheduler_kwargs,
            log_lr,
            log_grad_norm,
            sync_dist,
        )
        if num_clusters is None:
            num_clusters = num_classes
        self.model = model
        self.sync_dist = sync_dist
        self.plot_preds_at_epoch = plot_dict
        self.train_metrics = torchmetrics.MetricCollection(
            {
                "NMI": NormalizedMutualInfoScore(),
                "Homogeneity": HomogeneityScore(),
                "Completeness": CompletenessScore(),
            }
        )
        self.train_fuzzy_metrics = torchmetrics.MetricCollection(
            {
                "FuzzyMI": FuzzyMutualInformation(
                    num_classes=num_classes, num_clusters=num_clusters
                ),
                "FuzzyClusterCosine": FuzzyClusterCosine(
                    num_classes=num_classes, num_clusters=num_clusters
                ),
            }
        )
        self.test_metrics = torchmetrics.MetricCollection(
            {
                "NMI": NormalizedMutualInfoScore(),
                "Homogeneity": HomogeneityScore(),
                "Completeness": CompletenessScore(),
            }
        )
        self.test_fuzzy_metrics = torchmetrics.MetricCollection(
            {
                "FuzzyMI": FuzzyMutualInformation(
                    num_classes=num_classes, num_clusters=num_clusters
                ),
                "FuzzyClusterCosine": FuzzyClusterCosine(
                    num_classes=num_classes, num_clusters=num_clusters
                ),
            }
        )

    def forward(self, data):
        """
        ⏩
        """
        pooling_out = self.model(data)

        return pooling_out

    def training_step(self, batch, batch_idx):
        """
        🐾
        """
        y = batch.y
        pooling_out = self.forward(batch)

        loss = sum(pooling_out.loss.values())
        self.log(
            "train_loss",
            loss,
            on_step=False,
            on_epoch=True,
            prog_bar=False,
            sync_dist=self.sync_dist,
            batch_size=batch.batch_size,
        )

        for k, v in pooling_out.loss.items():
            self.log(
                f"{k}",
                v,
                on_step=False,
                on_epoch=True,
                prog_bar=False,
                sync_dist=self.sync_dist,
                batch_size=batch.batch_size,
            )

        self.train_metrics.update(pooling_out.so.s[0].argmax(axis=-1).detach(), y.int())
        self.train_fuzzy_metrics.update(pooling_out.so.s[0], y.long())
        self.log(
            "NMI",
            self.train_metrics["NMI"],
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            sync_dist=self.sync_dist,
            batch_size=batch.batch_size,
        )
        self.log(
            "Homogeneity",
            self.train_metrics["Homogeneity"],
            on_step=False,
            on_epoch=True,
            prog_bar=False,
            sync_dist=self.sync_dist,
            batch_size=batch.batch_size,
        )
        self.log(
            "Completeness",
            self.train_metrics["Completeness"],
            on_step=False,
            on_epoch=True,
            prog_bar=False,
            sync_dist=self.sync_dist,
            batch_size=batch.batch_size,
        )
        self.log(
            "FuzzyMI",
            self.train_fuzzy_metrics["FuzzyMI"],
            on_step=False,
            on_epoch=True,
            prog_bar=False,
            sync_dist=self.sync_dist,
            batch_size=batch.batch_size,
        )
        self.log(
            "FuzzyClusterCosine",
            self.train_fuzzy_metrics["FuzzyClusterCosine"],
            on_step=False,
            on_epoch=True,
            prog_bar=False,
            sync_dist=self.sync_dist,
            batch_size=batch.batch_size,
        )

        # Log images and artifacts
        if "train" in self.plot_preds_at_epoch["set"]:
            if self.logger is not None and hasattr(self.logger, "cfg"):
                if self.logger.cfg.logger.backend in ["neptune", "tensorboard"]:
                    self.maybe_log_stuff(
                        data_batch=batch,
                        batch_idx=batch_idx,
                        pooling_output=pooling_out,
                        plot_type=self.plot_preds_at_epoch["types"],
                        istest=False,
                    )

        return {"loss": loss}

    def test_step(self, batch, batch_idx):
        """
        🧪
        """
        y = batch.y
        pooling_out = self.forward(batch)

        # Log metrics
        loss = sum(pooling_out.loss.values())
        self.log(
            "test_loss", loss, sync_dist=self.sync_dist, batch_size=batch.batch_size
        )
        self.test_metrics.update(pooling_out.so.s[0].argmax(axis=-1).detach(), y.int())
        self.test_fuzzy_metrics.update(pooling_out.so.s[0], y.long())
        self.log(
            "test_NMI",
            self.test_metrics["NMI"],
            sync_dist=self.sync_dist,
            batch_size=batch.batch_size,
        )
        self.log(
            "test_Homogeneity",
            self.test_metrics["Homogeneity"],
            sync_dist=self.sync_dist,
            batch_size=batch.batch_size,
        )
        self.log(
            "test_Completeness",
            self.test_metrics["Completeness"],
            sync_dist=self.sync_dist,
            batch_size=batch.batch_size,
        )
        self.log(
            "test_FuzzyMI",
            self.test_fuzzy_metrics["FuzzyMI"],
            sync_dist=self.sync_dist,
            batch_size=batch.batch_size,
        )
        self.log(
            "test_FuzzyClusterCosine",
            self.test_fuzzy_metrics["FuzzyClusterCosine"],
            sync_dist=self.sync_dist,
            batch_size=batch.batch_size,
        )

        # Log custom metrics
        clust_acc, f1_macro, f1_micro = cluster_acc(
            pooling_out.so.s[0].argmax(axis=-1).detach().cpu().numpy(), y.cpu().numpy()
        )
        self.log(
            "clust_acc",
            clust_acc,
            on_step=False,
            on_epoch=True,
            batch_size=batch.batch_size,
            sync_dist=self.sync_dist,
        )
        self.log(
            "f1_macro",
            f1_macro,
            on_step=False,
            on_epoch=True,
            batch_size=batch.batch_size,
            sync_dist=self.sync_dist,
        )
        self.log(
            "f1_micro",
            f1_micro,
            on_step=False,
            on_epoch=True,
            batch_size=batch.batch_size,
            sync_dist=self.sync_dist,
        )

        # Fit a logistic regression classifier on s to predict y
        from sklearn.linear_model import LogisticRegression
        from sklearn.metrics import accuracy_score

        clf = LogisticRegression(random_state=0).fit(
            pooling_out.so.s[0].detach().cpu().numpy(), y.cpu().numpy()
        )
        y_pred = clf.predict(pooling_out.so.s[0].detach().cpu().numpy())
        acc = accuracy_score(y.cpu().numpy(), y_pred)
        self.log("CLF_acc", acc, sync_dist=self.sync_dist)

        # Log images and artifacts
        if "test" in self.plot_preds_at_epoch["set"]:
            if self.logger is not None and hasattr(self.logger, "cfg"):
                if self.logger.cfg.logger.backend in ["neptune", "tensorboard"]:
                    self.maybe_log_stuff(
                        data_batch=batch,
                        batch_idx=batch_idx,
                        pooling_output=pooling_out,
                        plot_type=self.plot_preds_at_epoch["types"],
                        istest=True,
                    )
