import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torchvision import transforms, datasets
import torchvision.transforms.functional as TF

from models.resnet50 import ResNet50
from utils.logger import Logger
from utils.list_dataset import ImageFilelist

from utils.evaluation import confidences_auc, search_k



import numpy as np
import matplotlib.pyplot as plt
from sklearn import metrics
import random
import os
import math
import argparse
import sys
import time

def mask_image(image, mask_size=56, stride=56):
    
    images = []
    b, c, h, w = image.shape
    cur_h = 0
    while cur_h <= h-mask_size:
        cur_w = 0
        while cur_w <= w-mask_size:
            
            mask = torch.zeros_like(image)
            mask[:,:,cur_h:cur_h+mask_size,cur_w:cur_w+mask_size] = 1
            mask_img = image.masked_fill(mask==1, 0.)
            images.append(mask_img)
            cur_w += stride

        cur_h += stride
    return torch.stack(images, dim=1)

def multi_transform(img, transforms, times=50):

    return torch.stack([transforms(img) for t in range(times)], dim=1)

def multi_transformations(img):

    trans_samples = []

    trans_methods = (TF.hflip, transforms.RandomGrayscale(p=1.0), 
                    # transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=1.0), 
                    transforms.RandomResizedCrop(32, scale=(0.2, 1.)),
                    TF.vflip, 90, 180, 270)
    for tran in trans_methods:
        if isinstance(tran, int):
            trans_samples.append(TF.rotate(img, tran))
        else:
            trans_samples.append(tran(img))
    
    trans_samples = torch.stack(trans_samples, dim=1)

    return trans_samples

def main():

    seed = 100
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
    # torch.cuda.set_device(torch.device("cuda:1"))
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    log = Logger('./log/imagenet.txt')
    sys.stdout = log

    target_model = ResNet50()
    target_model.load_state_dict(torch.load('./weights/resnet50_7613.pth'))


    target_model.to(device)

    transform_test = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # imagenet = datasets.ImageFolder(root='../datasets/imagenet_1k/val',transform=transform_test)
    imagenet = ImageFilelist(root='../datasets/imagenet_1k', flist='./datalists/imagenet/imagenet2012_val_list.txt',  transform=transform_test)
    inaturalist = datasets.ImageFolder(root='../datasets/iNaturalist',  transform=transform_test)
    places365 = datasets.ImageFolder(root='../datasets/places365',  transform=transform_test)
    sun = datasets.ImageFolder(root='../datasets/SUN',  transform=transform_test)
    texture = datasets.ImageFolder(root='../datasets/dtd/images',  transform=transform_test)


    batch_size = 32
    T=1
    print(f'Temperature is {T}')
    with torch.no_grad():

        target_model.eval()

        test_datasets = ['imagenet', 'inaturalist', 'places365', 'sun', 'texture']

        logits_inner_product = []
        sim_logits = []
        sim_sms = []
        sim_combines_03 = []
        sim_combines_05 = []
        sim_combines_07 = []

        for dataset in test_datasets:

            print('-'*10+dataset+'-'*10)
            data_loader = DataLoader(eval(dataset), batch_size=batch_size, shuffle=True, num_workers=2)

            logits = []
            sms = []
            trans_logits = []
            trans_sms = []
            for i, (samples, _) in enumerate(data_loader):

                samples = samples.to(device)

                samples_num = len(samples)

                logit, feat = target_model(samples, return_feature=True)
                # logit, feat = target_model(samples)
                sm = F.softmax(logit/T, dim=-1)
                logits.append(logit.cpu())
                sms.append(sm.cpu())

                # print(logit.shape, sm.shape)

                # Horizontal Flip
                # trans_samples = TF.hflip(samples)

                # Mask
                # trans_samples = mask_image(samples)
                trans_samples = mask_image(samples, mask_size=44, stride=44)
                # trans_samples = mask_image(samples, mask_size=37, stride=37)
                # trans_samples = mask_image(samples, mask_size=32, stride=32)
                # print(trans_samples.shape)

                # # SimCLR
                # aug = transforms.Compose([
                #     transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
                #     transforms.RandomHorizontalFlip(),
                #     transforms.RandomApply([
                #         transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
                #     ], p=0.8),
                #     transforms.RandomGrayscale(p=0.2)
                # ])

                # trans_samples = multi_transform(samples, aug, 25)

                # trans_samples = trans_samples.reshape(-1, 3, 224, 224)
                trans_logit, _ = target_model(trans_samples, return_feature=True)
                trans_logit, _ = target_model(trans_samples)
                trans_sm = F.softmax(trans_logit/T, dim=-1)

                # print(trans_logit.shape, trans_sm.shape)

                trans_logits.append(trans_logit.reshape(samples_num, -1, 1000).cpu())
                trans_sms.append(trans_sm.reshape(samples_num, -1, 1000).cpu())

            
            logits = torch.cat(logits, dim=0)
            sms = torch.cat(sms, dim=0)
            trans_logits = torch.cat(trans_logits, dim=0)
            trans_sms = torch.cat(trans_sms, dim=0)

            inner_logits_product = torch.bmm(trans_logits, logits.unsqueeze(dim=-1)).squeeze()
            sim_logit = F.cosine_similarity(logits.unsqueeze(dim=1), trans_logits, dim=-1)
            sim_sm = F.cosine_similarity(sms.unsqueeze(dim=1), trans_sms, dim=-1)

            alpha = 0.5
            sim_conbine_03 = sim_logit * 0.3 + 0.7 * sim_sm
            sim_conbine_05 = sim_logit * 0.5 + 0.5 * sim_sm
            sim_conbine_07 = sim_logit * 0.7 + 0.3 * sim_sm


            print(inner_logits_product.mean().item(), sim_logit.mean().item(), sim_sm.mean().item())
            print(sim_conbine_03.mean().item(), sim_conbine_05.mean().item(), sim_conbine_07.mean().item())

            # print(sim_logit.mean().item(), sim_sm.mean().item(), sim_conbine.mean().item())

            logits_inner_product.append(inner_logits_product.sort(descending=True)[0].cpu().numpy())
            sim_logits.append(sim_logit.sort(descending=True)[0].cpu().numpy())
            sim_sms.append(sim_sm.sort(descending=True)[0].cpu().numpy())
            sim_combines_03.append(sim_conbine_03.sort(descending=True)[0].cpu().numpy())
            sim_combines_05.append(sim_conbine_05.sort(descending=True)[0].cpu().numpy())
            sim_combines_07.append(sim_conbine_07.sort(descending=True)[0].cpu().numpy())

            

        print("Detection based on logits inner product")
        # confidences_auc(logits_inner_product, test_datasets)
        search_k(logits_inner_product, test_datasets)

        print("Detection based on similarity of logits")
        # confidences_auc(sim_logits, test_datasets)
        search_k(sim_logits, test_datasets)

        print("Detection based on similarity of softmax")
        # confidences_auc(sim_sms, test_datasets)
        search_k(sim_sms, test_datasets)

        print("Detection based on conbined similarity")
        # confidences_auc(sim_sms, test_datasets)
        print(f"==============Alpha is 0.3")
        search_k(sim_combines_03, test_datasets)

        print(f"==============Alpha is 0.5")
        search_k(sim_combines_05, test_datasets)

        print(f"==============Alpha is 0.7")
        search_k(sim_combines_07, test_datasets)





        
        

if __name__ == "__main__":

    main()
    # CUDA_VISIBLE_DEVICES=0 python consistency_multi.py