import argparse
import os
import math
import shutil
import random
import distutils.util
import numpy as np
import pandas as pd
import sys
import yaml

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.optim as optim
from torch.autograd import Variable
from torchvision.utils import save_image

# from cifar100.MemGuard.memguard_run import model

# config_file = './../../env.yml'
config_file = './env.yml'
with open(config_file, 'r') as stream:
    yamlfile = yaml.safe_load(stream)
    root_dir = yamlfile['root_dir']
    src_dir = yamlfile['src_dir']

sys.path.append(src_dir)
sys.path.append(os.path.join(src_dir, 'attack'))
sys.path.append(os.path.join(src_dir, 'models'))
from attack.dsq_attack import system_attack
from utils import mkdir_p, AverageMeter, accuracy, print_acc_conf, TrainRecorder
from cifar_utils import transform_train, transform_train_aug, transform_test, Cifardata, DistillCifardata, WarmUpLR, \
    ModelwNorm
from cifar100.models.model_selector import get_network
from cifar100.generative_models.acgan import Discriminator, Generator

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def test(predict, labels):
    correct = 0
    pred = predict.data.max(1)[1]
    correct = pred.eq(labels.data).cpu().sum()
    return correct, len(labels.data)


def save_checkpoint_g(state, is_best, acc, checkpoint):
    if not os.path.isdir(checkpoint):
        mkdir_p(checkpoint)
    filepath = os.path.join(checkpoint, 'model_g_last.pth.tar')
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(checkpoint, 'model_g_best.pth.tar'))


def save_checkpoint_d(state, is_best, acc, checkpoint):
    if not os.path.isdir(checkpoint):
        mkdir_p(checkpoint)
    filepath = os.path.join(checkpoint, 'model_d_last.pth.tar')
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(checkpoint, 'model_d_best.pth.tar'))


def get_learning_rate(optimizer):
    lr = []
    for param_group in optimizer.param_groups:
        lr += [param_group['lr']]
    return lr


def get_opt_and_lrsch(args, model_d, model_g, num_epoch, num_iter, warmup):
    optim_d = optim.Adam(model_d.parameters(), lr=0.00001, betas=(0.5, 0.999))
    optim_g = optim.Adam(model_g.parameters(), lr=0.00001, betas=(0.5, 0.999))
    train_scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
        optim_d, gamma=1., last_epoch=-1  # , last_epoch=classifier_epochs
    )
    train_scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
        optim_g, gamma=1., last_epoch=-1  # , last_epoch=classifier_epochs
    )
    return optim_d, optim_g, train_scheduler_d, train_scheduler_g


def main():
    parser = argparse.ArgumentParser(description='setting for cifar100')
    # parser.add_argument('--model', type=str, default='mobilenetv3_small_50')
    parser.add_argument('--classifier_epochs', type=int, default=200, help='classifier epochs')
    # parser.add_argument('--attack_epochs', type=int, default=150, help='attack epochs in NN attack')
    parser.add_argument('--print_epoch', type=int, default=5,
                        help='print model training stats per print_epoch_splitai during splitai training')
    parser.add_argument('--batch_size', type=int, default=256, help='batch size')
    parser.add_argument('--batch_step', type=int, default=1, help='batch accumulation steps')
    # parser.add_argument('--warmup', type=int, default=1, help='warm up epochs')
    parser.add_argument('--num_worker', type=int, default=1, help='number workers')
    parser.add_argument('--num_class', type=int, default=100, help='num class')
    parser.add_argument('--data_aug', type=distutils.util.strtobool, default=True, help='turn on data augmentation')
    parser.add_argument('--nz', type=int, default=100, help='Number of dimensions for input noise.')
    parser.add_argument('--run_idx', type=int, default=100, help='idx running')
    parser.add_argument('--save_path', default='save_checkpoints/', type=str, help='folder to save the checkpoints')
    parser.add_argument('--load_path', default='save_checkpoints/', type=str, help='folder to load the checkpoints')

    args = parser.parse_args()
    print(dict(args._get_kwargs()))

    # attack_epochs = args.attack_epochs
    batch_size = args.batch_size
    num_class = args.num_class
    classifier_epochs = args.classifier_epochs
    print_epoch = args.print_epoch
    # warmup = args.warmup
    num_worker = args.num_worker

    DATASET_PATH = os.path.join(root_dir, 'cifar100', 'data')
    checkpoint_path = os.path.join(args.save_path, 'cifar100', 'acgan',
                                   'aug' if args.data_aug else 'no_aug', str(args.run_idx))
    print(checkpoint_path)

    train_data_tr_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'tr_data.npy'))
    train_label_tr_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'tr_label.npy'))
    train_data_te_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'te_data.npy'))
    train_label_te_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'te_label.npy'))
    train_data = np.load(os.path.join(DATASET_PATH, 'partition', 'train_data.npy'))
    train_label = np.load(os.path.join(DATASET_PATH, 'partition', 'train_label.npy'))
    test_data = np.load(os.path.join(DATASET_PATH, 'partition', 'test_data.npy'))
    test_label = np.load(os.path.join(DATASET_PATH, 'partition', 'test_label.npy'))
    ref_data = np.load(os.path.join(DATASET_PATH, 'partition', 'ref_data.npy'))
    ref_label = np.load(os.path.join(DATASET_PATH, 'partition', 'ref_label.npy'))
    all_test_data = np.load(os.path.join(DATASET_PATH, 'partition', 'all_test_data.npy'))
    all_test_label = np.load(os.path.join(DATASET_PATH, 'partition', 'all_test_label.npy'))

    # print first 20 labels for each subset, for checking with other experiments
    print(train_label_tr_attack[:20])
    print(train_label_te_attack[:20])
    print(test_label[:20])
    print(ref_label[:20])

    # if data augmented
    if args.data_aug:
        trainset = Cifardata(train_data, train_label, transform_train_aug)
    else:
        trainset = Cifardata(train_data, train_label, transform_train)
    # load dataset
    # trainset = Cifardata(train_data, train_label, transform_train)
    traintestset = Cifardata(train_data, train_label, transform_test)
    testset = Cifardata(test_data, test_label, transform_test)
    refset = Cifardata(ref_data, ref_label, transform_test)

    trset = Cifardata(train_data_tr_attack, train_label_tr_attack, transform_test)
    teset = Cifardata(train_data_te_attack, train_label_te_attack, transform_test)
    alltestset = Cifardata(all_test_data, all_test_label, transform_test)

    trloader = torch.utils.data.DataLoader(trset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    teloader = torch.utils.data.DataLoader(teset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    alltestloader = torch.utils.data.DataLoader(alltestset, batch_size=batch_size, shuffle=False,
                                                num_workers=num_worker)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_worker)
    traintestloader = torch.utils.data.DataLoader(traintestset, batch_size=batch_size, shuffle=False,
                                                  num_workers=num_worker)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    refloader = torch.utils.data.DataLoader(refset, batch_size=batch_size, shuffle=False, num_workers=num_worker)

    INPUT_SIZE = 32 * 32 * 3
    IMG_SIZE = 32
    NUM_LABELS = num_class
    SAMPLE_SIZE = 64
    nz = args.nz

    best_acc = 0.00

    model_d = Discriminator(ndf=64, nc=3, nb_label=NUM_LABELS)
    model_g = Generator(nz=100, ngf=64, nc=3)

    iter_per_epoch = len(trainloader)

    model_d.to(device, torch.float)
    model_g.to(device, torch.float)

    optim_d, optim_g, train_scheduler_d, train_scheduler_g = get_opt_and_lrsch(args, model_d, model_g,
                                                                               classifier_epochs, iter_per_epoch,
                                                                               0)

    print("training sets: {:d}".format(len(trainset)))

    ''' loss function'''
    s_criterion = nn.BCELoss()
    c_criterion = nn.CrossEntropyLoss() #nn.NLLLoss()

    input = torch.FloatTensor(batch_size, 3, IMG_SIZE, IMG_SIZE).to(device)
    noise = torch.FloatTensor(batch_size, nz, 1, 1).to(device)
    fixed_noise = torch.FloatTensor(64, nz, 1, 1).normal_(0, 1).to(device)
    s_label = torch.FloatTensor(batch_size).to(device)
    c_label = torch.LongTensor(batch_size).to(device)

    real_label = 1
    fake_label = 0

    '''
    Training:

    - construct different mini-batches for real and fake images, and adjust G’s objective function to maximize log(D(G(z)))

    - Discriminator Training: update the discriminator by ascending its stochastic gradient, maximize log(D(x))+log(1−D(G(z)))

    - Generator Training: train the Generator by minimizing log(1-D(G(z)))
    '''
    print('Training...')

    netD = model_d
    netG = model_g
    for epoch in range(1, classifier_epochs + 1):
        fakeD_meter = AverageMeter()
        realD_meter = AverageMeter()
        g_meter = AverageMeter()
        fakeC_meter = AverageMeter()
        realC_meter = AverageMeter()
        for batch_idx, (train_x, train_y) in enumerate(trainloader):
            ###########################
            # (1) Update D network
            ###########################
            # train with real
            netD.zero_grad()
            img, label = train_x, train_y
            cur_batch_size = img.size(0)
            cur_input = input[0:cur_batch_size]
            cur_slabel = s_label[0:cur_batch_size]
            cur_clabel = c_label[0:cur_batch_size]
            cur_input.data.resize_(img.size()).copy_(img)
            cur_slabel.data.resize_(cur_batch_size).fill_(real_label)
            cur_clabel.data.resize_(cur_batch_size).copy_(label)
            # print(img.size(), cur_input.size(), cur_slabel.size(), cur_clabel.size())
            s_output, c_output = netD(cur_input)
            s_errD_real = s_criterion(s_output, cur_slabel.view(-1,1))
            c_errD_real = c_criterion(c_output, cur_clabel)
            errD_real = s_errD_real + c_errD_real
            errD_real.backward()
            D_x = s_output.data.mean()
            realD_meter.update(D_x, s_output.size(0))

            correct, length = test(c_output, cur_clabel)
            realC_meter.update(correct, length)

            # train with fake
            cur_noise = noise[0:cur_batch_size]
            cur_noise.data.resize_(cur_batch_size, nz, 1, 1)
            cur_noise.data.normal_(0, 1)

            label = np.random.randint(0, NUM_LABELS, cur_batch_size)
            noise_ = np.random.normal(0, 1, (cur_batch_size, nz))
            label_onehot = np.zeros((cur_batch_size, NUM_LABELS))
            label_onehot[np.arange(cur_batch_size), label] = 1
            noise_[np.arange(cur_batch_size), :NUM_LABELS] = label_onehot[np.arange(cur_batch_size)]

            noise_ = (torch.from_numpy(noise_))
            noise_ = noise_.resize_(cur_batch_size, nz, 1, 1)
            cur_noise.data.copy_(noise_)

            cur_clabel.data.resize_(cur_batch_size).copy_(torch.from_numpy(label))

            fake = netG(cur_noise)
            cur_slabel.data.fill_(fake_label)
            s_output, c_output = netD(fake.detach())
            s_errD_fake = s_criterion(s_output, cur_slabel.view(-1,1))
            c_errD_fake = c_criterion(c_output, cur_clabel)
            errD_fake = s_errD_fake + c_errD_fake

            errD_fake.backward()
            D_G_z1 = s_output.data.mean()
            fakeD_meter.update(D_G_z1, s_output.size(0))
            errD = s_errD_real + s_errD_fake
            optim_d.step()

            correct, length = test(c_output, cur_clabel)
            fakeC_meter.update(correct, length)

            ###########################
            # (2) Update G network
            ###########################
            netG.zero_grad()
            cur_slabel.data.fill_(real_label)  # fake labels are real for generator cost
            s_output, c_output = netD(fake)
            s_errG = s_criterion(s_output, cur_slabel.view(-1,1))
            c_errG = c_criterion(c_output, cur_clabel)

            errG = s_errG + c_errG
            errG.backward()
            D_G_z2 = s_output.data.mean()
            g_meter.update(D_G_z2, s_output.size(0))
            optim_g.step()


        ''' output training steps '''
        lr = get_learning_rate(optim_g)
        print(
            "Epoch: [{:d} | {:d}]: learning rate:{:.5f}. loss: mean D(fake) = {:.4f}, mean D(real) = {:.4f}, mean G(fake) = {:.4f}, real class acc: {:.4f}, fake class acc: {:.4f}".format(
                epoch, classifier_epochs, lr[0], fakeD_meter.avg, realD_meter.avg, g_meter.avg, realC_meter.avg, fakeC_meter.avg
            )
        )
        train_scheduler_d.step()
        train_scheduler_g.step()

        # save the last
        if epoch == 1:
            save_checkpoint_d({
                'epoch': epoch,
                'state_dict': model_d.state_dict(),
                'best_acc': best_acc,
                'optimizer': optim_d.state_dict(),
            }, False, best_acc, checkpoint=checkpoint_path)

        g_out = model_g(fixed_noise).data.view(SAMPLE_SIZE, 3, 32, 32).cpu()
        save_image(
            g_out, '{}/fake_samples_{}.png'.format(
                checkpoint_path, epoch
            )
        )

    # save the last
    save_checkpoint_d({
        'epoch': epoch,
        'state_dict': model_d.state_dict(),
        'best_acc': best_acc,
        'optimizer': optim_d.state_dict(),
    }, False, best_acc, checkpoint=checkpoint_path)

    save_checkpoint_g({
        'epoch': epoch,
        'state_dict': model_g.state_dict(),
        'best_acc': best_acc,
        'optimizer': optim_g.state_dict(),
    }, False, best_acc, checkpoint=checkpoint_path)

    g_out = model_g(fixed_noise).data.view(SAMPLE_SIZE, 3, 32, 32).cpu()
    save_image(
        g_out, '{}/fake_samples.png'.format(
            checkpoint_path
        )
    )


if __name__ == '__main__':
    main()
