import argparse
import os
import math
import shutil
import random
import distutils.util
import numpy as np
import pandas as pd
import sys
import yaml

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.optim as optim
from torch.autograd import Variable
from torchvision.utils import save_image

config_file = './env.yml'
with open(config_file, 'r') as stream:
    yamlfile = yaml.safe_load(stream)
    root_dir = yamlfile['root_dir']
    src_dir = yamlfile['src_dir']
sys.path.append(src_dir)
import cifar100.generative_models.losses as losses
import cifar100.generative_models.BigGAN.utils as utils

# from cifar100.MemGuard.memguard_run import model

# config_file = './../../env.yml'
config_file = './env.yml'
with open(config_file, 'r') as stream:
    yamlfile = yaml.safe_load(stream)
    root_dir = yamlfile['root_dir']
    src_dir = yamlfile['src_dir']

sys.path.append(src_dir)
sys.path.append(os.path.join(src_dir, 'attack'))
sys.path.append(os.path.join(src_dir, 'models'))
from attack.dsq_attack import system_attack
from utils import mkdir_p, AverageMeter, accuracy, print_acc_conf, TrainRecorder
from cifar_utils import transform_train, transform_train_aug, transform_test, Cifardata, DistillCifardata, WarmUpLR, \
    ModelwNorm
from cifar100.models.model_selector import get_network
from cifar100.generative_models.dcgan_cond import Discriminator, Generator
import cifar100.generative_models.icgan as icgan

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def GAN_training_function(
    G,
    D,
    GD,
    ema,
    state_dict,
    config,
    sample_conditionings,
    embedded_optimizers=True,
    device="cuda",
    batch_size=0,
):
    def train(x, y=None, features=None):
        if embedded_optimizers:
            G.optim.zero_grad()
            D.optim.zero_grad()
        else:
            GD.optimizer_D.zero_grad()
            GD.optimizer_G.zero_grad()
        # How many chunks to split x and y into?
        x = torch.split(x, batch_size)
        if y is not None:
            y = torch.split(y, batch_size)
        if features is not None:
            f_ = torch.split(features, batch_size)
        else:
            f_ = None
        counter = 0

        # Optionally toggle D and G's "require_grad"
        if config["toggle_grads"]:
            utils.toggle_grad(D, True)
            utils.toggle_grad(G, False)

        for step_index in range(config["num_D_steps"]):
            # If accumulating gradients, loop multiple times before an optimizer step
            if embedded_optimizers:
                D.optim.zero_grad()
            else:
                GD.optimizer_D.zero_grad()
            for accumulation_index in range(config["num_D_accumulations"]):
                # Sample conditioning for G
                sampled_cond = sample_conditionings()
                labels_g, f_g = None, None
                if features is not None and y is not None:
                    z_, labels_g, f_g = sampled_cond
                elif y is not None:
                    z_, labels_g = sampled_cond
                elif features is not None:
                    z_, f_g = sampled_cond
                # Tensors to device
                if labels_g is not None:
                    labels_g = (
                        labels_g[:batch_size].to(device, non_blocking=True).long()
                    )
                if f_g is not None:
                    f_g = f_g[:batch_size].to(device, non_blocking=True)
                z_ = z_[:batch_size].to(device, non_blocking=True)
                # Obtain discriminator scores
                D_fake, D_real = GD(
                    z_,
                    labels_g,
                    f_g,
                    x[counter],
                    y[counter] if y is not None else None,
                    f_[counter] if f_ is not None else None,
                    train_G=False,
                    split_D=config["split_D"],
                    policy=config["DiffAugment"],
                    DA=config["DA"],
                )

                # Compute components of D's loss, average them, and divide by
                # the number of gradient accumulations
                D_loss_real, D_loss_fake = losses.discriminator_loss(D_fake, D_real)
                D_loss = (D_loss_real + D_loss_fake) / float(
                    config["num_D_accumulations"]
                )
                D_loss.backward()
                counter += 1

            # Optionally apply ortho reg in D
            if config["D_ortho"] > 0.0:
                # Debug print to indicate we're using ortho reg in D.
                print("using modified ortho reg in D")
                utils.ortho(D, config["D_ortho"])

            if embedded_optimizers:
                D.optim.step()
            else:
                GD.optimizer_D.step()

        # Optionally toggle "requires_grad"
        if config["toggle_grads"]:
            utils.toggle_grad(D, False)
            utils.toggle_grad(G, True)

        # Zero G's gradients by default before training G, for safety
        if embedded_optimizers:
            G.optim.zero_grad()
        else:
            GD.optimizer_G.zero_grad()

        counter = 0
        # If accumulating gradients, loop multiple times
        for accumulation_index in range(config["num_G_accumulations"]):
            # Sample conditioning for G
            sampled_cond = sample_conditionings()
            labels_g, f_g = None, None
            if features is not None and y is not None:
                z_, labels_g, f_g = sampled_cond
            elif y is not None:
                z_, labels_g = sampled_cond
            elif features is not None:
                z_, f_g = sampled_cond
            # Tensors to device
            if labels_g is not None:
                labels_g = labels_g.to(device, non_blocking=True).long()
            if f_g is not None:
                f_g = f_g.to(device, non_blocking=True)
            z_ = z_.to(device, non_blocking=True)
            # Obtain discriminator scores
            D_fake = GD(
                z_,
                labels_g,
                f_g,
                train_G=True,
                split_D=config["split_D"],
                policy=config["DiffAugment"],
                DA=config["DA"],
            )
            G_loss = losses.generator_loss(D_fake) / float(
                config["num_G_accumulations"]
            )
            G_loss.backward()
            counter += 1

        # Optionally apply modified ortho reg in G
        if config["G_ortho"] > 0.0:
            print(
                "using modified ortho reg in G"
            )  # Debug print to indicate we're using ortho reg in G
            # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this
            utils.ortho(
                G,
                config["G_ortho"],
                blacklist=[param for param in G.shared.parameters()],
            )
        if embedded_optimizers:
            G.optim.step()
        else:
            GD.optimizer_G.step()

        # If we have an ema, update it, regardless of if we test with it or not
        if config["ema"]:
            ema.update(state_dict["itr"])

        out = {
            "G_loss": float(G_loss.item()),
            "D_loss_real": float(D_loss_real.item()),
            "D_loss_fake": float(D_loss_fake.item()),
        }
        # Return G's loss and the components of D's loss.
        return out

    return train

def test(testloader, model, criterion):
    model.eval()

    losses = AverageMeter()
    top1 = AverageMeter()

    for batch_ind, (inputs, targets) in enumerate(testloader):
        inputs = inputs.to(device, torch.float)
        targets = targets.to(device, torch.long)
        outputs = model(inputs)

        loss = criterion(outputs, targets)

        prec1, _ = accuracy(outputs.data, targets.data, topk=(1, 5))
        losses.update(loss.item(), inputs.size()[0])
        top1.update(prec1.item() / 100.0, inputs.size()[0])

    return (losses.avg, top1.avg)


def save_checkpoint_g(state, is_best, acc, checkpoint):
    if not os.path.isdir(checkpoint):
        mkdir_p(checkpoint)
    filepath = os.path.join(checkpoint, 'model_g_last.pth.tar')
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(checkpoint, 'model_g_best.pth.tar'))


def save_checkpoint_d(state, is_best, acc, checkpoint):
    if not os.path.isdir(checkpoint):
        mkdir_p(checkpoint)
    filepath = os.path.join(checkpoint, 'model_d_last.pth.tar')
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(checkpoint, 'model_d_best.pth.tar'))


def get_learning_rate(optimizer):
    lr = []
    for param_group in optimizer.param_groups:
        lr += [param_group['lr']]
    return lr


def get_opt_and_lrsch(args, model_d, model_g, num_epoch, num_iter, warmup):
    optim_d = optim.Adam(model_d.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optim_g = optim.Adam(model_g.parameters(), lr=0.0002, betas=(0.5, 0.999))
    train_scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
        optim_d, gamma=1., last_epoch=-1  # , last_epoch=classifier_epochs
    )
    train_scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
        optim_g, gamma=1., last_epoch=-1  # , last_epoch=classifier_epochs
    )
    return optim_d, optim_g, train_scheduler_d, train_scheduler_g


def main():
    parser = argparse.ArgumentParser(description='setting for cifar100')
    # parser.add_argument('--model', type=str, default='mobilenetv3_small_50')
    parser.add_argument('--classifier_epochs', type=int, default=200, help='classifier epochs')
    # parser.add_argument('--attack_epochs', type=int, default=150, help='attack epochs in NN attack')
    parser.add_argument('--print_epoch', type=int, default=5,
                        help='print model training stats per print_epoch_splitai during splitai training')
    parser.add_argument('--batch_size', type=int, default=256, help='batch size')
    parser.add_argument('--batch_step', type=int, default=1, help='batch accumulation steps')
    # parser.add_argument('--warmup', type=int, default=1, help='warm up epochs')
    parser.add_argument('--num_worker', type=int, default=1, help='number workers')
    parser.add_argument('--num_class', type=int, default=100, help='num class')
    parser.add_argument('--data_aug', type=distutils.util.strtobool, default=True, help='turn on data augmentation')
    parser.add_argument('--nz', type=int, default=100, help='Number of dimensions for input noise.')
    parser.add_argument('--run_idx', type=int, default=100, help='idx running')
    parser.add_argument('--save_path', default='save_checkpoints/', type=str, help='folder to save the checkpoints')
    parser.add_argument('--load_path', default='save_checkpoints/', type=str, help='folder to load the checkpoints')

    args = parser.parse_args()
    print(dict(args._get_kwargs()))

    # attack_epochs = args.attack_epochs
    batch_size = args.batch_size
    num_class = args.num_class
    classifier_epochs = args.classifier_epochs
    print_epoch = args.print_epoch
    # warmup = args.warmup
    num_worker = args.num_worker

    DATASET_PATH = os.path.join(root_dir, 'cifar100', 'data')
    checkpoint_path = os.path.join(args.save_path, 'cifar100', 'cdcgan',
                                   'aug' if args.data_aug else 'no_aug', str(args.run_idx))
    print(checkpoint_path)

    train_data_tr_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'tr_data.npy'))
    train_label_tr_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'tr_label.npy'))
    train_data_te_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'te_data.npy'))
    train_label_te_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'te_label.npy'))
    train_data = np.load(os.path.join(DATASET_PATH, 'partition', 'train_data.npy'))
    train_label = np.load(os.path.join(DATASET_PATH, 'partition', 'train_label.npy'))
    test_data = np.load(os.path.join(DATASET_PATH, 'partition', 'test_data.npy'))
    test_label = np.load(os.path.join(DATASET_PATH, 'partition', 'test_label.npy'))
    ref_data = np.load(os.path.join(DATASET_PATH, 'partition', 'ref_data.npy'))
    ref_label = np.load(os.path.join(DATASET_PATH, 'partition', 'ref_label.npy'))
    all_test_data = np.load(os.path.join(DATASET_PATH, 'partition', 'all_test_data.npy'))
    all_test_label = np.load(os.path.join(DATASET_PATH, 'partition', 'all_test_label.npy'))

    # print first 20 labels for each subset, for checking with other experiments
    print(train_label_tr_attack[:20])
    print(train_label_te_attack[:20])
    print(test_label[:20])
    print(ref_label[:20])

    # if data augmented
    if args.data_aug:
        trainset = Cifardata(train_data, train_label, transform_train_aug)
    else:
        trainset = Cifardata(train_data, train_label, transform_train)
    # load dataset
    # trainset = Cifardata(train_data, train_label, transform_train)
    traintestset = Cifardata(train_data, train_label, transform_test)
    testset = Cifardata(test_data, test_label, transform_test)
    refset = Cifardata(ref_data, ref_label, transform_test)

    trset = Cifardata(train_data_tr_attack, train_label_tr_attack, transform_test)
    teset = Cifardata(train_data_te_attack, train_label_te_attack, transform_test)
    alltestset = Cifardata(all_test_data, all_test_label, transform_test)

    trloader = torch.utils.data.DataLoader(trset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    teloader = torch.utils.data.DataLoader(teset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    alltestloader = torch.utils.data.DataLoader(alltestset, batch_size=batch_size, shuffle=False,
                                                num_workers=num_worker)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_worker)
    traintestloader = torch.utils.data.DataLoader(traintestset, batch_size=batch_size, shuffle=False,
                                                  num_workers=num_worker)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    refloader = torch.utils.data.DataLoader(refset, batch_size=batch_size, shuffle=False, num_workers=num_worker)

    INPUT_SIZE = 32 * 32 * 3
    NUM_LABELS = num_class
    SAMPLE_SIZE = 64
    nz = args.nz

    best_acc = 0.00

    model_d = Discriminator(classes=100, channels=3)
    model_g = Generator(classes=100, channels=3)
    config = {

    }
    G = icgan.Generator(**{**config, "embedded_optimizers": False}).to(device)
    D = icgan.Discriminator(**{**config, "embedded_optimizers": False}).to(device)

    iter_per_epoch = len(trainloader)
#
    #criterion = nn.BCELoss()
    #input = torch.FloatTensor(args.batch_size, INPUT_SIZE)
    #noise = torch.FloatTensor(args.batch_size, (args.nz))
#
    #fixed_noise = torch.FloatTensor(SAMPLE_SIZE, args.nz).normal_(0, 1)
    #fixed_labels = torch.zeros(SAMPLE_SIZE, NUM_LABELS)
    #for i in range(NUM_LABELS):
    #    for j in range(SAMPLE_SIZE // NUM_LABELS):
    #        fixed_labels[i * (SAMPLE_SIZE // NUM_LABELS) + j, i] = 1.0
#
    #label = torch.FloatTensor(args.batch_size)
    #one_hot_labels = torch.FloatTensor(args.batch_size, NUM_LABELS)
    #if torch.cuda.is_available():
    #    model_d.to(device, torch.float)
    #    model_g.to(device, torch.float)
    #    input, label = input.cuda(), label.cuda()
    #    noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
    #    one_hot_labels = one_hot_labels.cuda()
    #    fixed_labels = fixed_labels.cuda()

    #optim_d = optim.SGD(model_d.parameters(), lr=args.lr)
    #optim_g = optim.SGD(model_g.parameters(), lr=args.lr)
    #fixed_noise = Variable(fixed_noise)
    #fixed_labels = Variable(fixed_labels)

    model_d.to(device, torch.float)
    model_g.to(device, torch.float)

    optim_d, optim_g, train_scheduler_d, train_scheduler_g = get_opt_and_lrsch(args, model_d, model_g,
                                                                               classifier_epochs, iter_per_epoch,
                                                                               0)

    print("training sets: {:d}".format(len(trainset)))

    ''' loss function: we shall use the binary cross entropy loss as mentioned in the paper '''
    adversarial_loss = torch.nn.BCELoss()

    ''' Creating batch of latent vectors that we will use to visualize the progression of the generator '''
    #fixed_noise = torch.randn(64, args.nz, 1, 1, dtype=torch.float, device=device)
    fixed_noise = torch.randn(64, args.nz, dtype=torch.float, device=device)
    fixed_labels = torch.eye(NUM_LABELS, dtype=torch.float, device=device)[0:SAMPLE_SIZE]

    ''' defining real label as 1 and the fake label as 0, to be used when calculating the losses of Discriminator and Generator '''
    real_label = 1.
    fake_label = 0.

    '''
    Training:

    - construct different mini-batches for real and fake images, and adjust G’s objective function to maximize log(D(G(z)))

    - Discriminator Training: update the discriminator by ascending its stochastic gradient, maximize log(D(x))+log(1−D(G(z)))

    - Generator Training: train the Generator by minimizing log(1-D(G(z)))
    '''
    print('Training...')

    for epoch in range(1, classifier_epochs + 1):
        fakeD_meter = AverageMeter()
        realD_meter = AverageMeter()
        g_meter = AverageMeter()

        for batch_idx, (train_x, train_y) in enumerate(trainloader):

            ''' Training the Discriminator with real samples '''
            ''' updating Discriminator network: maximize log(D(x)) + log(1 - D(G(z))) '''
            model_d.zero_grad()

            ''' creating batches of real samples from the dataset '''
            batch = train_x.to(device)
            one_hot_labels = torch.eye(NUM_LABELS, dtype=torch.float, device=device)[train_y]
            b_size = batch.size(0)

            ''' creating the target tensor '''
            label = torch.full((b_size,), real_label, dtype=torch.float, device=device)

            ''' passing the batch of real samples through the discriminator '''
            output = model_d(batch, one_hot_labels).view(-1)

            ''' calculating the discriminator error for real samples '''
            errorD_real = adversarial_loss(output, label)

            ''' calculating the gradients through backprop '''
            errorD_real.backward()
            Dx = output.mean().item()
            realD_meter.update(Dx, output.size(0))

            ''' Training the Discriminator with fake samples '''
            ''' generating a fake batch from the generator '''
            #noise = torch.randn(b_size, nz, 1, 1, device=device)
            noise = torch.randn(b_size, nz, device=device)
            fake_batch = model_g(noise, one_hot_labels)
            label.fill_(fake_label)

            ''' passing the batch of fake samples through the discriminator '''
            output = model_d(fake_batch.detach(), one_hot_labels).view(-1)

            ''' calculating the discriminator error for fake samples '''
            errorD_fake = adversarial_loss(output, label)

            ''' calculating the gradients through backprop '''
            errorD_fake.backward()
            Dz = output.mean().item()
            fakeD_meter.update(Dz, output.size(0))

            ''' computing the final discriminator error '''
            errorD = errorD_fake + errorD_real

            ''' updating the discrimintor '''
            optim_d.step()

            ''' Training the Generator '''
            model_g.zero_grad()

            ''' creating the target tensor '''
            label.fill_(real_label)

            ''' passing the batch of fake samples through the discriminator '''
            output = model_d(fake_batch, one_hot_labels).view(-1)

            ''' calculating the generator error '''
            errorG = adversarial_loss(output, label)

            ''' calculating the gradients through backprop '''
            errorG.backward()
            Gz = output.mean().item()
            g_meter.update(Gz, output.size(0))

            ''' updating the generator '''
            optim_g.step()

        ''' output training steps '''
        lr = get_learning_rate(optim_g)
        print(
            "Epoch: [{:d} | {:d}]: learning rate:{:.4f}. loss: mean D(fake) = {:.4f}, mean D(real) = {:.4f}, mean G(fake) = {:.4f}".format(
                epoch, classifier_epochs, lr[0], fakeD_meter.avg, realD_meter.avg, g_meter.avg
            )
        )
        train_scheduler_d.step()
        train_scheduler_g.step()

        # save the last
        if epoch == 1:
            save_checkpoint_d({
                'epoch': epoch,
                'state_dict': model_d.state_dict(),
                'best_acc': best_acc,
                'optimizer': optim_d.state_dict(),
            }, False, best_acc, checkpoint=checkpoint_path)

        g_out = model_g(fixed_noise, fixed_labels).data.view(SAMPLE_SIZE, 3, 32, 32).cpu()
        save_image(
            g_out, '{}/fake_samples_{}.png'.format(
                checkpoint_path, epoch
            )
        )

    # save the last
    save_checkpoint_d({
        'epoch': epoch,
        'state_dict': model_d.state_dict(),
        'best_acc': best_acc,
        'optimizer': optim_d.state_dict(),
    }, False, best_acc, checkpoint=checkpoint_path)

    save_checkpoint_g({
        'epoch': epoch,
        'state_dict': model_g.state_dict(),
        'best_acc': best_acc,
        'optimizer': optim_g.state_dict(),
    }, False, best_acc, checkpoint=checkpoint_path)

    g_out = model_g(fixed_noise, fixed_labels).data.view(SAMPLE_SIZE, 3, 32, 32).cpu()
    save_image(
        g_out, '{}/fake_samples.png'.format(
            checkpoint_path
        )
    )


if __name__ == '__main__':
    main()
