# -*- coding: utf-8 -*-
from __future__ import division

import argparse
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.Normalize import Normalize
from model.WResnet import WResnet_Rotate_VFlip
from DataLoader import cifar10
import numpy as np
import random
from tqdm import trange, tqdm
from torch.utils.tensorboard import SummaryWriter
import yaml
import copy
from collections import OrderedDict

def perturb_test(net, nat, label):
    x = nat + (torch.rand_like(nat) - 0.5) * 2 * state['epsilon']
    x = torch.clamp(x, 0, 1)

    for i in range(state['num_steps']):
        output = net(x)
        loss = F.cross_entropy(output, label)
        grad = torch.autograd.grad(loss, x)[0]

        x = x + state['step_size'] * torch.sign(grad)
        noise = x - nat
        noise = torch.min(noise, torch.ones_like(noise) * state['epsilon'])
        noise = torch.max(noise, -torch.ones_like(noise) * state['epsilon'])
        x = nat + noise
        x = torch.clamp(x, 0, 1)

    return x.detach()

def perturb(net, nat, label):
    nat.requires_grad = True
    x = nat + (torch.rand_like(nat) - 0.5) * 2 * state['epsilon']
    x = torch.clamp(x, 0, 1)

    for i in range(state['train_num_steps']):
        output = net(x)
        loss = F.cross_entropy(output, label)
        grad = torch.autograd.grad(loss, x)[0]

        x = x + state['train_step_size'] * torch.sign(grad)
        noise = x - nat
        noise = torch.min(noise, torch.ones_like(noise) * state['epsilon'])
        noise = torch.max(noise, -torch.ones_like(noise) * state['epsilon'])
        x = nat + noise
        x = torch.clamp(x, 0, 1)

    return x.detach()

parser = argparse.ArgumentParser()
# Positional arguments
parser.add_argument('config', type=str, default='./dataset', help='Root for the Cifar dataset.')
parser.add_argument('--device', default=[0], type=int, nargs='+')
parser.add_argument('--batch_size_test', default=None, type=int)
parser.add_argument('--finetune_step', default=None, type=int)
parser.add_argument('--learning_rate', default=None, type=float)

args = parser.parse_args()

with open(args.config) as config_file:
    state = yaml.load(config_file, Loader=yaml.FullLoader)

if args.batch_size_test is not None:
    state['batch_size_test'] = args.batch_size_test
if args.finetune_step is not None:
    state['finetune_step'] = args.finetune_step
if args.batch_size_test is not None:
    state['learning_rate'] = args.learning_rate


for k, v in state.items():
    setattr(args, k, v)

device = torch.device("cuda:" + str(args.device[0]))
train_loader, test_loader, _, nlabels, mean, std = cifar10(args)
net = WResnet_Rotate_VFlip(mean, std)

assert args.load is not None
state_dict = torch.load(os.path.join("weight", args.load+".pytorch"))
module_dict = OrderedDict()
for key, item in state_dict.items():
    module_dict["module."+key] = item

net.to(device)
net = nn.parallel.DataParallel(net, args.device)

def finetune(model, train_loader, data, target, optimizer, args):

    nat = data.to(device)
    target = target.to(device)
    iter_train = iter(train_loader)
    for idx in range(args.finetune_step):
        target_vflip = []
        nat_vflip = []

        nat_vflip.append(data)
        nat_vflip.append(torch.flip(data, dims=[2]))
        target_vflip.append(torch.zeros_like(target))
        target_vflip.append(torch.ones_like(target))

        nat_vflip = torch.cat(nat_vflip, dim=0)
        target_vflip = torch.cat(target_vflip, dim=0)
        adv_vflip = nat_vflip

        target_rotate = []
        nat_rotate = []
        for i in range(4):
            nat_rotate.append(torch.rot90(nat, i, (2, 3)))
            target_rotate.append(torch.ones_like(target)*i)
        nat_rotate = torch.cat(nat_rotate, dim=0)
        target_rotate = torch.cat(target_rotate, dim=0)
        adv_rotate = nat_rotate

        optimizer.zero_grad()

        logits_vflip = model(adv_vflip, 'vflip')
        loss_vflip = F.cross_entropy(logits_vflip, target_vflip)
        logits_rotate = model(adv_rotate, 'rotate')
        loss_rotate = F.cross_entropy(logits_rotate, target_rotate)

        (args.penalty_rotate*loss_rotate).backward()
        (args.penalty_vflip*loss_vflip).backward()

        train_nat, train_target = next(iter_train)
        train_nat = train_nat.to(device)
        train_target = train_target.to(device)
        adv_train = perturb(model, train_nat, train_target).detach()
        logits_train = model(adv_train)
        loss_train = args.penalty * F.cross_entropy(logits_train, train_target)
        loss_train.backward()
        optimizer.step()

b_corrects = 0
a_corrects = 0
for batch_idx, (data, target) in tqdm(enumerate(test_loader), total=len(test_loader)):

    net.load_state_dict(copy.deepcopy(module_dict))
    net.eval()
    optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'], weight_decay=state['decay'])

    nat = data.to(device)
    nat.requires_grad = True
    target = target.to(device)

    adv = perturb_test(net, nat, target)
    logits = net(adv)
    acc = torch.argmax(logits, dim=1) == target
    correct = int(acc.sum())
    print("batch {}, before finetune: acc {}".format(batch_idx, correct/len(data)))
    b_corrects += correct

    accs = finetune(net, train_loader, adv.detach().clone(), target, optimizer, args)
    logits = net(adv)
    acc = torch.argmax(logits, dim=1) == target
    correct = int(acc.sum())
    print("batch {}, before finetune: acc {}".format(batch_idx, correct/len(data)))
    a_corrects += correct

    print("Total before finetune acc {}, Total after finetune acc {}".format(batch_idx, b_corrects, a_corrects))

