import argparse
import os
import math
import shutil
import random
import distutils.util
import numpy as np
import pandas as pd
import sys

import torchvision
import yaml
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.optim as optim
from torch.autograd import Variable
from torchvision.utils import save_image

# from cifar100.MemGuard.memguard_run import model

# config_file = './../../env.yml'
config_file = './env.yml'
with open(config_file, 'r') as stream:
    yamlfile = yaml.safe_load(stream)
    root_dir = yamlfile['root_dir']
    src_dir = yamlfile['src_dir']

sys.path.append(src_dir)
sys.path.append(os.path.join(src_dir, 'attack'))
sys.path.append(os.path.join(src_dir, 'models'))
from attack.dsq_attack import system_attack
from utils import mkdir_p, AverageMeter, accuracy, print_acc_conf, TrainRecorder
from cifar_utils import transform_train, transform_train_aug, transform_test, Cifardata, DistillCifardata, WarmUpLR, \
    ModelwNorm
from cifar100.models.model_selector import get_network
from cifar100.generative_models.pcgtacgan import Discriminator, Generator, G_D
from cifar100.generative_models.TAC import losses

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def toggle_grad(model, on_or_off):
    for param in model.parameters():
        param.requires_grad = on_or_off


class Distribution(torch.Tensor):
    # Init the params of the distribution
    def init_distribution(self, dist_type, **kwargs):
        self.dist_type = dist_type
        self.dist_kwargs = kwargs
        if self.dist_type == 'normal':
            self.mean, self.var = kwargs['mean'], kwargs['var']
        elif self.dist_type == 'categorical':
            self.num_categories = kwargs['num_categories']

    def sample_(self):
        if self.dist_type == 'normal':
            self.normal_(self.mean, self.var)
        elif self.dist_type == 'categorical':
            self.random_(0, self.num_categories)
            # return self.variable

    # Silly hack: overwrite the to() method to wrap the new object
    # in a distribution as well
    def to(self, *args, **kwargs):
        new_obj = Distribution(self)
        new_obj.init_distribution(self.dist_type, **self.dist_kwargs)
        new_obj.data = super().to(*args, **kwargs)
        return new_obj


def prepare_z_y(G_batch_size, dim_z, nclasses, device='cuda', z_var=1.0):
    z_ = Distribution(torch.randn(G_batch_size, dim_z, requires_grad=False))
    z_.init_distribution('normal', mean=0, var=z_var)
    z_ = z_.to(device, torch.float32)

    y_ = Distribution(torch.zeros(G_batch_size, requires_grad=False))
    y_.init_distribution('categorical', num_categories=nclasses)
    y_ = y_.to(device, torch.long)
    return z_, y_


def train_g(netd_g, netg, netc, dataset, step, opt):
    if not os.path.isdir(opt.checkpoint_path):
        mkdir_p(opt.checkpoint_path)

    noise, fake_label = prepare_z_y(G_batch_size=opt.batch_size, dim_z=opt.nz, nclasses=opt.num_class)

    G_D_net = G_D(G=netg, D=netd_g, C=netc)
    train = GAN_training_function(G=netg, D=netd_g, C=netc, GD=G_D_net, z_=noise, y_=fake_label, config=opt)

    data_loader = sample_data(dataset, opt)

    # pbar = tqdm(enumerate(loader), dynamic_ncols=True)
    pbar = tqdm(range(opt.iter), dynamic_ncols=True)

    for _ in pbar:

        image_c, label = next(data_loader)

        netg.train()
        netd_g.train()
        image_c = image_c.cuda()
        label = label.cuda()

        metrics = train(image_c, label)

        step = step + 1

        G_loss = metrics['G_loss']
        D_loss_real = metrics['D_loss_real']
        D_loss_fake = metrics['D_loss_fake']
        C_loss = metrics['C_loss']

        pbar.set_description(
            (', '.join(['itr: %d' % step]
                       + ['%s : %+4.3f' % (key, metrics[key])
                          for key in metrics]))
        )

        print('Epoch: [{:d} | {:d}], G_loss = {:.4f}, D_loss_real = {:.4f}, D_loss_fake = {:.4f}, C_loss = {:.4f}'.format(
            step, opt.iter, float(G_loss), float(D_loss_real), float(D_loss_fake), float(C_loss)
        ))
        if step % 250 == 0:
            test(netg, step, opt)
        if step % 1000 == 0:
            #torch.save({'G': netg.module.state_dict(),
            #            'D': netd_g.module.state_dict(),
            #            'step': step}, os.path.join(opt.savingroot, opt.dataset, f'chkpts/g_{step:03d}.pth'))
            save_checkpoint_d({
                'epoch': step,
                'state_dict': netd_g.state_dict(),
                'best_acc': 0,
                # 'optimizer': optim_d.state_dict(),
            }, False, 0, checkpoint=opt.checkpoint_path)

            save_checkpoint_g({
                'epoch': step,
                'state_dict': netg.state_dict(),
                # 'optimizer': optim_g.state_dict(),
            }, False, 0, checkpoint=opt.checkpoint_path)
            # print(opt.C_w)
            #
            # test_ac(netg,netd_c)

    #######################
    # save image pre epoch
    #######################

    return step


def sample_data(dataset, opt):
    loader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size * opt.num_D_steps, shuffle=True,
                                         num_workers=opt.num_worker, drop_last=True, pin_memory=True)
    # print(len(loader))
    loader = iter(loader)

    while True:
        try:
            yield next(loader)

        except StopIteration:
            loader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size * opt.num_D_steps, shuffle=True,
                                                 num_workers=opt.num_worker, drop_last=True, pin_memory=True)
            loader = iter(loader)
            yield next(loader)


def GAN_training_function(G, D, C, GD, z_, y_, config):
    def train(x, y):
        # G.module.optim.zero_grad()
        # D.module.optim.zero_grad()
        G.optim.zero_grad()
        D.optim.zero_grad()
        # How many chunks to split x and y into?
        x = torch.split(x, config.batch_size)
        y = torch.split(y, config.batch_size)
        counter = 0

        # Optionally toggle D and G's "require_grad"

        toggle_grad(D, True)
        toggle_grad(G, False)
        toggle_grad(C, False)

        for step_index in range(config.num_D_steps):
            z_.sample_()
            y_.sample_()
            D_fake, D_real, mi, c_cls, pc_cls = GD(
                z_[:config.batch_size], y_[:config.batch_size],
                x[counter], y[counter], train_G=False
            )

            temperature = 1.0

            with torch.no_grad():
                pc_cls = torch.softmax(pc_cls, dim=1)
            mi_soft_labels = torch.softmax(mi / temperature, dim=1)
            c_cls_soft_labels = torch.softmax(c_cls / temperature, dim=1)

            D_loss_real, D_loss_fake = losses.discriminator_loss(D_fake, D_real)
            D_loss = (D_loss_real + D_loss_fake) + \
                     config.C_w * (
                             F.cross_entropy(c_cls[D_fake.shape[0]:], y[counter]) +
                             F.cross_entropy(mi[:D_fake.shape[0]], y_)
                     )
            #D_loss = (D_loss_real + D_loss_fake) + \
            #         config.C_w * ((1 - config.C_kd) * (
            #        F.cross_entropy(c_cls[D_fake.shape[0]:], y[counter]) +
            #        F.cross_entropy(mi[:D_fake.shape[0]], y_)) +
            #                       config.C_kd * F.kl_div(torch.log(c_cls_soft_labels[D_fake.shape[0]:]),
            #                                              pc_cls[D_fake.shape[0]:]) +
            #                       config.C_kd * F.kl_div(torch.log(mi_soft_labels[:D_fake.shape[0]]),
            #                                              pc_cls[D_fake.shape[0]:])
            #                       )
            (D_loss).backward()
            counter += 1
            #D.module.optim.step()
            D.optim.step()

        # Optionally toggle "requires_grad"
        toggle_grad(D, False)
        toggle_grad(G, True)
        toggle_grad(C, False)

        # Zero G's gradients by default before training G, for safety
        #G.module.optim.zero_grad()
        G.optim.zero_grad()

        for step_index in range(config.num_G_steps):
            z_.sample_()
            y_.sample_()
            D_fake, mi, c_cls, pc_cls = GD(z_[:config.batch_size], y_[:config.batch_size], train_G=True)  # D(fake_img, y_)
            G_loss = losses.generator_loss(D_fake)

            C_loss = 0
            MI_loss = 0

            temperature = 1.0
            with torch.no_grad():
                pc_cls = torch.softmax(pc_cls, dim=1)
            mi_soft_labels = torch.softmax(mi / temperature, dim=1)
            c_cls_soft_labels = torch.softmax(c_cls / temperature, dim=1)

            MI_loss = ( 1-config.C_kd ) * F.cross_entropy(mi, y_) + config.C_kd * F.kl_div(torch.log(mi_soft_labels), pc_cls)
            C_loss = ( 1-config.C_kd ) * F.cross_entropy(c_cls, y_) + config.C_kd * F.kl_div(torch.log(c_cls_soft_labels), pc_cls)

            (G_loss - config.C_w * MI_loss + config.C_w * C_loss).backward()

        #G.module.optim.step()
        G.optim.step()

        out = {'G_loss': G_loss,
               'D_loss_real': D_loss_real,
               'D_loss_fake': D_loss_fake,
               'C_loss': C_loss,
               'MI_loss': MI_loss}
        # Return G's loss and the components of D's loss.
        return out

    return train


def denorm(x):
    return (x +1)/2


def test(netg,step,opt):
    netg.eval()
    toggle_grad(netg,False)

    for i in range(opt.num_class):
        fixed = torch.randn(10, opt.nz).cuda()
        label = torch.ones(10).long().cuda()*i
        if i == 0:
            fixed_input = netg(fixed,label)
        else:
            fixed_input = torch.cat([fixed_input, netg(fixed,label)],dim=0)

    save_image(denorm(fixed_input.data), os.path.join(opt.checkpoint_path, f'fake_samples_{step:03d}.jpg'), nrow=10)
    #g_out = model_g(fixed_noise).data.view(SAMPLE_SIZE, 3, 32, 32).cpu()
    #save_image(
    #    g_out, '{}/fake_samples.png'.format(
    #        checkpoint_path
    #    )
    #)
    toggle_grad(netg, True)


def save_checkpoint_g(state, is_best, acc, checkpoint):
    if not os.path.isdir(checkpoint):
        mkdir_p(checkpoint)
    filepath = os.path.join(checkpoint, 'model_g_last.pth.tar')
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(checkpoint, 'model_g_best.pth.tar'))


def save_checkpoint_d(state, is_best, acc, checkpoint):
    if not os.path.isdir(checkpoint):
        mkdir_p(checkpoint)
    filepath = os.path.join(checkpoint, 'model_d_last.pth.tar')
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(checkpoint, 'model_d_best.pth.tar'))


def get_learning_rate(optimizer):
    lr = []
    for param_group in optimizer.param_groups:
        lr += [param_group['lr']]
    return lr


def get_opt_and_lrsch(args, model_d, model_g, num_epoch, num_iter, warmup):
    optim_d = optim.Adam(model_d.parameters(), lr=0.00001, betas=(0.5, 0.999))
    optim_g = optim.Adam(model_g.parameters(), lr=0.00001, betas=(0.5, 0.999))
    train_scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
        optim_d, gamma=1., last_epoch=-1  # , last_epoch=classifier_epochs
    )
    train_scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
        optim_g, gamma=1., last_epoch=-1  # , last_epoch=classifier_epochs
    )
    return optim_d, optim_g, train_scheduler_d, train_scheduler_g


def main():
    parser = argparse.ArgumentParser(description='setting for cifar100')
    parser.add_argument('--model', type=str, default='mobilenetv3_small_50')
    parser.add_argument('--classifier_epochs', type=int, default=200, help='classifier epochs')
    # parser.add_argument('--attack_epochs', type=int, default=150, help='attack epochs in NN attack')
    parser.add_argument('--print_epoch', type=int, default=5,
                        help='print model training stats per print_epoch_splitai during splitai training')
    parser.add_argument('--batch_size', type=int, default=256, help='batch size')
    parser.add_argument('--batch_step', type=int, default=1, help='batch accumulation steps')
    # parser.add_argument('--warmup', type=int, default=1, help='warm up epochs')
    parser.add_argument('--num_worker', type=int, default=1, help='number workers')
    parser.add_argument('--num_class', type=int, default=100, help='num class')
    parser.add_argument('--data_aug', type=distutils.util.strtobool, default=True, help='turn on data augmentation')
    parser.add_argument('--run_idx', type=int, default=100, help='idx running')
    parser.add_argument('--save_path', default='save_checkpoints/', type=str, help='folder to save the checkpoints')
    parser.add_argument('--load_path', default='save_checkpoints/', type=str, help='folder to load the checkpoints')

    parser.add_argument('--nz', type=int, default=128, help='Number of dimensions for input noise.')
    parser.add_argument('--iter', default=60000, type=int, help='num epoches, suggest MNIST: 20, CIFAR10: 500')
    parser.add_argument('--num_D_steps', default=2, type=int, help='num_D_steps.')
    parser.add_argument('--num_G_steps', default=1, type=int, help='num_G_steps.')

    parser.add_argument('--C_w', default=1.0, type=float, help='weight of classifier')
    parser.add_argument('--C_kd', default=0.5, type=float, help='weight of KD')

    args = parser.parse_args()
    print(dict(args._get_kwargs()))

    # attack_epochs = args.attack_epochs
    batch_size = args.batch_size
    num_class = args.num_class
    classifier_epochs = args.classifier_epochs
    print_epoch = args.print_epoch
    # warmup = args.warmup
    num_worker = args.num_worker

    DATASET_PATH = os.path.join(root_dir, 'cifar100', 'data')
    checkpoint_path = os.path.join(args.save_path, 'cifar100', 'pcgtacgan',
                                   'aug' if args.data_aug else 'no_aug', str(args.run_idx))
    pretrain_model_path = os.path.join(
        args.load_path, 'cifar100', args.model, 'undefend', 'aug', str(args.run_idx)
    )
    print(checkpoint_path)
    args.checkpoint_path = checkpoint_path

    train_data_tr_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'tr_data.npy'))
    train_label_tr_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'tr_label.npy'))
    train_data_te_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'te_data.npy'))
    train_label_te_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'te_label.npy'))
    train_data = np.load(os.path.join(DATASET_PATH, 'partition', 'train_data.npy'))
    train_label = np.load(os.path.join(DATASET_PATH, 'partition', 'train_label.npy'))
    test_data = np.load(os.path.join(DATASET_PATH, 'partition', 'test_data.npy'))
    test_label = np.load(os.path.join(DATASET_PATH, 'partition', 'test_label.npy'))
    ref_data = np.load(os.path.join(DATASET_PATH, 'partition', 'ref_data.npy'))
    ref_label = np.load(os.path.join(DATASET_PATH, 'partition', 'ref_label.npy'))
    all_test_data = np.load(os.path.join(DATASET_PATH, 'partition', 'all_test_data.npy'))
    all_test_label = np.load(os.path.join(DATASET_PATH, 'partition', 'all_test_label.npy'))

    # print first 20 labels for each subset, for checking with other experiments
    print(train_label_tr_attack[:20])
    print(train_label_te_attack[:20])
    print(test_label[:20])
    print(ref_label[:20])

    # if data augmented
    if args.data_aug:
        trainset = Cifardata(train_data, train_label, transform_train_aug)
    else:
        trainset = Cifardata(train_data, train_label, transform_train)
    # load dataset
    # trainset = Cifardata(train_data, train_label, transform_train)
    traintestset = Cifardata(train_data, train_label, transform_test)
    testset = Cifardata(test_data, test_label, transform_test)
    refset = Cifardata(ref_data, ref_label, transform_test)

    trset = Cifardata(train_data_tr_attack, train_label_tr_attack, transform_test)
    teset = Cifardata(train_data_te_attack, train_label_te_attack, transform_test)
    alltestset = Cifardata(all_test_data, all_test_label, transform_test)

    trloader = torch.utils.data.DataLoader(trset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    teloader = torch.utils.data.DataLoader(teset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    alltestloader = torch.utils.data.DataLoader(alltestset, batch_size=batch_size, shuffle=False,
                                                num_workers=num_worker)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_worker)
    traintestloader = torch.utils.data.DataLoader(traintestset, batch_size=batch_size, shuffle=False,
                                                  num_workers=num_worker)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    refloader = torch.utils.data.DataLoader(refset, batch_size=batch_size, shuffle=False, num_workers=num_worker)

    INPUT_SIZE = 32 * 32 * 3
    IMG_SIZE = 32
    NUM_LABELS = num_class
    SAMPLE_SIZE = 64
    nz = args.nz

    best_acc = 0.00

    # model_d = Discriminator(ndf=64, nc=3, nb_label=NUM_LABELS)
    # model_g = Generator(nz=100, ngf=64, nc=3)
    model_d = Discriminator(n_classes=NUM_LABELS, resolution=32, AC=True)
    model_g = Generator(n_classes=NUM_LABELS, resolution=32, SN=True)
    model_1 = get_network(arch=args.model, num_classes=100)
    model_c = ModelwNorm(model_1)
    resume = f'{pretrain_model_path}/model_last.pth.tar'
    checkpoint = torch.load(resume, map_location='cpu')
    model_c.load_state_dict(checkpoint['state_dict'])

    iter_per_epoch = len(trainloader)

    model_d.to(device, torch.float)
    model_g.to(device, torch.float)
    model_c.to(device, torch.float)

    step = 0

    train_g(model_d, model_g, model_c, trainset, step, args)

    epoch = args.iter
    # save the last
    # save_checkpoint_d({
    #     'epoch': epoch,
    #     'state_dict': model_d.state_dict(),
    #     'best_acc': best_acc,
    #     # 'optimizer': optim_d.state_dict(),
    # }, False, best_acc, checkpoint=checkpoint_path)

    # save_checkpoint_g({
    #     'epoch': epoch,
    #     'state_dict': model_g.state_dict(),
    #     'best_acc': best_acc,
    #     # 'optimizer': optim_g.state_dict(),
    # }, False, best_acc, checkpoint=checkpoint_path)

    # g_out = model_g(fixed_noise).data.view(SAMPLE_SIZE, 3, 32, 32).cpu()
    # save_image(
    #     g_out, '{}/fake_samples.png'.format(
    #         checkpoint_path
    #     )
    # )


if __name__ == '__main__':
    main()
