import torch, os, datetime
import torch.nn as nn
from torch.autograd import Variable
from tqdm import tqdm

import argparse
import csv, os, imageio
import numpy as np

from utils import*

from torchvision.models import resnet50
from torchvision.models import vit_b_16, vit_l_16, vit_h_14, swin_b, swin_s

from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import datasets
from torchvision.transforms import ToTensor, Resize, Compose

import torch.optim as optim

import matplotlib.pyplot as plt


parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=10, help="batch size")
parser.add_argument('--num_workers', type=int, default=8, help="num_workers")
parser.add_argument('--train_size', type=int, default=6000, help="number of training images")
parser.add_argument('--test_size', type=int, default=1000, help="number of test images")
parser.add_argument('--mask_length', type=int, default=80, help="percentage of the patch size compared with the image size")

parser.add_argument('--target', type=int, default=453, help="target label")
parser.add_argument('--epochs', type=int, default=80, help="total epoch")
parser.add_argument('--data_dir', type=str, default='./datasets/imgNet/train/', help="dir of the dataset")
parser.add_argument('--GPU', type=str, default='0', help="index pf used GPU")
parser.add_argument('--log_dir', type=str, default='patch_attack_log.csv', help='dir of the log')
parser.add_argument('--mask_level', type=int, default=0, help='transparent level of mask')
parser.add_argument('-c', type=str, default='', help='comment')
args = parser.parse_args()


def mask_generation(patch, image_size):
    applied_patch = torch.zeros(image_size)         # np.zeros give np.uint8
    
    x_location, y_location = np.random.randint(low=18, high=image_size[-2]-patch.shape[-2]-18), np.random.randint(low=18, high=image_size[-1]-patch.shape[-1]-18)
    for i in range(patch.shape[0]):
            applied_patch[:,:, x_location:x_location + patch.shape[-2], y_location:y_location + patch.shape[-1]] = patch

    mask = applied_patch.clone()
    mask[mask != 0] = 1

    return applied_patch, mask, x_location, y_location

def test_patch(model, patch, batch_size, test_loader, target):
    model.eval()
    test_total, test_actual_total, test_success = 0, 0, 0

    for (image, label) in test_loader:
        test_total += label.shape[0]
        image = image.cuda()
        label = label.cuda()
        output = model(image)
        _, predicted = torch.max(output.data, 1)

        leader = False
        none_element = True
        for i in range(batch_size):
            if(predicted[i] == label[i] and predicted[i].data.cpu().numpy() != target):
                none_element = False
                if(leader):
                    image_testable = torch.cat((image_testable,image[i].unsqueeze(0)), 0)
                    label_testable = torch.cat((label_testable,label[i].unsqueeze(0)), 0)
                else:
                    image_testable = image[i].unsqueeze(0)
                    label_testable = label[i].unsqueeze(0)

                    leader = True
        if(none_element):
            continue

        # print(image_testable.shape)
        test_actual_total += image_testable.shape[0]
        applied_patch, mask, x_location, y_location = mask_generation(patch, image_testable.shape)

        applied_patch = applied_patch.cuda()
        mask = mask.cuda()
        perturbated_image = torch.mul(mask, applied_patch) + torch.mul((1 - mask), image_testable.cuda())

        # plt.imshow(np.transpose(perturbated_image.data.cpu().numpy()[0], (1,2,0)))
        # plt.show()
      
        output = model(perturbated_image)
        _, predicted = torch.max(output.data, 1)
        
        # print(predicted.data.cpu().numpy(), label_testable.data.cpu().numpy())
        for i in range(image_testable.shape[0]):
            if predicted[i].data.cpu().numpy() == target:
                test_success += 1
                
    print(test_total, test_success, test_actual_total)
    return test_success / test_actual_total

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

class MultiplyByHalf(nn.Module):
    def forward(self, x):
        return x * 0.5

class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( nz, ngf * 8, 5, 1, 0, bias=False),              # This init kernel_size and level desided the final patch size
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. ``(ngf*8) x 4 x 4``
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. ``(ngf*4) x 8 x 8``
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. ``(ngf*2) x 16 x 16``
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. ``(ngf) x 32 x 32``
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh(),
            MultiplyByHalf()
            # state size. ``(nc) x 64 x 64``
        )

    def forward(self, input):
        return self.main(input)

os.environ["CUDA_VISIBLE_DEVICES"] = args.GPU


# Number of channels in the training images. For color images this is 3
nc = 3
# Size of z latent vector (i.e. size of generator input)
nz = 100
# Size of feature maps in generator
ngf = args.mask_length
# Learning rate for optimizers
lr = 0.001
# Beta1 hyperparameter for Adam optimizers
beta1 = 0.9
# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1

netG = Generator(ngpu).cuda()
# netG.apply(weights_init)
model_weight = './100_80_swin_half/patch_45_weight.pth'
# netG = (torch.load(model_weight))
netG.load_state_dict(torch.load(model_weight))

criterion = nn.CrossEntropyLoss()
# mseloss = nn.MSELoss()

fixed_noise = torch.randn(1, nz, 1, 1).cuda()

optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))



if(args.c != ""):
    print("comment:{}".format(args.c))

if not os.path.exists("./{}_{}_{}".format(args.mask_level, args.mask_length, args.c)):
    os.mkdir("./{}_{}_{}".format(args.mask_level, args.mask_length, args.c))


# Load the datasets
train_loader, test_loader = dataloader(args.train_size, args.test_size, args.data_dir, args.batch_size, args.num_workers, 32000)

data_transforms = Compose([ToTensor(), Resize(size=(224, 224))])

# train_data = datasets.GTSRB(root="datasets", split="train", transform=data_transforms, download=False)
# val_data = datasets.GTSRB(root="datasets", split="test", transform=data_transforms, download=False)

# 2. Split into train / validation partitions
# val_percent = 0.2
# n_val = int(len(dataset) * val_percent)
# n_train = len(dataset) - n_val
# train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))

# train_loader = DataLoader(dataset=train_data, batch_size=1, num_workers=args.num_workers, pin_memory=True, shuffle=True)
# test_loader = DataLoader(dataset=val_data, batch_size=1, num_workers=args.num_workers, pin_memory=True, shuffle=False)

# for (image, label) in train_loader:
#     # print(image, label)
#     print(torch.min(image), torch.max(image))
# quit()

# Load the model

# model = vit_b_16(weights='DEFAULT').cuda()
# model = vit_l_16(weights='DEFAULT').cuda()
model = swin_b(weights='DEFAULT').cuda()
# model = swin_s(weights='DEFAULT').cuda()


model.eval()


# Initialize the patch
# init_patch = np.random.rand(3, 224, 224)
# patch_initialization(args.patch_type, image_size=(3, 224, 224), mask_length=args.mask_length)


# print('Shape: {}, Min: {}, Max: {}'.format(init_patch.shape, np.min(init_patch), np.max(init_patch)))

# patch_save = np.transpose(init_patch, (1, 2, 0))
# imageio.imwrite("./{}_{}_{}/imput_patch_init.png".format(args.mask_level, args.mask_length, args.c), patch_save)

if (args.log_dir == 'patch_attack_log.csv'):
    _log_dir = "./{}_{}_{}/patch_attack_log.csv".format(args.mask_level, args.mask_length, args.c)

with open(_log_dir, 'w+') as f:
    writer = csv.writer(f)
    writer.writerow(["epoch", "train_success", "test_success_1", "test_success_2"])

best_patch_epoch, best_patch_success_rate = 0, 0
n_train = len(train_loader)*args.batch_size


# patch = init_patch
# torch.from_numpy(init_patch).cuda()


# Generate the patch
batch_size = args.batch_size
for epoch in range(args.epochs):
    train_total, train_actual_total, train_success = 0, 0, 0
    att_target = torch.tensor(args.target).cuda()

#####################################
    with tqdm(total=n_train, desc=f'Epoch {epoch}/{args.epochs}', unit='img') as pbar:
  
        for (image, label) in train_loader:
            train_total += label.shape[0]
            label = label.cuda()
            output = model(image.cuda())
            _, predicted = torch.max(output.data, 1)

            leader = False
            none_element = True
            for i in range(batch_size):
                if(predicted[i] == label[i] and predicted[i] != att_target):
                    none_element = False
                    if(leader):
                        image_trainable = torch.cat((image_trainable,image[i].unsqueeze(0)), 0)
                        label_trainable = torch.cat((label_trainable,label[i].unsqueeze(0)), 0)
                    else:
                        image_trainable = image[i].unsqueeze(0)
                        label_trainable = label[i].unsqueeze(0)

                        leader = True
                        # print(label_trainable)

            if(none_element):
                continue
            train_actual_total += image_trainable.shape[0]

                     
            b_size = image_trainable.shape[0]
            label = torch.full((b_size,), 453, dtype=torch.int64).cuda()

            # Generate batch of latent vectors
            noise = torch.randn(b_size, nz, 1, 1).cuda()
            # Generate fake image batch with G
            _patch = netG(noise)

            # print(_patch.size())

            _noise = torch.randn(_patch.shape[1::]).cuda() * 0.1

            _patch = _patch + 0.5 + _noise
            # print(_patch)

            applied_patch, mask, x_location, y_location = mask_generation(_patch, image_trainable.shape)
            # applied_patch = torch.clamp(applied_patch, min=-0.0001, max=1.00001) 

            _image = image_trainable.type(torch.FloatTensor).cuda()
            # print(mask.size(), applied_patch.size(), _image.size())
            perturbated_image = torch.mul(mask.cuda(), applied_patch.cuda()) + torch.mul((1 - mask.type(torch.FloatTensor)).cuda(),_image)

            netG.zero_grad()

            label.fill_(453)  # fake labels are real for generator cost
            # Since we just updated D, perform another forward pass of all-fake batch through D
            output = model(perturbated_image)
            # _, _predicted = torch.max(output, 1)
            # Calculate G's loss based on this output
            # print(_predicted.dtype, label.dtype)
            # predicted = torch.where(_predicted == 428, 1, 0)

            errG = criterion(output, label)
            # print(errG)

            # Calculate gradients for G
            total_loss = errG  
            # + range_loss
            total_loss.backward()
            # D_G_z2 = output.mean().item()
            # Update G
            optimizerG.step()

            pbar.update(image.shape[0])

        
        # pbar.set_postfix(**{'loss (batch)': loss.item()})
        # Save patch in numpy, if save in fig, the image reload will not successfuly attack (because loss of accuracy)

    # print(errG, range_loss)

    with open("./{}_{}_{}/patch_{}.pth".format(args.mask_level, args.mask_length, args.c, epoch), "wb") as f:
        _patch = netG(fixed_noise).detach().cpu()
        _patch = _patch + 0.5
        print('Shape: {}, Min: {}, Max: {}'.format(_patch.size(), torch.min(_patch), torch.max(_patch)))
        # print(_patch.size())
        np.save(f, _patch.numpy())

    # torch.save(netG, "./{}_{}_{}/patch_{}_weight.pth".format(args.mask_level, args.mask_length, args.c, epoch))
    torch.save(netG.state_dict(), "./{}_{}_{}/patch_{}_weight.pth".format(args.mask_level, args.mask_length, args.c, epoch))

    patch_save = np.clip(np.transpose(_patch.squeeze(), (1, 2, 0)), 0, 1)
    imageio.imwrite("./{}_{}_{}/patch_{}.png".format(args.mask_level, args.mask_length, args.c, epoch), patch_save)
    # print("Saving... ./{}_{}_{}/patch_{}.png".format(args.mask_level, args.mask_length, args.c, epoch))

    # plt.imshow(np.clip(np.transpose(patch, (1, 2, 0)) * std + mean, 0, 1))
    # plt.savefig("training_pictures_2/" + str(epoch) + "_patch.png")
    # print(train_success, train_actual_total)
    # print("Epoch:{} Patch attack success rate on trainset: {:.3f}%".format(epoch, 100 * train_success / train_actual_total))

    ############### Patch distribution ##################

    # fig = plt.figure()
    # ax = plt.axes(projection='3d')
    # x, y = np.ogrid[0:args.mask_length, 0:args.mask_length]
    # # print(x.shape)
    # ax.plot_surface(x, y, np.max(patch, axis=0))
    # # ax.plot_surface(x, y, init_patch[1])
    # # ax.plot_surface(x, y, init_patch[2])

    # # plt.imshow(np.transpose(init_patch, (1, 2, 0)))
    # # plt.show()
    # plt.savefig("./{}_{}_{}/patch_{}_3d.png".format(args.mask_level, args.mask_length, args.c, epoch), dpi=300)
    # plt.close()
    #########################################
    
    test_success_rate_1 = test_patch(model, _patch, batch_size, test_loader, args.target)
    print("Epoch:{} Patch attack success rate on testset (1): {:.3f}%".format(epoch, 100 * test_success_rate_1))
    test_success_rate_2 = test_patch(model, _patch, batch_size, test_loader, args.target)
    print("Epoch:{} Patch attack success rate on testset (2): {:.3f}%".format(epoch, 100 * test_success_rate_2))

    # Record the statistics

    # brightness = np.max(patch, axis=0)
    # mean_brig = np.mean(brightness)
    # print(mean_brig)
    with open(_log_dir, 'a') as f:
        writer = csv.writer(f)
        writer.writerow([epoch, train_success / train_actual_total, test_success_rate_1, test_success_rate_2])

    if max(test_success_rate_1, test_success_rate_2) > best_patch_success_rate:
        best_patch_success_rate = max(test_success_rate_1, test_success_rate_2)
        best_patch_epoch = epoch

        # plt.imshow(np.clip(np.transpose(patch, (1, 2, 0)) * std + mean, 0, 1))
        # plt.savefig("training_pictures_01/best_patch.png")

    # Load the statistics and generate the line
    # print(_log_dir)
    log_generation(_log_dir)

print("The best patch is found at epoch {} with success rate {}% on testset".format(best_patch_epoch, 100 * best_patch_success_rate))
print(datetime.datetime.now())
print("#####################################################################################################")
