import torch
from torch import nn
from torch.utils.data import DataLoader
from torchmetrics.classification import BinaryAccuracy

from research.wsl_ece.metric.loss import LossFunction
from research.wsl_ece.metric.pl_module import ClassificationModule


class SupervisedModule(ClassificationModule):
    """
    A PyTorch Lightning module for supervised binary classification.

    This module supports regular supervised learning with labeled data,
    and includes temperature scaling for calibration.
    """

    def __init__(
        self,
        model: nn.Module,
        loss_fn: LossFunction = LossFunction.CROSS_ENTROPY,
        lr: float = 0.001,
        weight_decay: float = 2.5e-4,
        predict_probability: bool = False,
    ):
        """
        Initialize the supervised module.

        Args:
            model: The neural network model
            loss_fn: Loss function to use for training
            lr: Learning rate
            weight_decay: Weight decay for regularization
            temperature: Temperature scaling parameter
            predict_probability: Whether to output probabilities during prediction
        """
        super().__init__(
            model=model,
            loss_fn=loss_fn,
            lr=lr,
            weight_decay=weight_decay,
            predict_probability=predict_probability,
        )
        # Metrics
        self.train_accuracy = BinaryAccuracy()
        self.val_accuracy = BinaryAccuracy()

    def forward(self, x):
        """Forward pass through the model."""
        return self._forward_model(x)

    def training_step(self, batch, batch_idx):
        """Training step for supervised learning."""
        inputs, labels = batch
        labels = labels.to(torch.float32)

        outputs = self._forward_model(inputs).squeeze()
        loss = self.loss_fn(outputs, labels)

        # Calculate accuracy
        predictions = torch.sigmoid(outputs)
        self.train_accuracy(predictions, labels.int())

        # Logging
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_accuracy", self.train_accuracy, prog_bar=True)
        self.log("lr", self.trainer.optimizers[0].param_groups[0]["lr"], prog_bar=False)

        return loss

    def validation_step(self, batch, batch_idx):
        """Validation step for supervised learning."""
        inputs, labels = batch
        labels = labels.to(torch.float32)

        outputs = self._forward_model(inputs).squeeze()
        loss = self.loss_fn(outputs, labels)

        # Calculate accuracy
        predictions = torch.sigmoid(outputs)
        self.val_accuracy(predictions, labels.int())

        # Logging
        self.log("val_loss", loss, prog_bar=False)
        self.log("val_accuracy", self.val_accuracy, prog_bar=True)

        return loss

    def estimate_steps_per_epoch(self, train_dataloader: DataLoader) -> int:
        """
        Estimate the number of steps per epoch from a standard DataLoader.

        Args:
            train_dataloader: The training dataloader

        Returns:
            The estimated number of steps per epoch
        """
        self.estimated_steps_per_epoch = len(train_dataloader)
        return self.estimated_steps_per_epoch
