import torch
import torch.nn as nn

from vgg import VGG
from dataset import CIFAR10, SVHN
from tqdm import tqdm
import pickle
import os
import random

from mia_utils import *

def attack_noise(net, criterion, trainloader, testloader, sigma, loss_thres,num_run=10, save_dir='cifar10_experiments', max_iter=100, noise_type='gaussian', adv_steps=10, adv_eps=0.05, adv_target=None, only_train=False, sig_range=False, q=100, scale=2, filter=None):
    """
    Implementation of Algorithm 2, and update functions for gaussian noises,
    adversarial noises, and JPEG compression.
    Also saves losses for Algorithm 1.
    """
    net = net.cuda()
    net.eval()
    save = []
    loaders = (trainloader,) if only_train else (trainloader, testloader)
    for loader in loaders:
        all_dists = []
        for data in tqdm(loader):
            diff = []
            losses = []
            for run in range(num_run):
                images, labels = data[0], data[1]

                if noise_type == 'gaussian':
                    if sig_range:
                        noise = torch.normal(0, random.random()*sigma, size=images.shape)
                    else:
                        noise = torch.normal(0, sigma, size=images.shape)
                elif noise_type == 'adv':
                    if adv_target is not None:
                        noise = gen_adv_target(net, criterion, images, labels, eps=adv_eps, target_loss=adv_target)
                    else:
                        noise = gen_adv(net, criterion, images, labels, steps=adv_steps, eps=adv_eps)
                elif noise_type == 'jpeg':
                    noise = torch.zeros_like(images)
                    images = jpeg_compress_tf(images, q)
                elif noise_type == 'resize':
                    noise = torch.zeros_like(images)
                    images = resize_tf(images, scale)
                elif noise_type == 'filter':
                    noise = torch.zeros_like(images)
                    images = filter_tf(images, filter)
                else:
                    print('noise type error')

                noise_before = noise + 0
                loss = criterion(net((images+noise).cuda()), labels.cuda())
                losses.append(loss.detach().cpu().numpy())
                it = 0
                while torch.any(loss > loss_thres) and it < max_iter:
                    it += 1
                    output = net((images+noise).cuda())
                    loss = criterion(output, labels.cuda())

                    delta = torch.normal(0, sigma/10, size=images.shape)
                    output_new = net((images+noise+delta).cuda())
                    loss_new = criterion(output_new, labels.cuda())

                    grad = torch.stack([(loss_new[i] - loss[i])/(delta[i].cuda()) for i in range(len(loss))])
                    grad_mask = loss > loss_thres
                    grad *= grad_mask.view(len(grad),1,1,1)
                    noise = noise - 0.01*torch.sign(grad.detach().cpu())

                if it >= max_iter:
                    diff.append((noise_before-noise)+torch.ones(noise.shape)*100*(loss > loss_thres).cpu().view(len(noise),1,1,1))
                else:
                    diff.append((noise_before-noise))
            all_dists.append((diff, losses))
        save.append(all_dists)

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    file_name = 'dists_sigma'+str(sigma)
    if sig_range:
        file_name += 'range'
    file_name += '_loss'+str(loss_thres)+'_run'+str(num_run)+'_'+noise_type+'_maxiter'+str(max_iter)
    if noise_type == 'adv':
        if adv_target is not None:
            file_name += '_target' + str(adv_target) + '_eps' + str(adv_eps)
        else:
            file_name += '_steps' + str(adv_steps) + '_eps' + str(adv_eps)
    if noise_type == 'jpeg':
        file_name += '_qrange'+str(q)
    if noise_type == 'resize':
        file_name += '_resize'+str(scale)
    if noise_type == 'filter':
        file_name += '_filter'+filter
    file_name += '.pkl'
    with open(os.path.join(save_dir, file_name),'wb') as fp:
        pickle.dump(save, fp)

def attack_rot(net, criterion, trainloader, testloader, loss_thres, num_run=10, save_dir='cifar10_experiments', min_deg=10, max_deg=20, only_train=False):
    """
    Implementation of Algorithm 2, and update functions for rotations.
    Also saves losses for Algorithm 1.
    """
    net = net.cuda()
    net.eval()
    save = []
    loaders = (trainloader,) if only_train else (trainloader, testloader)
    for loader in loaders:
        all_dists = []
        for data in tqdm(loader):
            diff = []
            losses = []
            for run in range(num_run):
                images, labels = data[0], data[1]

                deg = torch.tensor([random.randint(min_deg, max_deg)*1.0 for _ in range(len(images))])

                deg_before = deg + 0
                output = net(rotate(images.cuda(), deg.cuda()))
                loss = criterion(output, labels.cuda())
                losses.append(loss.detach().cpu().numpy())
                delta = torch.tensor([0.1 for _ in range(len(images))])

                while torch.any(torch.logical_and(deg > 0, (loss > loss_thres).cpu())):
                    deg = deg - delta * torch.logical_and(deg > 0, (loss > loss_thres).cpu())
                    output_new = net(rotate(images.cuda(), deg.cuda()))
                    loss = criterion(output_new, labels.cuda())
                
                diff.append((deg_before-deg)+torch.ones(len(deg))*100*(loss > loss_thres).cpu())

            all_dists.append((diff, losses))
        save.append(all_dists)

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    file_name = 'dists_loss'+str(loss_thres)+'_run'+str(num_run)+'_rotnaive'
    file_name += '_min_deg'+str(min_deg) +'_max_deg'+str(max_deg)
    file_name += '.pkl'
    with open(os.path.join(save_dir, file_name),'wb') as fp:
        pickle.dump(save, fp)
    

if __name__ == "__main__":
    from opacus.utils.module_modification import convert_batchnorm_modules
    import argparse
    parser = argparse.ArgumentParser("Attacks on images")
    parser.add_argument("--dataset", type=str, default='cifar10')
    parser.add_argument("--split_ind", type=int)
    parser.add_argument("--noise_type", type=str, default='gaussian')
    parser.add_argument("--num_run", type=int)
    parser.add_argument("--max_iter", type=int, default=100)
    parser.add_argument('--only_train', action='store_true')

    parser.add_argument("--loss_thres", type=float, default=0.1)

    parser.add_argument("--sigma", type=float, default=0.1)
    parser.add_argument('--range', action='store_true')
    parser.add_argument("--adv_steps", type=int)
    parser.add_argument("--adv_eps", type=float)
    parser.add_argument("--adv_target", type=float)

    parser.add_argument("--min_deg", type=float)
    parser.add_argument("--max_deg", type=float)

    parser.add_argument('--q', type=int)

    parser.add_argument("--scale", type=float)

    parser.add_argument("--filter", type=str)

    parser.add_argument("--save_dir", type=str, default='cifar10_experiments')

    parser.add_argument("--def_wd", type=float)
    parser.add_argument('--dp', action='store_true')

    args = parser.parse_args()
    print(args)

    net = VGG('VGG19')

    if args.dataset == 'cifar10':
        cifar10 = CIFAR10()
        cifar10.split()
        split_loaders, testloader = cifar10.get_dataloaders(batch_size=64, split=True, shuffle=False)
        trainloader = split_loaders[args.split_ind]
        if args.dp:
            net = convert_batchnorm_modules(net)
            net.load_state_dict(torch.load('YOUR_MODEL_PATH'))
        elif args.def_wd is not None:
            net.load_state_dict(torch.load('YOUR_MODEL_PATH'))
        else:
            net.load_state_dict(torch.load('YOUR_MODEL_PATH'))
    elif args.dataset == 'svhn':
        svhn = SVHN()
        svhn.split()
        split_loaders, testloader = svhn.get_dataloaders(batch_size=64, split=True, shuffle=False)
        trainloader = split_loaders[args.split_ind]
        net.load_state_dict(torch.load('YOUR_MODEL_PATH'))

    criterion = nn.CrossEntropyLoss(reduction='none')

    if args.noise_type in ('gaussian', 'adv', 'jpeg', 'resize', 'filter'):
        attack_noise(net, criterion, trainloader, testloader, args.sigma, args.loss_thres, num_run=args.num_run, save_dir=args.save_dir, noise_type=args.noise_type, adv_steps=args.adv_steps, adv_eps=args.adv_eps, adv_target=args.adv_target, max_iter=args.max_iter, sig_range=args.range, q=args.q, only_train=args.only_train, scale=args.scale, filter=args.filter)
    else:
        attack_rot(net, criterion, trainloader, testloader, args.loss_thres, num_run=args.num_run, save_dir=args.save_dir, min_deg=args.min_deg, max_deg=args.max_deg, only_train=args.only_train)