"""
Generate OOD Results on the 5 standard datasets
LSUN(C), LSUN(R), iSUN, ImageNet(C) and ImageNet(R)
"""
import numpy as np
import os
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split

from utils_train import get_base_datasets, get_pretrained_model,ImageFolderOOD
import config

from ood_metrics import calc_metrics
from sklearn.metrics import roc_curve

from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from scipy.special import softmax

from sklearn.kernel_approximation import Nystroem
from sklearn.linear_model import SGDOneClassSVM

import warnings
# ignore SourceChangeWarning
warnings.filterwarnings("ignore", category=UserWarning)

checkpoints_pretrained = {
"resnet34_cifar10": config.CHECKPOINT_DIR + "quantile_resnet34_cifar10.ckpt",
"resnet34_cifar100": config.CHECKPOINT_DIR + "quantile_resnet34_cifar100.ckpt",
"resnet34_svhn": config.CHECKPOINT_DIR + "quantile_resnet34_svhn.ckpt",
"densenet_cifar10": config.CHECKPOINT_DIR + "quantile_densenet_cifar10.ckpt",
"densenet_svhn": config.CHECKPOINT_DIR + "quantile_densenet_svhn.ckpt",
"densenet_cifar100": config.CHECKPOINT_DIR + "quantile_densenet_cifar100.ckpt",
}

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
torch.set_float32_matmul_precision('high')

device = torch.device("cuda:1")

def establish_pretrained_model_quantile(name_base_model: str, remove_last_layer: bool = True):    
    """
    """
    model, num_classes, size_dataset = get_pretrained_model(name_base_model)
    model.conv1 = nn.Conv2d(
            in_channels=model.conv1.in_channels + 1,
            out_channels=model.conv1.out_channels,
            kernel_size=model.conv1.kernel_size,
            stride=model.conv1.stride,
            padding=model.conv1.padding,
            bias=False,
        )
    state_dict = torch.load(checkpoints_pretrained[name_base_model])["state_dict"]
    del state_dict['quantiles_list']
    model.load_state_dict(dict([(name.replace('backbone.',''), param) for name,param in state_dict.items()]),strict=True)

    # Remove the last layer
    if remove_last_layer:
        try:
            model.linear = nn.Identity()
        except:
            model.fc = nn.Identity()

    model.eval()

    # Make all the parameters in the backbone non-trainable
    for param in model.parameters():
        param.requires_grad_(False)

    return model

def establish_pretrained_model_baseline(name_base_model: str, remove_last_layer: bool = True):    
    """
    """
    model, num_classes, size_dataset = get_pretrained_model(name_base_model)
    
    # Remove the last layer
    if remove_last_layer:
        try:
            model.linear = nn.Identity()
        except:
            model.fc = nn.Identity()

    model.eval()

    # Make all the parameters in the backbone non-trainable
    for param in model.parameters():
        param.requires_grad_(False)

    return model


def get_quantile_representations(model, dataset):
    """
    - assumed that model outputs logits
    """
    model = model.to(device)
    dataloader = DataLoader(dataset, batch_size=512, shuffle=False, num_workers=5, pin_memory=True)
    features = []
    labels = []
    for batch in tqdm(dataloader, total=len(dataloader)):
        x, label = batch
        x = x.to(device)
        with torch.no_grad():
            features_tmp = []
            quant_val = (torch.ones((x.shape[0], 1, x.shape[2], x.shape[3])).float()).to(device)
            x = torch.cat([x, quant_val], dim=1)
            for tau in torch.linspace(0,1,102)[1:-1]:
                x[:,-1,:,:] = tau
                logits = model(x)
                features_tmp.append(logits)
            features_tmp = torch.stack(features_tmp, dim=1)
            features.append(features_tmp.cpu().numpy())
            labels.append(label.numpy())
    features = np.concatenate(features, axis=0)
    labels = np.concatenate(labels, axis=0)
    return features, labels

def get_baseline_representation(model, dataset):
    """
    - assumed that model outputs logits
    """
    model = model.to(device)
    dataloader = DataLoader(dataset, batch_size=512, shuffle=False, num_workers=5, pin_memory=True)
    logits_tot = []
    labels = []
    for batch in tqdm(dataloader, total=len(dataloader)):
        x, label = batch
        x = x.to(device)
        with torch.no_grad():
            logits = model(x)
            logits_tot.append(logits.cpu().numpy())
            labels.append(label.numpy())
    logits_tot = np.concatenate(logits_tot, axis=0)
    labels = np.concatenate(labels, axis=0)
    return logits_tot, labels


if __name__ == "__main__":

    name_base_module = "resnet34_svhn"
    version = "v0" 
    
    path_save = config.RESULTS_DIR + f"{name_base_module}_{version}/results.csv"
    os.makedirs(os.path.dirname(path_save), exist_ok=True)
    with open(path_save, "w") as f:
        f.write("base_module,method,dataset,AUROC,TNR_TPR95,Detection_Acc\n")
    
    model_quantile = establish_pretrained_model_quantile(name_base_module, remove_last_layer=False)
    ID_train, ID_dataset, num_classes, data_transform = get_base_datasets(name_base_module)

    features_train_quantile, labels_train = get_quantile_representations(model_quantile, ID_train)
    # with open("./dump/{}_{}_features_train.pkl".format(name_base_module, "v2_match"), "rb") as f:
    #     features_train_quantile = pickle.load(f)
    # with open("./dump/{}_{}_labels_train.pkl".format(name_base_module, "v2_match"), "rb") as f:
    #     labels_train = pickle.load(f)    

    Xtrain = features_train_quantile.reshape((features_train_quantile.shape[0], -1))

    clf_nystroem = Nystroem(kernel="rbf", n_components=5000)
    Xtrain = clf_nystroem.fit_transform(Xtrain)
    clf_total = SGDOneClassSVM(nu=0.1, verbose=4)
    clf_total.fit(Xtrain)
    
    features_test_quantile, labels_test = get_quantile_representations(model_quantile, ID_dataset)
    # with open("./dump/{}_{}_features.pkl".format(name_base_module, "v2_match"), "rb") as f:
    #     features_test_quantile = pickle.load(f)
    # with open("./dump/{}_{}_labels.pkl".format(name_base_module, "v2_match"), "rb") as f:
    #     labels_test = pickle.load(f)

    Xtest = features_test_quantile.reshape((features_test_quantile.shape[0], -1))
    Xtest_nystroem = clf_nystroem.transform(Xtest)
    
    model_baseline = establish_pretrained_model_baseline(name_base_module, remove_last_layer=False)

    features_test_baseline, labels_test_baseline = get_baseline_representation(model_baseline, ID_dataset)
    probs = softmax(features_test_baseline, axis=1)
    preds = np.argmax(probs, axis=1)
    acc = np.mean(preds==labels_test)
    print(f"Accuracy of the baseline model on the test set: {acc}")

    for name_ood_dataset in ['LSUN', 'iSUN', 'Imagenet', 'LSUN_resize', 'Imagenet_resize']:
        OOD_dataset = ImageFolderOOD(root=config.DATA_DIR+f"{name_ood_dataset}/", transform = data_transform)

        features_ood, _ = get_quantile_representations(model_quantile, OOD_dataset)
        Xood = features_ood.reshape((features_ood.shape[0], -1))
        Xood_nystroem = clf_nystroem.transform(Xood)


        scores_id = clf_total.score_samples(Xtest_nystroem)
        scores_ood = clf_total.score_samples(Xood_nystroem)

        scores = np.concatenate([scores_id, scores_ood])
        labels = np.concatenate([np.zeros_like(scores_id), np.ones_like(scores_ood)])

        dict_score = calc_metrics(scores, 1-labels)
        fpr, tpr, _ = roc_curve(1-labels, scores)

        auroc = dict_score['auroc']*100
        tnr_at_95_tpr = (1 - dict_score['fpr_at_95_tpr'])*100
        detection_acc = (100*0.5*(tpr + 1 - fpr).max())
        
        with open(path_save, "a") as f:
            f.write(f"{name_base_module},quantile,{name_ood_dataset},{auroc},{tnr_at_95_tpr},{detection_acc}\n")
        

        clf_theory = LogisticRegression(max_iter=10000)
        Xtrain = np.concatenate([features_test_quantile, features_ood], axis=0)
        Xtrain = Xtrain.reshape(Xtrain.shape[0], -1)
        scale = StandardScaler()
        Xtrain = scale.fit_transform(Xtrain)
        Ytrain = np.concatenate([np.zeros_like(labels_test), np.ones(len(features_ood))], axis=0)
        clf_theory.fit(Xtrain, Ytrain)
        print("Accuracy on test/ood split: {}".format(clf_theory.score(Xtrain, Ytrain)))

        scores_id = clf_theory.predict_proba(Xtrain)[:len(labels_test), 0]
        scores_ood = clf_theory.predict_proba(Xtrain)[len(labels_test):, 0]

        scores = np.concatenate([scores_id, scores_ood])
        labels = np.concatenate([np.zeros_like(scores_id), np.ones_like(scores_ood)])

        dict_score = calc_metrics(scores, 1-labels)
        fpr, tpr, _ = roc_curve(1-labels, scores)

        auroc = dict_score['auroc']*100
        tnr_at_95_tpr = (1 - dict_score['fpr_at_95_tpr'])*100
        detection_acc = (100*0.5*(tpr + 1 - fpr).max())
        
        with open(path_save, "a") as f:
            f.write(f"{name_base_module},quantile_theory,{name_ood_dataset},{auroc},{tnr_at_95_tpr},{detection_acc}\n")

        """
        Baseline OOD Scores
        """

        features_ood, _ = get_baseline_representation(model_baseline, OOD_dataset)

        scores_id = softmax(features_test_baseline, axis=1)
        scores_id = np.max(scores_id, axis=1)
        scores_ood = softmax(features_ood, axis=1)
        scores_ood = np.max(scores_ood, axis=1)

        scores = np.concatenate([scores_id, scores_ood])
        labels = np.concatenate([np.zeros_like(scores_id), np.ones_like(scores_ood)])

        dict_score = calc_metrics(scores, 1-labels)
        fpr, tpr, _ = roc_curve(1-labels, scores)

        auroc = dict_score['auroc']*100
        tnr_at_95_tpr = (1 - dict_score['fpr_at_95_tpr'])*100
        detection_acc = (100*0.5*(tpr + 1 - fpr).max())
        
        with open(path_save, "a") as f:
            f.write(f"{name_base_module},baseline,{name_ood_dataset},{auroc},{tnr_at_95_tpr},{detection_acc}\n")

        
        # Compute Theoretical Maximum

        clf_theory = LogisticRegression(max_iter=10000)
        Xtrain = np.concatenate([features_test_baseline, features_ood], axis=0)
        scale = StandardScaler()
        Xtrain = scale.fit_transform(Xtrain)
        Ytrain = np.concatenate([np.zeros_like(labels_test), np.ones(len(features_ood))], axis=0)
        clf_theory.fit(Xtrain, Ytrain)
        print("Accuracy on test/ood split: {}".format(clf_theory.score(Xtrain, Ytrain)))

        scores_id = clf_theory.predict_proba(Xtrain)[:len(labels_test), 0]
        scores_ood = clf_theory.predict_proba(Xtrain)[len(labels_test):, 0]

        scores = np.concatenate([scores_id, scores_ood])
        labels = np.concatenate([np.zeros_like(scores_id), np.ones_like(scores_ood)])

        dict_score = calc_metrics(scores, 1-labels)
        fpr, tpr, _ = roc_curve(1-labels, scores)

        auroc = dict_score['auroc']*100
        tnr_at_95_tpr = (1 - dict_score['fpr_at_95_tpr'])*100
        detection_acc = (100*0.5*(tpr + 1 - fpr).max())

        with open(path_save, "a") as f:
            f.write(f"{name_base_module},baseline_theory,{name_ood_dataset},{auroc},{tnr_at_95_tpr},{detection_acc}\n")

        