import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from .predictor import Predictor
from .identity import Identity
from .diff_predictor import diff_Predictor


class ConformalTemperatureScaling(nn.Module):
    def __init__(self, model, alpha):
        super().__init__()
        self.alpha = alpha
        self.temperature = nn.Parameter(torch.log(torch.tensor(1.5)).cuda())
        preprocessor = Identity(temperature=1.0)
        self.predictor = diff_Predictor(preprocessor=preprocessor, model=model, alpha=self.alpha)
        self.lr = 0.8
        self.stop = 0.05

    def train(self, logits, labels):
        optimizer = optim.SGD([self.temperature], lr=self.lr)

        for iter in range(10000):
            optimizer.zero_grad()
            T_old = self.temperature.item()
            out = logits / torch.exp(self.temperature)
            loss = self.criterion(out, labels)
            loss.backward()
            optimizer.step()
            print(torch.exp(self.temperature))

            if abs(self.temperature.item() - T_old) < self.stop:
                break

        print('Optimal temperature: %.3f' % torch.exp(self.temperature).item())
        return 0.0, 0.0

    def forward(self, logits, softmax=True):
        if softmax:
            softmax = nn.Softmax(dim=1)
            return softmax(logits / torch.exp(self.temperature))

        return logits / torch.exp(self.temperature)

    def criterion(self, logits, labels, fraction=0.5):
        val_split = int(fraction * logits.shape[0])
        cal_logits = logits[:val_split]
        cal_labels = labels[:val_split]
        test_logits = logits[val_split:]
        test_labels = labels[val_split:]

        self.predictor.calculate_threshold(cal_logits, cal_labels, random=False)
        tau = self.predictor.q_hat
        test_scores = self.predictor.score_function(test_logits, random=False)
        loss = torch.mean((tau - test_scores[range(test_scores.shape[0]), test_labels]) ** 2)
        return loss
