# -*- coding: utf-8 -*-
from __future__ import division

import argparse
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.Normalize import Normalize
from model.WResnet import WResnet_Rotate_VFlip
from DataLoader import cifar10
import numpy as np
import random
from tqdm import trange, tqdm
from torch.utils.tensorboard import SummaryWriter
import yaml
import copy
from collections import OrderedDict
import copy
from adapt_grad import RotateVFlipAdaptGrad
from SquareAttack import square_attack_linf


def perturb_test(net, nat, label):
    x = nat + (torch.rand_like(nat) - 0.5) * 2 * state['epsilon']
    x = torch.clamp(x, 0, 1)

    for i in range(state['num_steps']):
        output = net(x)
        loss = F.cross_entropy(output, label)
        grad = torch.autograd.grad(loss, x)[0]

        x = x + state['step_size'] * torch.sign(grad)
        noise = x - nat
        noise = torch.min(noise, torch.ones_like(noise) * state['epsilon'])
        noise = torch.max(noise, -torch.ones_like(noise) * state['epsilon'])
        x = nat + noise
        x = torch.clamp(x, 0, 1)

    return x.detach()

def perturb(net, nat, label):
    nat.requires_grad = True
    x = nat + (torch.rand_like(nat) - 0.5) * 2 * state['epsilon']
    x = torch.clamp(x, 0, 1)

    for i in range(state['train_num_steps']):
        output = net(x)
        loss = F.cross_entropy(output, label)
        grad = torch.autograd.grad(loss, x)[0]

        x = x + state['train_step_size'] * torch.sign(grad)
        noise = x - nat
        noise = torch.min(noise, torch.ones_like(noise) * state['epsilon'])
        noise = torch.max(noise, -torch.ones_like(noise) * state['epsilon'])
        x = nat + noise
        x = torch.clamp(x, 0, 1)

    return x.detach()

parser = argparse.ArgumentParser()
# Positional arguments
parser.add_argument('config', type=str, default='./dataset', help='Root for the Cifar dataset.')
parser.add_argument('--device', default=[0], type=int, nargs='+')
parser.add_argument('--batch_size_test', default=None, type=int)
parser.add_argument('--finetune_step', default=None, type=int)
parser.add_argument('--learning_rate', default=None, type=float)

args = parser.parse_args()

with open(args.config) as config_file:
    state = yaml.load(config_file, Loader=yaml.FullLoader)

if args.batch_size_test is not None:
    state['batch_size_test'] = args.batch_size_test
if args.finetune_step is not None:
    state['finetune_step'] = args.finetune_step
if args.batch_size_test is not None:
    state['learning_rate'] = args.learning_rate


for k, v in state.items():
    setattr(args, k, v)

device = torch.device("cuda:" + str(args.device[0]))
train_loader, test_loader, _, nlabels, mean, std = cifar10(args)
net = WResnet_Rotate_VFlip(mean, std)

assert args.load is not None
state_dict = torch.load(os.path.join("weight", args.load+".pytorch"))
module_dict = OrderedDict()
for key, item in state_dict.items():
    module_dict["module."+key] = item

net.to(device)
net = nn.parallel.DataParallel(net, args.device)

def finetune(model, train_loader, data, target, optimizer, args):

    nat = data.to(device)
    target = target.to(device)
    iter_train = iter(train_loader)
    for idx in range(args.finetune_step):
        target_vflip = []
        nat_vflip = []

        nat_vflip.append(data)
        nat_vflip.append(torch.flip(data, dims=[2]))
        target_vflip.append(torch.zeros_like(target))
        target_vflip.append(torch.ones_like(target))

        nat_vflip = torch.cat(nat_vflip, dim=0)
        target_vflip = torch.cat(target_vflip, dim=0)
        adv_vflip = nat_vflip

        target_rotate = []
        nat_rotate = []
        for i in range(4):
            nat_rotate.append(torch.rot90(nat, i, (2, 3)))
            target_rotate.append(torch.ones_like(target)*i)
        nat_rotate = torch.cat(nat_rotate, dim=0)
        target_rotate = torch.cat(target_rotate, dim=0)
        adv_rotate = nat_rotate

        optimizer.zero_grad()

        logits_vflip = model(adv_vflip, 'vflip')
        loss_vflip = F.cross_entropy(logits_vflip, target_vflip)
        logits_rotate = model(adv_rotate, 'rotate')
        loss_rotate = F.cross_entropy(logits_rotate, target_rotate)

        (args.penalty_rotate*loss_rotate).backward()
        (args.penalty_vflip*loss_vflip).backward()

        train_nat, train_target = next(iter_train)
        train_nat = train_nat.to(device)
        train_target = train_target.to(device)
        adv_train = perturb(model, train_nat, train_target).detach()
        logits_train = model(adv_train)
        loss_train = args.penalty * F.cross_entropy(logits_train, train_target)
        loss_train.backward()
        optimizer.step()

b_corrects = 0
a_corrects = 0
for batch_idx, (data, target) in tqdm(enumerate(test_loader), total=len(test_loader)):

    net.load_state_dict(copy.deepcopy(module_dict))
    net.eval()
    optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'], weight_decay=state['decay'])

    nat = data.to(device)
    nat.requires_grad = True
    target = target.to(device)

    with torch.no_grad():
        _, adv = square_attack_linf(net, nat.detach(), target, state['epsilon'], n_iters=2000, p_init=state['p_init'], device=device)
    logits = net(adv)
    acc = torch.argmax(logits, dim=1) == target
    correct = int(acc.sum())
    print("batch {}, before finetune: acc {}".format(batch_idx, correct/len(data)))
    b_corrects += correct

    accs = finetune(net, train_loader, adv.detach().clone(), target, optimizer, args)
    logits = net(adv)
    acc = torch.argmax(logits, dim=1) == target
    correct = int(acc.sum())
    print("batch {}, before finetune: acc {}".format(batch_idx, correct/len(data)))
    a_corrects += correct

    print("Total before finetune acc {}, Total after finetune acc {}".format(batch_idx, b_corrects, a_corrects))

