from typing import Optional, Mapping, Type
import torch
import torchmetrics
from torchmetrics.classification import Accuracy, AUROC

from .base_module import BaseModule


class NodeClassificationModule(BaseModule):
    """Lightning module to perform node classification with graph pooling 🎱"""

    def __init__(
        self,
        model: Optional[torch.nn.Module] = None,
        optim_class: Optional[Type] = None,
        optim_kwargs: Optional[Mapping] = None,
        scheduler_class: Optional[Type] = None,
        scheduler_kwargs: Optional[Mapping] = None,
        log_lr: bool = True,
        log_grad_norm: bool = False,
        sync_dist: bool = False,  # if ``True``, reduces the metric across devices. Causes overhead. Use only for multi-gpu train
        plot_dict: Optional[Mapping] = None,
        fold: Optional[int] = None,
    ):
        super().__init__(
            optim_class,
            optim_kwargs,
            scheduler_class,
            scheduler_kwargs,
            log_lr,
            log_grad_norm,
            sync_dist,
        )

        self.model = model
        self.fold = fold
        self.sync_dist = sync_dist
        self.plot_preds_at_epoch = plot_dict
        self.loss = torch.nn.CrossEntropyLoss()

        self.train_metrics = torchmetrics.MetricCollection(
            {
                "train_acc": Accuracy(task="multiclass", num_classes=model.num_classes),
            }
        )
        self.train_auroc = AUROC(task="multiclass", num_classes=model.num_classes)
        
        self.val_metrics = torchmetrics.MetricCollection(
            {
                "val_acc": Accuracy(task="multiclass", num_classes=model.num_classes),
            }
        )
        self.val_auroc = AUROC(task="multiclass", num_classes=model.num_classes)
        
        self.test_metrics = torchmetrics.MetricCollection(
            {
                "test_acc": Accuracy(task="multiclass", num_classes=model.num_classes),
            }
        )
        self.test_auroc = AUROC(task="multiclass", num_classes=model.num_classes)

    def forward(self, data):
        """
        ⏩
        """
        logits, pooling_out = self.model(data)

        return logits, pooling_out

    def training_step(self, batch, batch_idx):
        """
        🐾
        """
        logits, pooling_out = self.forward(batch)
        mask = batch.train_mask if self.fold is None else batch.train_mask[:, self.fold]
        task_loss = self.loss(logits[mask], batch.y[mask])
        if pooling_out.loss is None:
            loss = task_loss
        else:
            loss = task_loss + sum(pooling_out.loss.values())

        # Log losses
        self.log(
            "train_loss",
            loss,
            batch_size=batch.y.size(0),
            on_step=False,
            on_epoch=True,
            prog_bar=False,
            sync_dist=self.sync_dist,
        )
        self.log(
            "train_task_loss",
            task_loss,
            batch_size=batch.y.size(0),
            on_step=False,
            on_epoch=True,
            prog_bar=False,
            sync_dist=self.sync_dist,
        )
        if pooling_out.loss is not None:
            for k, v in pooling_out.loss.items():
                self.log(
                    f"train_{k}",
                    v,
                    on_step=False,
                    on_epoch=True,
                    prog_bar=False,
                    sync_dist=self.sync_dist,
                    batch_size=batch.batch_size,
                )

        # Log metrics
        self.train_metrics.update(logits[mask].argmax(1).detach().int(), batch.y[mask].int())
        self.train_auroc.update(logits[mask], batch.y[mask])
        self.log(
            "train_acc",
            self.train_metrics["train_acc"],
            batch_size=batch.y.size(0),
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            sync_dist=self.sync_dist,
        )
        self.log(
            "train_auroc",
            self.train_auroc,
            batch_size=batch.y.size(0),
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            sync_dist=self.sync_dist,
        )

        # Log images and artifacts
        if "train" in self.plot_preds_at_epoch["set"]:
            if self.logger is not None and hasattr(self.logger, "cfg"):
                if self.logger.cfg.logger.backend == "neptune":
                    self.maybe_log_stuff(
                        data_batch=batch,
                        batch_idx=batch_idx,
                        pooling_output=pooling_out,
                        plot_type=self.plot_preds_at_epoch["types"],
                        istest=False,
                    )

        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        """
        🐾
        """
        logits, pooling_out = self.forward(batch)
        mask = batch.val_mask if self.fold is None else batch.val_mask[:, self.fold]
        task_loss = self.loss(logits[mask], batch.y[mask])
        if pooling_out.loss is None:
            loss = task_loss
        else:
            loss = task_loss + sum(pooling_out.loss.values())

        # Log losses
        self.log(
            "val_loss",
            loss,
            batch_size=batch.y.size(0),
            on_step=False,
            on_epoch=True,
            prog_bar=False,
            sync_dist=self.sync_dist,
        )
        self.log(
            "val_task_loss",
            task_loss,
            batch_size=batch.y.size(0),
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            sync_dist=self.sync_dist,
        )
        if pooling_out.loss is not None:
            for k, v in pooling_out.loss.items():
                self.log(
                    f"train_{k}",
                    v,
                    on_step=False,
                    on_epoch=True,
                    prog_bar=False,
                    sync_dist=self.sync_dist,
                    batch_size=batch.batch_size,
                )

        # Log metrics
        self.val_metrics.update(logits[mask].argmax(1).detach().int(), batch.y[mask].int())
        self.val_auroc.update(logits[mask], batch.y[mask])
        self.log(
            "val_acc",
            self.val_metrics["val_acc"],
            batch_size=batch.y.size(0),
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            sync_dist=self.sync_dist,
        )
        self.log(
            "val_auroc",
            self.val_auroc,
            batch_size=batch.y.size(0),
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            sync_dist=self.sync_dist,
        )

        # Log images and artifacts
        if "val" in self.plot_preds_at_epoch["set"]:
            if self.logger is not None and hasattr(self.logger, "cfg"):
                if self.logger.cfg.logger.backend == "neptune":
                    self.maybe_log_stuff(
                        data_batch=batch,
                        batch_idx=batch_idx,
                        pooling_output=pooling_out,
                        plot_type=self.plot_preds_at_epoch["types"],
                        istest=False,
                    )

        return {"val_loss": loss}

    def test_step(self, batch, batch_idx):
        """
        🧪
        """
        logits, pooling_out = self.forward(batch)
        mask = batch.test_mask if self.fold is None else batch.test_mask[:, self.fold]
        task_loss = self.loss(logits[mask], batch.y[mask])
        if pooling_out.loss is None:
            loss = task_loss
        else:
            loss = task_loss + sum(pooling_out.loss.values())

        # Log losses
        self.log(
            "test_loss", loss, batch_size=batch.y.size(0), sync_dist=self.sync_dist
        )
        self.log(
            "test_task_loss",
            task_loss,
            batch_size=batch.y.size(0),
            sync_dist=self.sync_dist,
        )
        if pooling_out.loss is not None:
            for k, v in pooling_out.loss.items():
                self.log(
                    f"test_{k}",
                    v,
                    on_step=False,
                    on_epoch=True,
                    prog_bar=False,
                    sync_dist=self.sync_dist,
                    batch_size=batch.batch_size,
                )

        # Log metrics
        self.test_metrics.update(logits[mask].argmax(1).detach().int(), batch.y[mask].int())
        self.test_auroc.update(logits[mask], batch.y[mask])
        self.log(
            "test_acc",
            self.test_metrics["test_acc"],
            batch_size=batch.y.size(0),
            sync_dist=self.sync_dist,
        )
        self.log(
            "test_auroc",
            self.test_auroc,
            batch_size=batch.y.size(0),
            sync_dist=self.sync_dist,
        )

        # Log images and artifacts
        if "test" in self.plot_preds_at_epoch["set"]:
            if self.logger is not None and hasattr(self.logger, "cfg"):
                if self.logger.cfg.logger.backend == "neptune" and self.model.pooler.name != 'diffndp':
                    self.maybe_log_stuff(
                        data_batch=batch,
                        batch_idx=batch_idx,
                        pooling_output=pooling_out,
                        plot_type=self.plot_preds_at_epoch["types"],
                        istest=True,
                    )

        return {"test_loss": loss}
