import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim import Optimizer
import torch.backends.cudnn as cudnn
import tqdm
import argparse
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import TensorDataset, DataLoader,Subset
from models import *

import os
import copy
import random
import matplotlib.pyplot as plt
import numpy as np
#import cv2 as cv
from util import *


def get_args():
    parser = argparse.ArgumentParser(description='GEN-NARCISSUS')
    parser.add_argument('--backbone', type=str, default='resnet18', choices=['resnet18', 'resnet50', 'vgg19',
                                                                             'vit-b', 'densenet121', 'linear',
                                                                             '2nn', '3nn', 'lenet5', 'vit', 'wrn34-10'],
                        help='the model arch used in experiment')

    parser.add_argument('--dataset', default='cifar10', choices=['cifar10', 'cifar100', 'tinyimagenet',
                                                                 'miniimagenet', 'imagenet100'],
                        help='the dataset used in experiment')
    parser.add_argument('--data', type=str, default='data/CIFAR10', help='the directory of dataset')
    parser.add_argument('--num-classes', default=10, type=int, help='the number of classes in the dataset')
    parser.add_argument('--batch-size', type=int, default=128)
    parser.add_argument('--num-workers', type=int, default=4)

    parser.add_argument('--poison-path', type=str, default=None, help='the path of pretrained poison')
    parser.add_argument('--poison-size', type=int, default=32,
                        help='the image size of poisons')

    parser.add_argument('--optimizer', default='sgd', type=str,
                        help='the optimizer used in training')
    parser.add_argument('--epochs', default=200, type=int,
                        help='the number of total epochs to run')
    parser.add_argument('--lr', default=0.5, type=float, help='optimizer learning rate')

    parser.add_argument('--resume', action='store_true', help='if resume training')
    parser.add_argument('--cutout', action='store_true', help='use cutout')
    parser.add_argument('--cutmix', action='store_true', help='use cutmix')
    parser.add_argument('--mixup', action='store_true', help='use mixup')
    parser.add_argument('--gaussian-smooth', action='store_true', help='if use gaussian smooth')
    parser.add_argument('--random-noise', action='store_true', help='if use random noise')
    parser.add_argument('--get-lr-process', action='store_true', help='if get learning process')





    parser.add_argument('--gpu-id', type=int, default=5, help='the gpu id')
    parser.add_argument('--post-poisoning', action='store_true',
                        help='if generate post-poisoning watermark')
    parser.add_argument('--mask-type', default='fixed', choices=['random', 'fixed', 'fix-lt',
                                                                 'fix-lb', 'fix-rt', 'fix-rb'],
                        help='the type of mask for pixels')
    parser.add_argument('--targeted-class', default=2, type=int,
                        help='which class could be as target class')
    parser.add_argument('--watermark-budget', type=float, default=8, help='the watermark budget')
    parser.add_argument('--seed', default=1, type=int, help='random seed')
    parser.add_argument('--poison-budget', type=float, default=16, help='the poison budget')
    parser.add_argument('--poison-ratio', type=float, default=0.01, help='the poison ratio for one class')
    parser.add_argument('--multi-test', type=float, default=3.0, help='the scale of test trigger')
    parser.add_argument('--wm-length', type=int, default=2000, help='the watermarking length')

    arguments = parser.parse_args()
    return arguments

args = get_args()
watermark_budget = args.watermark_budget / 255
poison_budget = args.poison_budget / 255

os.environ['CUDA_VISIBLE_DEVICES'] = f'{args.gpu_id}'
random_seed = args.seed
np.random.seed(random_seed)
random.seed(random_seed)
torch.manual_seed(random_seed)

torch.cuda.set_device(args.gpu_id)
device = 'cuda'

'''
The path for target dataset and public out-of-distribution (POOD) dataset. The setting used 
here is CIFAR-10 as the target dataset and Tiny-ImageNet as the POOD dataset. Their directory
structure is as follows:

dataset_path--cifar-10-batches-py
            |
            |-tiny-imagenet-200
'''
dataset_path = '/data0/zhuyifan/data/'

#The target class label
lab = args.targeted_class

if not args.post_poisoning:
    if args.mask_type == 'random':
        C, H, W = 3, args.poison_size, args.poison_size
        mask = torch.zeros(1, C, H, W).cuda()

        num_pixels = C * H * W
        num_unmasked = int(num_pixels - args.wm_length)

        flat_mask = mask[0].view(-1)
        unmasked_indices = torch.randperm(num_pixels)[:num_unmasked]
        flat_mask[unmasked_indices] = 1

        mask = flat_mask.view(1, C, H ,W)


    elif args.mask_type == 'fix-lt':
        C, H, W = 3, args.poison_size, args.poison_size
        mask = torch.ones(1, C, H, W)
        num_pixels = C * H * W
        rat = np.sqrt(args.wm_length / num_pixels)

        for c in range(C):
            for h in range(int(rat * H)):
                for w in range(int(rat * W)):
                    mask[0, c, h, w] = 0

        mask = mask.cuda()

    elif args.mask_type == 'fix-lb':
        C, H, W = 3, args.poison_size, args.poison_size
        mask = torch.ones(1, C, H, W)
        num_pixels = C * H * W
        rat = np.sqrt(args.wm_length / num_pixels)

        for c in range(C):
            for h in range(H - int(rat * H), H):
                for w in range(int(rat * W)):
                    mask[0, c, h, w] = 0

        mask = mask.cuda()

    elif args.mask_type == 'fix-rt':
        C, H, W = 3, args.poison_size, args.poison_size
        mask = torch.ones(1, C, H, W)
        num_pixels = C * H * W
        rat = np.sqrt(args.wm_length / num_pixels)

        for c in range(C):
            for h in range(int(rat * H)):
                for w in range(W - int(rat * W), W):
                    mask[0, c, h, w] = 0

        mask = mask.cuda()

    elif args.mask_type == 'fix-rb':
        C, H, W = 3, args.poison_size, args.poison_size
        mask = torch.ones(1, C, H, W)
        num_pixels = C * H * W
        rat = np.sqrt(args.wm_length / num_pixels)

        for c in range(C):
            for h in range(H - int(rat * H), H):
                for w in range(W - int(rat * W), W):
                    mask[0, c, h, w] = 0

        mask = mask.cuda()

    else:
        C, H, W = 3, args.poison_size, args.poison_size
        mask = torch.ones(1, C, H, W)

    watermark = torch.where(torch.randn(1, C, H, W) < 0, -watermark_budget, watermark_budget) * (
                torch.ones(1, C, H, W) - mask.cpu())
    key = torch.where(watermark < -1e-6, -1,
                      torch.where(watermark > 1e-6, 1, 0)).cuda()


    #print('key:', key)
    print('the number of unmasked key pixels is:', torch.nonzero(key).size(0))



def narcissus_gen(dataset_path = dataset_path, lab = lab):
    noise_size = 32

    l_inf_r = poison_budget

    #Model for generating surrogate model and trigger
    surrogate_model = ResNet18_201().cuda()
    generating_model = ResNet18_201().cuda()

    #Surrogate model training epochs
    surrogate_epochs = 200

    #Learning rate for poison-warm-up
    generating_lr_warmup = 0.1
    warmup_round = 5

    #Learning rate for trigger generating
    generating_lr_tri = 0.01      
    gen_round = 1000

    #Training batch size
    train_batch_size = 350

    #The model for adding the noise
    patch_mode = 'add'

    #The argumention use for surrogate model training stage
    transform_surrogate_train = transforms.Compose([
        transforms.Resize(32),
        transforms.RandomCrop(32, padding=4),  
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    #The argumention use for all training set
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),  
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    #The argumention use for all testing set
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    ori_train = torchvision.datasets.CIFAR10(root=dataset_path, train=True, download=False, transform=transform_train)
    ori_test = torchvision.datasets.CIFAR10(root=dataset_path, train=False, download=False, transform=transform_test)
    outter_trainset = torchvision.datasets.ImageFolder(root=dataset_path + 'tiny-imagenet-200/train/', transform=transform_surrogate_train)

    #Outter train dataset
    train_label = [get_labels(ori_train)[x] for x in range(len(get_labels(ori_train)))]
    test_label = [get_labels(ori_test)[x] for x in range(len(get_labels(ori_test)))] 

    #Inner train dataset
    train_target_list = list(np.where(np.array(train_label)==lab)[0])
    train_target = Subset(ori_train,train_target_list)

    concoct_train_dataset = concoct_dataset(train_target,outter_trainset)

    surrogate_loader = torch.utils.data.DataLoader(concoct_train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=16)

    poi_warm_up_loader = torch.utils.data.DataLoader(train_target, batch_size=train_batch_size, shuffle=True, num_workers=16)

    trigger_gen_loaders = torch.utils.data.DataLoader(train_target, batch_size=train_batch_size, shuffle=True, num_workers=16)


    # Batch_grad
    condition = True
    noise = torch.zeros((1, 3, noise_size, noise_size), device=device)


    surrogate_model = surrogate_model
    criterion = torch.nn.CrossEntropyLoss()
    # outer_opt = torch.optim.RAdam(params=base_model.parameters(), lr=generating_lr_outer)
    surrogate_opt = torch.optim.SGD(params=surrogate_model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    surrogate_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(surrogate_opt, T_max=surrogate_epochs)

    #Training the surrogate model
    print('Training the surrogate model')
    for epoch in range(0, surrogate_epochs):
        surrogate_model.train()
        loss_list = []
        for images, labels in surrogate_loader:
            images, labels = images.cuda(), labels.cuda()
            surrogate_opt.zero_grad()
            outputs = surrogate_model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            loss_list.append(float(loss.data))
            surrogate_opt.step()
        surrogate_scheduler.step()
        ave_loss = np.average(np.array(loss_list))
        print('Epoch:%d, Loss: %.03f' % (epoch, ave_loss))
    #Save the surrogate model
    save_path = './checkpoint/surrogate_pretrain_' + str(surrogate_epochs) +'.pth'
    torch.save(surrogate_model.state_dict(),save_path)

    #Prepare models and optimizers for poi_warm_up training
    poi_warm_up_model = generating_model
    poi_warm_up_model.load_state_dict(surrogate_model.state_dict())

    poi_warm_up_opt = torch.optim.RAdam(params=poi_warm_up_model.parameters(), lr=generating_lr_warmup)

    #Poi_warm_up stage
    poi_warm_up_model.train()
    for param in poi_warm_up_model.parameters():
        param.requires_grad = True

    #Training the surrogate model
    for epoch in range(0, warmup_round):
        poi_warm_up_model.train()
        loss_list = []
        for images, labels in poi_warm_up_loader:
            images, labels = images.cuda(), labels.cuda()
            poi_warm_up_model.zero_grad()
            poi_warm_up_opt.zero_grad()
            outputs = poi_warm_up_model(images)
            loss = criterion(outputs, labels)
            loss.backward(retain_graph = True)
            loss_list.append(float(loss.data))
            poi_warm_up_opt.step()
        ave_loss = np.average(np.array(loss_list))
        print('Epoch:%d, Loss: %e' % (epoch, ave_loss))

    #Trigger generating stage
    for param in poi_warm_up_model.parameters():
        param.requires_grad = False

    if args.post_poisoning:
        batch_pert = torch.autograd.Variable(noise.cuda(), requires_grad=True)
    else:
        noise = noise * mask.to(noise.device)
        batch_pert = torch.autograd.Variable(noise.cuda(), requires_grad=True)
    batch_opt = torch.optim.RAdam(params=[batch_pert],lr=generating_lr_tri)
    for minmin in tqdm.tqdm(range(gen_round)):
        loss_list = []
        for images, labels in trigger_gen_loaders:
            images, labels = images.cuda(), labels.cuda()
            #if not args.post_poisoning:
            #    images = images #+ watermark.to(images.device)
            new_images = torch.clone(images)
            #print(new_images.size())
            clamp_batch_pert = torch.clamp(batch_pert,-l_inf_r,l_inf_r)
            new_images = torch.clamp(apply_noise_patch(clamp_batch_pert,new_images.clone(),mode=patch_mode),-1,1)
            per_logits = poi_warm_up_model.forward(new_images)
            loss = criterion(per_logits, labels)
            loss_regu = torch.mean(loss)
            batch_opt.zero_grad()
            loss_list.append(float(loss_regu.data))
            loss_regu.backward(retain_graph = True)
            with torch.no_grad():
                batch_pert.grad = batch_pert.grad * mask.to(batch_pert.device)
                #print(batch_pert.grad)
            batch_opt.step()
            #raise ValueError
        ave_loss = np.average(np.array(loss_list))
        ave_grad = np.sum(abs(batch_pert.grad).detach().cpu().numpy())
        print('Gradient:',ave_grad,'Loss:', ave_loss)
        if ave_grad == 0:
            break

    noise = torch.clamp(batch_pert,-l_inf_r,l_inf_r)
    #best_noise = noise.clone().detach().cpu()
    best_noise = noise.clone().detach().cpu() * mask.cpu()
    print('The number of non-zero noise dimension is:', torch.nonzero(best_noise).size(0))
    plt.imshow(np.transpose(noise[0].detach().cpu(),(1,2,0)))
    plt.show()
    print('Noise max val:',noise.max())

    return best_noise

best_noise = narcissus_gen(dataset_path = dataset_path, lab = lab)



if args.post_poisoning:
    save_dir = os.path.join('gen-narcissus', 'post-poisoning', f'pbud{args.poison_budget}_tcls{args.targeted_class}')
else:
    save_dir = os.path.join('gen-narcissus', 'poisoning-concurrent',
                            f'wmtype{args.mask_type}_wmbud{args.watermark_budget}_pbud{args.poison_budget}_wmlength{args.wm_length}_tcls{args.targeted_class}')

if not os.path.exists(save_dir):
    os.makedirs(save_dir)

if args.post_poisoning:
    torch.save(best_noise, os.path.join(save_dir, 'best_noise.pt'))
    print('save poison successfully')
else:
    torch.save(best_noise, os.path.join(save_dir, 'best_noise.pt'))
    print('save poison successfully')
    torch.save(watermark.cpu(), os.path.join(save_dir, 'watermark.pt'))
    print('save watermark successfully')