'''
Detect OOD samples with CLF
'''

import argparse
import numpy as np
from pathlib import Path
from functools import partial
import matplotlib.pyplot as plt
from sklearn.covariance import EmpiricalCovariance
from scipy import stats
# import sklearn.covariance

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader

from models import get_clf
from utils import compute_all_metrics
from datasets import get_ds_info, get_ds_trf, get_ood_trf, get_ds

def get_msp_score(data_loader, clf):
    clf.eval()

    msp_score = []
    for sample in data_loader:
        data = sample['data'].cuda()

        with torch.no_grad():
            logit = clf(data)

            prob = torch.softmax(logit, dim=1)
            msp_score.extend(torch.max(prob, dim=1)[0].tolist())

    return msp_score

def get_abs_score(data_loader, clf):
    '''
    Probability for absent class
    '''
    clf.eval()

    abs_score = []
    feat_norm = []

    for sample in data_loader:
        data = sample['data'].cuda()

        with torch.no_grad():
            logit = clf(data)
            # logit, feat = clf(data, ret_feat=True)
            # norm2 = torch.norm(feat, p=2, dim=1)

            prob = torch.softmax(logit, dim=1)
            abs_score.extend(prob[:, -1].tolist())
            # feat_norm.extend(norm2.tolist())

    # return [1 - abs for abs in abs_score], feat_norm
    return [1 - abs for abs in abs_score]

def get_logit_score(data_loader, clf):
    clf.eval()

    logit_score = []
    for sample in data_loader:
        data = sample['data'].cuda()

        with torch.no_grad():
            logit = clf(data)
            logit_score.extend(torch.max(logit, dim=1)[0].tolist())

    return logit_score

def get_odin_score(data_loader, clf, temperature=1000.0, magnitude=0.0014, std=(0.2470, 0.2435, 0.2616)):
    clf.eval()
    
    odin_scores = []
    
    for sample in data_loader:
        data = sample['data'].cuda()
        
        data.requires_grad = True
        logit = clf(data)
        pred = logit.detach().argmax(axis=1)
        logit = logit / temperature
        criterion = nn.CrossEntropyLoss()
        loss = criterion(logit, pred)
        loss.backward()
        
        # normalizing the gradient to binary in {-1, 1}
        gradient = torch.ge(data.grad.detach(), 0)
        gradient = (gradient.float() - 0.5) * 2
        
        gradient[:, 0] = gradient[:, 0] / std[0]
        gradient[:, 1] = gradient[:, 1] / std[1]
        gradient[:, 2] = gradient[:, 2] / std[2]
        
        tmpInputs = torch.add(data.detach(), -magnitude, gradient)
        logit = clf(tmpInputs)
        logit = logit / temperature
        # calculating the confidence after add the perturbation
        nnOutput = logit.detach()
        nnOutput = nnOutput - nnOutput.max(dim=1, keepdims=True).values
        nnOutput = nnOutput.exp() / nnOutput.exp().sum(dim=1, keepdims=True)
        
        odin_scores.extend(nnOutput.max(dim=1)[0].tolist())
    
    return odin_scores

def sample_estimator(data_loader, clf, num_classes):
    clf.eval()
    group_lasso = EmpiricalCovariance(assume_centered=False)

    num_sample_per_class = np.zeros(num_classes)
    list_features = [0] * num_classes

    for sample in data_loader:
        data = sample['data'].cuda()
        target = sample['label'].cuda()

        with torch.no_grad():
            _, penulti_feature = clf(data, ret_feat=True)

        # construct the sample matrix
        for i in range(target.size(0)):
            label = target[i]
            if num_sample_per_class[label] == 0:
                list_features[label] = penulti_feature[i].view(1, -1)
            else:
                list_features[label] = torch.cat((list_features[label], penulti_feature[i].view(1, -1)), 0)
            num_sample_per_class[label] += 1

    category_sample_mean = []
    for j in range(num_classes):
        category_sample_mean.append(torch.mean(list_features[j], 0))

    X = 0
    for j in range(num_classes):
        if j == 0:
            X = list_features[j] - category_sample_mean[j]
        else:
            X = torch.cat((X, list_features[j] - category_sample_mean[j]), 0)
        
        # find inverse
    group_lasso.fit(X.cpu().numpy())
    precision = group_lasso.precision_
    
    return category_sample_mean, torch.from_numpy(precision).float().cuda()

def get_mahalanobis_score(data_loader, clf, num_classes, sample_mean, precision):
    '''
    Negative mahalanobis distance to the cloest class center
    '''
    clf.eval()

    nm_score = []
    for sample in data_loader:
        data = sample['data'].cuda()

        with torch.no_grad():
            _, penul_feat = clf(data, ret_feat=True)

        term_gaus = torch.empty(0)
        for j in range(num_classes):
            category_sample_mean = sample_mean[j]
            zero_f = penul_feat - category_sample_mean
            # term_gau = torch.exp(-0.5 * torch.mm(torch.mm(zero_f, precision), zero_f.t()).diag()) # [BATCH,]
            term_gau = -0.5 * torch.mm(torch.mm(zero_f, precision), zero_f.t()).diag() # [BATCH, ]
            if j == 0:
                term_gaus = term_gau.view(-1, 1)
            else:
                term_gaus = torch.cat((term_gaus, term_gau.view(-1, 1)), dim=1)

        nm_score.extend(torch.max(term_gaus, dim=1)[0].tolist())

    return nm_score

def get_energy_score(data_loader, clf, temperature=1.0):
    clf.eval()
    
    energy_score = []

    for sample in data_loader:
        data = sample['data'].cuda()
        
        with torch.no_grad():
            logit = clf(data)
            energy_score.extend((temperature * torch.logsumexp(logit / temperature, dim=1)).tolist())
    
    return energy_score

def get_acc(data_loader, clf, num_classes):
    clf.eval()
    correct, total = 0, 0

    with torch.no_grad():
        for sample in data_loader:
            data = sample['data'].cuda()
            target = sample['label'].cuda()

            logit = clf(data)

            _, pred = logit[:, :num_classes].max(dim=1)
            correct += pred.eq(target).sum().item()
            total += target.size(0)
    
    print(correct / total * 100.)
    return correct / total * 100.

score_dic = {
    'msp': get_msp_score,
    'odin': get_odin_score,
    'abs': get_abs_score,
    'logit': get_logit_score,
    'maha': get_mahalanobis_score,
    'energy': get_energy_score
}

def main(args):

    _, std = get_ds_info(args.id, 'mean_and_std')
    test_trf_id = get_ds_trf(args.id, 'test')
    test_set_id = get_ds(root=args.data_dir, ds_name=args.id, split='test', transform=test_trf_id)
    test_loader_id = DataLoader(test_set_id, batch_size=args.batch_size, shuffle=False, num_workers=args.prefetch, pin_memory=True)
    
    test_loader_oods = []
    for ood in args.oods:
        test_trf_ood = get_ood_trf(args.id, ood, 'test')
        test_set_ood = get_ds(root=args.data_dir, ds_name=ood, split='test', transform=test_trf_ood)
        test_loader_oods.append(DataLoader(test_set_ood, batch_size=args.batch_size, shuffle=False, num_workers=args.prefetch, pin_memory=True))

    # load CLF
    num_classes = len(get_ds_info(args.id, 'classes'))
    if args.score == 'abs':
        clf = get_clf(args.arch, num_classes+1)
    elif args.score in ['maha', 'logit', 'energy', 'msp', 'odin']:
        clf = get_clf(args.arch, num_classes)
    else:
        raise RuntimeError('<<< Invalid score: '.format(args.score))
    
    clf = nn.DataParallel(clf)
    clf_path = Path(args.pretrain)

    if clf_path.is_file():
        clf_state = torch.load(str(clf_path), map_location='cuda:0')
        # cla_acc = clf_state['cla_acc']
        clf.load_state_dict(clf_state['state_dict'])
        # print('>>> load classifier from {} (classification acc {:.4f}%)'.format(str(clf_path), cla_acc))
    else:
        raise RuntimeError('<--- invlaid classifier path: {}'.format(str(clf_path)))

    # move CLF to gpu device
    gpu_idx = int(args.gpu_idx)
    if torch.cuda.is_available():
        torch.cuda.set_device(gpu_idx)
        clf.cuda()
        torch.cuda.manual_seed(args.seed)
    cudnn.benchmark = False

    # get_acc(test_loader_id, clf, num_classes)

    get_score = score_dic[args.score]
    if args.score == 'maha':
        train_set_id_test = get_ds(root=args.data_dir, ds_name=args.id, split='train', transform=test_trf_id)
        train_loader_id_test = DataLoader(train_set_id_test, batch_size=args.batch_size, shuffle=False, num_workers=args.prefetch, pin_memory=True)
        cat_mean, precision = sample_estimator(train_loader_id_test, clf, num_classes)
        get_score = partial(
            score_dic['maha'],
            num_classes=num_classes, 
            sample_mean=cat_mean, 
            precision=precision
        )
    elif args.score == 'odin':
        get_score = partial(
            score_dic['odin'],
            temperature=args.temperature,
            magnitude=args.magnitude,
            std=std
        )
    else:
        get_score = score_dic[args.score]
    # score_id, _ = get_score(test_loader_id, clf)
    score_id = get_score(test_loader_id, clf)
    label_id = np.ones(len(score_id))

    # visualize the confidence distribution
    plt.figure(figsize=(10, 10), dpi=100)
    
    score_ood_all = np.empty(0)

    ood_names, fprs, aurocs, auprs = [], [], [], []
    # pc_sum = 0
    pc = 0
    for i, test_loader_ood in enumerate(test_loader_oods):
        # result_dic = {'name': test_loader_ood.dataset.name}
        ood_names.append(test_loader_ood.dataset.name)

        # score_ood, norm_ood = get_score(test_loader_ood, clf)
        score_ood = get_score(test_loader_ood, clf)
        score_ood = np.array(score_ood)
        # norm_ood = np.array(norm_ood)

        # idxs = []
        # for j in range(1000):
        #     low_b = j / 1000
        #     up_b = (j+1) / 1000
        #     idxs_piece = np.where(np.logical_and(score_ood >= low_b , score_ood < up_b))[0]
        #     if len(idxs) > 10:
        #         idxs.extend(np.array(idxs_piece[:10]).tolist())
        #     else:
        #         idxs.extend(np.array(idxs_piece).tolist())
        
        # print(idxs)

        # pc_sum += np.corrcoef(np.array(score_ood), np.array(norm_ood))[0, 1]
        # score_ood = get_score(test_loader_ood, clf)
        # res = stats.spearmanr(score_ood[idxs], norm_ood[idxs])
        # pc = res.correlation
        
        label_ood = np.zeros(len(score_ood))

        # OOD detection
        score = np.concatenate([score_id, score_ood])
        label = np.concatenate([label_id, label_ood])

        # plot the histgrams
        bins = np.linspace(0.0, 1.0, 100)
        plt.subplot(3, 3, i+1)
        plt.hist(score_id, bins, color='g', label='id', alpha=0.5)
        thr_95 = np.sort(score_id)[int(len(score_id) * 0.05)]
        plt.axvline(thr_95, alpha=0.5)
        plt.hist(score_ood, bins, color='r', label='ood', alpha=0.5)
        plt.title(test_loader_ood.dataset.name)

        fpr, auroc, aupr, _ = compute_all_metrics(score, label)
        
        fprs.append(100. * fpr)
        aurocs.append(100. * auroc)
        auprs.append(100. * aupr)

        score_ood_all = np.concatenate([score_ood_all, score_ood[:8000]], axis=0)

    # save id and all ood scores seperately
    print('PC: ', pc)
    np.save(str(Path(args.pretrain).parent / 'id.npy'), score_id)
    np.save(str(Path(args.pretrain).parent / 'ood.npy'), score_ood_all)

    # save the figure
    plt.savefig(args.fig_name)

    # print results
    print('[ ID: {:7s} - OOD:'.format(args.id), end=' ')
    for ood_name in ood_names:
        print('{:5s}'.format(ood_name), end=' ')
    print(']')

    print('> FPR:  ', end=' ')
    for fpr in fprs:
        print('{:3.3f}'.format(fpr), end=' ')
    print('<')

    print('> AUROC:', end=' ')
    for auroc in aurocs:
        print('{:3.3f}'.format(auroc), end=' ')
    print('<')

    print('> AUPR: ', end=' ')
    for aupr in auprs:
        print('{:3.3f}'.format(aupr), end=' ')
    print('<')

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='detect ood')
    parser.add_argument('--seed', default=42, type=int, help='seed for initialize detection')
    parser.add_argument('--data_dir', type=str, default='../datasets')
    parser.add_argument('--id', type=str, default='cifar10')
    parser.add_argument('--oods', nargs='+', default=['svhn', 'lsunc', 'dtd', 'places365_10k', 'lsunr', 'isun'])
    parser.add_argument('--score', type=str, default='msp', choices=['msp', 'odin', 'abs', 'logit', 'maha', 'energy'])
    parser.add_argument('--temperature', type=int, default=1000)
    parser.add_argument('--magnitude', type=float, default=0.0014)
    parser.add_argument('--batch_size', type=int, default=200)
    parser.add_argument('--prefetch', type=int, default=10)
    parser.add_argument('--arch', type=str, default='densenet101', choices=['densenet101', 'wrn40'])
    parser.add_argument('--pretrain', type=str, default=None, help='path to pre-trained model')
    parser.add_argument('--fig_name', type=str, default='test.png')
    parser.add_argument('--gpu_idx', type=int, default=0)

    args = parser.parse_args()

    main(args)