from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional

import torch
from torch import Tensor, nn, optim
from tqdm import tqdm

from puupl.lib.losses import ECELoss


class PostProcessor(ABC):
    @abstractmethod
    def fit(self, logits: Tensor, pu_labels: Tensor, true_labels: Tensor
            ) -> "PostProcessor":
        """
        Fits the post-processor parameters using the predictions logits,
        the PU labels (1 = positive, 0 = unlabeled), and/or the true labels
        """
        pass

    @abstractmethod
    def scale(self, logits: Tensor) -> Tensor:
        """
        Transform the logits.
        """
        pass


class PusbScaler(PostProcessor):
    """
    The PUSB method [1] requires a custom threshold to separate
    positives and negatives. Since we use 0.5 in the rest of the code,
    here we compute their threshold and scale the predicted logits so
    that, after scaling, negative values are below the threshold and
    positive values are above it.

    [1] https://openreview.net/pdf?id=rJzLciCqKm
    """

    def __init__(self, prior: float):
        self.prior = prior
        self.theta: Optional[float] = None

    def scale(self, logits: Tensor) -> Tensor:
        if self.theta is None:
            raise RuntimeError('fit must be called before scale')
        return logits - self.theta

    def fit(self, logits: Tensor, pu_labels: Tensor, true_labels: Tensor) -> "PusbScaler":
        n = int(len(logits) * self.prior)
        s = sorted(logits.cpu().numpy(), reverse=True)
        self.theta = (s[n] + s[n + 1]) / 2
        tqdm.write(f'nnPUSB threshold is {self.theta}')
        return self


class TemperatureScaler(PostProcessor):
    """
    The temperature scaling module

    Can fit a temperature parameter from predicted logits on a validation data set

    as "inspired by"
    https://github.com/gpleiss/temperature_scaling/blob/master/temperature_scaling.py
    """
    def __init__(self) -> None:
        self.temperature = nn.Parameter(torch.ones(1) * 1.5)

    def scale(self, logits: Tensor) -> Tensor:
        return logits / self.temperature.to(logits.device)

    def fit(self, logits: Tensor, pu_labels: Tensor, true_labels: Tensor
            ) -> 'TemperatureScaler':
        """
        Tune the tempearature of the model (using the validation set).
        We're going to set it to optimize NLL.
        """
        logits = logits.cpu()
        labels = true_labels.cpu()

        nll_criterion = nn.BCEWithLogitsLoss()
        ece_criterion = ECELoss()

        # Calculate NLL and ECE before temperature scaling
        before_temperature_nll = nll_criterion(logits, labels).item()
        before_temperature_ece = ece_criterion(logits, labels).item()
        print('Before temperature - NLL: %.3f, ECE: %.3f' % (
            before_temperature_nll, before_temperature_ece))

        # optimize the temperature w.r.t. NLL
        optimizer = optim.LBFGS([self.temperature], lr=0.01, max_iter=50)

        def evaluate() -> Tensor:
            # optimize w.r.t ece
            # loss = nll_criterion(self.scale(logits), labels)
            loss = ece_criterion(self.scale(logits), labels)
            loss.backward()
            return loss
        optimizer.step(evaluate)  # type:ignore[arg-type]

        # Calculate NLL and ECE after temperature scaling
        after_temperature_nll = nll_criterion(self.scale(logits), labels).item()
        after_temperature_ece = ece_criterion(self.scale(logits), labels).item()
        tqdm.write('Optimal temperature: %.3f' % self.temperature.item())
        tqdm.write('After temperature - NLL: %.3f, ECE: %.3f' % (
                   after_temperature_nll, after_temperature_ece))
        return self


postprocessors_dict = {
    'pusb_scaler': PusbScaler,
    'temperature_scaler': TemperatureScaler,
}


def get_postprocessors(config: Optional[Dict[str, Any]]) -> List[PostProcessor]:
    if config is None:
        return []

    ps = []
    for k, v in config.items():
        cls = postprocessors_dict[k.lower()]
        ps.append(cls(**(v or {})))
    return ps
