import argparse
import importlib
import json
import math
import os
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import torch.utils.data
from datasets.datasets_pro import get_dataloader
import numpy as np

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='generator_pro_wo_kp')
parser.add_argument('--dis_model', type=str, default='discriminator_pro')
parser.add_argument('--z_dim', type=int, default=256)
parser.add_argument('--tau', type=float, default=0.01)
parser.add_argument('--n_keypoints', type=int, default=6)
parser.add_argument('--lr_gen', type=float, default=1e-4)
parser.add_argument('--lr_disc', type=float, default=4e-4)
parser.add_argument('--disc_iters', type=int, default=1)  # number of updates to discriminator for every update to generator
parser.add_argument('--num_workers', type=int, default=6)
parser.add_argument('--data_root', type=str, default='../data/celebaHQ')
parser.add_argument('--class_name', type=str, default='celebaHQ')
parser.add_argument('--n_embedding', type=int, default=64)
parser.add_argument('--checkpoint', type=int, default=0)
args = parser.parse_args()

args.log = '{0}_{1}_k{2}_tau{3}_{4}_{5}'.format(args.model, args.dis_model, args.n_keypoints, args.tau, args.class_name, args.n_embedding)
args.log = os.path.join('log', args.log)

os.makedirs(args.log, exist_ok=True)
if args.checkpoint == 0:
    with open(os.path.join(args.log, 'parameters.json'), 'wt') as f:
        json.dump(args.__dict__, f, indent=2)
else:
    with open(os.path.join(args.log, 'parameters.json'), 'rt') as f:
        t_args = argparse.Namespace()
        old_para = json.load(f)
        old_para.update(args.__dict__)
        t_args.__dict__.update(old_para)
        args = parser.parse_args(namespace=t_args)

batch_sizes = [128, 64, 24, 8]
image_sizes = [64, 128, 256, 512]

device = 'cuda:0'
device = torch.device(device) if torch.cuda.is_available() else torch.device('cpu')
model = importlib.import_module('models.' + args.model)
generator = model.Generator({'z_dim': args.z_dim, 'n_keypoints': args.n_keypoints,
                            'n_embedding': args.n_embedding, 'tau': args.tau}).to(device)
dis_model = importlib.import_module('models.' + args.dis_model)
discriminator = dis_model.Discriminator({}).to(device)
optim_disc = torch.optim.Adam(discriminator.parameters(), lr=args.lr_disc, betas=(0.5, 0.9))
optim_gen = torch.optim.Adam(filter(lambda p: p.requires_grad, generator.parameters()), lr=args.lr_gen, betas=(0.5, 0.9))

generator = torch.nn.DataParallel(generator)
discriminator = torch.nn.DataParallel(discriminator)


def gradient_penalty(images, output, weight=10):
    batch_size = images.shape[0]
    gradients = torch.autograd.grad(outputs=output, inputs=images,
                           grad_outputs=torch.ones(output.size(), device=images.device),
                           create_graph=True, retain_graph=True, only_inputs=True)[0]

    gradients = gradients.reshape(batch_size, -1)
    return weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean()


def freeze_grad(freeze):
    for name, layer in generator.module.named_children():
        if name in ['gen_keypoints_embedding_noise', 'gen_keypoints_layer', 'gen_background_embedding']:
            for parameter in layer.parameters():
                parameter.requires_grad = not freeze


def train_one_epoch():
    discriminator.train()
    generator.train()
    total_disc_loss = 0
    total_gen_loss = 0
    image_size = image_sizes[stage]

    data_loader = get_dataloader(args.data_root, image_size=image_size, class_name=args.class_name,
                                batch_size=batch_sizes[stage], num_workers=args.num_workers, pin_memory=True, drop_last=True)

    freeze_grad(freeze=freeze)

    for batch_index, real_batch in enumerate(data_loader):
        optim_disc.zero_grad()
        optim_gen.zero_grad()

        # update discriminator
        real_batch = {'img': real_batch['img'].to(device), 'stage': stage, 'alpha': alpha}
        real_batch['img'].requires_grad_()
        input_batch = {'input_noise{}'.format(noise_i): torch.randn(batch_sizes[stage], *noise_shape).to(device)
                       for noise_i, noise_shape in enumerate(generator.module.noise_shapes)}
        input_batch['stage'] = stage
        input_batch['alpha'] = alpha
        fake_batch = generator(input_batch, stage, alpha)
        d_real_out = discriminator(real_batch, stage, alpha)
        d_fake_out = discriminator(fake_batch, stage, alpha)
        disc_loss = F.softplus(d_fake_out).mean() + F.softplus(-d_real_out).mean() + gradient_penalty(real_batch['img'], d_real_out)
        disc_loss.backward()
        # print(disc_loss)
        total_disc_loss += disc_loss.item()
        optim_disc.step()

        # update generator
        if batch_index % args.disc_iters == 0:
            optim_disc.zero_grad()
            optim_gen.zero_grad()
            input_batch = {'input_noise{}'.format(noise_i): torch.randn(batch_sizes[stage], *noise_shape).to(device)
                           for noise_i, noise_shape in enumerate(generator.module.noise_shapes)}
            fake_batch = generator(input_batch, stage, alpha, requires_penalty=True)
            d_fake_out = discriminator(fake_batch, stage, alpha)
            gen_loss = F.softplus(-d_fake_out).mean()
            if 'penalty_on_keypoints' in fake_batch.keys():
                gen_loss = gen_loss + fake_batch['penalty_on_keypoints'].mean()
            if 'penalty_on_face' in fake_batch.keys():
                gen_loss = gen_loss + fake_batch['penalty_on_face'].mean()
            if 'penalty_on_bg' in fake_batch.keys():
                gen_loss = gen_loss + fake_batch['penalty_on_bg'].mean()
            if 'plr' in fake_batch.keys():
                gen_loss = gen_loss + fake_batch['plr'].mean()
            gen_loss.backward()
            total_gen_loss += gen_loss.item()
            optim_gen.step()

        if batch_index > 10000:
            break

    return total_disc_loss / args.disc_iters / len(data_loader) / 2, total_gen_loss / len(data_loader)


def evaluate(test_input_batch):
    eval_dir = os.path.join(args.log, 'eval')
    os.makedirs(eval_dir, exist_ok=True)

    generator.eval()
    with torch.no_grad():
        samples = generator(test_input_batch, stage, alpha)['img'].cpu().numpy()[:64].transpose((0, 2, 3, 1)) * 0.5 + 0.5
        samples = np.clip(samples, 0, 1)

    fig = plt.figure(figsize=(8, 8))
    gs = gridspec.GridSpec(8, 8)
    gs.update(wspace=0.1, hspace=0.1)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample)

    plt.savefig(os.path.join(eval_dir, '{}.png'.format(epoch)), bbox_inches='tight')
    plt.close(fig)

    with torch.no_grad():
        image_size = samples.shape[-2]
        keypoints = generator.module.gen_keypoints(test_input_batch).cpu().numpy()[:64] * (image_size / 2 - 0.5) + (image_size / 2 - 0.5)

    fig = plt.figure(figsize=(8, 8))
    gs = gridspec.GridSpec(8, 8)
    gs.update(wspace=0.1, hspace=0.1)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample)
        plt.scatter(keypoints[i, :, 1], keypoints[i, :, 0], c=list(range(args.n_keypoints)), s=20, marker='+')

    plt.savefig(os.path.join(eval_dir, '{}_keypoints.png'.format(epoch)), bbox_inches='tight')
    plt.close(fig)


if __name__ == '__main__':
    writer = SummaryWriter(os.path.join(args.log, 'runs'))
    checkpoint_dir = os.path.join(args.log, 'checkpoints')
    os.makedirs(checkpoint_dir, exist_ok=True)

    if args.checkpoint != 0:
        checkpoint = torch.load(os.path.join(args.log, 'checkpoints', 'epoch_{}.model'.format(args.checkpoint)),
                                map_location=lambda storage, location: storage)
        generator.module.load_state_dict(checkpoint['generator'])
        discriminator.module.load_state_dict(checkpoint['discriminator'])
        optim_gen.load_state_dict(checkpoint['optim_gen'])
        optim_disc.load_state_dict(checkpoint['optim_disc'])
        args.checkpoint += 1

    n_epochs = [100, 100, 100, 100]

    stage_list = np.concatenate([np.ones(n_epochs[i]) * i for i in range(len(n_epochs))]).astype(np.int)

    offset = 50

    alpha_list = []
    for i in range(len(n_epochs)):
        alpha_list.append(np.linspace(0, 1, offset))
        alpha_list.append(np.linspace(1, 1, n_epochs[i]-offset))
    alpha_list = np.concatenate(alpha_list)

    freeze_list = [False] * n_epochs[0]
    for i in range(1, len(n_epochs)):
        freeze_list += [True] * offset
        freeze_list += [False] * (n_epochs[i] - offset)
    freeze_list = np.array(freeze_list).reshape(-1)

    for epoch in range(args.checkpoint, 999):

        if epoch < len(stage_list):
            stage = stage_list[epoch]
            alpha = alpha_list[epoch]
            freeze = freeze_list[epoch]
        else:
            stage = stage_list[-1]
            alpha = 1
            freeze = False

        disc_loss, gen_loss = train_one_epoch()
        writer.add_scalars('loss', {'disc_loss': disc_loss,
                                    'gen_loss': gen_loss}, epoch + 1)
        test_input_batch = {'input_noise{}'.format(noise_i): torch.randn(batch_sizes[stage]*2, *noise_shape).to(device)
                            for noise_i, noise_shape in enumerate(generator.module.noise_shapes)}
        evaluate(test_input_batch)
        if (epoch + 1) % 1 == 0:
            torch.save(
                {
                    'generator': generator.module.state_dict(),
                    'discriminator': discriminator.module.state_dict(),
                    'optim_gen': optim_gen.state_dict(),
                    'optim_disc': optim_disc.state_dict(),
                },
                os.path.join(checkpoint_dir, 'epoch_{}.model'.format(epoch))
            )
