# -*- coding: utf-8 -*-
from __future__ import division

import argparse
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.Normalize import Normalize
from model.WResnet import WResnet_Rotate_VFlip
from DataLoader import cifar10
import numpy as np
import random
from tqdm import trange, tqdm
from torch.utils.tensorboard import SummaryWriter
import yaml
import time
import copy


def vflip(input):
    target_vflip = torch.randint(2, (len(input), ),  device=input.device, dtype=torch.long)
    nat_vflip = input.detach().clone()
    for i in range(len(input)):
        if int(target_vflip[i]) == 1:
            nat_vflip[i] = torch.flip(nat_vflip[i], dims=[1])
    return nat_vflip, target_vflip


def rotate(input):
    target_rotate = torch.randint(4, (len(input), ),  device=input.device, dtype=torch.long)
    nat_rotate = torch.empty_like(input)
    for i in range(len(input)):
        nat_rotate[i] = torch.rot90(input[i], int(target_rotate[i]), (1, 2))

    return nat_rotate, target_rotate


def perturb(net, nat, label, mode='label'):
    nat.requires_grad = True
    x = nat + (torch.rand_like(nat) - 0.5) * 2 * state['epsilon']
    x = torch.clamp(x, 0, 1)

    for i in range(state['num_steps']):
        y = net(x, mode)
        loss = F.cross_entropy(y, label)

        grad = torch.autograd.grad(loss, x)[0]

        x = x + state['step_size'] * torch.sign(grad)
        noise = x - nat
        noise = torch.min(noise, torch.ones_like(noise) * state['epsilon'])
        noise = torch.max(noise, -torch.ones_like(noise) * state['epsilon'])
        x = nat + noise
        x = torch.clamp(x, 0, 1)

    return x.detach()


def perturb_meta(net, nat, label, start, mode='label'):
    x = start.detach()
    x.requires_grad = True

    for i in range(state['meta_num_steps']):
        y = net(x, mode)
        loss = F.cross_entropy(y, label)

        grad = torch.autograd.grad(loss, x)[0]

        x = x + state['meta_step_size'] * torch.sign(grad)
        noise = x - nat
        noise = torch.min(noise, torch.ones_like(noise) * state['epsilon'])
        noise = torch.max(noise, -torch.ones_like(noise) * state['epsilon'])
        x = nat + noise
        x = torch.clamp(x, 0, 1)

    return x.detach()

def perturb_test(net, nat, label):
    x = nat + (torch.rand_like(nat) - 0.5) * 2 * state['epsilon']
    x = torch.clamp(x, 0, 1)

    for i in range(state['test_num_steps']):
        output = net(x)
        loss = F.cross_entropy(output, label)
        grad = torch.autograd.grad(loss, x)[0]

        x = x + state['test_step_size'] * torch.sign(grad)
        noise = x - nat
        noise = torch.min(noise, torch.ones_like(noise) * state['epsilon'])
        noise = torch.max(noise, -torch.ones_like(noise) * state['epsilon'])
        x = nat + noise
        x = torch.clamp(x, 0, 1)

    return x.detach()

parser = argparse.ArgumentParser()
# Positional arguments
parser.add_argument('config', type=str, default='./dataset', help='Root for the Cifar dataset.')
parser.add_argument('--device', default=[0], type=int, nargs='+')
args = parser.parse_args()

with open(args.config) as config_file:
    state = yaml.load(config_file, Loader=yaml.FullLoader)

for k, v in state.items():
    setattr(args, k, v)


if not os.path.isdir("./weight"):
    os.makedirs("./weight")

device = torch.device("cuda:" + str(args.device[0]))
train_loader, test_loader, _, nlabels, mean, std = cifar10(args)
net = WResnet_Rotate_VFlip(mean, std)

if args.load is not None:
    state_dict = torch.load(os.path.join("weight", args.load + ".pytorch"))
    net.load_state_dict(state_dict)

net.to(device)
net = nn.parallel.DataParallel(net, args.device)

optimizer = torch.optim.SGD(set(net.parameters()), state['learning_rate'], momentum=state['momentum'], weight_decay=state['decay'])
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.schedule, gamma=args.gamma)

def update_parameters(net, lr, grad):
    for i, p in enumerate(net.parameters()):
        p.data.add_(-lr*grad[i])

def assign_grad(net, grads):
    for i, p in enumerate(net.parameters()):
        grad = 0
        for k in range(len(grads)):
            if grads[k][i] is not None:
                grad += grads[k][i]
        grad /= len(grads)
        p.grad = grad

def meta(net, data, target, nat, args):
    model = copy.deepcopy(net)
    data = data.to(device)
    target = target.to(device)

    target_rotate = []
    nat_rotate = []

    for i in range(4):
        nat_rotate.append(torch.rot90(data, i, (2, 3)))
        target_rotate.append(torch.ones_like(target) * i)
    nat_rotate = torch.cat(nat_rotate, dim=0)
    target_rotate = torch.cat(target_rotate, dim=0)
    adv_rotate = nat_rotate

    target_vflip = []
    nat_vflip = []

    nat_vflip.append(data)
    nat_vflip.append(torch.flip(data, dims=[2]))
    target_vflip.append(torch.zeros_like(target))
    target_vflip.append(torch.ones_like(target))

    nat_vflip = torch.cat(nat_vflip, dim=0)
    target_vflip = torch.cat(target_vflip, dim=0)
    adv_vflip = nat_vflip

    logits_rotate = model(adv_rotate, 'rotate')
    loss_rotate = F.cross_entropy(logits_rotate, target_rotate)

    logits_vflip = model(adv_vflip, 'vflip')
    loss_vflip = F.cross_entropy(logits_vflip, target_vflip)

    loss = args.penalty_rotate*loss_rotate + args.penalty_vflip*loss_vflip
    grad_meta = torch.autograd.grad(loss, model.parameters(), allow_unused=True)

    update_parameters(model, optimizer.param_groups[0]['lr'], grad_meta)

    adv_meta = perturb_meta(model, nat, target, data,  'label')
    logits_train = model(adv_meta)
    loss_train = F.cross_entropy(logits_train, target)
    grad = torch.autograd.grad(loss_train, model.parameters(), allow_unused=True)

    return loss_train.detach(), loss_rotate.detach(), loss_vflip.detach(), grad


def train(total_step, train_loader, model, optimizer, args):

    model.train()
    for idx, (data, target) in tqdm(enumerate(train_loader), total=len(train_loader)):

        nat = data.to(device)
        target = target.to(device)
        adv = perturb(model, nat, target, 'label').detach()
        nat_rotate, target_rotate = rotate(nat)
        adv_rotate = perturb(model, nat_rotate.detach(), target_rotate, 'rotate').detach()
        nat_vflip, target_vflip = vflip(nat)
        adv_vflip = perturb(model, nat_vflip.detach(), target_vflip, 'vflip').detach()

        optimizer.zero_grad()
        b = len(data)//args.num_task
        grads = []
        losses_train = 0.0
        losses_vflip = 0.0
        losses_rotate = 0.0
        for i in range(args.num_task):
            loss_train, loss_rotate, loss_vflip, grad = meta(model, adv[b*i:b*(i+1)], target[b*i:b*(i+1)], nat[b*i:b*(i+1)].detach(), args)
            grads.append(grad)
            losses_train += loss_train
            losses_vflip += loss_vflip
            losses_rotate + loss_rotate

        loss_label = losses_train / args.num_task
        loss_vflip = losses_vflip / args.num_task
        loss_rotate = losses_rotate / args.num_task
        loss = loss_label + args.penalty_vflip * loss_vflip + args.penalty_rotate*loss_rotate
        assign_grad(model, grads)

        logits_rotate = model(adv_rotate, 'rotate')
        loss_rotate = args.penalty_rotate * F.cross_entropy(logits_rotate, target_rotate)
        loss_rotate.backward()

        logits_vflip = model(adv_vflip, 'vflip')
        loss_vflip = args.penalty_vflip * F.cross_entropy(logits_vflip, target_vflip)
        loss_vflip.backward()

        optimizer.step()

        total_step += 1
        if idx % 50 == 0:
            print("batch {}, loss_label {}, loss_rotate {}, loss_vflip {}, loss {}".format(idx, loss_label.item(), loss_rotate.item(), loss_vflip.item(), loss.item()))

    return len(train_loader)


def test(net, test_loader):
    net.eval()
    loss_avg = 0.0
    correct = 0
    for batch_idx, (data, target) in enumerate(test_loader):
        nat = data.to(device)
        nat.requires_grad = True
        target = target.to(device)
        adv = perturb_test(net, nat, target)

        logits = net(adv)
        loss = F.cross_entropy(logits, target)

        correct += int((torch.argmax(logits, dim=1) == target).sum())

        loss_avg += loss.item()

    state['test_loss'] = loss_avg / len(test_loader)
    state['test_accuracy'] = correct / len(test_loader.dataset)


# Main loop
best_accuracy = 0.0
total_step = 1
for epoch in trange(args.epochs):
    state['epoch'] = epoch
    total_step += train(total_step, train_loader, net, optimizer, args)
    test(net, test_loader)
    if state['test_accuracy'] > best_accuracy:
        best_accuracy = state['test_accuracy']

    torch.save(net.module.state_dict(), os.path.join("weight", args.save+str(epoch+1)+".pytorch"))
    print(state)
    print("Best accuracy: %f" % best_accuracy)
    scheduler.step()


