import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torchvision import transforms, datasets
import torchvision.transforms.functional as TF

from models.resnet50 import ResNet50
from utils.logger import Logger
from utils.list_dataset import ImageFilelist

from utils.evaluation import confidences_auc, search_k

import numpy as np
import random
import os
import math
import sys

def mask_image(image, mask_size=8, stride=8):
    
    images = []
    b, c, h, w = image.shape
    cur_h = 0
    while cur_h <= h-mask_size:
        cur_w = 0
        while cur_w <= w-mask_size:
            
            mask = torch.zeros_like(image)
            mask[:,:,cur_h:cur_h+mask_size,cur_w:cur_w+mask_size] = 1
            mask_img = image.masked_fill(mask==1, 0.)
            images.append(mask_img)
            cur_w += stride

        cur_h += stride
    return torch.stack(images, dim=1)

def mask_img(img, mask_top=0., mask_left=0., mask_size=28):

    # masked_img = torch.clone(img)
    # masked_img[:, :, mask_top:mask_top+mask_size, mask_left:mask_left+mask_size] =0.
    # return masked_img
    mask = torch.zeros_like(img)
    mask[:, :,mask_top:mask_top+mask_size,mask_left:mask_left+mask_size] = 1
    mask_img = img.masked_fill(mask==1, 0.)
    return mask_img

def center_mask(img, mask_size=28):
    
    _, _, image_height, image_width = img.shape
    mask_top = int(round((image_height - mask_size) / 2.0))
    mask_left = int(round((image_width - mask_size) / 2.0))
    
    return mask_img(img, mask_top, mask_left, mask_size)

def main():

    seed = 100
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    log = Logger('./log/tta_imgnet.txt')
    sys.stdout = log

    target_model = ResNet50()
    target_model.load_state_dict(torch.load('./weights/resnet50_7613.pth'))
    target_model.to(device)

    transform_test = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # imagenet = datasets.ImageFolder(root='../datasets/imagenet_1k/val',transform=transform_test)
    imagenet = ImageFilelist(root='../datasets/imagenet_1k', flist='./datalists/imagenet/imagenet2012_val_list.txt',  transform=transform_test)
    inaturalist = datasets.ImageFolder(root='../datasets/iNaturalist',  transform=transform_test)
    places365 = datasets.ImageFolder(root='../datasets/places365',  transform=transform_test)
    sun = datasets.ImageFolder(root='../datasets/SUN',  transform=transform_test)
    texture = datasets.ImageFolder(root='../datasets/dtd/images',  transform=transform_test)

    
    batch_size = 100
    T=1
    print(f'Temperature is {T}')
    with torch.no_grad():

        target_model.eval()

        test_datasets = ['imagenet', 'inaturalist', 'places365', 'sun', 'texture']

        logits_inner_product = []
        sim_logits = []
        sim_sms = []
        for dataset in test_datasets:

            print('-'*10+dataset+'-'*10)
            data_loader = DataLoader(eval(dataset), batch_size=batch_size, shuffle=True, num_workers=2)

            logits = []
            sms = []
            trans_logits = []
            trans_sms = []
            for i, (samples, _) in enumerate(data_loader):
        
                samples = samples.to(device)

                logit = target_model(samples)
                sm = F.softmax(logit/T, dim=-1)
                # energy = torch.logsumexp(logit / T, dim=1)

                logits.append(logit)
                sms.append(sm)

                # ==================== ID Transformation =========================

                # Hflip
                # trans_samples = TF.hflip(samples)

                # Gray
                # trans_samples = transforms.RandomGrayscale(p=1.0)(samples)

                # Central Mask
                # trans_samples = center_mask(samples, mask_size=14)

                # Central Crop
                # trans_samples = transforms.CenterCrop(size=230)(samples)

                # ==================== OOD Transformation =========================

                # Vflip
                # trans_samples = TF.vflip(samples)

                # Rotate
                # trans_samples = TF.rotate(samples, 180)

                # ColorJitter
                # trans_samples = transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)(samples)
                
                # Invert
                trans_samples = TF.invert(samples)



                trans_logit = target_model(trans_samples)
                trans_sm = F.softmax(trans_logit/T, dim=-1)


                trans_logits.append(trans_logit)
                trans_sms.append(trans_sm)


            
            logits = torch.cat(logits, dim=0)
            sms = torch.cat(sms, dim=0)
            trans_logits = torch.cat(trans_logits, dim=0)
            trans_sms = torch.cat(trans_sms, dim=0)

            # print(trans_logits.shape, logits.shape)


            inner_logits_product = torch.bmm(trans_logits.unsqueeze(dim=1), logits.unsqueeze(dim=-1)).squeeze()
            sim_logit = F.cosine_similarity(logits, trans_logits, dim=-1)
            sim_sm = F.cosine_similarity(sms, trans_sms, dim=-1)

            print(inner_logits_product.mean().item(), sim_logit.mean().item(), sim_sm.mean().item())

            logits_inner_product.append(inner_logits_product.cpu().numpy())
            sim_logits.append(sim_logit.cpu().numpy())
            sim_sms.append(sim_sm.cpu().numpy())
            

        print("Detection based on logits inner product")
        confidences_auc(logits_inner_product, test_datasets)

        print("Detection based on similarity of logits")
        confidences_auc(sim_logits, test_datasets)

        print("Detection based on similarity of softmax")
        confidences_auc(sim_sms, test_datasets)




        
        

if __name__ == "__main__":

    main()
    # CUDA_VISIBLE_DEVICES=0 python consistency_multi.py