import argparse
import os
import math
import shutil
import random
import distutils.util
import numpy as np
import pandas as pd
import sys
import yaml

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.optim as optim
from torch.autograd import Variable
from torchvision.utils import save_image

# from cifar100.MemGuard.memguard_run import model

# config_file = './../../env.yml'
config_file = './env.yml'
with open(config_file, 'r') as stream:
    yamlfile = yaml.safe_load(stream)
    root_dir = yamlfile['root_dir']
    src_dir = yamlfile['src_dir']

sys.path.append(src_dir)
sys.path.append(os.path.join(src_dir, 'attack'))
sys.path.append(os.path.join(src_dir, 'models'))
from attack.dsq_attack import system_attack
from utils import mkdir_p, AverageMeter, accuracy, print_acc_conf, TrainRecorder
from cifar_utils import transform_train, transform_train_aug, transform_test, Cifardata, DistillCifardata, WarmUpLR, \
    ModelwNorm
from cifar100.models.model_selector import get_network
from cifar100.generative_models.dcgan_cond import Discriminator, Generator

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def test(testloader, model, criterion):
    model.eval()

    losses = AverageMeter()
    top1 = AverageMeter()

    for batch_ind, (inputs, targets) in enumerate(testloader):
        inputs = inputs.to(device, torch.float)
        targets = targets.to(device, torch.long)
        outputs = model(inputs)

        loss = criterion(outputs, targets)

        prec1, _ = accuracy(outputs.data, targets.data, topk=(1, 5))
        losses.update(loss.item(), inputs.size()[0])
        top1.update(prec1.item() / 100.0, inputs.size()[0])

    return (losses.avg, top1.avg)


def save_checkpoint_g(state, is_best, acc, checkpoint):
    if not os.path.isdir(checkpoint):
        mkdir_p(checkpoint)
    filepath = os.path.join(checkpoint, 'model_g_last.pth.tar')
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(checkpoint, 'model_g_best.pth.tar'))


def save_checkpoint_d(state, is_best, acc, checkpoint):
    if not os.path.isdir(checkpoint):
        mkdir_p(checkpoint)
    filepath = os.path.join(checkpoint, 'model_d_last.pth.tar')
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(checkpoint, 'model_d_best.pth.tar'))


def get_learning_rate(optimizer):
    lr = []
    for param_group in optimizer.param_groups:
        lr += [param_group['lr']]
    return lr


def get_opt_and_lrsch(args, model_d, model_g, num_epoch, num_iter, warmup):
    optim_d = optim.Adam(model_d.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optim_g = optim.Adam(model_g.parameters(), lr=0.0002, betas=(0.5, 0.999))
    train_scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
        optim_d, gamma=1., last_epoch=-1  # , last_epoch=classifier_epochs
    )
    train_scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
        optim_g, gamma=1., last_epoch=-1  # , last_epoch=classifier_epochs
    )
    return optim_d, optim_g, train_scheduler_d, train_scheduler_g


def main():
    parser = argparse.ArgumentParser(description='setting for cifar100')
    # parser.add_argument('--model', type=str, default='mobilenetv3_small_50')
    parser.add_argument('--classifier_epochs', type=int, default=200, help='classifier epochs')
    # parser.add_argument('--attack_epochs', type=int, default=150, help='attack epochs in NN attack')
    parser.add_argument('--print_epoch', type=int, default=5,
                        help='print model training stats per print_epoch_splitai during splitai training')
    parser.add_argument('--batch_size', type=int, default=256, help='batch size')
    parser.add_argument('--batch_step', type=int, default=1, help='batch accumulation steps')
    # parser.add_argument('--warmup', type=int, default=1, help='warm up epochs')
    parser.add_argument('--num_worker', type=int, default=1, help='number workers')
    parser.add_argument('--num_class', type=int, default=100, help='num class')
    parser.add_argument('--data_aug', type=distutils.util.strtobool, default=True, help='turn on data augmentation')
    parser.add_argument('--nz', type=int, default=100, help='Number of dimensions for input noise.')
    parser.add_argument('--run_idx', type=int, default=100, help='idx running')
    parser.add_argument('--save_path', default='save_checkpoints/', type=str, help='folder to save the checkpoints')
    parser.add_argument('--load_path', default='save_checkpoints/', type=str, help='folder to load the checkpoints')

    args = parser.parse_args()
    print(dict(args._get_kwargs()))

    # attack_epochs = args.attack_epochs
    batch_size = args.batch_size
    num_class = args.num_class
    classifier_epochs = args.classifier_epochs
    print_epoch = args.print_epoch
    # warmup = args.warmup
    num_worker = args.num_worker

    DATASET_PATH = os.path.join(root_dir, 'cifar100', 'data')
    checkpoint_path = os.path.join(args.save_path, 'cifar100', 'cdcgan',
                                   'aug' if args.data_aug else 'no_aug', str(args.run_idx))
    print(checkpoint_path)

    train_data_tr_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'tr_data.npy'))
    train_label_tr_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'tr_label.npy'))
    train_data_te_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'te_data.npy'))
    train_label_te_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'te_label.npy'))
    train_data = np.load(os.path.join(DATASET_PATH, 'partition', 'train_data.npy'))
    train_label = np.load(os.path.join(DATASET_PATH, 'partition', 'train_label.npy'))
    test_data = np.load(os.path.join(DATASET_PATH, 'partition', 'test_data.npy'))
    test_label = np.load(os.path.join(DATASET_PATH, 'partition', 'test_label.npy'))
    ref_data = np.load(os.path.join(DATASET_PATH, 'partition', 'ref_data.npy'))
    ref_label = np.load(os.path.join(DATASET_PATH, 'partition', 'ref_label.npy'))
    all_test_data = np.load(os.path.join(DATASET_PATH, 'partition', 'all_test_data.npy'))
    all_test_label = np.load(os.path.join(DATASET_PATH, 'partition', 'all_test_label.npy'))

    # print first 20 labels for each subset, for checking with other experiments
    print(train_label_tr_attack[:20])
    print(train_label_te_attack[:20])
    print(test_label[:20])
    print(ref_label[:20])

    # if data augmented
    if args.data_aug:
        trainset = Cifardata(train_data, train_label, transform_train_aug)
    else:
        trainset = Cifardata(train_data, train_label, transform_train)
    # load dataset
    # trainset = Cifardata(train_data, train_label, transform_train)
    traintestset = Cifardata(train_data, train_label, transform_test)
    testset = Cifardata(test_data, test_label, transform_test)
    refset = Cifardata(ref_data, ref_label, transform_test)

    trset = Cifardata(train_data_tr_attack, train_label_tr_attack, transform_test)
    teset = Cifardata(train_data_te_attack, train_label_te_attack, transform_test)
    alltestset = Cifardata(all_test_data, all_test_label, transform_test)

    trloader = torch.utils.data.DataLoader(trset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    teloader = torch.utils.data.DataLoader(teset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    alltestloader = torch.utils.data.DataLoader(alltestset, batch_size=batch_size, shuffle=False,
                                                num_workers=num_worker)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_worker)
    traintestloader = torch.utils.data.DataLoader(traintestset, batch_size=batch_size, shuffle=False,
                                                  num_workers=num_worker)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    refloader = torch.utils.data.DataLoader(refset, batch_size=batch_size, shuffle=False, num_workers=num_worker)

    INPUT_SIZE = 32 * 32 * 3
    NUM_LABELS = num_class
    SAMPLE_SIZE = 64
    nz = args.nz

    best_acc = 0.00

    model_d = Discriminator(classes=100, channels=3)
    model_g = Generator(classes=100, channels=3)

    iter_per_epoch = len(trainloader)
#
    #criterion = nn.BCELoss()
    #input = torch.FloatTensor(args.batch_size, INPUT_SIZE)
    #noise = torch.FloatTensor(args.batch_size, (args.nz))
#
    #fixed_noise = torch.FloatTensor(SAMPLE_SIZE, args.nz).normal_(0, 1)
    #fixed_labels = torch.zeros(SAMPLE_SIZE, NUM_LABELS)
    #for i in range(NUM_LABELS):
    #    for j in range(SAMPLE_SIZE // NUM_LABELS):
    #        fixed_labels[i * (SAMPLE_SIZE // NUM_LABELS) + j, i] = 1.0
#
    #label = torch.FloatTensor(args.batch_size)
    #one_hot_labels = torch.FloatTensor(args.batch_size, NUM_LABELS)
    #if torch.cuda.is_available():
    #    model_d.to(device, torch.float)
    #    model_g.to(device, torch.float)
    #    input, label = input.cuda(), label.cuda()
    #    noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
    #    one_hot_labels = one_hot_labels.cuda()
    #    fixed_labels = fixed_labels.cuda()

    #optim_d = optim.SGD(model_d.parameters(), lr=args.lr)
    #optim_g = optim.SGD(model_g.parameters(), lr=args.lr)
    #fixed_noise = Variable(fixed_noise)
    #fixed_labels = Variable(fixed_labels)

    model_d.to(device, torch.float)
    model_g.to(device, torch.float)

    optim_d, optim_g, train_scheduler_d, train_scheduler_g = get_opt_and_lrsch(args, model_d, model_g,
                                                                               classifier_epochs, iter_per_epoch,
                                                                               0)

    print("training sets: {:d}".format(len(trainset)))

    ''' loss function: we shall use the binary cross entropy loss as mentioned in the paper '''
    adversarial_loss = torch.nn.BCELoss()

    ''' Creating batch of latent vectors that we will use to visualize the progression of the generator '''
    #fixed_noise = torch.randn(64, args.nz, 1, 1, dtype=torch.float, device=device)
    fixed_noise = torch.randn(64, args.nz, dtype=torch.float, device=device)
    fixed_labels = torch.eye(NUM_LABELS, dtype=torch.float, device=device)[0:SAMPLE_SIZE]

    ''' defining real label as 1 and the fake label as 0, to be used when calculating the losses of Discriminator and Generator '''
    real_label = 1.
    fake_label = 0.

    '''
    Training:

    - construct different mini-batches for real and fake images, and adjust G’s objective function to maximize log(D(G(z)))

    - Discriminator Training: update the discriminator by ascending its stochastic gradient, maximize log(D(x))+log(1−D(G(z)))

    - Generator Training: train the Generator by minimizing log(1-D(G(z)))
    '''
    print('Training...')

    for epoch in range(1, classifier_epochs + 1):
        fakeD_meter = AverageMeter()
        realD_meter = AverageMeter()
        g_meter = AverageMeter()

        for batch_idx, (train_x, train_y) in enumerate(trainloader):

            ''' Training the Discriminator with real samples '''
            ''' updating Discriminator network: maximize log(D(x)) + log(1 - D(G(z))) '''
            model_d.zero_grad()

            ''' creating batches of real samples from the dataset '''
            batch = train_x.to(device)
            one_hot_labels = torch.eye(NUM_LABELS, dtype=torch.float, device=device)[train_y]
            b_size = batch.size(0)

            ''' creating the target tensor '''
            label = torch.full((b_size,), real_label, dtype=torch.float, device=device)

            ''' passing the batch of real samples through the discriminator '''
            output = model_d(batch, one_hot_labels).view(-1)

            ''' calculating the discriminator error for real samples '''
            errorD_real = adversarial_loss(output, label)

            ''' calculating the gradients through backprop '''
            errorD_real.backward()
            Dx = output.mean().item()
            realD_meter.update(Dx, output.size(0))

            ''' Training the Discriminator with fake samples '''
            ''' generating a fake batch from the generator '''
            #noise = torch.randn(b_size, nz, 1, 1, device=device)
            noise = torch.randn(b_size, nz, device=device)
            fake_batch = model_g(noise, one_hot_labels)
            label.fill_(fake_label)

            ''' passing the batch of fake samples through the discriminator '''
            output = model_d(fake_batch.detach(), one_hot_labels).view(-1)

            ''' calculating the discriminator error for fake samples '''
            errorD_fake = adversarial_loss(output, label)

            ''' calculating the gradients through backprop '''
            errorD_fake.backward()
            Dz = output.mean().item()
            fakeD_meter.update(Dz, output.size(0))

            ''' computing the final discriminator error '''
            errorD = errorD_fake + errorD_real

            ''' updating the discrimintor '''
            optim_d.step()

            ''' Training the Generator '''
            model_g.zero_grad()

            ''' creating the target tensor '''
            label.fill_(real_label)

            ''' passing the batch of fake samples through the discriminator '''
            output = model_d(fake_batch, one_hot_labels).view(-1)

            ''' calculating the generator error '''
            errorG = adversarial_loss(output, label)

            ''' calculating the gradients through backprop '''
            errorG.backward()
            Gz = output.mean().item()
            g_meter.update(Gz, output.size(0))

            ''' updating the generator '''
            optim_g.step()

        ''' output training steps '''
        lr = get_learning_rate(optim_g)
        print(
            "Epoch: [{:d} | {:d}]: learning rate:{:.4f}. loss: mean D(fake) = {:.4f}, mean D(real) = {:.4f}, mean G(fake) = {:.4f}".format(
                epoch, classifier_epochs, lr[0], fakeD_meter.avg, realD_meter.avg, g_meter.avg
            )
        )
        train_scheduler_d.step()
        train_scheduler_g.step()

        # save the last
        if epoch == 1:
            save_checkpoint_d({
                'epoch': epoch,
                'state_dict': model_d.state_dict(),
                'best_acc': best_acc,
                'optimizer': optim_d.state_dict(),
            }, False, best_acc, checkpoint=checkpoint_path)

        g_out = model_g(fixed_noise, fixed_labels).data.view(SAMPLE_SIZE, 3, 32, 32).cpu()
        save_image(
            g_out, '{}/fake_samples_{}.png'.format(
                checkpoint_path, epoch
            )
        )

    # save the last
    save_checkpoint_d({
        'epoch': epoch,
        'state_dict': model_d.state_dict(),
        'best_acc': best_acc,
        'optimizer': optim_d.state_dict(),
    }, False, best_acc, checkpoint=checkpoint_path)

    save_checkpoint_g({
        'epoch': epoch,
        'state_dict': model_g.state_dict(),
        'best_acc': best_acc,
        'optimizer': optim_g.state_dict(),
    }, False, best_acc, checkpoint=checkpoint_path)

    g_out = model_g(fixed_noise, fixed_labels).data.view(SAMPLE_SIZE, 3, 32, 32).cpu()
    save_image(
        g_out, '{}/fake_samples.png'.format(
            checkpoint_path
        )
    )


if __name__ == '__main__':
    main()
