# pylint: disable=arguments-differ,too-many-ancestors
from typing import List, Tuple
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import Callback, EarlyStopping
from torch import nn, optim
from .losses import ListMLELoss


class MLPLightningModule(pl.LightningModule):
    """
    Lightning module which trains an MLP until convergence.
    """

    def __init__(self, model: nn.Sequential, loss: nn.Module, weight_decay: float = 0.0):
        super().__init__()

        self.model = model
        self.loss = loss
        self.weight_decay = weight_decay
        self.uses_ranking = isinstance(self.loss, ListMLELoss)

    def configure_optimizers(self) -> optim.Optimizer:
        return optim.Adam(self.model.parameters(), lr=1e-2, weight_decay=self.weight_decay)

    def configure_callbacks(self) -> List[Callback]:
        return [
            EarlyStopping(
                "train_loss",
                patience=50,
                min_delta=1e-3,
                check_on_train_epoch_end=True,
            )
        ]

    def training_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor], _batch_idx: int
    ) -> torch.Tensor:
        X, y_true, group_ids = batch
        y_pred = self.model(X)

        if self.uses_ranking:
            loss = self.loss(y_pred, y_true, group_ids)
        else:
            loss = self.loss(y_pred, y_true)

        self.log("train_loss", loss, on_step=False, on_epoch=True)
        return loss

    def predict_step(self, batch: torch.Tensor, _batch_idx: int) -> torch.Tensor:
        X = batch[0]
        return self.model(X)
