import sys
import torchvision
import argparse
import os
import shutil
import time
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as T
import torchvision.datasets as datasets
from torch.autograd import Variable
import matplotlib.pyplot as plt
import seaborn as sns
import resnet
import vgg


parser = argparse.ArgumentParser(description='Propert ResNet/VGG for CIFAR in pytorch')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet50')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('--batch_size', default=128, type=int,
                    metavar='N', help='mini-batch size (default: 128)')
parser.add_argument('--resume', default=None, type=str, metavar='PATH',
                    help='path to model.th (default: none)')
parser.add_argument('--save_fig', default='./ood.png', type=str,
                    help='The path used to save the ood figure')
parser.add_argument('--gpu', default='0', type=str)
parser.add_argument('--dataset', default='cifar10', type=str)

def compute_ent(prob):
    prob = prob.detach().cpu().numpy()
    ent = - np.sum(prob * np.log(prob), -1)
    return ent

def main():
    args = parser.parse_args()
    
    os.environ['CUDA_VISIBLE_DEVICES']= args.gpu
    
    if(args.dataset == 'cifar10'):
        mean = (0.4914, 0.4822, 0.4465)
        std = (0.2023, 0.1994, 0.2010)
        transform_test = T.Compose([
            T.ToTensor(),
            T.Normalize(mean, std),
            ])
        num_classes = 10
        testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
        
    elif(args.dataset == 'cifar100'):
        mean = (0.5071, 0.4867, 0.4408)
        std = (0.2675, 0.2565, 0.2761)
        transform_test = T.Compose([
            T.ToTensor(),
            T.Normalize(mean, std),
            ])
        num_classes = 100
        testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
    else:
        print("Invalid dataset name")
        exit()
        
    cifar_loader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=1, pin_memory=True)
    
    if args.arch == 'resnet50':
        model = resnet.ResNet50(num_classes)
    elif args.arch =='resnet18':
        model = resnet.ResNet18(num_classes)
    elif args.arch =='resnet34':
        model = resnet.ResNet34(num_classes)
    elif args.arch =='resnet101':
        model = resnet.ResNet101(num_classes)
    elif args.arch == 'vgg16':
        model = vgg.VGG('VGG16', num_classes)
    else:
        print("Invalid model name")
        exit()
    model = torch.nn.DataParallel(model).cuda()
    model.load_state_dict(torch.load('{}'.format(args.resume))['state_dict'])
    
    
    model.eval()

    sm_conf = []
    sm_acc = []
    sm_ent = []
    sm_ll = []
    sm_l2 = []
    sm_logits = []
    sm_targets = []

    for batch_idx, (inputs, targets) in enumerate(cifar_loader):
    #     inputs = inputs.unsqueeze(0).cuda()
        sm_logit = model(inputs.cuda())
        confidences, predictions = torch.max(F.softmax(sm_logit, dim=1), 1)
        targets = targets.cuda()
        accuracies = predictions.eq(targets)
        sm_ent.append(compute_ent(F.softmax(sm_logit, dim=1)))
        sm_conf.append(confidences.detach().cpu().numpy())
        sm_acc.append(accuracies.float().detach().cpu().numpy())
        criterion =  nn.CrossEntropyLoss(reduction='none')
        sm_ll.append(criterion(sm_logit, targets).float().detach().cpu().numpy())
        sm_l2.append(torch.sum(sm_logit**2, -1).float().detach().cpu().numpy())
        sm_logits.append(sm_logit.float().detach().cpu().numpy())
        sm_targets.append(targets.detach().cpu().numpy())


    sm_conf = np.concatenate(sm_conf)
    sm_acc = np.concatenate(sm_acc)
    sm_ent = np.concatenate(sm_ent)
    sm_logits = np.concatenate(sm_logits)
    sm_targets = np.concatenate(sm_targets)
    sm_ll = np.concatenate(sm_ll)
    sm_l2 = np.concatenate(sm_l2)

    NUM_BINS=15

    sm_bins = []

    sm_ece = []

    for i in range(NUM_BINS):
        idx = np.logical_and((i * (1/NUM_BINS) < sm_conf), ((i + 1)*(1/NUM_BINS)>= sm_conf))
        ac = sm_acc[idx]
        if ac.sum() == 0:
            sm_bins.append(0.)
        else:
            sm_ece.append(idx.sum()*np.abs(ac.mean() - sm_conf[idx].mean()))
            sm_bins.append(np.mean(ac))


    print("Base %s model results"%args.resume)
    print("="*20)
    print("Accuracy : ", round(sm_acc.mean()*100, 2))
    print("NLL : ", round(sm_ll.mean(), 2))
    print("ECE : ", round(sum(sm_ece) / 100,2))
    print("L2 function norm : ", round(np.sqrt(sm_l2.mean()), 2))
    
    svhn_sm_ent = []
    
    svhn_mean = np.array([111.60893668, 113.16127466, 120.56512767])/255.
    svhn_std = np.array([50.49768174, 51.2589843 , 50.24421614])/255.

    svhn = datasets.SVHN('.', split='test', transform=T.Compose([T.ToTensor(), T.Normalize(svhn_mean, svhn_std),
                                                                ]), download=True)
    svhn = torch.utils.data.DataLoader(svhn, batch_size=args.batch_size, shuffle=False, num_workers=1, pin_memory=False)
    
    for batch_idx, (inputs, targets) in enumerate(svhn):
        sm_logit = model(inputs.cuda())
        confidences, predictions = torch.max(F.softmax(sm_logit, dim=1), 1)
        targets = targets.cuda()
        accuracies = predictions.eq(targets)

        svhn_sm_ent.append(compute_ent(F.softmax(sm_logit, dim=1)))

    svhn_sm_ent = np.concatenate(svhn_sm_ent)
    
    plt.rcParams.update({'font.size': 14})
    
    plt.figure(figsize=(5,4))
    sns.kdeplot(sm_ent, label='In-distribution', shade=True, color="r")
    sns.kdeplot(svhn_sm_ent, label='Out-of-distribution', shade=True, color="b")

    plt.legend()
    plt.xlabel('Entropy')
    plt.ylabel('Density')

    plt.tight_layout()
    plt.grid(ls='--')
    
    plt.savefig(args.save_fig)
    


if __name__ == '__main__':
    main()