import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torchvision import transforms, datasets
import torchvision.transforms.functional as TF

from models.resnet50 import ResNet50
from utils.logger import Logger
from utils.list_dataset import ImageFilelist


import numpy as np
import matplotlib.pyplot as plt
from sklearn import metrics
import random
import os
import math
import argparse
import sys

normalizer = lambda x: x / np.linalg.norm(x, axis=-1, keepdims=True) + 1e-10

def mask_image(image, mask_size=56, stride=56):
    
    images = []
    b, c, h, w = image.shape
    cur_h = 0
    while cur_h <= h-mask_size:
        cur_w = 0
        while cur_w <= w-mask_size:
            
            mask = torch.zeros_like(image)
            mask[:,:,cur_h:cur_h+mask_size,cur_w:cur_w+mask_size] = 1
            mask_img = image.masked_fill(mask==1, 0.)
            images.append(mask_img)
            cur_w += stride

        cur_h += stride
    return torch.stack(images, dim=1)

def multi_transform(img, transforms, times=50):

    return torch.stack([transforms(img) for t in range(times)], dim=1)

def confidences_auc(confidences, datasets):

    confidences = np.array(confidences)
    id_confi = confidences[0]

    average = 0.
    for (ood_confi, dataset) in zip(confidences[1:], datasets[1:]):

        auroc, aupr_in = auc(id_confi, ood_confi)
        average += auroc
        print(f"For {dataset}, AUC: {auroc}")
    print(f"Average is: {average / (len(datasets)-1)}")

def search_k(confidences, datasets):

    confidences = np.array(confidences)
    id_confi = confidences[0]

    for i in range(id_confi.shape[1]):

        averge = 0.0
        print(f"-------------- K is {i+1} ----------------")
        for (ood_confi, dataset) in zip(confidences[1:], datasets[1:]):

            auroc, aupr_in = auc(id_confi[:, i], ood_confi[:, i])
            averge += auroc
            print(f"For {dataset}, AUC: {auroc}")
        print(f'Average is : {averge / (len(datasets)-1)}')

def auc(ind_conf, ood_conf):

    conf = np.concatenate((ind_conf, ood_conf))
    ind_indicator = np.concatenate((np.ones_like(ind_conf), np.zeros_like(ood_conf)))

    fpr, tpr, _ = metrics.roc_curve(ind_indicator, conf)
    precision_in, recall_in, _ = metrics.precision_recall_curve(
        ind_indicator, conf)
    precision_out, recall_out, _ = metrics.precision_recall_curve(
        1 - ind_indicator, 1 - conf)

    auroc = metrics.auc(fpr, tpr)
    aupr_in = metrics.auc(recall_in, precision_in)
    aupr_out = metrics.auc(recall_out, precision_out)

    return auroc, aupr_in

def multi_transformations(img):

    trans_samples = []

    trans_methods = (TF.hflip, transforms.RandomGrayscale(p=1.0), 
                    # transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=1.0), 
                    transforms.RandomResizedCrop(32, scale=(0.2, 1.)),
                    TF.vflip, 90, 180, 270)
    for tran in trans_methods:
        if isinstance(tran, int):
            trans_samples.append(TF.rotate(img, tran))
        else:
            trans_samples.append(tran(img))
    
    trans_samples = torch.stack(trans_samples, dim=1)

    return trans_samples

def main():

    seed = 100
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
    # torch.cuda.set_device(torch.device("cuda:1"))
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    log = Logger('./log/imagenet_dist.txt')
    sys.stdout = log

    # target_model = ResNet50()
    # target_model.load_state_dict(torch.load('./weights/resnet50_7613.pth'))
    # target_model.load_state_dict(torch.load('./weights/simsiam_imgnet.pth'))

    # target_model = WideResnet101()
    # msg = target_model.load_state_dict(torch.load('./weights/wide_resnet101_2_8251.pth'))

    target_model = VIT_B_16()
    msg = target_model.load_state_dict(torch.load('./weights/vit_b_16-81072.pth'))

    target_model.to(device)

    transform_test = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    imagenet = ImageFilelist(root='../datasets/imagenet_1k', flist='./datalists/imagenet/imagenet2012_val_list.txt',  transform=transform_test)
    inaturalist = datasets.ImageFolder(root='../datasets/iNaturalist',  transform=transform_test)
    places365 = datasets.ImageFolder(root='../datasets/places365',  transform=transform_test)
    sun = datasets.ImageFolder(root='../datasets/SUN',  transform=transform_test)
    texture = datasets.ImageFolder(root='../datasets/dtd/images',  transform=transform_test)

    batch_size = 32
    with torch.no_grad():

        target_model.eval()

        test_datasets = ['imagenet', 'inaturalist', 'places365', 'sun', 'texture']
        # logits_inner_product = []
        all_distances = []
        for dataset in test_datasets:

            print('-'*10+dataset+'-'*10)
            data_loader = DataLoader(eval(dataset), batch_size=batch_size, shuffle=True, num_workers=2)

            feats = []
            trans_feats = []
            for i, (samples, _) in enumerate(data_loader):
                
                samples = samples.to(device)

                samples_num = len(samples)

                # logit, feat = target_model(samples, return_feature=True)
                logit, feat = target_model(samples)
                feats.append(feat.cpu())

                # Horizontal Flip
                # trans_samples = TF.hflip(samples)

                # Mask
                # trans_samples = mask_image(samples)
                trans_samples = mask_image(samples, mask_size=44, stride=44)
                # trans_samples = mask_image(samples, mask_size=37, stride=37)
                # trans_samples = mask_image(samples, mask_size=32, stride=32)
                # print(trans_samples.shape)

                # # SimCLR
                # aug = transforms.Compose([
                #     transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
                #     transforms.RandomHorizontalFlip(),
                #     transforms.RandomApply([
                #         transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
                #     ], p=0.8),
                #     transforms.RandomGrayscale(p=0.2)
                # ])

                # trans_samples = multi_transform(samples, aug, 25)

                trans_samples = trans_samples.reshape(-1, 3, 224, 224)
                # trans_logit, trans_feat = target_model(trans_samples, return_feature=True)
                trans_logit, trans_feat = target_model(trans_samples)
                
                # trans_feats.append(trans_feat.view(samples_num, -1, 2048).cpu())
                trans_feats.append(trans_feat.view(samples_num, -1, 768).cpu())
                # trans_feats.append(trans_logit.view(samples_num, -1, 1000).cpu())


            
            feats = torch.cat(feats, dim=0)
            trans_feats = torch.cat(trans_feats, dim=0)

            feats = feats / feats.norm(dim=-1, keepdim=True) + 1e-10
            trans_feats = trans_feats / trans_feats.norm(dim=-1, keepdim=True) + 1e-10

            dist = -(trans_feats - feats.unsqueeze(dim=1)).norm(dim=-1)

            print(dist.mean().item())

            all_distances.append(dist.sort(descending=True)[0].cpu().numpy())


        # print("Detection based on logits inner product")
        # # confidences_auc(logits_inner_product, test_datasets)
        # search_k(logits_inner_product, test_datasets)

        print("Detection based on similarity of logits")
        # confidences_auc(sim_logits, test_datasets)
        search_k(all_distances, test_datasets)



        
        

if __name__ == "__main__":

    main()
    # CUDA_VISIBLE_DEVICES=0 python consistency_multi.py