import argparse
import operator
import sys
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from meta_adv_optimizer import MetaOneStageOptimizer
from torch.autograd import Variable
from torchvision import datasets, transforms
import numpy as np
from utils import *
from tqdm import tqdm
import time
from model_loader import load_model

parser = argparse.ArgumentParser(description='PyTorch REINFORCE example')
parser.add_argument('--batch_size', type=int, default=32, metavar='N',
                    help='training batch size (default: 32)')
parser.add_argument('--test_batch_size', type=int, default=100, metavar='N',
                    help='test batch size (default: 100)')
parser.add_argument('--optimizer_steps', type=int, default=20, metavar='N',
                    help='number of meta optimizer steps (default: 100)')
parser.add_argument('--pgd_steps', type=int, default=20, metavar='N',
                    help='number of meta pgd steps')
parser.add_argument('--truncated_bptt_step', type=int, default=10, metavar='N',
                    help='step at which it truncates bptt (default: 20)')
parser.add_argument('--updates_per_epoch', type=int, default=100, metavar='N',
                    help='updates per epoch (default: 100)')
parser.add_argument('--max_epoch', type=int, default=100, metavar='N',
                    help='number of epoch (default: 10000)')
parser.add_argument('--hidden_size', type=int, default=10, metavar='N',
                    help='hidden size of the meta optimizer (default: 10)')
parser.add_argument('--num_layers', type=int, default=2, metavar='N',
                    help='number of LSTM layers (default: 2)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='enables CUDA training')
parser.add_argument('--lr', default=1e-3, type=float)
parser.add_argument('--output_path', type=str, default="./ckpts",
                    help='output path to save the ckpt')
parser.add_argument('--model', type=str, default="resnet18",
                    help='target model')
parser.add_argument('--metaloss', type=str, default="cw",
                    help='meta loss')
parser.add_argument('--loss', type=str, default="cw",
                    help='loss')
parser.add_argument('--inputdim', type=int, default=3, metavar='N',
                    help='meta input dimension')
parser.add_argument('--meta_name', type=str, default="rnn",
                    help='meta name')                           
parser.add_argument('--process_grad', default=False, action='store_true', help='')
parser.add_argument('--reg', type=float, default=0.0)
parser.add_argument('--change_point', type=int, default=1)
parser.add_argument('--loss_final', action='store_true', default=False)
parser.add_argument('--filter', action='store_true', default=False)
parser.add_argument('--filter_direction', default='largest', type=str)
parser.add_argument('--test_pgd', action='store_true', default=False)
parser.add_argument('--seed', default=20, type=int)
parser.add_argument('--resume', type=str, default="", help='resume')
parser.add_argument('--evaluate', default=False, action='store_true', help='')
parser.add_argument('--data_nums', type=int, default=10000, help='data nums')
parser.add_argument('--norm', type=str, default="Linf", help='')

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
torch.backends.cudnn.benchmark=True

assert args.optimizer_steps % args.truncated_bptt_step == 0

eps = 8.0 / 255.0
eps = args.epsilon
if args.model == 'trades':
    eps = 0.031
print('epsilon:{}'.format(eps))
device='cuda:0'

output_path = args.output_path
if not os.path.exists(output_path):
    os.makedirs(output_path)

# dataloader
kwargs = {'num_workers': 8, 'pin_memory': True} if args.cuda else {}
train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('../data', train=False, download=True, transform=transforms.Compose([
                       transforms.ToTensor()
                    ])),
    batch_size=args.batch_size, shuffle=True, drop_last=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('../data', train=False, download=True, transform=transforms.Compose([
                       transforms.ToTensor()
                   ])),
    batch_size=args.test_batch_size, shuffle=False, drop_last=True, **kwargs)

# load target model
model = load_model(args.model, device, args.norm)


def clip_by_tensor(t, t_min, t_max):
    """
    clip_by_tensor
    :param t: tensor
    :param t_min: min
    :param t_max: max
    :return: cliped tensor
    """
    t = t.float()
    t_min = t_min.float()
    t_max = t_max.float()
 
    result = (t >= t_min).float() * t + (t < t_min).float() * t_min
    result = (result <= t_max).float() * result + (result > t_max).float() * t_max
    return result

def dlr_loss(x, y):
    x_sorted, ind_sorted = x.sort(dim=1)
    ind = (ind_sorted[:, -1] == y).float()
    
    return (-(x[np.arange(x.shape[0]), y] - x_sorted[:, -2] * ind - x_sorted[:, -1] * (1. - ind)) / (x_sorted[:, -1] - x_sorted[:, -3] + 1e-12)).sum()

def cw_loss(x, y, c=99999.0, class_num=10, indiv=False):
    logits_mask = torch.zeros(x.size(0), class_num).cuda().scatter_(1, y.unsqueeze(-1), 1)
    logit_this = torch.sum(logits_mask * x, dim=-1)
    logit_that = torch.max(x - c * logits_mask, dim=-1)[0]
    # return (logit_that - logit_this).sum()
    if indiv: return (logit_that - logit_this)
    return (logit_that - logit_this).mean()

def ce_loss(x, y):
    loss = nn.CrossEntropyLoss()(x, y)
    return loss

def ods_loss(x, y):
    randVector_ = torch.FloatTensor(x.shape).uniform_(-1.,1.).cuda()
    loss = (x * randVector_).sum()
    return loss


if args.loss == 'cw':
    criterion = cw_loss
elif args.loss == 'ce':
    criterion = ce_loss
elif args.loss == 'md':
    criterion = md_loss
elif args.loss == 'dlr':
    criterion = dlr_loss
else:
    raise IOError

if args.metaloss == 'cw':
    criterion_meta = cw_loss
elif args.metaloss == 'ce':
    criterion_meta = ce_loss
elif args.metaloss == 'dlr':
    criterion_meta = dlr_loss
elif args.metaloss == 'md':
    criterion_meta = md_loss

else:
    raise IOError


def test_pgd():
    correct = 0
    correct_adv = 0
    acc_steps = np.zeros([args.pgd_steps])
    for x, y in tqdm(test_loader):
        x, y = Variable(x).cuda(), Variable(y).cuda()
        with torch.no_grad():
            output = model(x)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(y.view_as(pred)).sum().item()

        if args.norm == 'Linf':
            x_adv = x.detach() + torch.FloatTensor(*x.shape).uniform_(-eps, eps).cuda()
        else: # 'L2'
            t = torch.randn(x.shape).to(device).detach()
            x_adv = x + eps * torch.ones_like(x).detach() * normalize(t)
        #
        grad_m = torch.zeros_like(x)
        momentum = 1

        for i in range(args.pgd_steps):
            step_size = adjust_step_size(i)
            x_adv.requires_grad_()
            if args.use_ni:
                x_ni = x_adv + momentum  * grad_m * 1e-4
                loss = criterion(model(x_ni), y)
            else:
                loss = criterion(model(x_adv), y)
            grad = torch.autograd.grad(loss, [x_adv])[0]
            if args.use_momentum:
                grad = grad / torch.mean(torch.abs(grad), dim=[1,2,3], keepdim=True)
                grad = momentum * grad_m + grad
            if args.norm == 'Linf':
                x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
                x_adv = torch.min(torch.max(x_adv, x - eps), x + eps)
            else:
                x_adv = x_adv.detach() + step_size * normalize(grad)
                x_adv = x + normalize(x_adv - x) * torch.min(eps * torch.ones_like(x).detach(), l2_norm(x_adv - x))
            x_adv = torch.clamp(x_adv, 0.0, 1.0)
            grad_m = grad
            
            if i == (args.pgd_steps - 1):
                with torch.no_grad():
                    output = model(x_adv)
                    pred = output.max(1, keepdim=True)[1]
                    acc_steps[i] += pred.eq(y.view_as(pred)).sum().item()
        
    print('{} clean:{}\t robust:{}'.format(save_name, correct, acc_steps[args.optimizer_steps-1]))


def attack_filter(x, y):
    filter_length = 32
    
    x_adv = x.detach() + torch.FloatTensor(*x.shape).uniform_(-eps, eps).cuda()

    with torch.no_grad():
        x_init = x_adv.clone()
        loss_init = cw_loss(model(x_init), y, indiv=True)

    for idx in range(args.pgd_steps):
        x_adv.requires_grad_()
        loss = cw_loss(model(x_adv), y)
        grad = torch.autograd.grad(loss, [x_adv])[0]
        step_size = adjust_step_size(idx)
        x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
        x_adv = torch.min(torch.max(x_adv, x - eps), x + eps)
        x_adv = torch.clamp(x_adv, 0.0, 1.0)

    with torch.no_grad():
        l = cw_loss(model(x_adv), y, indiv=True)
        loss_diff = (l - loss_init).detach().cpu()
    choice_index = loss_diff.topk(k=filter_length, largest=True) if args.filter_direction == 'largest' else loss_diff.topk(k=filter_length, largest=False)
    print(values, choice_index)
    # print(model(x_adv[choice_index]).max(1)[1] == y[choice_index].data)
    return x[choice_index], y[choice_index]
    

def adjust_step_size(idx):
    """decrease the learning rate"""
    if idx < args.change_point:
        step_size = eps
    else:
        step_size = eps / 4.
    return step_size

def eval_test(meta_optimizer):
    num_acc_adv = 0
    meta_optimizer.eval()
    acc_steps = np.zeros([args.pgd_steps])
    # for batch, (x, y) in enumerate(tqdm(test_loader)):
    for batch, (x, y) in enumerate(test_loader):
        x, y = Variable(x), Variable(y)
        if args.cuda:
            x, y = x.cuda(), y.cuda()

        clip_tensor_one = torch.ones_like(x)
        clip_tensor_zero = torch.zeros_like(x)
        x_max = clip_by_tensor(x.clone() + eps, clip_tensor_zero, clip_tensor_one)
        x_min = clip_by_tensor(x.clone() - eps, clip_tensor_zero, clip_tensor_one)

        x_adv = x.detach() + torch.FloatTensor(*x.shape).uniform_(-eps, eps).cuda()
        x_adv = clip_by_tensor(x_adv, x_min, x_max)
        x_adv = Variable(x_adv.data, requires_grad=True)

        if args.meta_name != 'gen':
            meta_optimizer.module.reset_lstm(
                keep_states=False, xadv=x_adv, use_cuda=args.cuda)

        for idx in range(args.pgd_steps):
            x_adv = x_adv.clone().detach().requires_grad_(True)

            # First we need to compute the gradients of the model
            f_x = model(x_adv)
            loss = -criterion(f_x, y)
            model.zero_grad()
            loss.backward()
            meta_step_size = adjust_step_size(idx)
            x_adv = meta_optimizer.module.meta_update(x_adv, "eval", idx, meta_step_size, x_min, x_max, clip_by_tensor)
        
    meta_optimizer.train()
    acc_steps = acc_steps / float((batch + 1) * args.test_batch_size)
    print('meta robust:', acc_steps)
    return acc_steps[args.optimizer_steps-1], acc_steps
    

def main():
    if args.test_pgd:
        test_pgd()

    # Create a meta optimizer that wraps a model into a meta model
    # to keep track of the meta updates.
    if args.meta_name == 'onernn':
        meta_optimizer = MetaOneStageOptimizer(args.num_layers, args.hidden_size, args.batch_size, args.inputdim)
    print(args)

    if args.resume:
        ckpt = torch.load(os.path.join(args.resume))
        meta_optimizer.load_state_dict(ckpt.module.state_dict())
        print('Resuming from {}' .format(args.resume))
    
    meta_optimizer = torch.nn.DataParallel(meta_optimizer)
    if args.cuda:
        meta_optimizer.module.cuda()
    
    if args.evaluate:
        adv_acc, _ = eval_test(meta_optimizer)
        print("evaluate adv_acc", adv_acc)
        exit()

    optimizer = optim.Adam(meta_optimizer.parameters(), lr=args.lr)
    best_adv_acc = 1.0
    acc_epochs = np.zeros([args.max_epoch])

    for epoch in range(args.max_epoch):
        decrease_in_loss = 0.0
        final_loss = 0.0
        train_iter = iter(train_loader)
        for i in tqdm(range(args.updates_per_epoch)):
            try:
                x, y = next(train_iter)
            except:
                train_iter = iter(train_loader)
                x, y = next(train_iter)
            x, y = Variable(x), Variable(y)
            if args.cuda:
                x, y = x.cuda(), y.cuda()
            # filter some examples
            if args.filter:
                x, y = attack_filter(x, y)
            clip_tensor_one = torch.ones_like(x)
            clip_tensor_zero = torch.zeros_like(x)
            x_max = clip_by_tensor(x + eps, clip_tensor_zero, clip_tensor_one)
            x_min = clip_by_tensor(x - eps, clip_tensor_zero, clip_tensor_one)
            x_adv = x.detach() + torch.FloatTensor(*x.shape).uniform_(-eps, eps).cuda()
            x_adv = clip_by_tensor(x_adv, x_min, x_max)
            x_adv = Variable(x_adv.data, requires_grad=True)

            # Compute initial loss of the model
            with torch.no_grad():
                f_x = model(x_adv)
                initial_loss = -criterion(f_x, y)
                #print(initial_loss)

            for k in range(args.optimizer_steps // args.truncated_bptt_step):
                # Keep states for truncated BPTT
                if args.meta_name != 'gen':
                    meta_optimizer.module.reset_lstm(
                        keep_states=k > 0, xadv=x_adv, use_cuda=args.cuda)
                else:
                    assert args.truncated_bptt_step == 1
                loss_sum = 0
                reg_total = 0
                prev_loss = torch.zeros(1)
                if args.cuda:
                    prev_loss = prev_loss.cuda()
                for j in range(args.truncated_bptt_step):

                    # First we need to compute the gradients of the model
                    f_x = model(x_adv)
                    
                    
                    loss = -criterion(f_x, y)
                    model.zero_grad()
                    loss.backward()

                    # Perfom a meta update using gradients from model
                    # and return the current meta model saved in the optimizer
                    meta_step_size = adjust_step_size(j + k * args.truncated_bptt_step)
                    meta_xadv, reg = meta_optimizer.module.meta_update(x_adv, "train", j+args.truncated_bptt_step*k, meta_step_size, x_min, x_max, clip_by_tensor)
                    x_adv.data.copy_(meta_xadv.data)

                    # Compute a loss for a step the meta optimizer
                    if not args.loss_final:
                        # Compute a loss for a step the meta optimizer
                        f_x = model(meta_xadv)
                        loss = -criterion_meta(f_x, y)
                        if args.reg > 0:
                            loss += args.reg * reg

                        loss_sum += (loss - Variable(prev_loss))
                        prev_loss = loss.data
                    else:
                        reg_total +=  reg

                if args.loss_final:
                    f_x = model(meta_xadv)
                    loss_sum = -criterion_meta(f_x, y)
                    if args.reg > 0:
                        loss_sum += args.reg * reg_total
            
                print("Epoch: {}, top:{}/{} loss: {:.4f}".format(epoch, k, args.optimizer_steps // args.truncated_bptt_step, loss_sum.item()))
                meta_optimizer.zero_grad()
                loss_sum.backward()
                optimizer.step()

            # value
            decrease_in_loss += loss.item() / initial_loss.item()
            final_loss += loss.item()

        print("Epoch: {}, final loss {}, average final/initial loss ratio: {}".format(epoch, final_loss / args.updates_per_epoch,
                                                                       decrease_in_loss / args.updates_per_epoch))
        if (epoch % 1 == 0): # eval
            adv_acc, acc_steps = eval_test(meta_optimizer)
            acc_epochs[epoch] = adv_acc
            if (adv_acc < best_adv_acc):
                best_adv_acc = adv_acc
                file_name = "{}_{:.4f}.pth".format(epoch, adv_acc)
                print("save ckpt:", file_name)
                torch.save(meta_optimizer, os.path.join(args.output_path, file_name))

if __name__ == "__main__":
    main()
