# -*- coding: utf-8 -*-
from __future__ import division

import argparse
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.Normalize import Normalize
from model.WResnet import WResnet_Rotate_VFlip
from DataLoader import cifar10
import numpy as np
import random
from tqdm import trange, tqdm
from torch.utils.tensorboard import SummaryWriter
import yaml
import copy
from collections import OrderedDict
import copy
from adapt_grad import RotateVFlipAdaptGrad
from FabAttack import FABAttack_PT

def perturb_test(net, nat, label):
    x = nat + (torch.rand_like(nat) - 0.5) * 2 * state['epsilon']
    x = torch.clamp(x, 0, 1)

    for i in range(state['num_steps']):
        output = net(x)
        loss = F.cross_entropy(output, label)
        grad = torch.autograd.grad(loss, x)[0]

        x = x + state['step_size'] * torch.sign(grad)
        noise = x - nat
        noise = torch.min(noise, torch.ones_like(noise) * state['epsilon'])
        noise = torch.max(noise, -torch.ones_like(noise) * state['epsilon'])
        x = nat + noise
        x = torch.clamp(x, 0, 1)

    return x.detach()

def perturb(net, nat, label):
    nat.requires_grad = True
    x = nat + (torch.rand_like(nat) - 0.5) * 2 * state['epsilon']
    x = torch.clamp(x, 0, 1)

    for i in range(state['train_num_steps']):
        output = net(x)
        loss = F.cross_entropy(output, label)
        grad = torch.autograd.grad(loss, x)[0]

        x = x + state['train_step_size'] * torch.sign(grad)
        noise = x - nat
        noise = torch.min(noise, torch.ones_like(noise) * state['epsilon'])
        noise = torch.max(noise, -torch.ones_like(noise) * state['epsilon'])
        x = nat + noise
        x = torch.clamp(x, 0, 1)

    return x.detach()

parser = argparse.ArgumentParser()
# Positional arguments
parser.add_argument('config', type=str, default='./dataset', help='Root for the Cifar dataset.')
parser.add_argument('--device', default=[0], type=int, nargs='+')
parser.add_argument('--batch_size_test', default=None, type=int)
parser.add_argument('--finetune_step', default=None, type=int)
parser.add_argument('--learning_rate', default=None, type=float)

args = parser.parse_args()

with open(args.config) as config_file:
    state = yaml.load(config_file, Loader=yaml.FullLoader)

if args.batch_size_test is not None:
    state['batch_size_test'] = args.batch_size_test
if args.finetune_step is not None:
    state['finetune_step'] = args.finetune_step
if args.batch_size_test is not None:
    state['learning_rate'] = args.learning_rate


for k, v in state.items():
    setattr(args, k, v)

device = torch.device("cuda:" + str(args.device[0]))
train_loader, test_loader, _, nlabels, mean, std = cifar10(args)
net = WResnet_Rotate_VFlip(mean, std)

assert args.load is not None
state_dict = torch.load(os.path.join("weight", args.load+".pytorch"))
module_dict = OrderedDict()
for key, item in state_dict.items():
    module_dict["module."+key] = item

net.to(device)
net = nn.parallel.DataParallel(net, args.device)

def finetune(model, train_loader, data, target, optimizer, args):

    nat = data.to(device)
    target = target.to(device)
    iter_train = iter(train_loader)
    for idx in range(args.finetune_step):
        target_vflip = []
        nat_vflip = []

        nat_vflip.append(data)
        nat_vflip.append(torch.flip(data, dims=[2]))
        target_vflip.append(torch.zeros_like(target))
        target_vflip.append(torch.ones_like(target))

        nat_vflip = torch.cat(nat_vflip, dim=0)
        target_vflip = torch.cat(target_vflip, dim=0)
        adv_vflip = nat_vflip

        target_rotate = []
        nat_rotate = []
        for i in range(4):
            nat_rotate.append(torch.rot90(nat, i, (2, 3)))
            target_rotate.append(torch.ones_like(target)*i)
        nat_rotate = torch.cat(nat_rotate, dim=0)
        target_rotate = torch.cat(target_rotate, dim=0)
        adv_rotate = nat_rotate

        optimizer.zero_grad()

        logits_vflip = model(adv_vflip, 'vflip')
        loss_vflip = F.cross_entropy(logits_vflip, target_vflip)
        logits_rotate = model(adv_rotate, 'rotate')
        loss_rotate = F.cross_entropy(logits_rotate, target_rotate)

        (args.penalty_rotate*loss_rotate).backward()
        (args.penalty_vflip*loss_vflip).backward()

        train_nat, train_target = next(iter_train)
        train_nat = train_nat.to(device)
        train_target = train_target.to(device)
        adv_train = perturb(model, train_nat, train_target).detach()
        logits_train = model(adv_train)
        loss_train = args.penalty * F.cross_entropy(logits_train, train_target)
        loss_train.backward()
        optimizer.step()

def perturb_autoattack(Attack, model, x_orig, y_orig):
    output = model(x_orig)
    correct_batch = y_orig.eq(output.max(dim=1)[1])
    robust_flags = torch.ones_like(correct_batch)
    x_adv = x_orig.clone().detach()
    num_robust = torch.sum(robust_flags).item()
    if num_robust == 0:
        return x_adv

    robust_lin_idcs = torch.nonzero(robust_flags, as_tuple=False)
    if num_robust > 1:
        robust_lin_idcs.squeeze_()

    batch_datapoint_idcs = robust_lin_idcs
    if len(batch_datapoint_idcs.shape) > 1:
        batch_datapoint_idcs.squeeze_(-1)
    x = x_orig[batch_datapoint_idcs, :].clone()
    y = y_orig[batch_datapoint_idcs].clone()

    # make sure that x is a 4d tensor even if there is only a single datapoint left
    if len(x.shape) == 3:
        x.unsqueeze_(dim=0)

    adv_curr = Attack.perturb(x, y)
    output = Attack.GradOracle.get_logits(adv_curr, target=None, update=False)
    false_batch = ~y.eq(output.max(dim=1)[1]).to(robust_flags.device)
    non_robust_lin_idcs = batch_datapoint_idcs[false_batch]
    robust_flags[non_robust_lin_idcs] = False

    x_adv = adv_curr.detach()

    return x_adv

b_corrects = 0
a_corrects = 0
GradOracle = RotateVFlipAdaptGrad(net, train_loader, state, device)
Attack = FABAttack_PT(GradOracle, n_restarts=1, n_iter=100, verbose=False, eps=state['epsilon'], norm='Linf', update_times=state['update_times'], device=device)

for batch_idx, (data, target) in tqdm(enumerate(test_loader), total=len(test_loader)):

    net.load_state_dict(copy.deepcopy(module_dict))
    net.eval()
    optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'], weight_decay=state['decay'])

    nat = data.to(device)
    nat.requires_grad = True
    target = target.to(device)

    adv = perturb_autoattack(Attack, net, nat.detach(), target)
    logits = net(adv)
    acc = torch.argmax(logits, dim=1) == target
    correct = int(acc.sum())
    print("batch {}, before finetune: acc {}".format(batch_idx, correct/len(data)))
    b_corrects += correct

    accs = finetune(net, train_loader, adv.detach().clone(), target, optimizer, args)
    logits = net(adv)
    acc = torch.argmax(logits, dim=1) == target
    correct = int(acc.sum())
    print("batch {}, before finetune: acc {}".format(batch_idx, correct/len(data)))
    a_corrects += correct

    print("Total before finetune acc {}, Total after finetune acc {}".format(batch_idx, b_corrects, a_corrects))

