# Misc
import argparse
import os
import os.path as osp
import numpy as np
import matplotlib.pyplot as plt

# PyTorch
import torch
import torch.optim as optim

# Sibling Modules
from vae import VAE
from data_generator import DataGenerator

def main():
    parser = argparse.ArgumentParser(description='Train a Variational Autoencoder')
    parser.add_argument('--gpu_id', default=0, type=int,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--output_dir', '-o', default='result_vae/',
                        help='Directory to output the result')
    parser.add_argument('--epochs', '-e', default=100, type=int,
                        help='Number of epochs')
    parser.add_argument('--latent_n', '-z', default=8, type=int,
                        help='Dimention of encoded vector')
    parser.add_argument('--batch_size', '-batch', type=int, default=8,
                        help='Learning minibatch size')
    parser.add_argument('--beta', '-b', type=float, default=0.1,
                        help='Beta coefficient for the KL loss')
    parser.add_argument('--gamma', '-g', type=float, default=10,
                        help='Gamma coefficient for the classification loss')
    parser.add_argument('--alpha', '-a', type=float, default=1, 
                        help='Alpha coefficient for the reconstruction loss')
    parser.add_argument('--augment_counter', type=int, default=10)

    args = parser.parse_args()

    if not osp.isdir(osp.join(args.output_dir)):
        os.makedirs(args.output_dir)

    print('\n###############################################')
    print('# GPU: \t\t\t{}'.format(args.gpu_id))
    print('# dim z: \t\t{}'.format(args.latent_n))
    print('# Minibatch-size: \t{}'.format(args.batch_size))
    print('# Epochs: \t\t{}'.format(args.epochs))
    print('# Alpha: \t\t{}'.format(args.alpha))
    print('# Beta: \t\t{}'.format(args.beta))
    print('# Gamma: \t\t{}'.format(args.gamma))
    print('# Augment Counter: \t{}'.format(args.augment_counter))
    print('# Out Folder: \t\t{}'.format(args.output_dir))
    print('###############################################\n')

    data_generator = DataGenerator(args.batch_size, data_split=0.8, augment_counter=args.augment_counter, 
                                   plot=False, DIRS=["data/"])
    trainloader, testloader, train_traj, test_traj, train_images, test_images, \
    train_cube_pos, test_cube_pos, train_labels, test_labels = data_generator.generate_data()

    print("Max Len: {0}, Max Val: {1}, Min Val: {2}".format(data_generator.max_len, np.max(train_traj), np.min(train_traj)))
    print("# train datapoints: {0}".format(len(trainloader) * args.batch_size))
    print("# test datapoints: {0}".format(len(testloader) * args.batch_size))


    ################################
    ########### TRAINING ###########
    ################################
    
    stats = {'train_loss': [], 'train_rec_loss': [], 'train_kl_loss': [], 'train_label_loss': [], 
             'train_label_acc': [],
             'valid_loss': [], 'valid_rec_loss': [], 'valid_kl_loss': [], 'valid_label_loss' : [],
             'valid_label_acc': []}

    device = torch.device("cuda:0")
    traj_ch = 14 # 7 joints * 2 (joint angles and efforts)
    net = VAE(alpha=args.alpha, beta=args.beta, gamma=args.gamma, latent_n=args.latent_n, 
              groups=data_generator.groups, traj_len=data_generator.traj_len, traj_ch=traj_ch, device=device)
    
    net.to(device)
    
    criterion = net.get_loss()
    optimizer = optim.Adam(net.parameters())

    # reset stats
    running_train_loss = running_test_loss = 0.0
    running_train_rec_loss = running_test_rec_loss = 0.0
    running_train_label_loss = running_test_label_loss = 0.0
    running_train_kld_loss = running_test_kld_loss = 0.0
    correct_train = total_train = 0
    correct_test = total_test = 0
    
    train_normalizer = len(trainloader)
    test_normalizer = len(testloader)

    for epoch in range(args.epochs):  # loop over the dataset multiple times

        # TRAIN
        for batch in trainloader:
            
            traj_in, img_in, cube_pos_in, label_in = batch[0].to(device), batch[1].to(device), batch[2].to(device), batch[3].to(device)
            img_in = img_in[:, 0, :, :, :]
            
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            traj_out, label_out, mus, logvars = net(traj_in, img_in)

            loss, rec, label, kld = criterion(traj_in, traj_out, label_in, label_out, mus, logvars)
            loss.backward()
            optimizer.step()

            running_train_loss += loss.item() / train_normalizer
            running_train_rec_loss += rec.item() / train_normalizer
            running_train_kld_loss += kld.item() / train_normalizer
            running_train_label_loss += label.item() / train_normalizer
            
            for i in range(net.groups_n):
                _, predicted = torch.max(label_out[i].data, 1)
                mask = (label_in[:, i] != 100).cpu().detach().numpy().astype(np.float32)
                total_train += np.sum(mask)
                correct_train += np.sum((predicted == label_in[:, i]).cpu().detach().numpy() * mask)

        # EVAL
        with torch.no_grad():
            for batch in testloader:
                traj_in, img_in, cube_pos_in, label_in = batch[0].to(device), batch[1].to(device), batch[2].to(device), batch[3].to(device)
                img_in = img_in[:, 0, :, :, :]
                
                traj_out, label_out, mus, logvars = net(traj_in, img_in)

                loss, rec, label, kld = criterion(traj_in, traj_out, label_in, label_out, mus, logvars)

                running_test_loss += loss.item() / test_normalizer
                running_test_rec_loss += rec.item() / test_normalizer
                running_test_kld_loss += kld.item() / test_normalizer
                running_test_label_loss += label.item() / test_normalizer
                
                for i in range(net.groups_n):
                    _, predicted = torch.max(label_out[i].data, 1)
                    mask = (label_in[:, i] != 100).cpu().detach().numpy().astype(np.float32)
                    total_test += np.sum(mask)
                    correct_test += np.sum((predicted == label_in[:, i]).cpu().detach().numpy() * mask)
        
        # print statistics
        stats['train_loss'].append(running_train_loss)
        stats['train_rec_loss'].append(running_train_rec_loss)
        stats['train_kl_loss'].append(running_train_kld_loss)
        stats['train_label_loss'].append(running_train_label_loss)
        
        stats['valid_loss'].append(running_test_loss)
        stats['valid_rec_loss'].append(running_test_rec_loss)
        stats['valid_kl_loss'].append(running_test_kld_loss)
        stats['valid_label_loss'].append(running_test_label_loss)
        
        stats['train_label_acc'].append(100 * correct_train / total_train)
        stats['valid_label_acc'].append(100 * correct_test / total_test)
        
        print(("Ep: {0}\t" + \
               "T_L: {1}\t" + \
               "V_L: {2}\tT_RL: {6}\tV_RL: {3}\tT_LL: {7}\tV_LL: {4}\tT_KL: {8}\tV_KL: {5}\t" + \
               "T_A: {9} \t V_A: {10}").format(epoch, 
                                            round(stats['train_loss'][-1], 1),
                                            round(stats['valid_loss'][-1], 1),
                                            round(stats['valid_rec_loss'][-1], 5),
                                            round(stats['valid_label_loss'][-1], 2),
                                            round(stats['valid_kl_loss'][-1], 2),
                                            round(stats['train_rec_loss'][-1], 5),
                                            round(stats['train_label_loss'][-1], 2),
                                            round(stats['train_kl_loss'][-1], 2),
                                            int(stats['train_label_acc'][-1]),
                                            int(stats['valid_label_acc'][-1])))

        # reset stats
        running_train_loss = running_test_loss = 0.0
        running_train_rec_loss = running_test_rec_loss = 0.0
        running_train_label_loss = running_test_label_loss = 0.0
        running_train_kld_loss = running_test_kld_loss = 0.0
        correct_train = total_train = 0
        correct_test = total_test = 0

    print('Finished Training. Model Saved.')
    torch.save(net.state_dict(), osp.join(args.output_dir, "model_vae"))
    
    # PLOT LOSS
    fig = plt.figure(figsize=(12, 4))

    ax = fig.add_subplot(1, 4, 1)   
    ax.plot(range(args.epochs), stats['train_rec_loss'], label="rec")   
    ax.plot(range(args.epochs), stats['valid_rec_loss'], label="v_rec")
    ax.grid()
    ax.legend()
    
    ax = fig.add_subplot(1, 4, 2)   
    ax.plot(range(args.epochs), stats['train_kl_loss'], label="kl")   
    ax.plot(range(args.epochs), stats['valid_kl_loss'], label="v_kl")
    ax.grid()
    ax.legend()
    
    ax = fig.add_subplot(1, 4, 3)   
    ax.plot(range(args.epochs), stats['train_label_loss'], label="label")   
    ax.plot(range(args.epochs), stats['valid_label_loss'], label="v_label")
    ax.grid()
    ax.legend()

    ax = fig.add_subplot(1, 4, 4)
    ax.plot(range(args.epochs), stats['train_label_acc'], label="acc")   
    ax.plot(range(args.epochs), stats['valid_label_acc'], label="v_acc")
    ax.grid()
    ax.legend()
       
    fig.subplots_adjust(right=0.99, left=0.05, top=0.95)
    fig.savefig(osp.join(args.output_dir, "loss"))
                

if __name__ == '__main__':
    main()
