import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision, os, datetime

import argparse
import csv, os, imageio, glob
import numpy as np

from utils import*

from torchvision.models import resnet50
from torchvision.models import vit_b_16, swin_b

from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import datasets
from torchvision.transforms import ToTensor, Resize, Compose, Normalize, ToPILImage

from skimage.color import rgb2hsv, rgb2gray
import matplotlib.pyplot as plt

import matplotlib.colors as colors

from scipy.ndimage import gaussian_filter, convolve, zoom
from PIL import Image


def mask_generation(mask_type='rectangle', patch=None, image_size=(3, 224, 224), mask_level=100):
    applied_patch = np.zeros(image_size).astype(float)         # np.zeros give np.uint8
   
    x_location, y_location = np.random.randint(low=20, high=40), np.random.randint(low=150, high=160)
    for i in range(patch.shape[0]):
        applied_patch[:, x_location:x_location + patch.shape[1], y_location:y_location + patch.shape[2]] = patch

    mask = applied_patch.copy()
    mask[mask != 0] = mask_level/100.0

    return applied_patch, mask, x_location, y_location

def tester(image):
    
    image = image.transpose(2,0,1)

    normalize = Normalize(mean=[0.485, 0.456, 0.406], 
                                     std=[0.229, 0.224, 0.225])

    # create a transform to convert PIL image to tensor
    to_tensor = transforms.ToTensor()
    resize = Resize(size=(224, 224))

    normalized_image = normalize(resize(torch.tensor(image)))

    img = normalized_image.unsqueeze(0).type(torch.FloatTensor).cuda()
    # img = perturbated_image.unsqueeze(0).cuda()

    output = model(img)
    target_probability = torch.mean(torch.nn.functional.softmax(output, dim=1).data[:, 453])
    _, predicted = torch.max(output.data, 1)

    print(target_probability, predicted, torch.mean(torch.nn.functional.softmax(output, dim=1).data[:, predicted.cpu().numpy()[0]]))
    print('\n')


def get_attention_map(img, get_mask=False):
    x = img
    # x.size()
    print(x.shape)

    att_mat = model(x)
    print(att_mat.shape)

    att_mat = torch.stack(att_mat).squeeze(1)


    # Average the attention weights across all heads.
    att_mat = torch.mean(att_mat, dim=1)

    # To account for residual connections, we add an identity matrix to the
    # attention matrix and re-normalize the weights.
    residual_att = torch.eye(att_mat.size(1))
    aug_att_mat = att_mat + residual_att
    aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1)

    # Recursively multiply the weight matrices
    joint_attentions = torch.zeros(aug_att_mat.size())
    joint_attentions[0] = aug_att_mat[0]

    for n in range(1, aug_att_mat.size(0)):
        joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n-1])

    v = joint_attentions[-1]
    grid_size = int(np.sqrt(aug_att_mat.size(-1)))
    mask = v[0, 1:].reshape(grid_size, grid_size).detach().numpy()
    if get_mask:
        result = cv2.resize(mask / mask.max(), img.size)
    else:        
        mask = cv2.resize(mask / mask.max(), img.size)[..., np.newaxis]
        result = (mask * img).astype("uint8")
    
    return result

def plot_attention_map(original_img, att_map):
    fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16, 16))
    ax1.set_title('Original')
    ax2.set_title('Attention Map Last Layer')
    _ = ax1.imshow(original_img)
    _ = ax2.imshow(att_map)


parser = argparse.ArgumentParser()

# parser.add_argument('-n', type=str, default=827, help="name of input image")
parser.add_argument('--target', type=int, default=428, help="target label")
parser.add_argument('--gpu', type=str, default='1', help="index pf used GPU")
args = parser.parse_args()

# Patch attack via optimization
# According to reference [1], one image is attacked each time
# Assert: applied patch should be a numpy
# Return the final perturbated picture and the applied patch. Their types are both numpy

# ################### patch => Original image

os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

model = vit_b_16(weights='DEFAULT').cuda()
# model = vit_h_14(weights='DEFAULT').cuda()

# model = swin_b(weights='DEFAULT').cuda()
model.eval()


# Batch test
###############################################

# for img_dir in glob.glob("./tmp/*.jpg"):
#     print(img_dir)
#     img = imageio.imread(img_dir)/255.0
#     tester(img)
# quit()


##############################################

image = imageio.imread("./demo_img/ILSVRC2012_val_00012873.JPEG")/255.0
image = image.transpose(2,0,1)
# print(image.shape)

# quit()
# create a normalization transform
normalize = Normalize(mean=[0.485, 0.456, 0.406], 
                                 std=[0.229, 0.224, 0.225])

# create a transform to convert PIL image to tensor
to_tensor = transforms.ToTensor()

resize = Resize(size=(224, 224))

# open the PIL image

# print(to_tensor(image))
# quit()

# apply the transforms
# normalize
normalized_image = (resize(torch.tensor(image)))


with open("./0_48_swinb_481/patch_75.pth", 'rb') as f:
    init_patch = np.load(f)


###########################################
# color_trans = colors.rgb_to_hsv(np.clip(init_patch.transpose(1,2,0), 0, 1))
# cal mean diff:

# img_np = normalized_image.numpy()

# mean_diff = np.mean(img_np[:,28:28+70, 224-70-28:224-28] - init_patch, axis=(1,2))
# print(mean_diff.shape, mean_diff)
# print(np.mean(img_np[:,28:28+70, 224-70-28:224-28], axis=(1,2)))
# print(np.mean(init_patch, axis=(1,2)))

# print(init_patch.shape)
# init_patch[0,:,:] += mean_diff[0]/1
# init_patch[1,:,:] += mean_diff[1]/1
# init_patch[2,:,:] += mean_diff[2]/1
# print(np.max(init_patch[:,:,0]), np.max(init_patch[:,:,1]), np.max(init_patch[:,:,2]))

######################################
# init_patch[0,:,:] -= 0.1
# adj_mat = np.random.random((70,70)) / 7
# init_patch[0,:,:] -= adj_mat

# kernel = np.array([[1, 2, 1], [2, 4, 2], [1, 2, 1]])
# kernel = kernel / np.sum(kernel)  # Normalize the kernel

# # Apply the Gaussian filter using convolution
# r_channel = convolve(init_patch[0,:,:], kernel)
# g_channel = convolve(init_patch[1,:,:], kernel)
# b_channel = convolve(init_patch[2,:,:], kernel)
# init_patch = np.dstack((r_channel, g_channel, b_channel)).transpose(2,0,1)

# init_patch = zoom(init_patch, (3,112,112)/np.array(init_patch.shape), order=1)

###########################################
# print("Tue matching the patch")
# img_mean = np.mean(img_np, axis=(1,2))
# patch_mean = np.mean(init_patch, axis=(1,2))
# diff_mean = img_mean - patch_mean

# init_patch[0,:,:] += mean_diff[0]/1
# init_patch[1,:,:] += mean_diff[1]/1
# init_patch[2,:,:] += mean_diff[2]/1

# print(img_mean, patch_mean, diff_mean)

# init_patch = zoom(init_patch, (3,98,98)/np.array(init_patch.shape), order=1)
# patch_save = np.clip(np.transpose(init_patch, (1, 2, 0)), 0, 1)
# imageio.imwrite("./gaussian_adj.png", patch_save)
# quit()

###########################################
patch = (init_patch.squeeze() + 0.5)


applied_patch, mask, x_location, y_location = mask_generation('rectangle', patch, (3, 224, 224))

applied_patch = torch.from_numpy(applied_patch).type(torch.FloatTensor)
mask = torch.from_numpy(mask).type(torch.FloatTensor)
perturbated_image = torch.mul(mask, applied_patch) + torch.mul((1 - mask.type(torch.FloatTensor)), normalized_image.type(torch.FloatTensor))

# print(perturbated_image.cpu().numpy().shape)

img_save = np.clip(np.transpose(perturbated_image.cpu().numpy(), (1, 2, 0)), 0, 1)
imageio.imwrite("./tmp.png", img_save)

# resized_img = [Resize(size=(224, 224))(image)]
# norm_img = [Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))(resized_img)]
# print(norm_img)
# tensor_img = [ToTensor()(norm_img)]

# img = normalized_image.unsqueeze(0).type(torch.FloatTensor).cuda()
img = perturbated_image.unsqueeze(0).cuda()


# result_1 = get_attention_map(img)
# plot_attention_map(img, result_1)
# quit()

output = model(img)
target_probability = torch.mean(torch.nn.functional.softmax(output, dim=1).data[:, 428])
_, predicted = torch.max(output.data, 1)

print(target_probability, predicted, torch.mean(torch.nn.functional.softmax(output, dim=1).data[:, predicted.cpu().numpy()[0]]))