import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torchvision import transforms, datasets
import torchvision.transforms.functional as TF

from models.resnet50 import ResNet50
from postprocessors import get_postprocessor
from utils.list_dataset import ImageFilelist
from utils.logger import Logger

from utils.evaluation import confidences_auc, search_k


import numpy as np
import matplotlib.pyplot as plt
import random
import os
import math
import argparse
import sys



def main():

    seed = 100
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
    # torch.cuda.set_device(torch.device("cuda:1"))
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    log = Logger('./log/baselines-imgnet.txt')
    sys.stdout = log

    
    model = ResNet50()
    # model.load_state_dict(torch.load('./weights/resnet50_80858.pth'))
    model.load_state_dict(torch.load('./weights/resnet50_7613.pth'))
    model.to(device)
    model.eval()

    transform_test = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    
    # imagenet = datasets.ImageFolder(root='../datasets/imagenet_1k/val',transform=transform_test)
    imagenet = ImageFilelist(root='../datasets/imagenet_1k', flist='./datalists/imagenet/imagenet2012_val_list.txt',  transform=transform_test)
    imagenet_o = ImageFilelist(root='../datasets/', flist='./datalists/imagenet/test_imagenet_o.txt',  transform=transform_test)
    openimage_o = ImageFilelist(root='../datasets/', flist='./datalists/imagenet/test_openimage_o.txt',  transform=transform_test)
    inaturalist = datasets.ImageFolder(root='../datasets/iNaturalist',  transform=transform_test)
    places365 = datasets.ImageFolder(root='../datasets/places365',  transform=transform_test)
    sun = datasets.ImageFolder(root='../datasets/SUN',  transform=transform_test)
    texture = datasets.ImageFolder(root='../datasets/dtd/images',  transform=transform_test)

    # msp, ml, react, knn, vim, odin, energy
    method="knn"

    if method in ("react", "knn", "vim"):
        # id_loader = DataLoader(datasets.ImageFolder(root='../datasets/imagenet_1k/train', transform=transform_test),
        #                    batch_size=100, shuffle=True, num_workers=2)
        id_loader = DataLoader(ImageFilelist(root='../datasets/imagenet_1k', flist='./datalists/imagenet/imagenet2012_train_random_200k.txt',  transform=transform_test),
                               batch_size=256, shuffle=True, num_workers=2)
        
    else:
        id_loader = None
 
    detector = get_postprocessor(method, config=None)
    detector.setup(model, id_loader)

    print(f"Detection Method: {method.upper()}")
    batch_size = 256
    # with torch.no_grad():

    model.eval()
    # test_datasets = ['imagenet', 'imagenet_o', 'openimage_o', 'inaturalist', 'places365', 'sun', 'texture']
    test_datasets = ['imagenet', 'inaturalist', 'places365', 'sun', 'texture']
    # test_datasets = ['imagenet', 'texture']

    scores = []
    for dataset in test_datasets:

        print('-'*10+dataset+'-'*10)
        data_loader = DataLoader(eval(dataset), batch_size=batch_size, shuffle=True, num_workers=2)

        dataset_score = []
        for i, (samples, _) in enumerate(data_loader):
            
            samples = samples.to(device)
            _, score = detector.postprocess(model, samples)
            dataset_score.append(score.cpu().numpy())
        
        dataset_score = np.concatenate(dataset_score, axis=0)
        # dataset_score = torch.cat(dataset_score, dim=0)
        print(dataset_score.shape, dataset_score.mean())

        scores.append(dataset_score)

    confidences_auc(scores, test_datasets)


if __name__ == "__main__":

    main()