from tqdm import tqdm
import torch
import numpy as np
import torch.nn as nn
from copy import deepcopy
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score
from sklearn.mixture import GaussianMixture
from util import *
import torchvision.models
import torchvision.transforms
import torch.nn.functional as F
'''
Simple Likelihood Testing
'''
def eval_likelhiood(in_ll, out_ll, in_datasetname, out_datasetname):
    in_label = np.ones(len(in_ll)).tolist()
    out_label = np.zeros(len(out_ll)).tolist()


    test_input = deepcopy(in_ll)
    test_input.extend(out_ll)
    test_label = deepcopy(in_label)
    test_label.extend(out_label)

    auc = roc_auc_score(test_label,test_input)
    print()
    print("Only Likelihood")
    print(f"In-dist Dataset : {in_datasetname} / Out-dist Dataset : {out_datasetname}")
    print(f"AUC : {auc}")
    print()
    return auc

'''
Serrà, Joan, et al. "Input complexity and out-of-distribution detection with likelihood-based generative models."
arXiv preprint arXiv:1909.11480 (2019).
'''
def eval_complexity(in_ll, out_ll, in_complexity, out_complexity, in_datasetname, out_datasetname):
    
    in_label = np.ones(len(in_ll)).tolist()
    out_label = np.zeros(len(out_ll)).tolist()
    
    in_score = np.array(in_ll) + np.array(in_complexity) * 8 * np.log(2)
    out_score = np.array(out_ll) + np.array(out_complexity) * 8 * np.log(2) # bits comvert

    test_input = deepcopy(in_score.tolist())
    test_input.extend(out_score.tolist())
    test_label = deepcopy(in_label)
    test_label.extend(out_label)
    auc = roc_auc_score(test_label,test_input)

    print()
    print("Complexity + Likelihood Test")
    print(f"In-dist Dataset : {in_datasetname} / Out-dist Dataset : {out_datasetname}")
    print(f"AUC : {auc}")
    print()
    return auc
'''
Typicality Test in Latent Space
'''
def eval_typicality_latent(in_z_tensor, out_z_tensor, in_datasetname, out_datasetname):
    in_label = np.zeros(in_z_tensor.shape[0]).tolist()
    out_label = np.ones(out_z_tensor.shape[0]).tolist() #Larger Score -> Anomaly
    in_score = abs(torch.sum(torch.pow(in_z_tensor,2), axis=1) - (in_z_tensor.shape[1]) **(1/2))
    out_score = abs(torch.sum(torch.pow(out_z_tensor,2), axis=1) - (out_z_tensor.shape[1]) **(1/2))
    print((in_z_tensor.shape[1]) **(1/2))
    test_input = deepcopy(in_score.tolist())
    test_input.extend(out_score.tolist())
    test_label = deepcopy(in_label)
    test_label.extend(out_label)
    auc = roc_auc_score(test_label,test_input)

    print()
    print("Typicality Test in Latent Space")
    print(f"In-dist Dataset : {in_datasetname} / Out-dist Dataset : {out_datasetname}")
    print(f"AUC : {auc}")
    print()
    return auc
'''
Nalisnick, Eric, et al. "Detecting out-of-distribution inputs to deep generative models using typicality." 
arXiv preprint arXiv:1906.02994 (2019).
'''
def eval_typicality_entropy(model, device, in_dist_train_loader, in_ll, out_ll, in_datasetname, out_datasetname, mlp_flag):
    train_ll = []
    if mlp_flag==True:
        with torch.no_grad():
            for i, x in enumerate(iter(in_dist_train_loader)):
                ll = model.log_prob(x.to(device)) 
                train_ll.extend((ll).cpu().detach().tolist())

    else:
        with torch.no_grad():
            for i, x in enumerate(iter(in_dist_train_loader)):
                nll = model(x.to(device)) 
                train_ll.extend((nll*(-1)).cpu().detach().tolist())

    print(len(train_ll))
    estimated_entropy_negative =  np.mean(np.array(train_ll))

    plt.figure(figsize=(12,12))
    plt.title(estimated_entropy_negative)
    plt.hist(train_ll, label='Train LL',alpha=0.5)
    plt.hist(in_ll, label='Indist Test LL',alpha=0.5)
    plt.hist(out_ll, label='Outdist Test LL',alpha=0.5)
    plt.legend()
    plt.savefig('./tmp.png')
    plt.show()
    
    in_label = np.zeros(len(in_ll)).tolist()
    out_label = np.ones(len(out_ll)).tolist() #Larger Score -> Anomaly

    in_score = abs(np.array(in_ll) - estimated_entropy_negative)
    out_score = abs(np.array(out_ll) - estimated_entropy_negative)
    plt.figure(figsize=(12,12))
    plt.hist(in_score, label='In_test',alpha=0.5)
    plt.hist(out_score, label='Out_test',alpha=0.5)
    plt.legend()
    plt.savefig('./tmp_score.png')
    plt.show()
    
    test_input = deepcopy(in_score.tolist())
    test_input.extend(out_score.tolist())
    test_label = deepcopy(in_label)
    test_label.extend(out_label)
    auc = roc_auc_score(test_label,test_input)

    print()
    print("Typicality Test Using Entropy Estimation")
    print(f"In-dist Dataset : {in_datasetname} / Out-dist Dataset : {out_datasetname}")
    print(f"AUC : {auc}")
    print()
    return auc

def statistic_gmm(in_z_ll, in_dist_test_complexity, out_z_ll, out_dist_test_complexity, in_datasetname, out_datasetname, seed):
    gmm = GaussianMixture(n_components=3, random_state=seed)
    in_dist_input = np.stack((in_z_ll, in_dist_test_complexity), axis=1)
    out_dist_input = np.stack((out_z_ll, out_dist_test_complexity), axis=1)

    gmm.fit(in_dist_input)
    in_score = gmm.score_samples(in_dist_input)
    out_score = gmm.score_samples(out_dist_input)

    in_label = np.ones(len(in_score)).tolist()
    out_label = np.zeros(len(out_score)).tolist()

    test_input = deepcopy(in_score.tolist())
    test_input.extend(out_score.tolist())
    test_label = deepcopy(in_label)
    test_label.extend(out_label)
    auc = roc_auc_score(test_label,test_input)

    print()
    print("Z + Compelexity GMM Likelihood Estimation")
    print(f"In-dist Dataset : {in_datasetname} / Out-dist Dataset : {out_datasetname}")
    print(f"AUC : {auc}")
    print()
    return auc

def likelihood_ratio(background_model, in_ll, out_ll, K, device, in_dist_test_loader, out_dist_test_loader, in_datasetname, out_datasetname):
    background_in_ll = []
    background_out_ll = []
    
    with torch.no_grad():
        for i, x in enumerate(iter(in_dist_test_loader)):
            nll = background_model(x.to(device))
            nll_np = nll.cpu().numpy()
            background_in_ll.extend((nll_np*(-1)).tolist())

        for i, x in enumerate(iter(out_dist_test_loader)):
            nll = background_model(x.to(device))
            nll_np = nll.cpu().numpy()
            background_out_ll.extend((nll_np*(-1)).tolist())

    in_ll_np = np.array(in_ll)
    out_ll_np = np.array(out_ll)
    background_in_ll = np.array(background_in_ll)
    background_out_ll = np.array(background_out_ll)

    in_score = in_ll_np - background_in_ll
    out_score = out_ll_np - background_out_ll

    in_label = np.ones(len(in_score)).tolist()
    out_label = np.zeros(len(out_score)).tolist()

    test_input = deepcopy(in_score.tolist())
    test_input.extend(out_score.tolist())
    test_label = deepcopy(in_label)
    test_label.extend(out_label)
    auc = roc_auc_score(test_label,test_input)


    print()
    print("Likelihood Ratio (Background)")
    print(f"In-dist Dataset : {in_datasetname} / Out-dist Dataset : {out_datasetname}")
    print(f"AUC : {auc}")
    print()
    return auc


def perturb_pretrained(model, device, batch_size, in_dist_train_loader, in_dist_test_loader, out_dist_test_loader, in_datasetname, out_datasetname, feature_extractor, mlp_flag):

    def preprocess_tensor_image(img_tensor):

        img_tensor = F.interpolate(img_tensor, size=(224, 224), mode='bilinear', align_corners=False)
        normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225])
        img_tensor = normalize(img_tensor)
        return img_tensor

    def cosine_similarity_matrix(A: torch.Tensor, B: torch.Tensor):
        A_norm = F.normalize(A, p=2, dim=1) 
        B_norm = F.normalize(B, p=2, dim=1)
        sim_matrix = A_norm @ B_norm.T
        sim_matrix = sim_matrix.cpu().detach()
        return sim_matrix

    pretrained_model = None
    embedding_extractor =None
    if feature_extractor == 'resnet50':
        pretrained_model = torchvision.models.resnet50(pretrained=True).to(device)
        pretrained_model.eval() 
        embedding_extractor = nn.Sequential(*list(pretrained_model.children())[:-1]) 

    elif feature_extractor == 'resnet101':
        pretrained_model = torchvision.models.resnet101(pretrained=True).to(device)
        pretrained_model.eval()  
        embedding_extractor = nn.Sequential(*list(pretrained_model.children())[:-1]) 

    elif feature_extractor == 'resnet152':
        pretrained_model = torchvision.models.resnet152(pretrained=True).to(device)
        pretrained_model.eval()  
        embedding_extractor = nn.Sequential(*list(pretrained_model.children())[:-1]) 


    with torch.no_grad():
        def emb_extract(dataloader, threshold):
            emb_list = []
            for x in tqdm(iter(dataloader)):
                x = preprocess_tensor_image(x)
                x = x.to(device)
                x = embedding_extractor(x).cpu().detach().reshape(-1, 2048)
                emb_list.extend(x)
                del x
            emb_tensor = torch.stack(emb_list, axis=0)
            if threshold is None:
                random_indices = torch.randperm(emb_tensor.size(0))[:1000]
                threshold = torch.quantile(emb_tensor[random_indices],0.90)
            emb_tensor = torch.clamp(emb_tensor, max=threshold)
            return emb_tensor, threshold
             
    in_train_emb_tensor, threshold = emb_extract(in_dist_train_loader, None)
    in_test_emb_tensor, _ = emb_extract(in_dist_test_loader, threshold)
    out_test_emb_tensor, _ = emb_extract(out_dist_test_loader, threshold)

    sim_in = cosine_similarity_matrix(in_test_emb_tensor, in_train_emb_tensor)
    sim_out = cosine_similarity_matrix(out_test_emb_tensor, in_train_emb_tensor)


    sim_in_max = sim_in.max(axis=1)[0].reshape(-1, 1, 1, 1)
    sim_out_max = sim_out.max(axis=1)[0].reshape(-1, 1, 1, 1)

    sim_in_loader = torch.utils.data.DataLoader(sim_in_max, batch_size=batch_size, shuffle=False, drop_last = False)
    sim_out_loader = torch.utils.data.DataLoader(sim_out_max, batch_size=batch_size, shuffle=False, drop_last = False)
    model.eval()
    if in_datasetname in ['MNIST', 'FashionMNIST']:
        alpha_list = [i / 30 for i in range(1, 31)]
    else:
        alpha_list = [i / 20 for i in range(1, 21)]

    alpha_list=[0.4]

    non_nan_alpha_list = []
    auc_list = []
    for alpha in alpha_list:
        in_ll = []
        out_ll = []
        with torch.no_grad():
            for x, sim_tensor in zip(in_dist_test_loader, sim_in_loader ):
                perturb = torch.randn(x.shape)
                x_perturbed = x + alpha * (1-sim_tensor) * perturb
                #x_perturbed =   alpha * (1-sim_tensor) * perturb
                nll = model(x_perturbed.to(device))
                nll_np = nll.cpu().numpy()
                in_ll.extend((nll_np*(-1)).tolist())

            for x, sim_tensor in zip(out_dist_test_loader, sim_out_loader):
                perturb = torch.randn(x.shape)
                x_perturbed = x + alpha * (1-sim_tensor) * perturb
                #x_perturbed =   alpha * (1-sim_tensor) * perturb
                nll = model(x_perturbed.to(device))
                nll_np = nll.cpu().numpy()
                out_ll.extend((nll_np*(-1)).tolist())
        in_label = np.ones(len(in_ll)).tolist()
        out_label = np.zeros(len(out_ll)).tolist()


        test_input = deepcopy(in_ll)
        test_input.extend(out_ll)
        test_label = deepcopy(in_label)
        test_label.extend(out_label)

        if np.any(np.isnan(np.array(test_input)) | np.isinf(np.array(test_input))):
            print(f"alpha : {alpha} -> nan or inf detected")
            break
        else:
            auc = roc_auc_score(test_label,test_input)
            auc_list.append(auc)
            print()
            print("Perturbation using Pretrain Model")
            print(f"In-dist Dataset : {in_datasetname} / Out-dist Dataset : {out_datasetname}")
            print(f"Alpha : {alpha}")
            print(f"AUC : {auc}")
            print()
            non_nan_alpha_list.append(alpha)
    return auc_list, non_nan_alpha_list

def noise_injection(model, device, in_dist_test_loader, out_dist_test_loader, in_datasetname, out_datasetname):
    non_nan_alpha_list = []
    auc_list = []
    alpha_list = [0, 0.02]
    in_ll_list = []
    out_ll_list = []
    for alpha in alpha_list:
        in_ll = []
        out_ll = []
        with torch.no_grad():
            for x in in_dist_test_loader:
                nll = model(x.to(device))
                nll_np = nll.cpu().numpy()
                in_ll.extend((nll_np*(-1)).tolist())

            for x in out_dist_test_loader:
                perturb = torch.randn(x.shape)
                x_perturbed = x + alpha * perturb
                nll = model(x_perturbed.to(device))
                nll_np = nll.cpu().numpy()
                out_ll.extend((nll_np*(-1)).tolist())
        in_label = np.ones(len(in_ll)).tolist()
        out_label = np.zeros(len(out_ll)).tolist()


        test_input = deepcopy(in_ll)
        test_input.extend(out_ll)
        test_label = deepcopy(in_label)
        test_label.extend(out_label)

        if np.any(np.isnan(np.array(test_input)) | np.isinf(np.array(test_input))):
            print(f"alpha : {alpha} -> nan or inf detected")
            break
        else:
            auc = roc_auc_score(test_label,test_input)
            auc_list.append(auc)
            print()
            print("Perturbation using Pretrain Model")
            print(f"In-dist Dataset : {in_datasetname} / Out-dist Dataset : {out_datasetname}")
            print(f"Alpha : {alpha}")
            print(f"AUC : {auc}")
            print()
            non_nan_alpha_list.append(alpha)
        in_ll_list.append(in_ll)
        out_ll_list.append(out_ll)
    return auc_list, non_nan_alpha_list, in_ll_list, out_ll_list


def constant_test(model,in_ll, device, in_dist_test_loader, in_datasetname):

    model.eval()
    out_ll = []
    #const = torch.arange(256).view(256, 1, 1, 1)
    #print(const)
    #out_input = const.expand(-1,3,32,32).clone() / 255.0
    auc_list = []
    with torch.no_grad():
        for i in range(1, 100):
            out_ll = []
            out_input = torch.randn((1000, 3, 32, 32))
            out_input = out_input * 0.00001 * i
            nll = model(out_input.to(device))
            nll_np = nll.cpu().numpy()
            out_ll.extend((nll_np*(-1)).tolist())


            #print(np.mean(out_ll))
            #print(np.mean(in_ll))
            #print(out_ll)
            in_label = np.ones(len(in_ll)).tolist()
            out_label = np.zeros(len(out_ll)).tolist()


            test_input = deepcopy(in_ll)
            test_input.extend(out_ll)
            test_label = deepcopy(in_label)
            test_label.extend(out_label)
            auc = roc_auc_score(test_label,test_input)
            auc_list.append(auc)
    df = pd.DataFrame(auc_list)
    df.to_csv(f'./result/{in_datasetname}.csv')