import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torchvision import transforms, datasets
import torchvision.transforms.functional as TF

from models.resnet18_32x32 import ResNet18_32x32
from utils.logger import Logger
from utils.list_dataset import ImageFilelist

from utils.evaluation import confidences_auc, search_k

import numpy as np
import matplotlib.pyplot as plt
from sklearn import metrics
import random
import os
import math
import argparse
import sys

def mask_image(image, mask_size=8, stride=8):
    
    images = []
    b, c, h, w = image.shape
    cur_h = 0
    while cur_h <= h-mask_size:
        cur_w = 0
        while cur_w <= w-mask_size:
            
            mask = torch.zeros_like(image)
            mask[:,:,cur_h:cur_h+mask_size,cur_w:cur_w+mask_size] = 1
            mask_img = image.masked_fill(mask==1, 0.)
            images.append(mask_img)
            cur_w += stride

        cur_h += stride
    return torch.stack(images, dim=1)

def multi_transform(img, transforms, times=50):

    return torch.stack([transforms(img) for t in range(times)], dim=1)

def multi_transformations(img):

    trans_samples = []

    trans_methods = (TF.hflip, transforms.RandomGrayscale(p=1.0), 
                    # transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=1.0), 
                    transforms.RandomResizedCrop(32, scale=(0.2, 1.)),
                    TF.vflip, 90, 180, 270)
    for tran in trans_methods:
        if isinstance(tran, int):
            trans_samples.append(TF.rotate(img, tran))
        else:
            trans_samples.append(tran(img))
    
    trans_samples = torch.stack(trans_samples, dim=1)

    return trans_samples

def main():

    seed = 100
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    log = Logger('./log/cifar10_dist.txt')
    sys.stdout = log

    target_model = ResNet18_32x32()
    target_model.load_state_dict(torch.load('./weights/resnet18_9554.pth'))
    target_model.to(device)

    transform_test = transforms.Compose([
        transforms.Resize((32,32)),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    cifar10 = datasets.CIFAR10(root='../datasets', train=False, transform=transform_test, download=True)
    cifar100 = datasets.cifar.CIFAR100(root='../datasets', train=False, transform=transform_test, download=True)
    dtd = datasets.ImageFolder(root='../datasets/dtd/images',  transform=transform_test)
    places365 = datasets.ImageFolder(root='../datasets/places365',  transform=transform_test)
    svhn = datasets.SVHN(root='../datasets', split='test', transform=transform_test, download=True)
    isun = datasets.ImageFolder(root='../datasets/iSUN',  transform=transform_test)
    lsun = datasets.ImageFolder(root='../datasets/LSUN',  transform=transform_test)

    batch_size = 100
    T=1
    print(f'Temperature is {T}')
    with torch.no_grad():

        target_model.eval()

        test_datasets = ['cifar10', 'cifar100', 'svhn', 'dtd', 'places365', 'isun', 'lsun']
        # test_datasets = ['cifar10', 'dtd', 'places365']

        similarities = []
        distances = []
        for dataset in test_datasets:

            print('-'*10+dataset+'-'*10)
            data_loader = DataLoader(eval(dataset), batch_size=batch_size, shuffle=True, num_workers=2)

            feats = []
            trans_feats = []
            for i, (samples, _) in enumerate(data_loader):
                
                samples = samples.to(device)

                samples_num = len(samples)

                logit, feat = target_model(samples, return_feature=True)
                feats.append(feat)


                # Mask
                trans_samples = mask_image(samples)
                # trans_samples = mask_image(samples, mask_size=6, stride=6)
                # trans_samples = mask_image(samples, mask_size=5, stride=5)
                # print(trans_samples.shape)

                # # SimCLR
                # aug = transforms.Compose([
                #     transforms.RandomResizedCrop(32, scale=(0.2, 1.)),
                #     transforms.RandomHorizontalFlip(),
                #     transforms.RandomApply([
                #         transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
                #     ], p=0.8),
                #     transforms.RandomGrayscale(p=0.2)
                # ])

                # trans_samples = multi_transform(samples, aug, 20)

                trans_samples = trans_samples.reshape(-1, 3, 32, 32)
                trans_logit, trans_feat = target_model(trans_samples, return_feature=True)
                trans_feats.append(trans_feat.view(samples_num, -1, 512))

            feats = torch.cat(feats, dim=0).unsqueeze(dim=1)
            trans_feats = torch.cat(trans_feats, dim=0)

            sim_feats = F.cosine_similarity(feats, trans_feats, dim=-1)

            feats = feats / feats.norm(dim=-1, keepdim=True) + 1e-10
            trans_feats = trans_feats / trans_feats.norm(dim=-1, keepdim=True) + 1e-10

            dist_feats = -(feats - trans_feats).norm(dim=-1)
        
            sim_feats = sim_feats.sort(descending=True)[0]
            dist_feats = dist_feats.sort(descending=True)[0]

            print(sim_feats.mean().item(), dist_feats.mean().item())

            similarities.append(sim_feats.cpu().numpy())
            distances.append(dist_feats.cpu().numpy())
            

        print("Detection based on feature similarity")
        # confidences_auc(logits_inner_product, test_datasets)
        search_k(similarities, test_datasets)

        print("Detection based on feature distances")
        # confidences_auc(sim_logits, test_datasets)
        search_k(distances, test_datasets)





        
        

if __name__ == "__main__":

    main()
    # CUDA_VISIBLE_DEVICES=0 python consistency_multi.py