from abc import abstractmethod
from typing import TypedDict

import numpy as np
import torch
from lightning.pytorch import LightningModule
from torch import nn
from torchmetrics.classification import BinaryAccuracy, BinaryF1Score
from transformers.optimization import get_cosine_schedule_with_warmup

from research.wsl_ece.metric.loss import LossFunction


class NumpyPredictStepOutput(TypedDict):
    predictions: np.ndarray
    labels: np.ndarray


def accumulate_predictions(predict_output) -> NumpyPredictStepOutput:
    predictions = []
    labels = []
    for output in predict_output:
        predictions.append(output["predictions"].detach().cpu().numpy())
        labels.append(output["labels"].detach().cpu().numpy())
    return {"predictions": np.concatenate(predictions), "labels": np.concatenate(labels)}


class ClassificationModule(LightningModule):
    """
    A PyTorch Lightning module for binary classification with temperature scaling.
    """

    estimated_steps_per_epoch: int | None = None

    def __init__(
        self,
        model: nn.Module,
        loss_fn: LossFunction = LossFunction.CROSS_ENTROPY,
        lr: float = 0.001,
        weight_decay: float = 2.5e-4,
        predict_probability: bool = False,
    ):
        super().__init__()
        self.model = model
        self.loss_fn = loss_fn
        self.lr = lr
        self.weight_decay = weight_decay
        self.predict_probability = predict_probability

        # Metrics
        self.test_accuracy = BinaryAccuracy()
        self.test_f1_score = BinaryF1Score()

    def __post_init__(self):
        self.save_hyperparameters(ignore=["model"])

    def _forward_model(self, data):
        """
        Forward pass through the model, handling both tensor and dictionary inputs.

        Args:
            data: Either a tensor (for MNIST/CIFAR) or dict with keys like input_ids, attention_mask, etc. (for DDI)

        Returns:
            Model output tensor
        """
        if isinstance(data, dict):
            # Dictionary input (DDI dataset) - pass as keyword arguments
            return self.model(**data)
        else:
            # Tensor input (MNIST/CIFAR) - pass directly
            data = data.to(torch.float32)
            return self.model(data)

    def test_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self._forward_model(inputs).squeeze()

        self.log("test_loss", self.loss_fn(outputs, labels.to(torch.float32)), prog_bar=False)
        self.log("test_accuracy", self.test_accuracy(outputs, labels), prog_bar=False)
        self.log("test_f1_score", self.test_f1_score(outputs, labels), prog_bar=False)

    def predict_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self._forward_model(inputs).squeeze()
        if self.predict_probability:
            # Scale the outputs to [0, 1]
            outputs = torch.sigmoid(outputs)
        return {"predictions": outputs, "labels": labels}

    def configure_optimizers(self):  # type: ignore
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        assert self.estimated_steps_per_epoch is not None, (
            "estimated_steps_per_epoch is None. Did you call estimate_steps_per_epoch_from_combined_loader?"
        )
        num_training_steps = self.estimated_steps_per_epoch * self.trainer.max_epochs  # type: ignore
        scheduler = get_cosine_schedule_with_warmup(
            optimizer, num_warmup_steps=num_training_steps // 10, num_training_steps=num_training_steps
        )
        return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "interval": "step"}}

    @abstractmethod
    def estimate_steps_per_epoch(self, train_dataloader) -> int:
        """
        Estimate the number of steps per epoch from a standard DataLoader.

        Args:
            train_dataloader: The training dataloader

        Returns:
            The estimated number of steps per epoch
        """
        raise NotImplementedError("This method should be implemented in subclasses.")
