import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torchvision import transforms, datasets
import torchvision.transforms.functional as TF

from models.resnet18_32x32 import ResNet18_32x32
from postprocessors import get_postprocessor
from utils.evaluation import confidences_auc
from utils.logger import Logger


import numpy as np
import random
import os
import math
import sys


def main():

    seed = 100
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    os.environ['CUDA_VISIBLE_DEVICES'] = '0,2'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    log = Logger('./log/baselines_cifar10.txt')
    sys.stdout = log
    
    model = ResNet18_32x32()
    # model.load_state_dict(torch.load('./weights/cider_c10_head.pth', map_location=torch.device("cpu"))) # cider
    # model.load_state_dict(torch.load('./weights/resnet18_9421.pth'))
    model.load_state_dict(torch.load('./weights/resnet18_9554.pth'))
    model.to(device)
    model.eval()

    normalization = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))

    transform_test = transforms.Compose([
        transforms.Resize((32,32)),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    cifar10 = datasets.CIFAR10(root='../datasets', train=False, transform=transform_test, download=True)
    cifar100 = datasets.cifar.CIFAR100(root='../datasets', train=False, transform=transform_test, download=True)
    svhn = datasets.SVHN(root='../datasets', split='test', transform=transform_test, download=True)
    dtd = datasets.ImageFolder(root='../datasets/dtd/images',  transform=transform_test)
    places365 = datasets.ImageFolder(root='../datasets/places365',  transform=transform_test)
    isun = datasets.ImageFolder(root='../datasets/iSUN',  transform=transform_test)
    lsun = datasets.ImageFolder(root='../datasets/LSUN',  transform=transform_test)

    id_loader = DataLoader(datasets.CIFAR10(root='../datasets', train=True, transform=transform_test),
                           batch_size=1000, shuffle=True, num_workers=2)
 
    # msp, ml, react, knn, vim, odin, energy, mds
    method="knn"
    detector = get_postprocessor(method, config=None)
    detector.setup(model, id_loader)

    print(f"Detection Method: {method.upper()}")
    batch_size = 200
    # with torch.no_grad():

    model.eval()
    test_datasets = ['cifar10', 'cifar100', 'svhn', 'dtd', 'places365', 'isun', 'lsun']

    scores = []
    for dataset in test_datasets:

        print('-'*10+dataset+'-'*10)
        data_loader = DataLoader(eval(dataset), batch_size=batch_size, shuffle=True, num_workers=2)

        dataset_score = []
        for i, (samples, _) in enumerate(data_loader):
            
            samples = samples.to(device)
            _, score = detector.postprocess(model, samples)
            dataset_score.append(score)
        
        dataset_score = torch.cat(dataset_score, dim=0)
        print(dataset_score.mean().item())

        scores.append(dataset_score.squeeze().cpu().numpy())

    confidences_auc(scores, test_datasets)


        
        

if __name__ == "__main__":

    main()