from .loss_function import _ECELoss
import torch
import torch.optim as optim
import torch.nn as nn 
    
def  get_optimal_parameters_RRA(transformation,calib_loader,device):
    """
    Tune the tempearature of the model (using the validation set).
    We're going to set it to optimize NLL.
    valid_loader (DataLoader): validation set loader
    """
    transformation.to(device)
    nll_criterion = nn.CrossEntropyLoss().to(device)
    ece_criterion = _ECELoss().to(device)
    # First: collect all the logits and labels for the validation set
    logits_list = []
    labels_list = []
    with torch.no_grad():
        for batch_idx, examples in enumerate(calib_loader):
            logits, label = examples[0], examples[1]
            logits_list.append(logits)
            labels_list.append(label)
        # print(len(logits_list))
        # print(examples[0])
        logits = torch.cat(logits_list).to(device)
        labels = torch.cat(labels_list).to(device)
    # Calculate NLL and ECE before temperature scaling
    before_temperature_O_nll = nll_criterion(logits, labels).item()
    before_temperature_nll = nll_criterion(transformation(logits), labels).item()
    before_temperature_ece = ece_criterion(transformation(logits), labels).item()
    before_temperature_O_ece = ece_criterion(logits, labels).item()
    # print('Before temperature - OriginalNLL: %.3f, OriginalECE: %.3f   NLL: %.3f, ECE: %.3f' % (before_temperature_O_nll,before_temperature_O_ece,before_temperature_nll, before_temperature_ece))

    # Next: optimize the temperature w.r.t. NLL
    # optimizer = optim.LBFGS(transformation.parameters(), lr=0.01, max_iter=1000)
    # the setting of the original paper , test for small t
    optimizer = optim.LBFGS(transformation.parameters(), lr=0.1, max_iter=3000,line_search_fn='strong_wolfe')

    def eval():
        optimizer.zero_grad()
        loss = nll_criterion(transformation(logits), labels)
        loss.backward()
        return loss
    optimizer.step(eval)

    # Calculate NLL and ECE after temperature scaling
    after_temperature_nll = nll_criterion(transformation(logits), labels).item()
    after_temperature_ece = ece_criterion(transformation(logits), labels).item()
    # print('After temperature - NLL: %.3f, ECE: %.6f ' % (after_temperature_nll, after_temperature_ece))



    return transformation