from torch import nn

from research.wsl_ece.metric.dataloader import PUDataLoader
from research.wsl_ece.metric.loss import LossFunction
from research.wsl_ece.metric.pl_module import ClassificationModule


class PUModule(ClassificationModule):
    def __init__(
        self,
        model: nn.Module,
        prior: float,
        loss_fn: LossFunction = LossFunction.CROSS_ENTROPY,
        lr: float = 0.001,
        weight_decay: float = 2.5e-4,
        predict_probability: bool = False,
        balanced_error: bool = False,
    ):
        super().__init__(
            model=model,
            loss_fn=loss_fn,
            lr=lr,
            weight_decay=weight_decay,
            predict_probability=predict_probability,
        )
        self.prior = prior
        self.balanced_error = balanced_error

    def nnpu_risk(self, positive_outputs, unlabeled_outputs):
        """Calculate nnPU risk.

        See Kiryo et al. (2017) for details.
        """
        positive_risk = self.prior * self.loss_fn(positive_outputs, 1)
        negative_risk = self.loss_fn(unlabeled_outputs, 0) - self.prior * self.loss_fn(positive_outputs, 0)
        if self.balanced_error:
            positive_risk *= 0.5 / self.prior
            negative_risk *= 0.5 / (1 - self.prior)
        risk = positive_risk + negative_risk
        if negative_risk < 0:
            return risk.detach() - negative_risk
        else:
            return risk

    def upu_risk(self, positive_outputs, unlabeled_outputs):
        positive_risk = self.prior * self.loss_fn(positive_outputs, 1)
        negative_risk = self.loss_fn(unlabeled_outputs, 0) - self.prior * self.loss_fn(positive_outputs, 0)
        if self.balanced_error:
            positive_risk *= 0.5 / self.prior
            negative_risk *= 0.5 / (1 - self.prior)
        risk = positive_risk + negative_risk
        return risk

    def pu_accuracy(self, positive_outputs, unlabeled_outputs):
        loss_fn = LossFunction.ZERO_ONE
        positive_accuracy = self.prior * loss_fn(positive_outputs, 0)
        negative_accuracy = loss_fn(unlabeled_outputs, 1) - self.prior * loss_fn(positive_outputs, 1)
        if self.balanced_error:
            positive_accuracy *= 0.5 / self.prior
            negative_accuracy *= 0.5 / (1 - self.prior)
        accuracy = positive_accuracy + negative_accuracy
        return accuracy

    def training_step(self, batch, batch_idx):
        pos_batch, unlabeled_batch = batch["positive"], batch["unlabeled"]
        pos_data, _ = pos_batch
        unlabeled_data, _ = unlabeled_batch

        pos_outputs = self._forward_model(pos_data).squeeze()
        unlabeled_outputs = self._forward_model(unlabeled_data).squeeze()

        # No temperature scaling during training
        loss = self.nnpu_risk(pos_outputs, unlabeled_outputs)
        self.log("train_loss", loss, prog_bar=True)
        self.log("lr", self.trainer.optimizers[0].param_groups[0]["lr"], prog_bar=False)
        return loss

    def validation_step(self, batch, batch_idx):
        pos_batch, unlabeled_batch = batch["positive"], batch["unlabeled"]
        pos_data, _ = pos_batch
        unlabeled_data, _ = unlabeled_batch

        pos_outputs = self._forward_model(pos_data).squeeze()
        unlabeled_outputs = self._forward_model(unlabeled_data).squeeze()

        self.log("val_loss", self.upu_risk(pos_outputs, unlabeled_outputs), prog_bar=False)
        self.log("val_accuracy", self.pu_accuracy(pos_outputs, unlabeled_outputs), prog_bar=True)

    def estimate_steps_per_epoch(self, train_dataloader: PUDataLoader) -> int:
        """Estimate the number of steps per epoch from a PUDataLoader.

        Args:
            train_dataloader (PUDataLoader): The train dataloader.
        Returns:
            int: The estimated number of steps per epoch.
        """
        num_batches = max(len(train_dataloader["positive"]), len(train_dataloader["unlabeled"]))
        self.estimated_steps_per_epoch = num_batches
        return self.estimated_steps_per_epoch
