# fmt: off
from argparse import Namespace
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchmetrics import AUROC, Accuracy, AveragePrecision, CalibrationError

import pytorch_lightning as pl
from pytorch_lightning.utilities.memory import get_model_size_mb
from pytorch_lightning.utilities.types import STEP_OUTPUT

from pysemble.networks import ResNet50_GrE
from pysemble.metrics import FPR95Metric, NLLMetric


# fmt: on
class PackedEns(pl.LightningModule):
    def __init__(
        self,
        n_estimators: int = 4,
        augmentation: int = 2,
        n_subgroups: int = 2,
        num_classes: int = 10,
    ) -> None:

        super().__init__()

        self.n_estimators = n_estimators
        self.n_subgroups = n_subgroups
        self.augmentation = augmentation

        self.model = ResNet50_GrE(
            n_estimators=self.n_estimators,
            augmentation=self.augmentation,
            n_subgroups=self.n_subgroups,
            num_classes=num_classes,
        )

        # metrics
        self.val_acc = Accuracy()
        self.test_acc = Accuracy()
        self.test_nll = NLLMetric()
        self.test_ece = CalibrationError()
        self.test_aupr = AveragePrecision(pos_label=1)
        self.test_auroc = AUROC(pos_label=1)
        self.test_fpr95 = FPR95Metric(pos_label=1)

    def configure_optimizers(self) -> dict:
        r"""Hyperparameters from Deep Residual Learning for Image Recognition
        https://arxiv.org/pdf/1512.03385.pdf
        """
        optimizer = optim.SGD(
            self.parameters(),
            lr=0.1,
            momentum=0.9,
            weight_decay=5e-4,
            nesterov=True,
        )
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=[60, 120, 160],
            gamma=0.2,
        )
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

    @property
    def criterion(self) -> nn.Module:
        return nn.CrossEntropyLoss()

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return self.model.forward(input)

    def on_train_start(self) -> None:
        # hyperparameters for performances
        param = {}
        param["storage"] = f"{get_model_size_mb(self)} MB"
        if self.logger is not None:
            self.logger.log_hyperparams(
                Namespace(**param),
                {
                    "hp/val_acc": 0,
                    "hp/test_acc": 0,
                    "hp/test_nll": 0,
                    "hp/test_ece": 0,
                    "hp/test_aupr": 0,
                    "hp/test_auroc": 0,
                    "hp/test_fpr95": 0,
                },
            )

    def training_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
    ) -> STEP_OUTPUT:
        inputs, targets = batch
        targets = targets.repeat(self.n_estimators)
        logits = self.forward(inputs)
        loss = self.criterion(logits, targets)
        self.log("train_loss", loss)
        return loss

    def validation_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
    ) -> Optional[STEP_OUTPUT]:
        inputs, targets = batch
        logits = self.forward(inputs)
        logits = logits.reshape(self.n_estimators, -1, logits.size(-1))
        probs_per_est = F.softmax(logits, dim=-1)
        probs = probs_per_est.mean(dim=0)
        self.val_acc(probs, targets)
        self.log("hp/val_acc", self.val_acc, on_epoch=True, sync_dist=True)

    def test_step(
        self,
        batch: Tuple[torch.Tensor, torch.Tensor],
        batch_idx: int,
        dataloader_idx: int,
    ) -> Optional[STEP_OUTPUT]:
        inputs, targets = batch
        logits = self.forward(inputs)
        logits = logits.reshape(self.n_estimators, -1, logits.size(-1))
        probs_per_est = F.softmax(logits, dim=-1)
        probs = probs_per_est.mean(dim=0)
        confs, _ = probs.max(-1)

        if dataloader_idx == 0:
            self.test_acc(probs, targets)
            self.test_nll(probs, targets)
            self.test_ece(probs, targets)
            self.test_aupr(-confs, torch.zeros_like(targets))
            self.test_auroc(-confs, torch.zeros_like(targets))
            self.test_fpr95(-confs, torch.zeros_like(targets))
            self.log(
                "hp/test_acc",
                self.test_acc,
                on_epoch=True,
                add_dataloader_idx=False,
                sync_dist=True,
            )
            self.log(
                "hp/test_nll",
                self.test_nll,
                on_epoch=True,
                add_dataloader_idx=False,
                sync_dist=True,
            )
            self.log(
                "hp/test_ece",
                self.test_ece,
                on_epoch=True,
                add_dataloader_idx=False,
                sync_dist=True,
            )

        else:
            self.test_aupr(-confs, torch.ones_like(targets))
            self.test_auroc(-confs, torch.ones_like(targets))
            self.test_fpr95(-confs, torch.ones_like(targets))
            self.log(
                "hp/test_aupr",
                self.test_aupr,
                on_epoch=True,
                add_dataloader_idx=False,
                sync_dist=True,
            )
            self.log(
                "hp/test_auroc",
                self.test_auroc,
                on_epoch=True,
                add_dataloader_idx=False,
                sync_dist=True,
            )
            self.log(
                "hp/test_fpr95",
                self.test_fpr95,
                on_epoch=True,
                add_dataloader_idx=False,
                sync_dist=True,
            )
