'''
Author: 
Email:
Date: 2020-10-12 22:14:36
LastEditTime: 2021-05-30 18:39:56
Description: 
'''

import numpy as np
import argparse
from matplotlib import pyplot as plt

import torch
import torch.utils.data
from torch.autograd import Variable
from utils import CUDA, kaiming_init

from data_loader import VehicleSceneTreeDataset
from model import ConditionalSceneVAE


batch_size = 128
lr = 0.001
z_dim = 32
epochs = 10000
save_iter = 10
print_iter = 1
sample_num = 10
num_workers = 0
position_scale = 40
beta_max = 0.5
gamma = 10.0
epoch_th = 100

model = CUDA(ConditionalSceneVAE(z_dim))
model.apply(kaiming_init)


def test():
    model.load_model('./models/conditional.model.pth')
    z = Variable(CUDA(torch.zeros(1, z_dim)))
    sample_number = 10
    condition_1 = {}
    condition_1['num_lane'] = 2
    condition_1['xywh'] = np.array([-3.7, 0.0, 3.7*2, 30])
    condition_1['direction'] = 0
    condition_2 = {}
    condition_2['num_lane'] = 2
    condition_2['xywh'] = np.array([3.7, 0.0, 3.7*2, 30])
    condition_2['direction'] = 1
    condition_3 = {}
    condition_3['num_lane'] = 2
    condition_3['xywh'] = np.array([-3.7, -30, 3.7*2, 30])
    condition_3['direction'] = 0
    condition_4 = {}
    condition_4['num_lane'] = 1
    condition_4['xywh'] = np.array([3.7*3/2, -30, 3.7, 30])
    condition_4['direction'] = 1
    condition_5 = {}
    condition_5['num_lane'] = 1
    condition_5['xywh'] = np.array([3.7*1/2, -35, 3.7, 20])
    condition_5['direction'] = 1
    condition = [condition_1, condition_2, condition_3, condition_4, condition_5]

    for s_i in range(sample_number):
        plt.figure(figsize=(4, 6))
        for c_i in condition:
            # random initialize z
            z_init = np.random.normal(0.0, 1.0, size=(1, z_dim))
            z.data.copy_(CUDA(torch.from_numpy(z_init)))
            model.decode(s_i, z, position_scale, c_i)
        print('[{}/{}]'.format(s_i, sample_number))
        plt.tight_layout()
        plt.xlim([-20, 20]), plt.ylim([-60, 30])
        plt.savefig('./samples/'+str(s_i)+'_tree_plot.png', dpi=300)


# for adjusting the weight of kl divergence, from 0 to 1
def beta_scheduler(epoch, gamma=10.0, epoch_th=100):
    beta = 1.0/(1.0+np.exp(-(epoch-epoch_th)/gamma))
    return beta


def train():
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    #model.load_model('./models')
    dataset = VehicleSceneTreeDataset('./data/vehicle_sbtree_dataset.npy', position_scale)
    train_iter = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=dataset.collate_fn)

    for e_i in range(epochs):
        total_loss_list = []
        kldiv_loss_list = []
        recon_loss_list = []
        beta = beta_max*beta_scheduler(e_i, gamma, epoch_th)
        for t_i, batch_data in enumerate(train_iter):
            recon_loss, kldiv_loss, root_nodes = model(batch_data)
            total_loss = recon_loss + beta*kldiv_loss

            total_loss_list.append(total_loss.item())
            kldiv_loss_list.append(kldiv_loss.item())
            recon_loss_list.append(recon_loss.item())

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

        avg_total_loss = np.mean(total_loss_list)
        avg_kldiv_loss = np.mean(kldiv_loss_list)
        avg_recon_loss = np.mean(recon_loss_list)
        if (e_i+1) % print_iter == 0:
            print('Epoch: {}, Beta: {}, Total loss: {}, KLD: {}, Recon loss: {}'.format(e_i, beta, avg_total_loss, avg_kldiv_loss, avg_recon_loss))

        if ((e_i+1) % save_iter == 0):
            model.save_model('./models/tree.model.pth')


parser = argparse.ArgumentParser(description='Train T-VAE model')
parser.add_argument('--type', type=int, default=0)
args = parser.parse_args()


if args.type == 0:
    train() 
elif args.type == 1:
    test()
