import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torchvision import transforms, datasets
import torchvision.transforms.functional as TF

from models.resnet18_32x32 import ResNet18_32x32
from utils.logger import Logger
from utils.list_dataset import ImageFilelist
from postprocessors import get_postprocessor

from utils.evaluation import confidences_auc, search_k

import numpy as np
import random
import os
import math
import sys

def mask_image(image, mask_size=8, stride=8):
    
    images = []
    b, c, h, w = image.shape
    cur_h = 0
    while cur_h <= h-mask_size:
        cur_w = 0
        while cur_w <= w-mask_size:
            
            mask = torch.zeros_like(image)
            mask[:,:,cur_h:cur_h+mask_size,cur_w:cur_w+mask_size] = 1
            mask_img = image.masked_fill(mask==1, 0.)
            images.append(mask_img)
            cur_w += stride

        cur_h += stride
    return torch.stack(images, dim=1)

def multi_transform(img, transforms, times=50):

    return torch.stack([transforms(img) for t in range(times)], dim=1)

def mask_img(img, mask_top=0., mask_left=0., mask_size=8):

    # masked_img = torch.clone(img)
    # masked_img[:, :, mask_top:mask_top+mask_size, mask_left:mask_left+mask_size] =0.
    # return masked_img
    mask = torch.zeros_like(img)
    mask[:, :,mask_top:mask_top+mask_size,mask_left:mask_left+mask_size] = 1
    mask_img = img.masked_fill(mask==1, 0.)
    return mask_img

def center_mask(img, mask_size=8):
    
    _, _, image_height, image_width = img.shape
    mask_top = int(round((image_height - mask_size) / 2.0))
    mask_left = int(round((image_width - mask_size) / 2.0))
    
    return mask_img(img, mask_top, mask_left, mask_size)

def test_time_augmentation(samples):

    trans_samples = []
    trans_samples.append(TF.hflip(samples))
    trans_samples.append(transforms.RandomGrayscale(p=1.0)(samples))
    trans_samples.append(center_mask(samples, mask_size=4))
    trans_samples.append(transforms.Resize((32,32))(transforms.CenterCrop(size=30)(samples)))


    return torch.cat(trans_samples, dim=1)


def multi_transformations(img):

    trans_samples = []

    trans_methods = (TF.hflip, transforms.RandomGrayscale(p=1.0), 
                    # transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=1.0), 
                    transforms.RandomResizedCrop(32, scale=(0.2, 1.)),
                    TF.vflip, 90, 180, 270)
    for tran in trans_methods:
        if isinstance(tran, int):
            trans_samples.append(TF.rotate(img, tran))
        else:
            trans_samples.append(tran(img))
    
    trans_samples = torch.stack(trans_samples, dim=1)

    return trans_samples

def main():

    seed = 100
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
    # torch.cuda.set_device(torch.device("cuda:1"))
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    log = Logger('./log/cifar10.txt')
    sys.stdout = log

    target_model = ResNet18_32x32()
    target_model.load_state_dict(torch.load('./weights/resnet18_9554.pth'))
    target_model.to(device)

    transform_test = transforms.Compose([
        transforms.Resize((32,32)),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    cifar10 = datasets.CIFAR10(root='../datasets', train=False, transform=transform_test, download=True)
    cifar100 = datasets.cifar.CIFAR100(root='../datasets', train=False, transform=transform_test, download=True)
    dtd = datasets.ImageFolder(root='../datasets/dtd/images',  transform=transform_test)
    places365 = datasets.ImageFolder(root='../datasets/places365',  transform=transform_test)
    svhn = datasets.SVHN(root='../datasets', split='test', transform=transform_test, download=True)
    isun = datasets.ImageFolder(root='../datasets/iSUN',  transform=transform_test)
    lsun = datasets.ImageFolder(root='../datasets/LSUN',  transform=transform_test)
    
    batch_size = 100
    T=1
    print(f'Temperature is {T}')
    with torch.no_grad():

        target_model.eval()

        test_datasets = ['cifar10', 'cifar100', 'svhn', 'dtd', 'places365', 'isun', 'lsun']
        # test_datasets = ['cifar10', 'dtd']
        logits_inner_product = []
        sim_logits = []
        sim_sms = []
        for dataset in test_datasets:

            print('-'*10+dataset+'-'*10)
            data_loader = DataLoader(eval(dataset), batch_size=batch_size, shuffle=True, num_workers=2)

            logits = []
            sms = []
            trans_logits = []
            trans_sms = []
            dataset_score = []

            for i, (samples, _) in enumerate(data_loader):
        
                samples = samples.to(device)
                samples_num = len(samples)

                logit, feat = target_model(samples, return_feature=True)
                sm = F.softmax(logit/T, dim=-1)
                logits.append(logit)
                sms.append(sm)

                # _, trans_score = detector.postprocess(target_model, samples)
                trans_score = sm.max(-1)[0]
                dataset_score.append(trans_score)


                # Mask
                trans_samples = mask_image(samples)
                # trans_samples = mask_image(samples, mask_size=7, stride=7)
                # trans_samples = mask_image(samples, mask_size=5, stride=5)
                # trans_samples = mask_image(samples)
                # print(trans_samples.shape)

                # Mask + Hflip + gray
                # trans_samples = test_time_augmentation(samples)

                # # SimCLR
                # aug = transforms.Compose([
                #     transforms.RandomResizedCrop(32, scale=(0.08, 1.)),
                #     transforms.RandomHorizontalFlip(),
                #     transforms.RandomApply([
                #         transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
                #     ], p=0.8),
                #     transforms.RandomGrayscale(p=0.2)
                # ])

                # trans_samples = multi_transform(samples, aug, 16)

                trans_samples = trans_samples.reshape(-1, 3, 32, 32)
                trans_logit, _ = target_model(trans_samples, return_feature=True)
                trans_sm = F.softmax(trans_logit/T, dim=-1)
                trans_logits.append(trans_logit.view(samples_num, -1, 10))
                trans_sms.append(trans_sm.view(samples_num, -1, 10))

            
            logits = torch.cat(logits, dim=0)
            sms = torch.cat(sms, dim=0)
            trans_logits = torch.cat(trans_logits, dim=0)
            trans_sms = torch.cat(trans_sms, dim=0)

            inner_logits_product = torch.bmm(trans_logits, logits.unsqueeze(dim=-1)).squeeze()
            sim_logit = F.cosine_similarity(logits.unsqueeze(dim=1), trans_logits, dim=-1)
            sim_sm = F.cosine_similarity(sms.unsqueeze(dim=1), trans_sms, dim=-1)

            print(inner_logits_product.mean().item(), sim_logit.mean().item(), sim_sm.mean().item())

            logits_inner_product.append(inner_logits_product.sort(descending=True)[0].cpu().numpy())
            sim_logits.append(sim_logit.sort(descending=True)[0].cpu().numpy())
            sim_sms.append(sim_sm.sort(descending=True)[0].cpu().numpy())
            

        print("Detection based on logits inner product")
        # confidences_auc(logits_inner_product, test_datasets)
        search_k(logits_inner_product, test_datasets)

        print("Detection based on similarity of logits")
        # confidences_auc(sim_logits, test_datasets)
        search_k(sim_logits, test_datasets)

        print("Detection based on similarity of softmax")
        # confidences_auc(sim_sms, test_datasets)
        search_k(sim_sms, test_datasets)




        
        

if __name__ == "__main__":

    main()
    # CUDA_VISIBLE_DEVICES=0 python consistency_multi.py