
import torch
from utils.measures import auroc
from copy import deepcopy
from utils.load_dataset import load_dataset

global best_res
best_res = {'AUROC': -1,
            'AUPR': -1,
            'FPR': -1,
            'TEMP': -1,
            'NOISE_MAG': -1}

global res_cache
res_cache = {}

def pred_for_temp_eps(dataset_in, 
                      dataset_out,
                      temp,
                      noise_mag,
                      net,
                      device):

    dataset_out_len = len(dataset_out.test_loader.dataset)
    dataset_in_len = len(dataset_in.test_loader.dataset)
    criterion = torch.nn.CrossEntropyLoss()
    fc_w = net.linear.weight.data.clone().detach().to(device)
    fc_b = net.linear.bias.data.clone().detach().to(device)

    fc_w_norm = torch.nn.functional.normalize(fc_w, dim=1)

    pred = torch.zeros((dataset_in_len + dataset_out_len)).to(device)
    y = torch.zeros_like(pred).to(device)
    index = 0
    datasets = [dataset_in.test_loader, dataset_out.test_loader]

    for dataset_index, dataset in enumerate(datasets):
        for batch_idx, (data, labels) in enumerate(dataset):
            data = data.to(device)
            labels = labels.to(device)
            data.requires_grad_(True)

            out = net(data)

            _, pred_labels = torch.max(out, dim=1)
            out = out / temp
            loss = criterion(out, pred_labels)
            loss.backward()

            # Normalizing the gradient to binary in {0, 1}
            gradient = torch.ge(data.grad.data, 0)
            gradient = (gradient.float() - 0.5) * 2

            # Normalizing the gradient to the same space of image
            gradient[0][0] = gradient[0][0] / dataset_in.std[0]
            gradient[0][1] = gradient[0][1] / dataset_in.std[1]
            gradient[0][2] = gradient[0][2] / dataset_in.std[2]

            # Adding small perturbations to images
            temp_inputs = torch.add(input=data.data, alpha=-noise_mag, other=gradient)
            with torch.no_grad():
                outputs, fet = net(temp_inputs, latent=True)

                norm_fet = torch.nn.functional.normalize(fet, dim=1)
                cos_theta = torch.mm(norm_fet, fc_w_norm.T).max(dim=1)[0]

                outputs = outputs / temp
                smax = torch.nn.functional.softmax(outputs, dim=1).max(axis=1)[0]
                pred[index: index + data.shape[0]] = smax + cos_theta
                
            # Note dataset_index = 0 for In-Dist
            # and dataset_index = 1 for OoD
            y[index: index + data.shape[0]] = torch.ones_like(labels).to(device) * dataset_index
            index += data.shape[0]

    labels = y.cpu().numpy()
    pred = -pred.cpu().numpy()
    return labels, pred

def set_ODIN_hyperparams_for_indist(dataset_in, net, net_name, device):
    global best_res
    global res_cache

    sp = net_name.split("_")
    key = sp[0] + sp[2]
    
    # If we have already calculated the params for net use the caches params 
    if key in res_cache:
        best_res = res_cache[key]
        return

    # https://arxiv.org/pdf/1706.02690.pdf reports that they use iSUN for hyperparamter tuning
    # and set temperature to 1000
    nmag_list = [0, 0.0005, 0.001, 0.0014, 0.002, 0.0024, 0.005, 0.01, 0.05, 0.1, 0.2]
    temp_list = [1, 10, 100, 1000, 10000]

    ood_dataset = load_dataset(dataset="iSUN",
                               train_batch_size=dataset_in.train_batch_size,
                               test_batch_size=dataset_in.test_batch_size,
                               val_split=dataset_in.val_split,
                               augment=dataset_in.augment,
                               padding_crop=dataset_in.padding_crop,
                               shuffle=dataset_in.shuffle,
                               random_seed=dataset_in.random_seed,
                               device=device,
                               mean=dataset_in.mean,
                               std=dataset_in.std)

    for noise_mag in nmag_list:
        for temp in temp_list:
            print("Temp {}, Noise {}".format(temp, noise_mag))
            labels, pred = pred_for_temp_eps(dataset_in, 
                                             ood_dataset,
                                             temp,
                                             noise_mag,
                                             net,
                                             device)
            
            auroc_val = auroc(pred, labels)

            if(auroc_val > best_res['AUROC']):
                best_res['AUROC'] = auroc_val
                best_res['TEMP'] = temp
                best_res['NOISE_MAG'] = noise_mag

    print("Dataset: {}".format(dataset_in.name))
    print(best_res)
    res_cache[key] = deepcopy(best_res)

#  ODIN
def ODIN(dataset_in, dataset_out, net, device):
    global best_res
    labels, pred = pred_for_temp_eps(dataset_in, 
                                     dataset_out,
                                     best_res['TEMP'],
                                     best_res['NOISE_MAG'],
                                     net,
                                     device)
    
    return labels, pred
