from __future__ import annotations

from typing import TYPE_CHECKING

from torch import optim

from .._ood_model import _OODModel

if TYPE_CHECKING:
    from jaxtyping import Float
    from torch import Tensor


class _SWAGModel(_OODModel):
    """Base class for SWAG (Stochastic Weight Averaging Gaussian) models in the OOD detection experiment."""

    def forward(
        self, inputs: Float[Tensor, "batch *in_feature"]
    ) -> Float[Tensor, "batch *out_feature"]:
        return self.model(inputs)

    def configure_optimizers(self) -> optim.Optimizer:
        if hasattr(self.model, "parameters_and_lrs"):
            parameters = self.model.parameters_and_lrs(lr=self.lr, optimizer="SGD")
        else:
            parameters = self.model.parameters()

        optimizer = optim.SGD(
            parameters,
            lr=self.lr,
            momentum=self.momentum,
            nesterov=self.nesterov,
            weight_decay=self.weight_decay,
        )
        return optimizer
        # lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(
        #     optimizer=optimizer, T_max=self.max_epochs
        # )
        # return {
        #     "optimizer": optimizer,
        #     "lr_scheduler": {
        #         "scheduler": lr_scheduler,
        #         "interval": "epoch",
        #         "frequency": 1,
        #     },
        # }
        # TODO: possibly needs a specific learning rate scheduler (see https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/)

    def on_epoch_end(self):
        self.model.update_wrapper()
