import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torchvision import transforms, datasets
import torchvision.transforms.functional as TF

from models.resnet18_32x32 import ResNet18_32x32
from utils.logger import Logger
from utils.list_dataset import ImageFilelist
from postprocessors import get_postprocessor

from utils.evaluation import confidences_auc, search_k

import numpy as np
import random
import os
import math
import sys

def mask_image(image, mask_size=8, stride=8):
    
    images = []
    b, c, h, w = image.shape
    cur_h = 0
    while cur_h <= h-mask_size:
        cur_w = 0
        while cur_w <= w-mask_size:
            
            mask = torch.zeros_like(image)
            mask[:,:,cur_h:cur_h+mask_size,cur_w:cur_w+mask_size] = 1
            mask_img = image.masked_fill(mask==1, 0.)
            images.append(mask_img)
            cur_w += stride

        cur_h += stride
    return torch.stack(images, dim=1)

def mask_img(img, mask_top=0., mask_left=0., mask_size=8):

    # masked_img = torch.clone(img)
    # masked_img[:, :, mask_top:mask_top+mask_size, mask_left:mask_left+mask_size] =0.
    # return masked_img
    mask = torch.zeros_like(img)
    mask[:, :,mask_top:mask_top+mask_size,mask_left:mask_left+mask_size] = 1
    mask_img = img.masked_fill(mask==1, 0.)
    return mask_img

def center_mask(img, mask_size=8):
    
    _, _, image_height, image_width = img.shape
    mask_top = int(round((image_height - mask_size) / 2.0))
    mask_left = int(round((image_width - mask_size) / 2.0))
    
    return mask_img(img, mask_top, mask_left, mask_size)

def main():

    seed = 100
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    log = Logger('./log/tta_c10.txt')
    sys.stdout = log

    target_model = ResNet18_32x32()
    target_model.load_state_dict(torch.load('./weights/resnet18_9554.pth'))
    target_model.to(device)

    transform_test = transforms.Compose([
        transforms.Resize((32,32)),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    cifar10 = datasets.CIFAR10(root='../datasets', train=False, transform=transform_test, download=True)
    cifar100 = datasets.cifar.CIFAR100(root='../datasets', train=False, transform=transform_test, download=True)
    dtd = datasets.ImageFolder(root='../datasets/dtd/images',  transform=transform_test)
    places365 = datasets.ImageFolder(root='../datasets/places365',  transform=transform_test)
    svhn = datasets.SVHN(root='../datasets', split='test', transform=transform_test, download=True)
    isun = datasets.ImageFolder(root='../datasets/iSUN',  transform=transform_test)
    lsun = datasets.ImageFolder(root='../datasets/LSUN',  transform=transform_test)
    
    batch_size = 100
    T=1
    print(f'Temperature is {T}')
    with torch.no_grad():

        target_model.eval()

        test_datasets = ['cifar10', 'cifar100', 'svhn', 'dtd', 'places365', 'isun', 'lsun']
        # test_datasets = ['cifar10', 'dtd']
        logits_inner_product = []
        sim_logits = []
        sim_sms = []
        dataset_scores = []
        for dataset in test_datasets:

            print('-'*10+dataset+'-'*10)
            data_loader = DataLoader(eval(dataset), batch_size=batch_size, shuffle=True, num_workers=2)

            logits = []
            sms = []
            trans_logits = []
            trans_sms = []
            trans_scores = []
            for i, (samples, _) in enumerate(data_loader):
        
                samples = samples.to(device)

                logit = target_model(samples)
                sm = F.softmax(logit/T, dim=-1)
                # energy = torch.logsumexp(logit / T, dim=1)


                logits.append(logit)
                sms.append(sm)

                # ==================== ID Transformation =========================

                # Hflip
                # trans_samples = TF.hflip(samples)

                # Gray
                # trans_samples = transforms.RandomGrayscale(p=1.0)(samples)

                # Central Mask
                # trans_samples = center_mask(samples, mask_size=4)

                # Central Crop
                # trans_samples = transforms.CenterCrop(size=30)(samples)

                # ==================== OOD Transformation =========================

                # Vflip
                # trans_samples = TF.vflip(samples)

                # Rotate
                # trans_samples = TF.rotate(samples, 270)

                # ColorJitter
                # trans_samples = transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)(samples)
                
                # Invert
                trans_samples = TF.invert(samples)

                

                trans_logit = target_model(trans_samples)
                trans_sm = F.softmax(trans_logit/T, dim=-1)
                trans_energy = torch.logsumexp(trans_logit / T, dim=-1)

                trans_logits.append(trans_logit)
                trans_sms.append(trans_sm)
                trans_scores.append(trans_energy)


            
            logits = torch.cat(logits, dim=0)
            sms = torch.cat(sms, dim=0)
            trans_logits = torch.cat(trans_logits, dim=0)
            trans_sms = torch.cat(trans_sms, dim=0)
            trans_scores = torch.cat(trans_scores, dim=0)

            # print(trans_logits.shape, logits.shape)


            inner_logits_product = torch.bmm(trans_logits.unsqueeze(dim=1), logits.unsqueeze(dim=-1)).squeeze()
            sim_logit = F.cosine_similarity(logits, trans_logits, dim=-1)
            sim_sm = F.cosine_similarity(sms, trans_sms, dim=-1)

            print(inner_logits_product.mean().item(), sim_logit.mean().item(), sim_sm.mean().item(), trans_scores.mean().item())

            logits_inner_product.append(inner_logits_product.cpu().numpy())
            sim_logits.append(sim_logit.cpu().numpy())
            sim_sms.append(sim_sm.cpu().numpy())
            dataset_scores.append(trans_scores.cpu().numpy())

            

        print("Detection based on logits inner product")
        confidences_auc(logits_inner_product, test_datasets)

        print("Detection based on similarity of logits")
        confidences_auc(sim_logits, test_datasets)

        print("Detection based on similarity of softmax")
        confidences_auc(sim_sms, test_datasets)

        print("Detection based on trans energy")
        confidences_auc(dataset_scores, test_datasets)




        
        

if __name__ == "__main__":

    main()
    # CUDA_VISIBLE_DEVICES=0 python consistency_multi.py