from typing import Optional, Mapping, Type
import torch
import torchmetrics
from ogb.graphproppred import Evaluator
from torchmetrics.classification import (
    Accuracy,
    MulticlassF1Score,
    MulticlassAUROC,
    AveragePrecision,
)

from .base_module import BaseModule
from source.utils.misc import get_gpu_memory_from_nvidia_smi


class ClassificationModule(BaseModule):
    def __init__(
        self,
        model: Optional[torch.nn.Module] = None,
        optim_class: Optional[Type] = None,
        optim_kwargs: Optional[Mapping] = None,
        scheduler_class: Optional[Type] = None,
        scheduler_kwargs: Optional[Mapping] = None,
        log_lr: bool = True,
        log_grad_norm: bool = False,
        sync_dist: bool = False,  # if ``True``, reduces the metric across devices. Causes overhead. Use only for multi-gpu train
        plot_dict: Optional[Mapping] = None,
        **kwargs,
    ):
        super().__init__(
            optim_class,
            optim_kwargs,
            scheduler_class,
            scheduler_kwargs,
            log_lr,
            log_grad_norm,
            sync_dist,
        )

        self.model = model
        self.sync_dist = sync_dist
        self.plot_preds_at_epoch = plot_dict

    def init_metrics(self):
        pass

    def log_metrics_step(self, preds, labels, set_):
        pass

    def log_metrics_epoch(self, set_):
        pass

    def compute_task_loss(self, logits, labels):
        pass

    def forward(self, data):
        """
        ⏩
        """
        logits, pooling_out = self.model(data)

        return logits, pooling_out

    def training_step(self, batch, batch_idx):
        """
        🐾
        """
        logits, pooling_out = self.forward(batch)
        task_loss = self.compute_task_loss(logits, batch.y)
        if pooling_out.loss is None:
            loss = task_loss
        else:
            loss = task_loss + sum(pooling_out.loss.values())

        # Log losses
        self.log(
            "train_loss",
            loss,
            batch_size=batch.y.size(0),
            on_step=False,
            on_epoch=True,
            prog_bar=False,
            sync_dist=self.sync_dist,
        )
        self.log(
            "train_task_loss",
            task_loss,
            batch_size=batch.y.size(0),
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            sync_dist=self.sync_dist,
        )
        if pooling_out.loss is not None:
            for k, v in pooling_out.loss.items():
                self.log(
                    f"train_{k}",
                    v,
                    on_step=False,
                    on_epoch=True,
                    prog_bar=False,
                    sync_dist=self.sync_dist,
                    batch_size=batch.batch_size,
                )

        # Log metrics
        self.log_metrics_step(logits, batch.y, set_="train")

        # Log images and artifacts
        if "train" in self.plot_preds_at_epoch["set"]:
            if self.logger is not None and hasattr(self.logger, "cfg"):
                if self.logger.cfg.logger.backend == "neptune":
                    self.maybe_log_stuff(
                        data_batch=batch,
                        batch_idx=batch_idx,
                        pooling_output=pooling_out,
                        plot_type=self.plot_preds_at_epoch["types"],
                        istest=False,
                    )

        return {"loss": loss}

    def on_train_epoch_start(self) -> None:
        super().on_train_epoch_start()  # This logs the learning rate
        torch.cuda.reset_peak_memory_stats(self.device)

    def on_train_epoch_end(self):
        """
        🏁
        """
        self.log_metrics_epoch(set_="train")

        # Log gpu memory usage TODO: comment this out if not needed
        #_, used_mem = get_gpu_memory_from_nvidia_smi(0)
        used_gb = torch.cuda.max_memory_allocated(self.device) / 1024**3
        #used_gb = used_mem / 1024
        self.log(
            "peak_train_gpu_usage (GB)",
            used_gb,
            on_step=False,
            on_epoch=True,
            prog_bar=False,
            sync_dist=self.sync_dist,
        )

    def validation_step(self, batch, batch_idx):
        """
        🐾
        """
        logits, pooling_out = self.forward(batch)
        task_loss = self.compute_task_loss(logits, batch.y)
        if pooling_out.loss is None:
            loss = task_loss
        else:
            loss = task_loss + sum(pooling_out.loss.values())

        # Log losses
        self.log(
            "val_loss",
            loss,
            batch_size=batch.y.size(0),
            on_step=False,
            on_epoch=True,
            prog_bar=False,
            sync_dist=self.sync_dist,
        )
        self.log(
            "val_task_loss",
            task_loss,
            batch_size=batch.y.size(0),
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            sync_dist=self.sync_dist,
        )
        if pooling_out.loss is not None:
            for k, v in pooling_out.loss.items():
                self.log(
                    f"val_{k}",
                    v,
                    on_step=False,
                    on_epoch=True,
                    prog_bar=False,
                    sync_dist=self.sync_dist,
                    batch_size=batch.batch_size,
                )

        # Log metrics
        self.log_metrics_step(logits, batch.y, set_="val")

        # Log images and artifacts
        if "val" in self.plot_preds_at_epoch["set"]:
            if self.logger is not None and hasattr(self.logger, "cfg"):
                if self.logger.cfg.logger.backend == "neptune":
                    self.maybe_log_stuff(
                        data_batch=batch,
                        batch_idx=batch_idx,
                        pooling_output=pooling_out,
                        plot_type=self.plot_preds_at_epoch["types"],
                        istest=False,
                    )

        return {"val_loss": loss}

    def on_validation_epoch_end(self):
        """
        🏁
        """
        self.log_metrics_epoch(set_="val")

        # Log gpu memory usage : comment this out if not needed
        _, used_mem = get_gpu_memory_from_nvidia_smi(0)
        used_gb = used_mem / 1024
        self.log(
            "val_gpu_usage (GB)",
            used_gb,
            on_step=False,
            on_epoch=True,
            prog_bar=False,
            sync_dist=self.sync_dist,
        )

    def test_step(self, batch, batch_idx):
        """
        🧪
        """
        logits, pooling_out = self.forward(batch)
        task_loss = self.compute_task_loss(logits, batch.y)
        if pooling_out.loss is None:
            loss = task_loss
        else:
            loss = task_loss + sum(pooling_out.loss.values())

        # Log losses
        self.log(
            "test_loss", loss, batch_size=batch.y.size(0), sync_dist=self.sync_dist
        )
        self.log(
            "test_task_loss",
            task_loss,
            batch_size=batch.y.size(0),
            sync_dist=self.sync_dist,
        )
        if pooling_out.loss is not None:
            for k, v in pooling_out.loss.items():
                self.log(
                    f"test_{k}",
                    v,
                    on_step=False,
                    on_epoch=True,
                    prog_bar=False,
                    sync_dist=self.sync_dist,
                    batch_size=batch.batch_size,
                )

        # Log metrics
        self.log_metrics_step(logits, batch.y, set_="test")

        # Log images and artifacts
        if "test" in self.plot_preds_at_epoch["set"]:
            if self.logger is not None and hasattr(self.logger, "cfg"):
                if self.logger.cfg.logger.backend == "neptune":
                    self.maybe_log_stuff(
                        data_batch=batch,
                        batch_idx=batch_idx,
                        pooling_output=pooling_out,
                        plot_type=self.plot_preds_at_epoch["types"],
                        istest=True,
                    )

        return {"test_loss": loss}

    def on_test_epoch_end(self):
        """
        🏁
        """
        self.log_metrics_epoch(set_="test")


class SingleClassificationModule(ClassificationModule):
    def __init__(
        self,
        model: Optional[torch.nn.Module] = None,
        optim_class: Optional[Type] = None,
        optim_kwargs: Optional[Mapping] = None,
        scheduler_class: Optional[Type] = None,
        scheduler_kwargs: Optional[Mapping] = None,
        log_lr: bool = True,
        log_grad_norm: bool = False,
        sync_dist: bool = False,  # if ``True``, reduces the metric across devices. Causes overhead. Use only for multi-gpu train
        plot_dict: Optional[Mapping] = None,
        **kwargs,
    ):
        super().__init__(
            model=model,
            optim_class=optim_class,
            optim_kwargs=optim_kwargs,
            scheduler_class=scheduler_class,
            scheduler_kwargs=scheduler_kwargs,
            log_lr=log_lr,
            log_grad_norm=log_grad_norm,
            sync_dist=sync_dist,
            plot_dict=plot_dict,
        )

        self.loss = torch.nn.CrossEntropyLoss()
        self.init_metrics()

    def compute_task_loss(self, logits, labels):
        return self.loss(logits, labels)

    def init_metrics(self):
        self.train_metrics = torchmetrics.MetricCollection(
            {
                "train_acc": Accuracy(
                    task="multiclass", num_classes=self.model.num_classes
                ),
                "train_f1": MulticlassF1Score(
                    num_classes=self.model.num_classes, average="macro"
                ),
                "train_auroc": MulticlassAUROC(num_classes=self.model.num_classes),
            }
        )
        self.val_metrics = torchmetrics.MetricCollection(
            {
                "val_acc": Accuracy(
                    task="multiclass", num_classes=self.model.num_classes
                ),
                "val_f1": MulticlassF1Score(
                    num_classes=self.model.num_classes, average="macro"
                ),
                "val_auroc": MulticlassAUROC(num_classes=self.model.num_classes),
            }
        )
        self.test_metrics = torchmetrics.MetricCollection(
            {
                "test_acc": Accuracy(
                    task="multiclass", num_classes=self.model.num_classes
                ),
                "test_f1": MulticlassF1Score(
                    num_classes=self.model.num_classes, average="macro"
                ),
                "test_auroc": MulticlassAUROC(num_classes=self.model.num_classes),
            }
        )

    def log_metrics_epoch(self, set_):
        """
        🏁
        """
        if set_ == "train":
            train_ = self.train_metrics.compute()
            self.log("train_acc", train_["train_acc"], sync_dist=self.sync_dist)
            self.log("train_f1", train_["train_f1"], sync_dist=self.sync_dist)
            self.log("train_auroc", train_["train_auroc"], sync_dist=self.sync_dist)
            self.train_metrics.reset()
        if set_ == "val":
            val_ = self.val_metrics.compute()
            self.log("val_acc", val_["val_acc"], sync_dist=self.sync_dist)
            self.log("val_f1", val_["val_f1"], sync_dist=self.sync_dist)
            self.log("val_auroc", val_["val_auroc"], sync_dist=self.sync_dist)
            self.val_metrics.reset()
        elif set_ == "test":
            test_ = self.test_metrics.compute()
            self.log("test_acc", test_["test_acc"], sync_dist=self.sync_dist)
            self.log("test_f1", test_["test_f1"], sync_dist=self.sync_dist)
            self.log("test_auroc", test_["test_auroc"], sync_dist=self.sync_dist)
            self.test_metrics.reset()

    def log_metrics_step(self, preds, labels, set_):
        """
        🏁
        """
        if set_ == "train":
            self.train_metrics.update(preds, labels)
        elif set_ == "val":
            self.val_metrics.update(preds, labels)
        elif set_ == "test":
            self.test_metrics.update(preds, labels)


class MultiClassificationModule(ClassificationModule):
    def __init__(
        self,
        model: Optional[torch.nn.Module] = None,
        optim_class: Optional[Type] = None,
        optim_kwargs: Optional[Mapping] = None,
        scheduler_class: Optional[Type] = None,
        scheduler_kwargs: Optional[Mapping] = None,
        log_lr: bool = True,
        log_grad_norm: bool = False,
        sync_dist: bool = False,
        plot_dict: Optional[Mapping] = None,
        ogbg_evaluator: Optional[
            bool
        ] = False,  # compute the metrics using the OGB evaluator
        ogbg_dataset: Optional[str] = "ogbg-molpcba",
        **kwargs,
    ):
        super().__init__(
            model=model,
            optim_class=optim_class,
            optim_kwargs=optim_kwargs,
            scheduler_class=scheduler_class,
            scheduler_kwargs=scheduler_kwargs,
            log_lr=log_lr,
            log_grad_norm=log_grad_norm,
            sync_dist=sync_dist,
            plot_dict=plot_dict,
        )

        self.loss = torch.nn.BCEWithLogitsLoss()
        self.ogbg_evaluator = ogbg_evaluator
        self.ogbg_dataset = ogbg_dataset
        self.init_metrics()

    def compute_task_loss(self, logits, labels):
        mask = ~torch.isnan(labels)
        task_loss = self.loss(logits[mask], labels[mask])
        return task_loss

    def init_metrics(self):
        if self.ogbg_evaluator:
            self.evaluator = Evaluator(name=self.ogbg_dataset)
        else:
            self.train_ap = AveragePrecision(
                num_classes=self.model.num_classes, task="binary"
            )
            self.val_ap = AveragePrecision(
                num_classes=self.model.num_classes, task="binary"
            )
            self.test_ap = AveragePrecision(
                num_classes=self.model.num_classes, task="binary"
            )

    def log_metrics_step(self, preds, labels, set_):
        if set_ == "train":
            if self.ogbg_evaluator:
                metrics = self.evaluator.eval({"y_true": labels, "y_pred": preds})
                self.log(
                    "train_ap",
                    metrics["ap"],
                    batch_size=labels.size(0),
                    on_step=False,
                    on_epoch=True,
                    prog_bar=True,
                    sync_dist=self.sync_dist,
                )
            else:
                mask = ~torch.isnan(labels)
                self.train_ap(preds[mask], labels[mask].long())
                self.log(
                    "train_ap",
                    self.train_ap,
                    batch_size=labels.size(0),
                    on_step=False,
                    on_epoch=True,
                    prog_bar=True,
                    sync_dist=self.sync_dist,
                )
        elif set_ == "val":
            if self.ogbg_evaluator:
                metrics = self.evaluator.eval({"y_true": labels, "y_pred": preds})
                self.log(
                    "val_ap",
                    metrics["ap"],
                    batch_size=labels.size(0),
                    on_step=False,
                    on_epoch=True,
                    prog_bar=True,
                    sync_dist=self.sync_dist,
                )
            else:
                mask = ~torch.isnan(labels)
                self.val_ap(preds[mask], labels[mask].long())
                self.log(
                    "val_ap",
                    self.val_ap,
                    batch_size=labels.size(0),
                    on_step=False,
                    on_epoch=True,
                    prog_bar=True,
                    sync_dist=self.sync_dist,
                )
        elif set_ == "test":
            if self.ogbg_evaluator:
                metrics = self.evaluator.eval({"y_true": labels, "y_pred": preds})
                self.log(
                    "test_ap",
                    metrics["ap"],
                    batch_size=labels.size(0),
                    on_step=False,
                    on_epoch=True,
                    prog_bar=True,
                    sync_dist=self.sync_dist,
                )
            else:
                mask = ~torch.isnan(labels)
                self.test_ap(preds[mask], labels[mask].long())
                self.log(
                    "test_ap",
                    self.test_ap,
                    batch_size=labels.size(0),
                    on_step=False,
                    on_epoch=True,
                    prog_bar=True,
                    sync_dist=self.sync_dist,
                )


class RegressionModule(ClassificationModule):
    def __init__(
        self,
        model: Optional[torch.nn.Module] = None,
        optim_class: Optional[Type] = None,
        optim_kwargs: Optional[Mapping] = None,
        scheduler_class: Optional[Type] = None,
        scheduler_kwargs: Optional[Mapping] = None,
        log_lr: bool = True,
        log_grad_norm: bool = False,
        sync_dist: bool = False,
        plot_dict: Optional[Mapping] = None,
        **kwargs,
    ):
        super().__init__(
            model=model,
            optim_class=optim_class,
            optim_kwargs=optim_kwargs,
            scheduler_class=scheduler_class,
            scheduler_kwargs=scheduler_kwargs,
            log_lr=log_lr,
            log_grad_norm=log_grad_norm,
            sync_dist=sync_dist,
            plot_dict=plot_dict,
        )

        self.loss = torch.nn.MSELoss()
        self.init_metrics()

    def compute_task_loss(self, logits, labels):
        return self.loss(logits, labels)

    def init_metrics(self):
        self.train_mae = torchmetrics.MeanAbsoluteError()
        self.val_mae = torchmetrics.MeanAbsoluteError()
        self.test_mae = torchmetrics.MeanAbsoluteError()

    def log_metrics_step(self, preds, labels, set_):
        if set_ == "train":
            self.train_mae(preds, labels)
            self.log(
                "train_mae",
                self.train_mae,
                batch_size=labels.size(0),
                on_step=False,
                on_epoch=True,
                prog_bar=True,
                sync_dist=self.sync_dist,
            )
        elif set_ == "val":
            self.val_mae(preds, labels)
            self.log(
                "val_mae",
                self.val_mae,
                batch_size=labels.size(0),
                on_step=False,
                on_epoch=True,
                prog_bar=True,
                sync_dist=self.sync_dist,
            )
        elif set_ == "test":
            self.test_mae(preds, labels)
            self.log(
                "test_mae",
                self.test_mae,
                batch_size=labels.size(0),
                on_step=False,
                on_epoch=True,
                prog_bar=True,
                sync_dist=self.sync_dist,
            )
