import argparse
import logging
import sys
import time
import math

import numpy as np
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.autograd import grad as autograd
import torchvision
from collections import OrderedDict

import os

from wideresnet import WideResNet
from preactresnet import PreActResNet18

from utils import *
torch.autograd.set_detect_anomaly(True)

mu = torch.tensor(cifar10_mean).view(3,1,1).cuda()
std = torch.tensor(cifar10_std).view(3,1,1).cuda()

def normalize(X):
    return (X - mu)/std

upper_limit, lower_limit = 1,0


def clamp(X, lower_limit, upper_limit):
    return torch.max(torch.min(X, upper_limit), lower_limit)


def attack_pgd(model, X, y, epsilon, alpha, attack_iters, restarts,
               norm, early_stop=False,
               mixup=False, y_a=None, y_b=None, lam=None):
    max_loss = torch.zeros(y.shape[0]).cuda()
    max_delta = torch.zeros_like(X).cuda()

    if epsilon == 0:
        return max_delta

    for _ in range(restarts):
        delta = torch.zeros_like(X).cuda()
        if norm == "l_inf":
            delta.uniform_(-epsilon, epsilon)
        elif norm == "l_2":
            delta.normal_()
            d_flat = delta.view(delta.size(0),-1)
            n = d_flat.norm(p=2,dim=1).view(delta.size(0),1,1,1)
            r = torch.zeros_like(n).uniform_(0, 1)
            delta *= r/n*epsilon
        else:
            raise ValueError
        delta = clamp(delta, lower_limit-X, upper_limit-X)
        delta.requires_grad = True
        for _ in range(attack_iters):
            output = model(normalize(X + delta))
            if early_stop:
                index = torch.where(output.max(1)[1] == y)[0]
            else:
                index = slice(None,None,None)
            if not isinstance(index, slice) and len(index) == 0:
                break
            if mixup:
                criterion = nn.CrossEntropyLoss()
                loss = mixup_criterion(criterion, model(normalize(X+delta)), y_a, y_b, lam)
            else:
                loss = F.cross_entropy(output, y)
            loss.backward()
            grad = delta.grad.detach()
            d = delta[index, :, :, :]
            g = grad[index, :, :, :]
            x = X[index, :, :, :]
            if norm == "l_inf":
                d = torch.clamp(d + alpha * torch.sign(g), min=-epsilon, max=epsilon)
            elif norm == "l_2":
                g_norm = torch.norm(g.view(g.shape[0],-1),dim=1).view(-1,1,1,1)
                scaled_g = g/(g_norm + 1e-10)
                d = (d + scaled_g*alpha).view(d.size(0),-1).renorm(p=2,dim=0,maxnorm=epsilon).view_as(d)
            d = clamp(d, lower_limit - x, upper_limit - x)
            delta.data[index, :, :, :] = d
            delta.grad.zero_()
        if mixup:
            criterion = nn.CrossEntropyLoss(reduction='none')
            all_loss = mixup_criterion(criterion, model(normalize(X+delta)), y_a, y_b, lam)
        else:
            all_loss = F.cross_entropy(model(normalize(X+delta)), y, reduction='none')
        max_delta[all_loss >= max_loss] = delta.detach()[all_loss >= max_loss]
        max_loss = torch.max(max_loss, all_loss)
    return max_delta


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', default='PreActResNet18')
    parser.add_argument('--l2', default=0, type=float)
    parser.add_argument('--l2-clean', default=0.001, type=float)
    parser.add_argument('--l1', default=0, type=float)
    parser.add_argument('--l1-clean', default=0.00001, type=float)
    parser.add_argument('--batch-size', default=128, type=int)
    parser.add_argument('--data-dir', default='./tmp', type=str)
    parser.add_argument('--adv-epochs', default=200, type=int)
    parser.add_argument('--clean-epochs', default=200, type=int)
    parser.add_argument('--adv-lr', default=0.1, type=float)
    parser.add_argument('--clean-lr', default=0.1, type=float)
    parser.add_argument('--attack', default='pgd', type=str, choices=['pgd', 'fgsm', 'free', 'none'])
    parser.add_argument('--epsilon', default=8, type=int)
    parser.add_argument('--attack-iters', default=10, type=int)
    parser.add_argument('--restarts', default=1, type=int)
    parser.add_argument('--pgd-alpha', default=2, type=float)
    parser.add_argument('--fgsm-alpha', default=1.25, type=float)
    parser.add_argument('--norm', default='l_inf', type=str, choices=['l_inf', 'l_2'])
    parser.add_argument('--fgsm-init', default='random', choices=['zero', 'random', 'previous'])
    parser.add_argument('--fname', default='tmp/cifar_model/gowal/wide/', type=str)
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--half', action='store_true')
    parser.add_argument('--width-factor', default=10, type=int)
    parser.add_argument('--resume', default=0, type=int)
    parser.add_argument('--cutout', action='store_true')
    parser.add_argument('--cutout-len', type=int)
    parser.add_argument('--mixup', action='store_true')
    parser.add_argument('--mixup-alpha', type=float)
    parser.add_argument('--eval', action='store_true')
    parser.add_argument('--val', action='store_true')
    parser.add_argument('--chkpt-iters', default=10, type=int)
    # parser.add_argument('--classes', default='[0,1]', help='Classes to be used')
    parser.add_argument('--clean', default='', help='The clean model')
    parser.add_argument('--order', default='', help='Order of the samples in the dataset')
    parser.add_argument('--order-other', default='', help='Order of the samples in the dataset')
    parser.add_argument('--weight-true', default=1, type=float)
    parser.add_argument('--no-adjust', default=False, action='store_true')
    parser.add_argument('--data-other', default='', help='')

    return parser.parse_args()

args = get_args()

def clean_training(epochs, lr, train_batches, test_batches, logger):
    if args.model == 'PreActResNet18':
        model = PreActResNet18()
    elif args.model == 'WideResNet':
        model = WideResNet(34, 10, widen_factor=args.width_factor, dropRate=0.0)
    else:
        raise ValueError("Unknown model")

    model = nn.DataParallel(model).cuda()
    model.train()

    if args.clean != '':
        model.load_state_dict(torch.load(args.clean))
        return model

    if args.l2:
        decay, no_decay = [], []
        for name,param in model.named_parameters():
            if 'bn' not in name and 'bias' not in name:
                decay.append(param)
            else:
                no_decay.append(param)
        params = [{'params':decay, 'weight_decay':args.l2_clean},
                  {'params':no_decay, 'weight_decay': 0 }]
    else:
        params = model.parameters()


    opt = torch.optim.SGD(params, lr=lr, momentum=0.9, weight_decay=5e-4)

    criterion = nn.CrossEntropyLoss()

    def lr_schedule(t, epochs, lr):
        if t / epochs < 0.5:
            return lr
        elif t / epochs < 0.75:
            return lr / 10.
        else:
            return lr / 100.

    start_epoch = 0

    logger.info('Epoch \t Train Time \t Test Time \t LR \t \t Train Loss \t Train Acc \t Test Loss \t Test Acc')
    for epoch in range(start_epoch, epochs):
        model.train()
        start_time = time.time()
        train_loss = 0
        train_acc = 0
        train_n = 0
        i = 0

        for (X, y) in train_batches:
            X, y = X.float().cuda(), y.cuda()
            i += 1
            lr_new = lr_schedule(epoch + (i + 1) / len(train_batches), epochs, lr)
            opt.param_groups[0].update(lr=lr_new)

            output = model(normalize(torch.clamp(X , min=lower_limit, max=upper_limit)))

            loss = criterion(output, y)

            if args.l1_clean:
                for name,param in model.named_parameters():
                    if 'bn' not in name and 'bias' not in name:
                        loss += args.l1*param.abs().sum()

            opt.zero_grad()
            loss.backward()
            opt.step()

            train_loss += loss.item() * y.size(0)
            train_acc += (output.max(1)[1] == y).sum().item()
            train_n += y.size(0)

        train_time = time.time()

        model.eval()
        test_loss = 0
        test_acc = 0
        test_robust_loss = 0
        test_robust_acc = 0
        test_n = 0
        for (X, y) in test_batches:
            X, y = X.float().cuda(), y.cuda()
            output = model(normalize(X))
            loss = criterion(output, y)

            test_loss += loss.item() * y.size(0)
            test_acc += (output.max(1)[1] == y).sum().item()
            test_n += y.size(0)

        test_time = time.time()

        logger.info('%d \t %.1f \t %.1f \t %.4f \t %.4f \t %.4f \t %.4f \t %.4f ',
            epoch, train_time - start_time, test_time - train_time, lr,
            train_loss/train_n, train_acc/train_n, test_loss/test_n, test_acc/test_n)

        # save checkpoint
        if (epoch+1) % args.chkpt_iters == 0 or epoch+1 == epochs:
            torch.save(model.state_dict(), os.path.join(args.fname, f'clean_model_{epoch}.pth'))

    return model



def adv_training(epochs, lr, train_set_old, train_set, test_batches, logger, transforms):

    def adv_test(test_batches, logger, model):
        model.eval()
        test_loss = 0
        test_acc = 0
        test_robust_loss = 0
        test_robust_acc = 0
        test_n = 0
        for (X, y) in test_batches:

            X, y = X.float().cuda(), y.cuda()

            # Random initialization
            if args.attack == 'none':
                delta = torch.zeros_like(X)
            else:
                delta = attack_pgd(model, X, y, epsilon, pgd_alpha, args.attack_iters, args.restarts, args.norm, early_stop=args.eval)
            delta = delta.detach()

            robust_output = model(normalize(torch.clamp(X + delta[:X.size(0)], min=lower_limit, max=upper_limit)))
            robust_loss = criterion(robust_output, y).mean()

            output = model(normalize(X))
            loss = criterion(output, y).mean()

            test_robust_loss += robust_loss.item() * y.size(0)
            test_robust_acc += (robust_output.max(1)[1] == y).sum().item()
            test_loss += loss.item() * y.size(0)
            test_acc += (output.max(1)[1] == y).sum().item()
            test_n += y.size(0)

        test_time = time.time()

        logger.info('%d \t %.1f \t \t %.1f \t \t %.4f \t %.4f \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f \t %.4f \t \t %.4f',
            epoch, train_time - start_time, test_time - train_time, lr,
            train_loss/train_n, train_acc/train_n, train_robust_loss/train_n, train_robust_acc/train_n,
            test_loss/test_n, test_acc/test_n, test_robust_loss/test_n, test_robust_acc/test_n)

    epsilon = (args.epsilon / 255.)
    pgd_alpha = (args.pgd_alpha / 255.) * args.epsilon/8


    if args.model == 'PreActResNet18':
        model = PreActResNet18()
    elif args.model == 'WideResNet':
        model = WideResNet(34, 10, widen_factor=args.width_factor, dropRate=0.0)
    else:
        raise ValueError("Unknown model")

    model = nn.DataParallel(model).cuda()
    model.train()

    if args.l2:
        decay, no_decay = [], []
        for name,param in model.named_parameters():
            if 'bn' not in name and 'bias' not in name:
                decay.append(param)
            else:
                no_decay.append(param)
        params = [{'params':decay, 'weight_decay':args.l2},
                  {'params':no_decay, 'weight_decay': 0 }]
    else:
        params = model.parameters()
    

    opt = torch.optim.SGD(params, lr=lr, momentum=0.9, weight_decay=5e-4)

    criterion = nn.CrossEntropyLoss(reduce = False)

    if args.attack == 'free':
        delta = torch.zeros(args.batch_size, 3, 32, 32).cuda()
        delta.requires_grad = True
    elif args.attack == 'fgsm' and args.fgsm_init == 'previous':
        delta = torch.zeros(args.batch_size, 3, 32, 32).cuda()
        delta.requires_grad = True

    def lr_schedule(t, epochs, lr):
        if t / epochs < 0.5:
            return lr
        elif t / epochs < 0.75:
            return lr / 10.
        else:
            return lr / 100.


    best_test_robust_acc = 0
    best_val_robust_acc = 0
    
    cnt_iter = 0
    w = int( 1000000/50000 )
    # w = 10    
    bias = 0.
    sd1 = 0.
    sd2 = 0.

    epoch = 0

    if args.resume:
        start_epoch = args.resume
        model.load_state_dict(torch.load(os.path.join(args.fname, f'model_{start_epoch-1}.pth')))
        opt.load_state_dict(torch.load(os.path.join(args.fname, f'opt_{start_epoch-1}.pth')))
        m = torch.load(os.path.join(args.fname, f'cnt_{start_epoch-1}.pth'))
        cnt_iter = m['cnt']
        w = m['w']
        bias = m['bias']
        sd1 = m['sd1']
        sd2 = m['sd2']
        epoch = m['epoch']
        logger.info(f'Resuming at epoch {start_epoch}')

        best_test_robust_acc = 0
    else:
        start_epoch = 0

    total_n = 50000 
    total_iter = args.adv_epochs * (total_n/args.batch_size)

    w0 = max(int(int( 1000000/50000 )/4),1)
    num_per_epoch = int(50000/128)

    w_change = False

    logger.info('Epoch \t Train Time \t Test Time \t LR \t \t Train Loss \t Train Acc \t Train Robust Loss \t Train Robust Acc \t Test Loss \t Test Acc \t Test Robust Loss \t Test Robust Acc')



    # train_set_old_ = list(zip(transpose(pad(train_set_old_['data'], 4)/255.), train_set_old_['labels'], train_set_old_['weight']))

    # split the 1M dataset into 20 parts
    # manually count the number of used samples to determine which part for the next batch

    train_set = list(zip(transpose(pad(train_set['data'], 4)/255.), train_set['labels'],train_set['labels'] ))

    random.shuffle(train_set)

    train_sets = []
    for i in range(20):
        start = i*50000
        end = (i+1)*50000
        tmp = train_set[start:end]
        train_sets.append(Batches(Transform_weight(tmp, transforms)))

    def next_batch(prev, num):
        # prev: number of samples used previously
        # num: the batch_size for this batch
        # return the next batch and the total used samples
        i = int(prev / 50000)

        # check whether the current part have unsamples data with size >= num
        # otherwide use the next part
        if (i+1)*50000 < prev + num:
            i += 1
            prev = (i+1)*50000
        
        if i == 20:
            i = 0
            prev = 0
        
        return train_sets[i].get_batch(num), prev + num

    train_set_old = list(zip(transpose(pad(train_set_old['data'], 4)/255.), train_set_old['labels'],train_set_old['labels']))
    train_set_old = Transform_weight(train_set_old, transforms)   

    train = Batches(train_set_old)

    prev = 0

    while True:

        if cnt_iter >= total_iter:
            break

        num_other = 1000000
        num = 50000

        print(epoch, sd1, sd2, (sd1*num*((w+1)**2)+sd2*num_other),( (w+1)*num+num_other ),  (sd1*num*((w+1)**2)+sd2*num_other)/( (w+1)*num+num_other )**2)

        a = 1/args.weight_true*abs(bias**2/10)*(num_other**2)/(num_other+(w+w0)*num)**2 + (sd1*num*((w+w0)**2)+sd2*num_other)/( (w+w0)*num+num_other )**2
        b = 1/args.weight_true*abs(bias**2/10)*(num_other**2)/(num_other+(w)*num)**2 + (sd1*num*((w)**2)+sd2*num_other)/( (w)*num+num_other )**2
        c = 1/args.weight_true*abs(bias**2/10)*(num_other**2)/(num_other+max(w-w0,1)*num)**2 + (sd1*num*(max(w-w0,1)**2)+sd2*num_other)/( max(w-w0,1)*num+num_other )**2

        w_change = False

        if cnt_iter <= total_iter/5:
            if a < b:
                w += w0
                w_change = True
            if c < b:
                w -= w0
                w_change = not w_change
        

        w = w if w > 0.5 else 1

        print(a,b,c,w, 1/args.weight_true*abs(bias**2/10)*(num_other**2)/(num_other+(w)*num)**2,(sd1*num*((w)**2)+sd2*num_other)/( (w)*num+num_other )**2 )

        if args.no_adjust:
            w = args.weight_true


        model.train()
        start_time = time.time()
        train_loss = 0
        train_acc = 0
        train_robust_loss = 0
        train_robust_acc = 0
        train_n = 0
        i = 0
        batch_size_extra = int(args.batch_size*(1000000)/(1000000+50000*w))
        batch_size = args.batch_size - batch_size_extra
        print(epoch, batch_size, batch_size_extra)

        
        while True:


            cnt_iter += 1
            if cnt_iter >= total_iter:
                break
            
            if cnt_iter % num_per_epoch == 0:
                train_time = time.time()
                epoch += 1
                adv_test(test_batches, logger, model)
                break

            X, y, _ = train.get_batch(batch_size)
            (X1, y1, _), prev = next_batch(prev, batch_size_extra)

            size = len(y)
            X, y = X.float().cuda(), y.long().cuda()

            X1, y1 = X1.float().cuda(), y1.long().cuda()

            X = torch.cat((X, X1), 0)
            y = torch.cat((y, y1), 0)


            i += 1

            lr_new = lr_schedule( (cnt_iter/total_iter)*epochs, epochs, lr)

            opt.param_groups[0].update(lr=lr_new)

            if args.attack == 'pgd':
                # Random initialization
                if args.mixup:
                    delta = attack_pgd(model, X, y, epsilon, pgd_alpha, args.attack_iters, args.restarts, args.norm, mixup=True, y_a=y_a, y_b=y_b, lam=lam)
                else:
                    delta = attack_pgd(model, X, y, epsilon, pgd_alpha, args.attack_iters, args.restarts, args.norm)
                delta = delta.detach()
            elif args.attack == 'fgsm':
                delta = attack_pgd(model, X, y, epsilon, args.fgsm_alpha*epsilon, 1, 1, args.norm)
            # Standard training
            elif args.attack == 'none':
                delta = torch.zeros_like(X)

            robust_output = model(normalize(torch.clamp(X + delta[:X.size(0)], min=lower_limit, max=upper_limit)))
            robust_loss = criterion(robust_output, y)

            bias = bias*0.9 +   robust_loss[size:].mean().item() - robust_loss[:size].mean().item()
            sd1 = sd1*0.9 + robust_loss[:size].var().item()
            sd2 = sd2*0.9 + robust_loss[size:].var().item()
            if i == 1:  
                print(bias, sd1, sd2, robust_loss[:size].mean().item(), robust_loss[size:].mean().item())
            robust_loss = robust_loss.mean()

            if args.l1:
                for name,param in model.named_parameters():
                    if 'bn' not in name and 'bias' not in name:
                        robust_loss += args.l1*param.abs().sum()

            opt.zero_grad()
            robust_loss.backward()
            opt.step()

            output = model(normalize(X))
            loss = criterion(output, y)
            loss = loss.mean()

            train_robust_loss += robust_loss.item() * y.size(0)
            train_robust_acc += (robust_output.max(1)[1] == y).sum().item()
            train_loss += loss.item() * y.size(0)
            train_acc += (output.max(1)[1] == y).sum().item()
            train_n += y.size(0)

        torch.save(model.state_dict(),os.path.join(args.fname, f'model_{epoch-1}.pth'))
        torch.save(opt.state_dict(),os.path.join(args.fname, f'opt_{epoch-1}.pth'))
        m = {'cnt':cnt_iter, 'w':w, 'bias':bias, 'sd1':sd1, 'sd2':sd2,'epoch': epoch}
        torch.save(m,os.path.join(args.fname, f'cnt_{epoch-1}.pth'))

            

def add_label(model, dataset, transforms):
    # return dataset
    print('Constructing list')
    # print(len(dataset['labels']))
    dataset_ = list(zip(transpose(pad(dataset['data'], 4)/255.), dataset['labels']))
    # print(len(dataset_))
    print('Constructing transform')
    dataset_ = Transform(dataset_, transforms)
    # print(len(dataset_))
    print('Constructing dataloader')
    batches = torch.utils.data.DataLoader(dataset_, batch_size=args.batch_size, num_workers=50, pin_memory=True, shuffle=False, drop_last=False)
    # print('123')
    res = []
    i=0
    # print(len(batches))
    for (X, y) in batches:
        print(i)
        i+=1

        X, y = X.float().cuda(), y.cuda()

        output = model(normalize(X))
        
        y_new = output.max(1)[1].cpu().numpy()

        res += list(y_new)
    
    return {'data': dataset['data'], 'labels': res}
    

def main():
    if args.no_adjust:
        args.fname += '/no_adjust/'+str(args.weight_true)+'/seed_'+str(args.seed)+'/'+str(args.epsilon)+'/'+str(args.l1)+'_'+str(args.adv_lr)+'/'
    else:
        args.fname += '/'+str(args.weight_true)+'/seed_'+str(args.seed)+'/'+str(args.epsilon)+'/'+str(args.l1)+'_'+str(args.adv_lr)+'/'

    print(args.fname)
    if not os.path.exists(args.fname):
        os.makedirs(args.fname)

    logger = logging.getLogger(__name__)

    logging.basicConfig(
        format='[%(asctime)s] - %(message)s',
        datefmt='%Y/%m/%d %H:%M:%S',
        level=logging.DEBUG,
        handlers=[
            logging.FileHandler(os.path.join(args.fname, 'record.log')),
            logging.StreamHandler()
        ])

    logger.info(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    transforms = [Crop(32, 32), FlipLR()]
    if args.cutout:
        transforms.append(Cutout(args.cutout_len, args.cutout_len))

    train_dataset = torchvision.datasets.CIFAR10(root=args.data_dir, train=True, download=True)

    if args.order == '':
        order = []
    else:
        order=torch.load(args.order)['order']

    train_set, _, _, idx = select_class(train_dataset, classes = [0,1,2,3,4,5,6,7,8,9], num_per_class = 5000, order = order)

    if args.order == '':
        torch.save({'order': idx}, os.path.join(args.fname, f'order.pth'))

    train_set_old = train_set

    train_set = list(zip(transpose(pad(train_set['data'], 4)/255.), train_set['labels']))
    train_set = Transform(train_set, transforms)

    train_batches = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, num_workers=2, pin_memory=True, shuffle=True, drop_last=False)
    
    test_dataset = torchvision.datasets.CIFAR10(root=args.data_dir, train=False, download=True)
    test_set, _ , _ , _ = select_class(test_dataset, classes =  [0,1,2,3,4,5,6,7,8,9], num_per_class = 1000 )
    test_set = list(zip(transpose(pad(test_set['data'], 4)/255.), test_set['labels']))
    test_batches = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size, num_workers=2, pin_memory=True, shuffle=False, drop_last=False)


    adv_epochs = args.adv_epochs
    clean_epochs = args.clean_epochs

    adv_lr = args.adv_lr
    clean_lr = args.clean_lr

    if args.data_other != '':
        with open('./tmp/data.pth','rb') as infile:
            train_other = pickle.load(infile)
    else:
        clean_model = clean_training(clean_epochs, clean_lr, train_batches, test_batches, logger)

        logger.info("Clean model get.")

        logger.info("Loading data.")

        train_other = np.load( os.path.join(args.data_dir, f'cifar10_ddpm.npz'))

        train_other = train_other['image']

        train_other = {'data': train_other, 'labels':[1 for i in range(1000000)]}

        logger.info("Calculating labels.")


        train_other = add_label(clean_model, train_other, transforms)

        logger.info("Label get for train_other.")
        logger.info("Weight get for train_other.")

        with open('./tmp/data.pth','wb') as outfile:
            pickle.dump(train_other, outfile, pickle.HIGHEST_PROTOCOL)


    train_set_old['weight'] = [1 for i in range(len(train_set_old['labels']))]

    logger.info("Appending the new datasets...")

    # adv training

    adv_model = adv_training(adv_epochs, adv_lr, train_set_old, train_other, test_batches, logger, transforms)


if __name__ == "__main__":
    main()
