import math

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.optim import LBFGS
from torch.utils.data import DataLoader
from tqdm import tqdm  # type: ignore
from utils import ece, to_device

T = torch.Tensor


__all__ = ["TemperatureScaler"]


class TemperatureScaler(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        # NOTE: log(e - 1) gives softplus activation of 1
        self.temperature = nn.Parameter(torch.tensor([np.log(math.e - 1)], requires_grad=True))

    def reset_parameters(self) -> None:
        self.temperature.data.fill_(np.log(math.e - 1))

    @property
    def t(self) -> T:
        return torch.clamp(F.softplus(self.temperature), max=100)

    def tune(self, logits: T, labels: T) -> None:
        """
        This assumes that the list of logits for the whole dataset is given as inputs
        Tune the tempearature of the model (using the validation set).
        We're going to set it to optimize NLL.
        """
        self.reset_parameters()

        # Calculate NLL and ECE before temperature scaling, the logits that come out of the meta models
        # are expected to already be log-softmax function applied to them
        before_temperature_nll = F.nll_loss(logits.log_softmax(dim=-1), labels).item()
        before_temperature_ece = ece(labels, logits)[0]
        print('Before temperature - NLL: %.3f, ECE: %.3f' % (before_temperature_nll, before_temperature_ece))

        # Next: optimize the temperature w.r.t. NLL
        optimizer = LBFGS([self.temperature], lr=0.001, max_iter=50)

        def eval() -> float:
            # method must exist
            scaled_logits = logits / self.t
            loss = F.nll_loss(scaled_logits.log_softmax(dim=-1), labels)
            loss.backward()
            return loss.item()

        optimizer.step(eval)

        # Calculate NLL and ECE after temperature scaling
        scaled_logits = logits / self.t
        after_temperature_nll = F.nll_loss(scaled_logits.log_softmax(dim=-1), labels).item()
        after_temperature_ece = ece(labels, scaled_logits)[0]
        print('Optimal temperature: %.3f' % self.t.item())
        print('After temperature - NLL: %.3f, ECE: %.3f' % (after_temperature_nll, after_temperature_ece))

    def forward(self, logits: T) -> T:
        return logits / self.t  # type: ignore

    def tune_few_shot(self, model: nn.Module, device: torch.device, loader: DataLoader, n_way: int = 5, k_shot: int = 5) -> None:
        """
        NOTE: this was originally made for protonet based models
        Tune the tempearature of the model (using the validation set).
        We're going to set it to optimize NLL.
        valid_loader (DataLoader): validation set loader
        """
        # First: collect all the logits and labels for the validation set
        logits_lst, labels_lst, episodes = [], [], 0
        model.eval()
        with torch.no_grad():
            for (x_spt, y_spt, x_qry, y_qry) in tqdm(loader, ncols=75, leave=False):
                for (xs, ys, xq, yq) in zip(x_spt, y_spt, x_qry, y_qry):
                    # the toy datasets put the n_way k_shot into the support tensor so we need to get that value if it exists

                    xs, ys, xq, yq = to_device(xs, ys, xq, yq, device=device)
                    lgts = model.get_logits(xs, ys, xq, n_way=n_way, k_shot=k_shot)  # type: ignore
                    logits_lst.append(lgts)  # type: ignore
                    labels_lst.append(yq)

                    episodes += 1
                if episodes > 500:
                    break

        logits, labels = torch.cat(logits_lst), torch.cat(labels_lst)  # type: ignore
        self.tune(logits, labels)
