# Misc
import argparse
import os
import os.path as osp
import numpy as np
import matplotlib.pyplot as plt

# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

# Sibling Modules
from aae import Encoder, Decoder, Discriminator
from data_generator import DataGenerator

def main():
    parser = argparse.ArgumentParser(description='Train Adversarial Autoencoder')
    parser.add_argument('--gpu_id', default=0, type=int,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--output_dir', '-o', default='result_aae/',
                        help='Directory to output the result')
    parser.add_argument('--epochs', '-e', default=100, type=int,
                        help='Number of epochs')
    parser.add_argument('--latent_n', '-z', default=8, type=int,
                        help='Dimention of encoded vector')
    parser.add_argument('--batch_size', '-batch', type=int, default=8,
                        help='Learning minibatch size')
    parser.add_argument('--beta', '-b', type=float, default=0,
                        help='Beta coefficient for the KL loss')
    parser.add_argument('--gamma', '-g', type=float, default=10,
                        help='Gamma coefficient for the classification loss')
    parser.add_argument('--alpha', '-a', type=float, default=1, 
                        help='Alpha coefficient for the reconstruction loss')
    parser.add_argument('--augment_counter', type=int, default=10)

    args = parser.parse_args()

    if not osp.isdir(osp.join(args.output_dir)):
        os.makedirs(args.output_dir)

    print('\n###############################################')
    print('# GPU: \t\t\t{}'.format(args.gpu_id))
    print('# dim z: \t\t{}'.format(args.latent_n))
    print('# Minibatch-size: \t{}'.format(args.batch_size))
    print('# Epochs: \t\t{}'.format(args.epochs))
    print('# Alpha: \t\t{}'.format(args.alpha))
    print('# Beta: \t\t{}'.format(args.beta))
    print('# Gamma: \t\t{}'.format(args.gamma))
    print('# Augment Counter: \t{}'.format(args.augment_counter))
    print('# Out Folder: \t\t{}'.format(args.output_dir))
    print('###############################################\n')

    data_generator = DataGenerator(args.batch_size, data_split=0.8, augment_counter=args.augment_counter, 
                                   plot=False, DIRS=["data/"])
    trainloader, testloader, train_traj, test_traj, train_images, test_images, \
    train_cube_pos, test_cube_pos, train_labels, test_labels = data_generator.generate_data()

    print("Max Len: {0}, Max Val: {1}, Min Val: {2}".format(data_generator.max_len, np.max(train_traj), np.min(train_traj)))
    print("# train datapoints: {0}".format(len(trainloader) * args.batch_size))
    print("# test datapoints: {0}".format(len(testloader) * args.batch_size))


    ################################
    ########### TRAINING ###########
    ################################
    
    rec_losses = []
    disc_losses = []
    gen_losses = []

    device = torch.device("cuda:0")
    traj_ch = 14 # 7 joints * 2 (joint angles and efforts)
    Q = Encoder(latent_n=args.latent_n, groups=data_generator.groups, traj_len=data_generator.traj_len, traj_ch=traj_ch, device=device)
    P = Decoder(latent_n=args.latent_n, groups=data_generator.groups, traj_len=data_generator.traj_len, traj_ch=traj_ch, device=device)
    D = Discriminator(latent_n=args.latent_n, hidden=64)
    
    Q.to(device)
    P.to(device)
    D.to(device)
    
    # Set learning rates
    gen_lr = 0.0005
    reg_lr = 0.0005

    #encode/decode optimizers
    optim_P = torch.optim.Adam(P.parameters(), lr=gen_lr)
    optim_Q_enc = torch.optim.Adam(Q.parameters(), lr=gen_lr)
    #regularizing optimizers
    optim_Q_gen = torch.optim.Adam(Q.parameters(), lr=reg_lr)
    optim_D = torch.optim.Adam(D.parameters(), lr=reg_lr)
    
    for epoch in range(args.epochs):  # loop over the dataset multiple times
        
        rec_losses.append([])
        disc_losses.append([])
        gen_losses.append([])
        correct = total = 0
        
        # TRAIN
        for batch_idx, batch in enumerate(trainloader):
            
            # get the inputs; data is a list of [inputs, labels]
            traj_in, img_in, cube_pos_in, label_in = batch[0].to(device), batch[1].to(device), batch[2].to(device), batch[3].to(device)
            img_in = img_in[:, 0, :, :, :]
            
            # zero the parameter gradients
            optim_P.zero_grad()
            optim_Q_enc.zero_grad()
            optim_Q_gen.zero_grad()
            optim_D.zero_grad()

            z, label_out, img_embed = Q(traj_in, img_in)   #encode to z
            traj_out = P(z, img_embed) #decode to X reconstruction
            tmp = args.alpha * nn.MSELoss(reduction="none")(traj_out, traj_in)
            recon_loss = torch.mean(torch.sum(tmp.view(tmp.shape[0], -1), dim=-1))
            
            if args.gamma:
                for i in range(len(data_generator.groups)):
                    recon_loss += args.gamma * nn.CrossEntropyLoss(ignore_index=100)(label_out[i], label_in[:, i])
                    
                    _, predicted = torch.max(label_out[i].data, 1)
                    mask = (label_in[:, i] != 100).cpu().detach().numpy().astype(np.float32)
                    total += np.sum(mask)
                    correct += np.sum((predicted == label_in[:, i]).cpu().detach().numpy() * mask)

            recon_loss.backward()
            optim_P.step()
            optim_Q_enc.step()

            # Discriminator
            # true prior is random normal (randn)
            # this is constraining the Z-projection to be normal!
            Q.eval()
            
            if batch_idx % 1 == 0:
                for p in Q.parameters():
                    p.requires_grad = False

                # Uniform prior in the range [-2, 2]
                z_real = Variable(torch.rand(traj_in.size()[0], args.latent_n) * 4 - 2).cuda()
                D_real = D(z_real)

                z_fake, _, _ = Q(traj_in, img_in)
                D_fake = D(z_fake)

                D_loss = (-torch.mean(torch.log(D_real + 1e-15) + torch.log(1 - D_fake + 1e-15)))

                D_loss.backward()
                optim_D.step()

                for p in Q.parameters():
                    p.requires_grad = True
            
            # Generator
            z_fake, _, _ = Q(traj_in, img_in)
            D_fake = D(z_fake)

            G_loss = 1 * (-torch.mean(torch.log(D_fake + 1e-15)))

            G_loss.backward()
            optim_Q_gen.step()
            
            if batch_idx % 10 == 0:
                rec_losses[-1].append(round(recon_loss.item(), 3))
                disc_losses[-1].append(round(D_loss.item(), 3))
                gen_losses[-1].append(round(G_loss.item(), 3))
        
        # Uniform prior in the range [-2, 2]
        real = Variable(torch.rand(traj_in.size()[0], args.latent_n) * 4 - 2).cuda()
        Greal = (-torch.mean(torch.log(D(real) + 1e-15)))
        acc = 0
        if args.gamma:
            acc = 100 * correct / total
        print(("Ep: {0}\tRec loss: {1}\tDiscrm loss: {2}\tGen fake: {3}\t\tGen real: {4}\tAcc: {5}").format(epoch, 
                                                                                                            round(recon_loss.item(), 3), 
                                                                                                            round(D_loss.item(), 3), 
                                                                                                            round(G_loss.item(), 3), 
                                                                                                            round(Greal.item(), 3),
                                                                                                            round(acc, 3)))

    print('Finished Training. Models Saved.')
    torch.save(Q.state_dict(), osp.join(args.output_dir, "encoder_aae"))
    torch.save(P.state_dict(), osp.join(args.output_dir, "decoder_aae"))
    torch.save(D.state_dict(), osp.join(args.output_dir, "discriminator_aae"))
    
    fig = plt.figure(figsize=(9, 4))

    ax = fig.add_subplot(1, 3, 1)   
    ax.plot(range(args.epochs), np.mean(rec_losses, axis=1), label="rec")   
    ax.grid()
    ax.legend()

    ax = fig.add_subplot(1, 3, 2)   
    ax.plot(range(args.epochs), np.mean(disc_losses, axis=1), label="disc")   
    ax.grid()
    ax.legend()

    ax = fig.add_subplot(1, 3, 3)   
    ax.plot(range(args.epochs), np.mean(gen_losses, axis=1), label="gen")   
    ax.grid()
    ax.legend()

    fig.subplots_adjust(right=0.99, left=0.05, top=0.95)
    fig.savefig(osp.join(args.output_dir, "loss"))                

if __name__ == '__main__':
    main()
