'''
Author:
Email: 
Date: 2020-10-12 22:14:36
LastEditTime: 2021-05-30 16:33:20
Description: 
    This file implements VAE, T-VAE, and GVAE models.
'''

import os
import numpy as np
from matplotlib import pyplot as plt
from nltk import CFG, Nonterminal

import torch
from torch import nn
from torch.autograd import Variable

from structure import TreeParser
from utils import CUDA, COLOR, CPU


BOX = 0
QUAD = 1
EMPTY = 2
PLANE = 3
WORLD = 4


class SamplerEncoder(nn.Module):
    def __init__(self, z_size, feature_size, hidden_size):
        super(SamplerEncoder, self).__init__()
        self.mlp1 = nn.Linear(feature_size, hidden_size)
        self.mlp2mu = nn.Linear(hidden_size, z_size)
        self.mlp2var = nn.Linear(hidden_size, z_size)
        self.tanh = nn.Tanh()

    def forward(self, root_input):
        encode = self.tanh(self.mlp1(root_input))
        mu = self.mlp2mu(encode)
        logvar = self.mlp2var(encode)

        std = logvar.div(2).exp()
        eps = CUDA(Variable(std.data.new(std.size()).normal_()))
        z = mu+std*eps

        KLD_element = -0.5*(1+logvar-mu.pow(2)-logvar.exp())
        return torch.cat([z, mu, KLD_element], dim=1)


class BoxEncoder(nn.Module): 
    def __init__(self, input_size, feature_size, hidden_size): 
        super(BoxEncoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, feature_size),
            nn.Tanh(),
        )

    def forward(self, prop, alpha):
        feature = torch.cat([prop, alpha], dim=1)
        feature = self.encoder(feature)
        return feature


class EmptyEncoder(nn.Module): 
    def __init__(self, feature_size): 
        super(EmptyEncoder, self).__init__()
        self.feature_size = feature_size

    def forward(self):
        return CUDA(torch.zeros(1, self.feature_size))


class QuadEncoder(nn.Module):
    def __init__(self, input_size, feature_size, hidden_size):
        super(QuadEncoder, self).__init__()
        self.common = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
        )

        self.children_list = nn.ModuleList()
        for c_i in range(4):
            one_child = nn.Sequential(
                nn.Linear(feature_size, hidden_size),
                nn.Tanh(),
                nn.Linear(hidden_size, hidden_size),
            )
            self.children_list.append(one_child)

        self.output = nn.Sequential(
            nn.Tanh(),
            nn.Linear(hidden_size, feature_size),
            nn.Tanh(),
            nn.Linear(hidden_size, feature_size),
        )

    def forward(self, child_0, child_1, child_2, child_3, prop, alpha):
        feature = torch.cat([prop, alpha], dim=1)
        x = self.common(feature)
        x = x + self.children_list[0](child_0) 
        x = x + self.children_list[1](child_1) 
        x = x + self.children_list[2](child_2) 
        x = x + self.children_list[3](child_3) 
        x = self.output(x)
        return x


class PlaneEncoder(nn.Module): 
    def __init__(self, input_size, feature_size, hidden_size): 
        super(PlaneEncoder, self).__init__()
        self.common = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
        )

        self.children_list = nn.ModuleList()
        for c_i in range(4):
            one_child = nn.Sequential(
                nn.Linear(feature_size, hidden_size),
                nn.Tanh(),
                nn.Linear(hidden_size, hidden_size),
            )
            self.children_list.append(one_child)

        self.output = nn.Sequential(
            nn.Tanh(),
            nn.Linear(hidden_size, feature_size),
            nn.Tanh(),
            nn.Linear(hidden_size, feature_size)
        )

    def forward(self, child_0, child_1, child_2, child_3, prop, alpha):
        feature = torch.cat([prop, alpha], dim=1)
        x = self.common(feature)
        x = x + self.children_list[0](child_0) 
        x = x + self.children_list[1](child_1) 
        x = x + self.children_list[2](child_2) 
        x = x + self.children_list[3](child_3) 
        x = self.output(x)
        return x


class WorldEncoder(nn.Module):
    def __init__(self, input_size, feature_size, hidden_size):
        super(WorldEncoder, self).__init__()
        self.xywh_net = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
        )

        self.children_list = nn.ModuleList()
        for c_i in range(4):
            one_child = nn.Sequential(
                nn.Linear(feature_size, hidden_size),
                nn.Tanh(),
                nn.Linear(hidden_size, hidden_size),
            )
            self.children_list.append(one_child)

        self.output = nn.Sequential(
            nn.Tanh(),
            nn.Linear(hidden_size, feature_size),
            nn.Tanh(),
            nn.Linear(hidden_size, feature_size),
        )

    def forward(self, child_0, child_1, child_2, child_3, xywh):
        x = self.xywh_net(xywh)
        x = x + self.children_list[0](child_0) 
        x = x + self.children_list[1](child_1) 
        x = x + self.children_list[2](child_2) 
        x = x + self.children_list[3](child_3) 
        x = self.output(x)
        return x


class ScenarioEncoder(nn.Module):
    def __init__(self, z_size, feature_size, hidden_size):
        super(ScenarioEncoder, self).__init__()
        box_prop_size = 4
        alpha_xy_size = 2
        xywh_size = 4
        self.box_encoder = BoxEncoder(input_size=box_prop_size+alpha_xy_size, feature_size=feature_size, hidden_size=hidden_size)
        self.quad_encoder = QuadEncoder(input_size=box_prop_size+alpha_xy_size, feature_size=feature_size, hidden_size=hidden_size)
        self.empty_encoder = EmptyEncoder(feature_size=feature_size)
        self.plane_encoder = PlaneEncoder(input_size=box_prop_size+alpha_xy_size, feature_size=feature_size, hidden_size=hidden_size)
        self.world_encoder = WorldEncoder(input_size=xywh_size, feature_size=feature_size, hidden_size=hidden_size)
        self.sampler_encoder = SamplerEncoder(z_size=z_size, feature_size=feature_size, hidden_size=hidden_size)

    def boxEncoder(self, prop, alpha):
        return self.box_encoder(prop, alpha)

    def emptyEncoder(self):
        return self.empty_encoder()

    def quadEncoder(self, child_0, child_1, child_2, child_3, prop, alpha):
        return self.quad_encoder(child_0, child_1, child_2, child_3, prop, alpha)

    def planeEncoder(self, child_0, child_1, child_2, child_3, prop, alpha):
        return self.plane_encoder(child_0, child_1, child_2, child_3, prop, alpha)

    def worldEncoder(self, child_0, child_1, child_2, child_3, xywh):
        return self.world_encoder(child_0, child_1, child_2, child_3, xywh)

    def samplerEncoder(self, child_0):
        return self.sampler_encoder(child_0)


class SamplerDecoder(nn.Module):
    """ Decode a randomly sampled noise into a feature vector 
    """
    def __init__(self, z_size, feature_size, hidden_size):
        super(SamplerDecoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.Linear(z_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, feature_size),
            nn.Tanh(),
        )

    def forward(self, input_feature):
        output = self.decoder(input_feature)
        return output


class BoxDecoder(nn.Module):
    """ Decode the box property from feature vector
    """
    def __init__(self, feature_size, input_size, hidden_size):
        super(BoxDecoder, self).__init__()
        self.box_decoder = nn.Sequential(
            nn.Linear(feature_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, input_size[0]),
            nn.Sigmoid()  # use sigmoid because we normalized the orientation and color to [0, 1]
        )
        self.alpha_decoder = nn.Sequential(
            nn.Linear(feature_size, input_size[1]),
            nn.Tanh(), # alpha is in [-1, 1]
        )

    def forward(self, parent_feature):
        box_prop = self.box_decoder(parent_feature)
        alpha = self.alpha_decoder(parent_feature)
        return box_prop, alpha


class QuadDecoder(nn.Module):
    """ Decode the feature of child from feature vector
    """
    def __init__(self, feature_size, hidden_size, input_size):
        super(QuadDecoder, self).__init__()
        self.common = nn.Sequential(
            nn.Linear(feature_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
        )
        self.children_list = nn.ModuleList()
        for c_i in range(4):
            one_child = nn.Sequential(
                nn.Linear(hidden_size, hidden_size),
                nn.Tanh(),
                nn.Linear(hidden_size, feature_size),
            )
            self.children_list.append(one_child)
        self.box_decoder = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, input_size[0]),
            nn.Sigmoid(),
        )
        self.alpha_decoder = nn.Sequential(
            nn.Linear(hidden_size, input_size[1]),
            nn.Tanh(), # alpha is in [-1, 1]
        )

    def forward(self, parent_feature):
        feature = self.common(parent_feature)
        child_0 = self.children_list[0](feature)
        child_1 = self.children_list[1](feature)
        child_2 = self.children_list[2](feature)
        child_3 = self.children_list[3](feature)
        box_prop = self.box_decoder(feature)
        alpha = self.alpha_decoder(feature)
        return child_0, child_1, child_2, child_3, box_prop, alpha


class PlaneDecoder(nn.Module):
    """ Decode the feature of child from feature vector
    """
    def __init__(self, feature_size, hidden_size, input_size):
        super(PlaneDecoder, self).__init__()
        self.common = nn.Sequential(
            nn.Linear(feature_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
        )

        self.children_list = nn.ModuleList()
        for c_i in range(4):
            one_child = nn.Sequential(
                nn.Linear(hidden_size, hidden_size),
                nn.Tanh(),
                nn.Linear(hidden_size, feature_size),
            )
            self.children_list.append(one_child)

        self.box_decoder = nn.Sequential(
            nn.Linear(feature_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, input_size[0]),
            nn.Sigmoid(),   # use sigmoid because we normalized the orientation and color to [0, 1]
        )
        self.alpha_decoder = nn.Sequential(
            nn.Linear(feature_size, input_size[1]),
            nn.Tanh(), # alpha is in [-1, 1]
        )

    def forward(self, parent_feature):
        feature = self.common(parent_feature)
        child_0 = self.children_list[0](feature)
        child_1 = self.children_list[1](feature)
        child_2 = self.children_list[2](feature)
        child_3 = self.children_list[3](feature)
        box_prop = self.box_decoder(feature)
        alpha = self.alpha_decoder(feature)
        return child_0, child_1, child_2, child_3, box_prop, alpha


class WorldDecoder(nn.Module):
    """ Decode the feature of child from feature vector
    """
    def __init__(self, feature_size, hidden_size, input_size):
        super(WorldDecoder, self).__init__()
        self.common = nn.Sequential(
            nn.Linear(feature_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
        )

        self.children_list = nn.ModuleList()
        for c_i in range(4):
            one_child = nn.Sequential(
                nn.Linear(hidden_size, hidden_size),
                nn.Tanh(),
                nn.Linear(hidden_size, feature_size),
            )
            self.children_list.append(one_child)

        self.xywh_net = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, input_size),
        )

    def forward(self, parent_feature):
        feature = self.common(parent_feature)
        child_0 = self.children_list[0](feature)
        child_1 = self.children_list[1](feature)
        child_2 = self.children_list[2](feature)
        child_3 = self.children_list[3](feature)
        xywh = self.xywh_net(feature)
        return child_0, child_1, child_2, child_3, xywh


class NodeClassifier(nn.Module):
    def __init__(self, feature_size, hidden_size, class_size=1):
        super(NodeClassifier, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(feature_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, class_size),
        )

    def forward(self, input_feature):
        x = self.classifier(input_feature)
        return x


class ScenarioDecoder(nn.Module):
    def __init__(self, z_size, feature_size, hidden_size, class_size):
        super(ScenarioDecoder, self).__init__()
        box_prop_size = 4
        alpha_xy_size = 2
        xywh_size = 4

        self.box_decoder = BoxDecoder(feature_size=feature_size, hidden_size=hidden_size, input_size=[box_prop_size, alpha_xy_size])
        self.quad_decoder = QuadDecoder(feature_size=feature_size, hidden_size=hidden_size, input_size=[box_prop_size, alpha_xy_size])
        self.plane_decoder = PlaneDecoder(feature_size=feature_size, hidden_size=hidden_size, input_size=[box_prop_size, alpha_xy_size])
        self.world_decoder = WorldDecoder(feature_size=feature_size, hidden_size=hidden_size, input_size=xywh_size)

        self.sampler_decoder = SamplerDecoder(z_size=z_size, feature_size=feature_size, hidden_size=hidden_size)
        self.node_classifier = NodeClassifier(feature_size=feature_size, hidden_size=hidden_size, class_size=class_size)

        self.mse_loss = nn.MSELoss(reduction='none')  
        self.bce_loss = nn.BCEWithLogitsLoss(reduction='none') 
        self.ce_loss = nn.CrossEntropyLoss(reduction='none')
        
    def boxDecoder(self, parent):
        return self.box_decoder(parent)

    def quadDecoder(self, parent):
        return self.quad_decoder(parent)

    def planeDecoder(self, parent):
        return self.plane_decoder(parent)

    def worldDecoder(self, parent):
        return self.world_decoder(parent)

    def samplerDecoder(self, parent):
        return self.sampler_decoder(parent)

    def nodeClassifier(self, parent):
        return self.node_classifier(parent)

    # NOTE: even if they all use mse loss, we need to separete them
    # this loss is used to calculate property loss for VEHICLE
    def boxLossEstimator(self, predict, gt, normalizer):
        # we dont calculate the loss of x and y
        loss = self.mse_loss(predict, gt).mean(dim=1)/normalizer[0]
        return loss

    # this loss is used to calculate 2D alpha for VEHICLE and QUAD
    def alpha2DLossEstimator(self, predict, gt, normalizer):
        loss = self.mse_loss(predict, gt).mean(dim=1)/normalizer[0]
        return loss

    # this loss is used to calculate xywh of ROAD node
    def xywhLossEstimator(self, predict, gt, normalizer):
        loss = self.mse_loss(predict, gt).mean(dim=1)/normalizer[0]
        return loss

    # this loss is for classifier node
    def ceLossEstimator(self, predict, gt, normalizer):
        loss = self.ce_loss(predict, gt[:, 0])/normalizer[0]
        return loss

    def vectorAdder(self, v1, v2):
        return v1.add_(v2)

    def vectorDivider(self, nom, den):
        return nom/den


class TreeVAE(nn.Module):
    """ Only the World node contains absolute information, other nodes only have relative information
    """
    def __init__(self, z_dim):
        super(TreeVAE, self).__init__()
        self.feature_size = 128
        self.hidden_size = 128
        self.z_dim = z_dim
        self.class_size = 5

        self.encoder = ScenarioEncoder(z_size=self.z_dim, feature_size=self.feature_size, hidden_size=self.hidden_size)
        self.decoder = ScenarioDecoder(z_size=self.z_dim, feature_size=self.feature_size, hidden_size=self.hidden_size, class_size=self.class_size)

        print(COLOR.GREEN+'Model Info:')
        print('\tModel name:', 'Stick Breaking VAE')
        print('\tDimension of z:', self.z_dim)
        print('\tSize of feature:', self.feature_size)
        print('\tSize of hidden layer:', self.hidden_size)
        print(COLOR.WHITE+'')

    def forward(self, batch):
        # Initialize encode and decode parsers
        enc_parser = TreeParser(type='encode')
        dec_parser = TreeParser(type='decode')

        enc_parser_nodes = [] # list of fold nodes for encoding
        dec_parser_nodes = [] # list of fold nodes for decoding
        kld_parser_nodes = []
        mu_list = []

        # Collect computation nodes recursively from encoding process
        for data in batch:
            enc_parser_nodes.append(enc_parser.parse(data))

        # apply the tree on the encoder model
        enc_parser_nodes = enc_parser.apply(self.encoder, [enc_parser_nodes])

        # split into a list per data
        enc_parser_nodes = torch.split(enc_parser_nodes[0], 1, dim=0)

        # collect computation nodes recursively from decoding process
        for data, fnode in zip(batch, enc_parser_nodes):
            z, mu, kl_div = torch.chunk(fnode, 3, dim=1)
            dec_parser_nodes.append(dec_parser.parse(data, z))
            mu_list.append(mu)
            kld_parser_nodes.append(kl_div)

        # apply the tree on the decoder model
        recon_loss = dec_parser.apply(self.decoder, [dec_parser_nodes])
        recon_loss = recon_loss[0].mean()   # avg. reconstruction loss per example

        kldiv_loss = torch.cat(kld_parser_nodes, dim=0).mean() # [B, Z] -> [1]
        mu_list = torch.stack(mu_list)
        
        return recon_loss, kldiv_loss, mu_list

    # equation to calculate the left center and right part center given (alpha, x, w)
    @staticmethod
    def left_or_bottom(alpha, x_or_y, w_or_h):
        new_x_or_y = x_or_y - w_or_h*(1-alpha)/4
        new_w_or_h = w_or_h*(1+alpha)/2
        return new_x_or_y, new_w_or_h
    
    @staticmethod
    def right_or_top(alpha, x_or_y, w_or_h):
        new_x_or_y = x_or_y + w_or_h*(1+alpha)/4
        new_w_or_h = w_or_h*(1-alpha)/2
        return new_x_or_y, new_w_or_h 

    # divide function for x and y, if the ratio is None, just copy the parent node
    def divide_2D(self, road_x, road_y, road_w, road_h, alpha_x=None, alpha_y=None):
        # x and w
        left_x, left_w = self.left_or_bottom(alpha_x, road_x, road_w)
        right_x, right_w = self.right_or_top(alpha_x, road_x, road_w)

        # y and h
        bottom_y, bottom_h = self.left_or_bottom(alpha_y, road_y, road_h)
        top_y, top_h = self.right_or_top(alpha_y, road_y, road_h)

        # collect the new coordinate
        left_top = torch.cat([left_x, top_y, left_w, top_h], dim=0)             # [4,]
        right_top = torch.cat([right_x, top_y, right_w, top_h], dim=0)          # [4,]
        left_bottom = torch.cat([left_x, bottom_y, left_w, bottom_h], dim=0)    # [4,]
        right_bottom = torch.cat([right_x, bottom_y, right_w, bottom_h], dim=0) # [4,]
        
        return left_top, right_top, left_bottom, right_bottom

    def add_node(self, feature, xywh, scene_parameter):
        label_prob = self.decoder.nodeClassifier(feature) 
        label_prob_softmax = torch.softmax(label_prob, dim=-1)
        label_prob_softmax[:, WORLD] = -9999.9
        label = torch.argmax(label_prob_softmax)

        # we should not use numpy to break the gradient chain
        if label == BOX:
            box_prop, alpha = self.decoder.box_decoder(feature)
            theta = box_prop[0, 0:1]*2*np.pi
            color = box_prop[:, 1:4]
            x = xywh[0:1] + alpha[0, 0]*xywh[2:3]/2
            y = xywh[1:2] + alpha[0, 1]*xywh[3:4]/2
            pose = torch.cat([x, y, theta], dim=0)[None]
            scene_parameter['box_poses'].append(pose)
            scene_parameter['box_colors'].append(color)
        elif label == QUAD:
            child_0_f, child_1_f, child_2_f, child_3_f, box_prop, alpha = self.decoder.quad_decoder(feature)
            theta = box_prop[0, 0:1]*2*np.pi
            color = box_prop[:, 1:4]
            x = xywh[0:1] + alpha[0, 0]*xywh[2:3]/2
            y = xywh[1:2] + alpha[0, 1]*xywh[3:4]/2
            pose = torch.cat([x, y, theta], dim=0)[None]
            scene_parameter['box_poses'].append(pose)
            scene_parameter['box_colors'].append(color)
            left_top, right_top, left_bottom, right_bottom = self.divide_2D(xywh[0:1], xywh[1:2], xywh[2:3], xywh[3:4], alpha[0, 0], alpha[0, 1])
            self.add_node(child_0_f, left_top, scene_parameter)
            self.add_node(child_1_f, right_top, scene_parameter)
            self.add_node(child_2_f, left_bottom, scene_parameter)
            self.add_node(child_3_f, right_bottom, scene_parameter)  
        elif label == EMPTY:
            pass
        elif label == PLANE:
            child_0_f, child_1_f, child_2_f, child_3_f, box_prop, alpha = self.decoder.plane_decoder(feature)
            theta = box_prop[0, 0:1]*2*np.pi
            color = box_prop[:, 1:4]
            x = xywh[0:1] + alpha[0, 0]*xywh[2:3]/2
            y = xywh[1:2] + alpha[0, 1]*xywh[3:4]/2
            pose = torch.cat([x, y, theta], dim=0)[None]
            scene_parameter['plane_poses'].append(pose)
            scene_parameter['plane_colors'].append(color)
            left_top, right_top, left_bottom, right_bottom = self.divide_2D(xywh[0:1], xywh[1:2], xywh[2:3], xywh[3:4], alpha[0, 0], alpha[0, 1])
            self.add_node(child_0_f, left_top, scene_parameter)
            self.add_node(child_1_f, right_top, scene_parameter)
            self.add_node(child_2_f, left_bottom, scene_parameter)
            self.add_node(child_3_f, right_bottom, scene_parameter)  
        else:
            print('Node type:', label)
            raise ValueError('Wrong Type')
        return label_prob

    def add_node_use_knowledge(self, feature, xywh, scene_parameter, color_collector=None, parent_xy=None):
        label_prob = self.decoder.nodeClassifier(feature) 
        label_prob_softmax = torch.softmax(label_prob, dim=-1)
        label_prob_softmax[:, WORLD] = -9999.9
        label = torch.argmax(label_prob_softmax)

        # we should not use numpy to break the gradient chain
        if label == BOX:
            box_prop, alpha = self.decoder.box_decoder(feature)
            theta = box_prop[0, 0:1]*2*np.pi
            color = box_prop[:, 1:4]
            x = xywh[0:1] + alpha[0, 0]*xywh[2:3]/2
            y = xywh[1:2] + alpha[0, 1]*xywh[3:4]/2
            pose = torch.cat([x, y, theta], dim=0)[None]
            scene_parameter['box_poses'].append(pose)
            scene_parameter['box_colors'].append(color)
            # for color loss
            if color_collector is not None:
                color_collector.append(color) 
            # for distance loss
            if parent_xy is not None:
                distance = torch.sqrt((parent_xy[0]-x)**2 + (parent_xy[1]-y)**2) - self.dist_threshold
                truncated_distance = torch.nn.functional.relu(distance)
                self.distance_list.append(truncated_distance)
        elif label == QUAD:
            child_0_f, child_1_f, child_2_f, child_3_f, box_prop, alpha = self.decoder.quad_decoder(feature)
            theta = box_prop[0, 0:1]*2*np.pi
            color = box_prop[:, 1:4]
            x = xywh[0:1] + alpha[0, 0]*xywh[2:3]/2
            y = xywh[1:2] + alpha[0, 1]*xywh[3:4]/2
            pose = torch.cat([x, y, theta], dim=0)[None]
            scene_parameter['box_poses'].append(pose)
            scene_parameter['box_colors'].append(color)
            if color_collector is not None:
                color_collector.append(color) # for color loss
            # for distance loss
            if parent_xy is not None:
                distance = torch.sqrt((parent_xy[0]-x)**2 + (parent_xy[1]-y)**2) - self.dist_threshold
                truncated_distance = torch.nn.functional.relu(distance)
                self.distance_list.append(truncated_distance)
            left_top, right_top, left_bottom, right_bottom = self.divide_2D(xywh[0:1], xywh[1:2], xywh[2:3], xywh[3:4], alpha[0, 0], alpha[0, 1])
            self.add_node_use_knowledge(child_0_f, left_top, scene_parameter, color_collector, [x, y])
            self.add_node_use_knowledge(child_1_f, right_top, scene_parameter, color_collector, [x, y])
            self.add_node_use_knowledge(child_2_f, left_bottom, scene_parameter, color_collector, [x, y])
            self.add_node_use_knowledge(child_3_f, right_bottom, scene_parameter, color_collector, [x, y])  
        elif label == EMPTY:
            pass
        elif label == PLANE:
            child_0_f, child_1_f, child_2_f, child_3_f, box_prop, alpha = self.decoder.plane_decoder(feature)
            theta = box_prop[0, 0:1]*2*np.pi
            color = box_prop[:, 1:4]
            x = xywh[0:1] + alpha[0, 0]*xywh[2:3]/2
            y = xywh[1:2] + alpha[0, 1]*xywh[3:4]/2
            pose = torch.cat([x, y, theta], dim=0)[None]
            scene_parameter['plane_poses'].append(pose)
            scene_parameter['plane_colors'].append(color)
            left_top, right_top, left_bottom, right_bottom = self.divide_2D(xywh[0:1], xywh[1:2], xywh[2:3], xywh[3:4], alpha[0, 0], alpha[0, 1])
            self.color_dict[self.counter] = [] # create a new list
            self.add_node_use_knowledge(child_0_f, left_top, scene_parameter, self.color_dict[self.counter], [x, y])
            self.add_node_use_knowledge(child_1_f, right_top, scene_parameter, self.color_dict[self.counter], [x, y])
            self.add_node_use_knowledge(child_2_f, left_bottom, scene_parameter, self.color_dict[self.counter], [x, y])
            self.add_node_use_knowledge(child_3_f, right_bottom, scene_parameter, self.color_dict[self.counter], [x, y])  
            self.counter += 1
        else:
            print('Node type:', label)
            raise ValueError('Wrong Type')
        return label_prob

    def decode(self, feature, position_scale, use_kg=False):
        # the first node will always be a WORLD, then we apply the scale to get the real position
        scene_parameter = {'box_poses': [], 'box_colors': [], 'plane_poses': [], 'plane_colors': []}
        feature = self.decoder.samplerDecoder(feature)
        child_0_f, child_1_f, child_2_f, child_3_f, xywh = self.decoder.world_decoder(feature)
        world_xywh = xywh[0] * position_scale # [0, 0, 20, 20]
        world_xywh = CUDA(torch.tensor([0.0, 0.0, 20.0, 20.0])) # use groundtruth

        x = world_xywh[0:1]  
        y = world_xywh[1:2] 
        w = world_xywh[2:3] 
        h = world_xywh[3:4]

        self.color_dict = {}
        self.counter = 0 # to count the number of plate

        self.distance_list = [] # to collect the relative distance of box
        self.dist_threshold = 3 # distance should be smaller than 2

        # the data is generated in this order, thus it should not be changed
        label_prob_1 = self.add_node_use_knowledge(child_0_f, torch.cat([x+w/4, y+h/4, w/2, h/2]), scene_parameter) # [5, 5]   - 1
        label_prob_2 = self.add_node_use_knowledge(child_1_f, torch.cat([x+w/4, y-h/4, w/2, h/2]), scene_parameter) # [5, -5]  - 2
        label_prob_3 = self.add_node_use_knowledge(child_2_f, torch.cat([x-w/4, y+h/4, w/2, h/2]), scene_parameter) # [-5, 5]  - 3
        label_prob_4 = self.add_node_use_knowledge(child_3_f, torch.cat([x-w/4, y-h/4, w/2, h/2]), scene_parameter) # [-5, -5] - 4

        # process scene prameters
        for k_i in scene_parameter.keys():
            if len(scene_parameter[k_i]) > 0: # we may dont have plane 
                scene_parameter[k_i] = torch.cat(scene_parameter[k_i], axis=0)[None] # [1, N, 3]
            else:
                scene_parameter[k_i] = None

        if use_kg:
            # branch type loss
            predicted_label = torch.cat([label_prob_1, label_prob_2, label_prob_3, label_prob_4], dim=0) # [4, 5]
            target_label = CUDA(torch.tensor([PLANE, EMPTY, EMPTY, PLANE]))
            branch_loss = torch.nn.functional.cross_entropy(predicted_label, target_label)

            # color consistent loss
            color_loss = 0
            for c_i in self.color_dict.keys():
                if len(self.color_dict[c_i]) == 0: # no box in current plate
                    continue
                # make colors close to the mean value
                colors_in_one_branch = torch.cat(self.color_dict[c_i], dim=0)
                mean_color = torch.mean(colors_in_one_branch, dim=0, keepdim=True).repeat(colors_in_one_branch.shape[0], 1)
                color_loss += torch.nn.functional.mse_loss(colors_in_one_branch, mean_color.detach())

            # relative distance loss
            distance_loss = torch.cat(self.distance_list, dim=0).mean()
            
            return scene_parameter, branch_loss + color_loss + distance_loss
        else:
            return scene_parameter

    def save_model(self, filename):
        state = {'encoder': self.encoder.state_dict(), 'decoder': self.decoder.state_dict()}
        torch.save(state, filename)

    def load_model(self, filename):
        if os.path.isfile(filename):
            checkpoint = torch.load(filename)
            self.encoder.load_state_dict(checkpoint['encoder'])
            self.decoder.load_state_dict(checkpoint['decoder'])
        else:
            raise FileNotFoundError

    @staticmethod
    def get_box(pose, width, height):
        xy = pose[0:2]
        o = pose[2]
        position = np.array([[-width/2, -height/2], [-width/2, height/2], [width/2, height/2], [width/2, -height/2], [-width/2, -height/2]])
        rotation_matrix = np.array([[np.cos(o), -np.sin(o)], [np.sin(o), np.cos(o)]])
        position = rotation_matrix.dot(position.T).T + xy
        return position

    @staticmethod
    def plot_box(points, color):
        plt.plot(points[:, 0], points[:, 1], color+'-')
    
    @staticmethod
    def plot_arrow(pose, color, direction):
        scale = 0.5
        xy = pose[0:2]
        o = pose[2]
        if direction:
            position = scale*np.array([[0.0, 1.0]]).T
        else:
            position = scale*np.array([[0.0, -1.0]]).T
        rotation_matrix = np.array([[np.cos(o), -np.sin(o)], [np.sin(o), np.cos(o)]])
        position = rotation_matrix.dot(position)[:, 0] 
        plt.arrow(xy[0], xy[1], position[0], position[1], color=color, width=0.3)

    @staticmethod
    def plot_xywh(xywh):
        # plot road
        road_x, road_y, road_w, road_h = xywh[0:1], xywh[1:2], xywh[2:3], xywh[3:4]
        road_show = np.array([
            [road_x-road_w/2, road_y-road_h/2],
            [road_x-road_w/2, road_y+road_h/2],
            [road_x+road_w/2, road_y+road_h/2],
            [road_x+road_w/2, road_y-road_h/2],
            [road_x-road_w/2, road_y-road_h/2]
        ])
        # table and line should also be rotated
        plt.plot(road_show[:, 0], road_show[:, 1], 'k-', linewidth=2)


class GridVAE(nn.Module):
    """ This simple VAE uses the coordinate of all objects.
    """
    def __init__(self, z_dim):
        super(GridVAE, self).__init__()

        self.prop_dim = 6
        self.object_num = 8+2
        self.hidden_size = 128
        self.z_dim = z_dim
        self.encoder = nn.Sequential(
            nn.Linear(self.object_num*self.prop_dim, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.z_dim*2),
        )

        self.decoder = nn.Sequential(
            nn.Linear(self.z_dim, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.ReLU(),
        )
        self.localizer = nn.Sequential(
            nn.Linear(self.hidden_size, self.object_num*3), # cannot use sigmoid here, because x and y have negative values
        )
        self.color_output = nn.Sequential(
            nn.Linear(self.hidden_size, self.object_num*3),
            nn.Sigmoid()
        )
        self.mse_loss = nn.MSELoss(reduction='sum')  

        print(COLOR.GREEN+'Model Info:')
        print('\tModel name:', 'Grid VAE')
        print('\tDimension of z:', self.z_dim)
        print('\tNumber of object:', self.object_num)
        print('\tSize of hidden layer:', self.hidden_size)
        print(COLOR.WHITE+'')

    def reparametrize(self, mu, logvar):
        std = logvar.div(2).exp()
        eps = CUDA(Variable(std.data.new(std.size()).normal_()))
        return mu+std*eps

    def forward(self, prop):
        # [B, 10, 6] -> [B, 60]
        batch_size = prop.shape[0]
        prop_input = prop.reshape(batch_size, self.object_num*self.prop_dim)

        z = self.encoder(prop_input)
        mu = z[:, 0:self.z_dim]
        logvar = z[:, self.z_dim:]
        z = self.reparametrize(mu, logvar)
        prop_ = self.decoder(z)
        pos_ = self.localizer(prop_).reshape(batch_size, self.object_num, 3)
        color_ = self.color_output(prop_).reshape(batch_size, self.object_num, 3)

        # losses
        pos_loss = self.mse_loss(prop[:, :, 0:3], pos_)/batch_size
        color_loss = self.mse_loss(prop[:, :, 3:6], color_)/batch_size
        recon_loss = (pos_loss+color_loss)/2.0
        kld = -0.5*(1+logvar-mu.pow(2)-logvar.exp())
        kld = kld.mean()
        return recon_loss, kld, mu[None]

    def decode(self, z, position_scale):
        prop = self.decoder(z)
        pos_ = self.localizer(prop).reshape(prop.shape[0], self.object_num, 3)
        color_ = self.color_output(prop).reshape(prop.shape[0], self.object_num, 3)

        # denormalization
        scale = CUDA(torch.tensor([position_scale, position_scale, 2*np.pi]))[None][None]
        pos_ = pos_ * scale

        scene_parameters = {}
        scene_parameters['plane_poses'] = pos_[:, 0:2, :]
        scene_parameters['plane_colors'] = color_[:, 0:2, :]
        scene_parameters['box_poses'] = pos_[:, 2:10, :]
        scene_parameters['box_colors'] = color_[:, 2:10, :]
        return scene_parameters

    def sample(self, num):
        z = CUDA(Variable(torch.randn(num, self.z_dim)))
        f = self.decoder(z)
        pos_ = self.localizer(f).view(num, self.object_num, 3)
        return pos_

    def save_model(self, filename):
        state = {'model_state': self.state_dict()}
        torch.save(state, filename)

    def load_model(self, filename):
        if os.path.isfile(filename):
            checkpoint = torch.load(filename)
            self.load_state_dict(checkpoint['model_state'])
        else:
            raise FileNotFoundError


class GrammarVAE(nn.Module):
    """ This Grammar VAE model.
    """
    def __init__(self, z_dim, max_length, rule_dim, attri_dim):
        super(GrammarVAE, self).__init__()
        self.hidden_size = 256
        self.max_length = max_length
        self.rule_dim = rule_dim
        self.attri_dim = attri_dim
        self.output_dim = rule_dim + attri_dim # 15
        self.z_dim = z_dim
        self.encoder_cnn = nn.Sequential(
            nn.Conv1d(rule_dim+attri_dim, 32, kernel_size=3), # [32, 23]
            nn.ReLU(),
            nn.Conv1d(32, 64, kernel_size=3), # [64, 21]
            nn.ReLU(),
            nn.Conv1d(64, 128, kernel_size=3), # [128, 19]
            nn.ReLU(),
        )
        self.encoder_fc = nn.Sequential(
            nn.Linear(128*19, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, 128),
            nn.ReLU(),
            nn.Linear(128, z_dim*2),
        )

        self.lstm_size = 128
        self.linear_in = nn.Linear(self.z_dim, self.lstm_size)
        self.linear_out_classify = nn.Linear(self.lstm_size, rule_dim)
        self.linear_out_pose = nn.Linear(self.lstm_size, 3)
        self.linear_out_color = nn.Linear(self.lstm_size, 3)
        self.rnn = nn.LSTM(self.lstm_size, self.lstm_size, batch_first=True)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.mse_loss = nn.MSELoss()  
        self.ce_loss = nn.CrossEntropyLoss()  

        print(COLOR.GREEN+'Model Info:')
        print('\tModel name:', 'Grammar VAE')
        print('\tDimension of z:', self.z_dim)
        print('\tSize of hidden layer:', self.hidden_size)
        print(COLOR.WHITE+'')

        grammar = '''
            W -> P
            W -> B
            P -> P E
            P -> P B
            P -> B
            P -> E
            B -> B
            B -> E
            E -> E
        '''

        self.GCFG = CFG.fromstring(grammar)

    @staticmethod
    def get_mask(nonterminal, grammar, as_variable=False):
        if isinstance(nonterminal, Nonterminal):
            mask = [rule.lhs() == nonterminal for rule in grammar.productions()]
            mask = Variable(torch.FloatTensor(mask)) if as_variable else mask
            return CUDA(mask)
        else:
            raise ValueError('Input must be instance of nltk.Nonterminal')

    def reparametrize(self, mu, logvar):
        std = logvar.div(2).exp()
        eps = CUDA(Variable(std.data.new(std.size()).normal_()))
        return mu+std*eps

    def forward(self, x):
        # preprocess the data x
        logits = x[:, 0:self.rule_dim, :]   # [B, 9, 25]
        _, y = Variable(logits).max(1) # The rule index
        y = y.view(-1)
        attribute = x[:, self.rule_dim:, :]  # [B, 6, 15]

        # x - [B, 15, 25]
        # encoder
        h = self.encoder_cnn(x)
        h = h.view(x.size(0), -1) # flatten
        z = self.encoder_fc(h)
        mu = z[:, 0:self.z_dim]
        logvar = z[:, self.z_dim:]
        z = self.reparametrize(mu, logvar)

        # decoder
        x_ = self.relu(self.linear_in(z))
        x_ = x_.unsqueeze(1).expand(-1, self.max_length, -1)
        hx = (CUDA(Variable(torch.zeros(1, x_.size(0), self.lstm_size))), CUDA(Variable(torch.zeros(1, x_.size(0), self.lstm_size))))

        x_, _ = self.rnn(x_, hx)
        x_ = self.relu(x_)
        logits_ = self.linear_out_classify(x_) # [B, 25, 9]
        logits_ = logits_.view(-1, logits_.size(-1))
        pose_ = self.linear_out_pose(x_) # [B, 25, 3]
        color_ = torch.sigmoid(self.linear_out_color(x_)) # [B, 25, 3]
        attribute_ = torch.cat([pose_, color_], dim=2)
        attribute_ = attribute_.transpose(1, 2)  # [B, 6，25]

        # reconstruction error
        ce_loss = self.ce_loss(logits_, y) 
        mse_loss = self.mse_loss(attribute, attribute_)
        recon_loss = ce_loss + 10*mse_loss

        # KL divergence
        kld = -0.5*(1+logvar-mu.pow(2)-logvar.exp())
        kld = kld.mean()
        return recon_loss, kld, mu[None]

    @staticmethod
    def parse_grammar(grammar_str, attribute):
        # only need to consider the left hand side, rhs may have multiple objects
        object = str(grammar_str._lhs)
        if object in ['W', 'E']:
            return [], []
        elif object == 'P':
            return [attribute], []
        elif object == 'B':
            return [], [attribute]
        else:
            raise ValueError('Invalid node type - {}'.format(object))

    def decode(self, z, position_scale):
        # z - [1, Z]
        x_ = self.relu(self.linear_in(z))
        x_ = x_.unsqueeze(1).expand(-1, self.max_length, -1)
        hx = (CUDA(Variable(torch.zeros(1, x_.size(0), self.lstm_size))), CUDA(Variable(torch.zeros(1, x_.size(0), self.lstm_size))))
        x_, _ = self.rnn(x_, hx)
        x_ = self.relu(x_)
        logits_ = self.linear_out_classify(x_) # [B, 25, 9]
        pose_ = self.linear_out_pose(x_) # [B, 25, 3]
        color_ = torch.sigmoid(self.linear_out_color(x_)) # [B, 25, 3]
        x_ = torch.cat([logits_, pose_, color_], dim=2)

        return self.decode_from_x(x_, position_scale)

    def decode_from_x(self, x_, position_scale):
        # reconstruction error
        logits_ = x_[0, :, 0:self.rule_dim]    # [25, 9]
        attribute_ = x_[0, :, self.rule_dim:]  # [25, 6]

        stack = [Nonterminal('W')] # start from symbol W
        rules = []
        plate_attribute = []
        box_attribute = []
        t = 0
        while len(stack) > 0:
            alpha = stack.pop()
            if str(alpha) == 'E':
                continue
            mask = self.get_mask(alpha, self.GCFG, as_variable=True)
            probs = mask * logits_[t].exp() # mask out impossible rules
            probs = probs / probs.sum()
            _, idx = probs.max(-1) # argmax

            # select rule idx
            rule = self.GCFG.productions()[CPU(idx)] 
            rules.append(rule)
            p_a, b_a = self.parse_grammar(rule, attribute_[t])
            plate_attribute += p_a
            box_attribute += b_a

            # add rhs nonterminals to stack in reversed order
            for symbol in reversed(rule.rhs()):
                if isinstance(symbol, Nonterminal):
                    stack.append(symbol)
            t += 1
            if t == self.max_length:
                break

        # denormalization
        scale = CUDA(torch.tensor([position_scale, position_scale, 2*np.pi]))
        scene_parameters = {}

        if len(plate_attribute) > 0:
            plate_attribute = torch.stack(plate_attribute, dim=0) # [N, 6]
            plate_attribute[:, 0:3] = plate_attribute[:, 0:3] * scale
            scene_parameters['plane_poses'] = plate_attribute[:, 0:3][None]
            scene_parameters['plane_colors'] = plate_attribute[:, 3:6][None]
        else:
            scene_parameters['plane_poses'] = None
            scene_parameters['plane_colors'] = None
        if len(box_attribute) > 0:
            box_attribute = torch.stack(box_attribute, dim=0)     # [M, 6]
            box_attribute[:, 0:3] = box_attribute[:, 0:3] * scale
            scene_parameters['box_poses'] = box_attribute[:, 0:3][None]
            scene_parameters['box_colors'] = box_attribute[:, 3:6][None]
        else:
            scene_parameters['box_poses'] = None
            scene_parameters['box_colors'] = None

        return scene_parameters

    def save_model(self, filename):
        state = {'model_state': self.state_dict()}
        torch.save(state, filename)

    def load_model(self, filename):
        if os.path.isfile(filename):
            checkpoint = torch.load(filename)
            self.load_state_dict(checkpoint['model_state'])
        else:
            raise FileNotFoundError
