import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision, os, datetime

from tqdm import tqdm

import argparse
import csv, os, imageio
import numpy as np

from utils import*

from torchvision.models import resnet50, vgg16, convnext_base
from torchvision.models import vit_b_16, vit_l_16, vit_h_14, swin_b, swin_s, swin_t

from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import datasets
from torchvision.transforms import ToTensor, Resize, Compose, Normalize

from skimage.color import rgb2hsv, rgb2gray
from skimage.transform import resize
import matplotlib.pyplot as plt

import matplotlib.colors as colors

from scipy.ndimage import gaussian_filter, convolve, zoom


parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=20, help="batch size")
parser.add_argument('--num_workers', type=int, default=4, help="num_workers")
parser.add_argument('--train_size', type=int, default=2000, help="number of training images")
parser.add_argument('--test_size', type=int, default=1000, help="number of test images")
parser.add_argument('--target', type=int, default=827, help="target label")
parser.add_argument('--data_dir', type=str, default='./datasets/imgNet/train/', help="dir of the dataset")
parser.add_argument('--gpu', type=str, default='0', help="index pf used GPU")
parser.add_argument('-c', type=str, default='', help='comment')
args = parser.parse_args()

def mask_generation(mask_type='rectangle', patch=None, image_size=(3, 224, 224)):
    applied_patch = np.zeros(image_size).astype(float)         # np.zeros give np.uint8

    # patch location
    x_location, y_location = np.random.randint(low=18, high=image_size[1]-patch.shape[1]-18), np.random.randint(low=18, high=image_size[2]-patch.shape[2]-18)
   
    for i in range(patch.shape[0]):
        applied_patch[:, x_location:x_location + patch.shape[1], y_location:y_location + patch.shape[2]] = patch

    mask = applied_patch.copy()
    mask[mask != 0] = 1

    return applied_patch, mask, x_location, y_location


# Test the patch on dataset
def test_patch(model, patch, batch_size, test_loader, target):
    model.eval()
    test_total, test_actual_total, test_success = 0, 0, 0

    for (image, label) in test_loader:
        test_total += label.shape[0]
        # assert image.shape[0] == 1, 'Only one picture should be loaded each time.'
        image = image.cuda()
        label = label.cuda()
        output = model(image)
        _, predicted = torch.max(output.data, 1)

        leader = False
        none_element = True
        for i in range(batch_size):
            if(predicted[i] == label[i] and predicted[i].data.cpu().numpy() != target):
                none_element = False
                if(leader):
                    image_testable = torch.cat((image_testable,image[i].unsqueeze(0)), 0)
                    label_testable = torch.cat((label_testable,label[i].unsqueeze(0)), 0)
                else:
                    image_testable = image[i].unsqueeze(0)
                    label_testable = label[i].unsqueeze(0)

                    leader = True
        if(none_element):
            continue

        # print(image_testable.shape)
        test_actual_total += image_testable.shape[0]
        applied_patch, mask, x_location, y_location = mask_generation('rectangle', patch, (3, 224, 224))

        applied_patch = torch.from_numpy(applied_patch).type(torch.FloatTensor).cuda()
        mask = torch.from_numpy(mask).type(torch.FloatTensor).cuda()
        perturbated_image = torch.mul(mask, applied_patch) + torch.mul((1 - mask), image_testable.cuda())
        # perturbated_image = torch.clamp(perturbated_image, min=-0.0001, max=3.0)          # The data original distribution!!!

        # print(mask.shape, applied_patch.shape)

        output = model(perturbated_image)

        _, predicted = torch.max(output.data, 1)
        
        # print(predicted.data.cpu().numpy(), label_testable.data.cpu().numpy())
        for i in range(image_testable.shape[0]):
            if predicted[i].data.cpu().numpy() == target:
                test_success += 1
                

    print(test_total, test_success, test_actual_total)
    return test_success / test_actual_total


# ################### patch => Original image

os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

if(args.c != ""):
    print("comment:{}".format(args.c))


# Load the datasets
train_loader, test_loader = dataloader(args.train_size, args.test_size, args.data_dir, args.batch_size, args.num_workers, 32000)

# data_transforms = Compose([ToTensor(), Resize(size=(256, 256))])

# model = vit_b_16(weights='IMAGENET1K_SWAG_LINEAR_V1').cuda()

# model = vit_b_16(weights='DEFAULT').cuda()
# model = vit_l_16(weights='DEFAULT').cuda()
model = swin_b(weights='DEFAULT').cuda()
# model = swin_t(weights='DEFAULT').cuda()

# model = torch.hub.load('facebookresearch/deit:main', 'deit_small_patch16_224', pretrained=True).cuda()

model.eval()

save_name = "80_patch_3_n4"

# with open("./100_80_swin_3/patch_75.pth", 'rb') as f:
#     init_patch = np.load(f)

# init_patch = (init_patch.squeeze()+ 0.5)

# rotation_angle = 2

# for i in range(init_patch.shape[0]):           # Why lossing Acc. here? from float32 -> uint8?
#     init_patch[i] = np.rot90(init_patch[i], rotation_angle)  # The actual rotation angle is rotation_angle * 90 


# tmp = init_patch[0,:,:]
# init_patch[0,:,:] -= 0.05
# init_patch[1,:,:] += 0.15
# init_patch[2,:,:] += 0.15
# print(np.min(init_patch[0,:,:]), np.max(init_patch[0,:,:]), np.min(init_patch[1,:,:]), np.max(init_patch[1,:,:]), np.min(init_patch[2,:,:]), np.max(init_patch[2,:,:]))


# adj_mat = np.random.random((80,80)) * 0.15
# init_patch[0,:,:] -= adj_mat
# adj_mat = np.random.random((80,80)) * 0.1
# init_patch[1,:,:] -= adj_mat
# adj_mat = np.random.random((80,80)) * 0.1
# init_patch[2,:,:] -= adj_mat

# init_patch = np.clip(init_patch, 0, 1)
# print('Shape: {}, Min: {}, Max: {}'.format(init_patch.shape, np.min(init_patch), np.max(init_patch)))


# patch_save = np.transpose(init_patch, (1, 2, 0))
# imageio.imwrite("./{}.png".format(save_name), patch_save)
# quit()

############### Patch distribution ##################

# fig = plt.figure()
# ax = plt.axes(projection='3d')
# x, y = np.ogrid[0:init_patch.shape[-2], 0:init_patch.shape[-1]]
# print(x.shape)
# ax.axes.set_zlim3d(bottom=0, top=1)

# ax.plot_surface(x, y, np.max(init_patch, axis=0))
# # ax.plot_surface(x, y, init_patch[1])
# # ax.plot_surface(x, y, init_patch[2])

# # plt.imshow(np.transpose(init_patch, (1, 2, 0)))
# # plt.show()
# plt.savefig("./{}_3d.png".format(save_name), dpi=300)
# plt.close()

# best_patch_epoch, best_patch_success_rate = 0, 0
# n_train = len(train_loader)*args.batch_size

###########################################

# init_patch[0,:,:] += 0.1
# init_patch[1,:,:] -= 0.1
# init_patch[2,:,:] -= 0.15
# # print(np.max(init_patch[:,:,0]), np.max(init_patch[:,:,1]), np.max(init_patch[:,:,2]))
# # quit()
# adj_mat = np.random.random((80,80)) / 10
# init_patch[0,:,:] -= adj_mat
# adj_mat = np.random.random((80,80)) / 10
# init_patch[1,:,:] -= adj_mat
# adj_mat = np.random.random((80,80)) / 10
# init_patch[2,:,:] -= adj_mat

# kernel = np.array([[1, 2, 1], [2, 4, 2], [1, 2, 1]])
# kernel = kernel / np.sum(kernel)  # Normalize the kernel

# # Apply the Gaussian filter using convolution
# r_channel = convolve(init_patch[0,:,:], kernel)
# g_channel = convolve(init_patch[1,:,:], kernel)
# b_channel = convolve(init_patch[2,:,:], kernel)
# init_patch = np.dstack((r_channel, g_channel, b_channel)).transpose(2,0,1)
###########################################

# init_patch = zoom(init_patch, (3,60,60)/np.array(init_patch.shape), order=1)


###########################################
# Original Stove

stove = imageio.imread("./stove.JPEG")
stove = np.transpose(stove, (2, 0, 1))/255.0

normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

resize = Resize(size=(96, 96))

normalized_image = normalize(resize(torch.tensor(stove)))
init_patch = normalized_image.data.numpy()  



# stove = resize(stove, (182, 182),  anti_aliasing=True)
# init_patch = np.transpose(stove, (2, 0, 1))

###########################################

print('Shape: {}, Min: {}, Max: {}'.format(init_patch.shape, np.min(init_patch), np.max(init_patch)))

patch = init_patch
# torch.from_numpy(init_patch).cuda()

# patch_save = np.clip(np.transpose(patch, (1, 2, 0)), 0, 1)
# imageio.imwrite("./large_adj.png", patch_save)
# quit()

# patch = init_patch[::-1]
# print(patch)
# print(init_patch)
# print(patch.shape)
# quit()

# Generate the patch
batch_size = args.batch_size
for epoch in range(10):
    train_total, train_actual_total, train_success = 0, 0, 0
    att_target = torch.tensor(args.target).cuda()

    
    test_success_rate_1 = test_patch(model, patch, batch_size, test_loader, args.target)
    print("Epoch:{} Patch attack success rate on testset (1): {:.3f}%".format(epoch, 100 * test_success_rate_1))
    test_success_rate_2 = test_patch(model, patch, batch_size, test_loader, args.target)
    print("Epoch:{} Patch attack success rate on testset (2): {:.3f}%".format(epoch, 100 * test_success_rate_2))

   
print("The best patch is found at epoch {} with success rate {}% on testset".format(best_patch_epoch, 100 * best_patch_success_rate))
print(datetime.datetime.now())
print("#####################################################################################################")

