'''
Author: 
Email: 
Date: 2020-10-12 22:14:36
LastEditTime: 2021-05-30 18:37:55
Description: 
    The model is a tree structure that use the road and lane conditions.
'''

import os
import numpy as np
import copy
from matplotlib import pyplot as plt

import torch
from torch import nn
from torch.autograd import Variable

from structure import TreeParser
from utils import CUDA, CPU


VEHICLE = 0
QUAD = 1
EMPTY = 2
LANE = 3
ROAD = 4


class SamplerEncoder(nn.Module):
    def __init__(self, z_size, feature_size, hidden_size):
        super(SamplerEncoder, self).__init__()
        self.mlp1 = nn.Linear(feature_size, hidden_size)
        self.mlp2mu = nn.Linear(hidden_size, z_size)
        self.mlp2var = nn.Linear(hidden_size, z_size)
        self.tanh = nn.Tanh()

    def forward(self, root_input):
        encode = self.tanh(self.mlp1(root_input))
        mu = self.mlp2mu(encode)
        logvar = self.mlp2var(encode)

        std = logvar.div(2).exp()
        eps = CUDA(Variable(std.data.new(std.size()).normal_()))
        z = mu+std*eps

        KLD_element = -0.5*(1+logvar-mu.pow(2)-logvar.exp())
        return torch.cat([z, KLD_element], dim=1)


class VehicleEncoder(nn.Module): 
    def __init__(self, input_size, feature_size, hidden_size): 
        super(VehicleEncoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, feature_size),
            nn.Tanh(),
        )

    def forward(self, vehicle, alpha):
        vehicle_pro = vehicle[:, 2:3] # the only property is the orientation. we dont consider the lane direction for now
        feature = torch.cat([vehicle_pro, alpha], dim=1)
        feature = self.encoder(feature)
        return feature


class EmptyEncoder(nn.Module): 
    def __init__(self, feature_size): 
        super(EmptyEncoder, self).__init__()
        self.feature_size = feature_size

    def forward(self):
        return CUDA(torch.zeros(1, self.feature_size))


class QuadEncoder(nn.Module):
    def __init__(self, input_size, feature_size, hidden_size):
        super(QuadEncoder, self).__init__()
        self.common = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
        )

        self.children_list = nn.ModuleList()
        for c_i in range(4):
            one_child = nn.Sequential(
                nn.Linear(feature_size, hidden_size),
                nn.Tanh(),
                nn.Linear(hidden_size, hidden_size),
            )
            self.children_list.append(one_child)

        self.tanh = nn.Tanh()
        self.output = nn.Sequential(
            #nn.Linear(hidden_size, hidden_size),
            #nn.Tanh(),
            nn.Linear(hidden_size, feature_size),
            nn.Tanh()
        )

    def forward(self, child_0, child_1, child_2, child_3, vehicle, alpha):
        vehicle_pro = vehicle[:, 2:6]
        feature = torch.cat([vehicle_pro, alpha], dim=1)
        x = self.common(feature)
        x += self.children_list[0](child_0) 
        x += self.children_list[1](child_1) 
        x += self.children_list[2](child_2) 
        x += self.children_list[3](child_3) 
        x = self.tanh(x)
        x = self.output(x)
        return x


class LaneEncoder(nn.Module): 
    def __init__(self, input_size, feature_size, hidden_size): 
        super(LaneEncoder, self).__init__()
        self.alpha_net = nn.Sequential(
            nn.Linear(input_size, hidden_size)
        )

        self.children_list = nn.ModuleList()
        for c_i in range(2):
            one_child = nn.Sequential(
                nn.Linear(feature_size, hidden_size),
                nn.Tanh(),
                nn.Linear(hidden_size, hidden_size),
            )
            self.children_list.append(one_child)

        self.tanh = nn.Tanh()
        self.output = nn.Sequential(
            nn.Linear(hidden_size, feature_size),
            nn.Tanh(),
        )

    def forward(self, child_0, child_1, alpha):
        x = self.alpha_net(alpha)
        x += self.children_list[0](child_0) 
        x += self.children_list[1](child_1) 
        x = self.tanh(x)
        x = self.output(x)
        return x


class RoadEncoder(nn.Module):
    def __init__(self, input_size, feature_size, hidden_size):
        super(RoadEncoder, self).__init__()
        self.road_center = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
        )

        self.child_net = nn.Sequential(
            nn.Linear(feature_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
        )

        self.tanh = nn.Tanh()
        self.output = nn.Sequential(
            nn.Linear(hidden_size, feature_size),
            nn.Tanh(),
        )

    def forward(self, child_0, xywh):
        x = self.road_center(xywh)
        x += self.child_net(child_0) 
        x = self.tanh(x)
        x = self.output(x)
        return x


class ScenarioEncoder(nn.Module):
    def __init__(self, z_size, feature_size, hidden_size):
        super(ScenarioEncoder, self).__init__()
        vehicle_prop_size = 1
        alpha_xy_size = 2
        alpha_x_size = 1
        xywh_size = 4
        self.vehicle_encoder = VehicleEncoder(input_size=vehicle_prop_size+alpha_xy_size, feature_size=feature_size, hidden_size=hidden_size)
        self.quad_encoder = QuadEncoder(input_size=vehicle_prop_size+alpha_xy_size, feature_size=feature_size, hidden_size=hidden_size)
        self.empty_encoder = EmptyEncoder(feature_size=feature_size)

        self.lane_encoder = LaneEncoder(input_size=alpha_x_size, feature_size=feature_size, hidden_size=hidden_size)
        self.road_encoder = RoadEncoder(input_size=xywh_size, feature_size=feature_size, hidden_size=hidden_size)

        self.sampler_encoder = SamplerEncoder(z_size=z_size, feature_size=feature_size, hidden_size=hidden_size)

    def vehicleEncoder(self, vehicle, alpha):
        return self.vehicle_encoder(vehicle, alpha)

    def emptyEncoder(self):
        return self.empty_encoder()

    def quadEncoder(self, child_0, child_1, child_2, child_3, vehicle, alpha):
        return self.quad_encoder(child_0, child_1, child_2, child_3, vehicle, alpha)

    def laneEncoder(self, child_0, child_1, alpha):
        return self.lane_encoder(child_0, child_1, alpha)

    def roadEncoder(self, child_0, xywh):
        return self.road_encoder(child_0, xywh)

    def samplerEncoder(self, child_0):
        return self.sampler_encoder(child_0)


class SamplerDecoder(nn.Module):
    """ Decode a randomly sampled noise into a feature vector 
    """
    def __init__(self, z_size, feature_size, hidden_size):
        super(SamplerDecoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.Linear(z_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, feature_size),
            nn.Tanh(),
        )

    def forward(self, input_feature):
        output = self.decoder(input_feature)
        return output


class VehicleDecoder(nn.Module):
    """ Decode the vehicle pose from feature vector
    """
    def __init__(self, feature_size, input_size, hidden_size):
        super(VehicleDecoder, self).__init__()
        self.vehicle_decoder = nn.Sequential(
            nn.Linear(feature_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, input_size[0]),
        )
        self.alpha_decoder = nn.Sequential(
            nn.Linear(feature_size, input_size[1]),
            nn.Tanh(), # alpha is in [-1, 1]
        )

    def forward(self, parent_feature):
        vehicle_prop = self.vehicle_decoder(parent_feature)
        alpha = self.alpha_decoder(parent_feature)
        return vehicle_prop, alpha


class QuadDecoder(nn.Module):
    """ Decode the feature of child from feature vector
    """
    def __init__(self, feature_size, hidden_size, input_size):
        super(QuadDecoder, self).__init__()
        self.common = nn.Sequential(
            nn.Linear(feature_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
        )
        self.children_list = nn.ModuleList()
        for c_i in range(4):
            one_child = nn.Sequential(
                nn.Linear(hidden_size, feature_size),
                nn.Tanh(),
            )
            self.children_list.append(one_child)
        self.vehicle_decoder = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, input_size[0]),
        )
        self.alpha_decoder = nn.Sequential(
            nn.Linear(hidden_size, input_size[1]),
            nn.Tanh(), # alpha is in [-1, 1]
        )

    def forward(self, parent_feature):
        feature = self.common(parent_feature)
        child_0 = self.children_list[0](feature)
        child_1 = self.children_list[1](feature)
        child_2 = self.children_list[2](feature)
        child_3 = self.children_list[3](feature)
        vehicle_prop = self.vehicle_decoder(feature)
        alpha = self.alpha_decoder(feature)
        return child_0, child_1, child_2, child_3, vehicle_prop, alpha


class LaneDecoder(nn.Module):
    """ Decode the feature of child from feature vector
    """
    def __init__(self, feature_size, hidden_size, input_size):
        super(LaneDecoder, self).__init__()
        self.common = nn.Sequential(
            nn.Linear(feature_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
        )
        self.children_list = nn.ModuleList()
        for c_i in range(2):
            one_child = nn.Sequential(
                nn.Linear(hidden_size, feature_size),
                nn.Tanh(),
            )
            self.children_list.append(one_child)
        self.alpha_net = nn.Sequential(
            nn.Linear(hidden_size, input_size),
            nn.Tanh(),
        )

    def forward(self, parent_feature):
        feature = self.common(parent_feature)
        child_0 = self.children_list[0](feature)
        child_1 = self.children_list[1](feature)
        alpha = self.alpha_net(feature)
        return child_0, child_1, alpha


class RoadDecoder(nn.Module):
    """ Decode the feature of child from feature vector
    """
    def __init__(self, feature_size, hidden_size, input_size):
        super(RoadDecoder, self).__init__()
        self.common = nn.Sequential(
            nn.Linear(feature_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
        )
        self.child_net = nn.Sequential(
            nn.Linear(hidden_size, feature_size),
            nn.Tanh(),
        )
        self.xywh_net = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, input_size),
        )

    def forward(self, parent_feature):
        feature = self.common(parent_feature)
        child_0 = self.child_net(feature)
        xywh = self.xywh_net(feature)
        return child_0, xywh


class NodeClassifier(nn.Module):
    def __init__(self, feature_size, hidden_size, class_size=1):
        super(NodeClassifier, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(feature_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, class_size),
        )

    def forward(self, input_feature):
        x = self.classifier(input_feature)
        return x


class ScenarioDecoder(nn.Module):
    def __init__(self, z_size, feature_size, hidden_size, class_size):
        super(ScenarioDecoder, self).__init__()
        vehicle_prop_size = 1
        alpha_xy_size = 2
        alpha_x_size = 1
        xywh_size = 4

        self.vehicle_decoder = VehicleDecoder(feature_size=feature_size, hidden_size=hidden_size, input_size=[vehicle_prop_size, alpha_xy_size])
        self.quad_decoder = QuadDecoder(feature_size=feature_size, hidden_size=hidden_size, input_size=[vehicle_prop_size, alpha_xy_size])
        self.lane_decoder = LaneDecoder(feature_size=feature_size, hidden_size=hidden_size, input_size=alpha_x_size)
        self.road_decoder = RoadDecoder(feature_size=feature_size, hidden_size=hidden_size, input_size=xywh_size)

        self.sampler_decoder = SamplerDecoder(z_size=z_size, feature_size=feature_size, hidden_size=hidden_size)
        self.node_classifier = NodeClassifier(feature_size=feature_size, hidden_size=hidden_size, class_size=class_size)
        
        self.mse_loss = nn.MSELoss(reduction='none')  
        self.bce_loss = nn.BCEWithLogitsLoss(reduction='none') 
        self.ce_loss = nn.CrossEntropyLoss(reduction='none')
        
    def vehicleDecoder(self, parent):
        return self.vehicle_decoder(parent)

    def quadDecoder(self, parent):
        return self.quad_decoder(parent)

    def laneDecoder(self, parent):
        return self.lane_decoder(parent)

    def roadDecoder(self, parent):
        return self.road_decoder(parent)

    def samplerDecoder(self, parent):
        return self.sampler_decoder(parent)

    def nodeClassifier(self, parent):
        return self.node_classifier(parent)

    # NOTE: even if they all use mse loss, we need to separete them
    # this loss is used to calculate property loss for VEHICLE
    def vehicleLossEstimator(self, predict, gt, normalizer):
        # we dont calculate the loss of x and y
        loss = self.mse_loss(predict, gt[:, 2:3]).mean(dim=1)/normalizer[0]
        return loss

    # this loss is used to calculate 1D alpha for LANE
    def alpha1DLossEstimator(self, predict, gt, normalizer):
        loss = self.mse_loss(predict, gt).mean(dim=1)/normalizer[0]
        return loss
    
    # this loss is used to calculate 2D alpha for VEHICLE and QUAD
    def alpha2DLossEstimator(self, predict, gt, normalizer):
        loss = self.mse_loss(predict, gt).mean(dim=1)/normalizer[0]
        return loss

    # this loss is used to calculate xywh of ROAD node
    def xywhLossEstimator(self, predict, gt, normalizer):
        loss = self.mse_loss(predict, gt).mean(dim=1)/normalizer[0]
        return loss

    # this loss is for classifier node
    def ceLossEstimator(self, predict, gt, normalizer):
        loss = self.ce_loss(predict, gt[:, 0])/normalizer[0]
        return loss

    def vectorAdder(self, v1, v2):
        return v1.add_(v2)

    def vectorDivider(self, nom, den):
        return nom/den


class ConditionalSceneVAE(nn.Module):
    """ Only the Road node contains absolute information, other nodes only have relative information
    """
    def __init__(self, z_dim):
        super(ConditionalSceneVAE, self).__init__()
        self.feature_size = 64
        self.hidden_size = 64
        self.z_dim = z_dim
        self.class_size = 5

        self.encoder = ScenarioEncoder(z_size=self.z_dim, feature_size=self.feature_size, hidden_size=self.hidden_size)
        self.decoder = ScenarioDecoder(z_size=self.z_dim, feature_size=self.feature_size, hidden_size=self.hidden_size, class_size=self.class_size)

        # definition of conditions
        self.lane_tree = {
            1: {'ratio': None},
            2: {'ratio': 0.0, 'left': None, 'right': None},
            3: {'ratio': 1.0/3.0, 'left': {'ratio': 0.0, 'left': None, 'right': None}, 'right': None},
            4: {'ratio': 0.0, 'left': {'ratio': 0.0, 'left': None, 'right': None}, 'right': {'ratio': 0.0, 'left': None, 'right': None}},
            5: {'ratio': 1.0/5.0, 
                    'left': {'ratio': 1.0/3.0, 'left': {'ratio': 0.0, 'left': None, 'right': None}, 'right': None}, 
                    'right': {'ratio': 0.0, 'left': None, 'right': None}}, 
            6: {'ratio': 0.0, 
                    'left': {'ratio': 1.0/3.0, 'left': {'ratio': 0.0, 'left': None, 'right': None}, 'right': None},
                    'right': {'ratio': 1.0/3.0, 'left': {'ratio': 0.0, 'left': None, 'right': None}, 'right': None}}}

    def forward(self, batch):
        # Initialize encode and decode parsers
        enc_parser = TreeParser(type='encode')
        dec_parser = TreeParser(type='decode')

        enc_parser_nodes = [] # list of fold nodes for encoding
        dec_parser_nodes = [] # list of fold nodes for decoding
        kld_parser_nodes = []
        z_list = []

        # Collect computation nodes recursively from encoding process
        for data in batch:
            enc_parser_nodes.append(enc_parser.parse(data))

        # apply the tree on the encoder model
        enc_parser_nodes = enc_parser.apply(self.encoder, [enc_parser_nodes])

        # split into a list per data
        enc_parser_nodes = torch.split(enc_parser_nodes[0], 1, 0)

        # collect computation nodes recursively from decoding process
        for data, fnode in zip(batch, enc_parser_nodes):
            z, kl_div = torch.chunk(fnode, 2, 1)
            dec_parser_nodes.append(dec_parser.parse(data, z))
            z_list.append(z)
            kld_parser_nodes.append(kl_div)

        # apply the tree on the decoder model
        recon_loss = dec_parser.apply(self.decoder, [dec_parser_nodes])
        recon_loss = recon_loss[0].mean()    # avg. reconstruction loss per example

        kldiv_loss = torch.cat(kld_parser_nodes, dim=0).mean() # [B, Z] -> [1]
        z_list = torch.stack(z_list)
        
        return recon_loss, kldiv_loss, z_list

    # equation to calculate the left center and right part center given (alpha, x, w)
    @staticmethod
    def left_or_bottom(alpha, x_or_y, w_or_h):
        new_x_or_y = x_or_y - w_or_h*(1-alpha)/4
        new_w_or_h = w_or_h*(1+alpha)/2
        return new_x_or_y, new_w_or_h
    
    @staticmethod
    def right_or_top(alpha, x_or_y, w_or_h):
        new_x_or_y = x_or_y + w_or_h*(1+alpha)/4
        new_w_or_h = w_or_h*(1-alpha)/2
        return new_x_or_y, new_w_or_h 

    # divide function for x (1D), for y we just copy the parent's value
    def divide_1D(self, road_x, road_y, road_w, road_h, alpha_x=None):
        # x and w
        left_x, left_w = self.left_or_bottom(alpha_x, road_x, road_w)
        right_x, right_w = self.right_or_top(alpha_x, road_x, road_w)
        # y and h
        left_y = copy.deepcopy(road_y)
        left_h = copy.deepcopy(road_h)
        right_y = copy.deepcopy(road_y)
        right_h = copy.deepcopy(road_h)
        # collect the new coordinate
        left = np.concatenate([left_x, left_y, left_w, left_h], axis=0)       # [4,]
        right = np.concatenate([right_x, right_y, right_w, right_h], axis=0)   # [4,]
        return left, right

    # divide function for x and y, if the ratio is None, just copy the parent node
    def divide_2D(self, road_x, road_y, road_w, road_h, alpha_x=None, alpha_y=None):
        # x and w
        left_x, left_w = self.left_or_bottom(alpha_x, road_x, road_w)
        right_x, right_w = self.right_or_top(alpha_x, road_x, road_w)

        # y and h
        bottom_y, bottom_h = self.left_or_bottom(alpha_y, road_y, road_h)
        top_y, top_h = self.right_or_top(alpha_y, road_y, road_h)

        # collect the new coordinate
        left_top = np.concatenate([left_x, top_y, left_w, top_h], axis=0)             # [4,]
        right_top = np.concatenate([right_x, top_y, right_w, top_h], axis=0)          # [4,]
        left_bottom = np.concatenate([left_x, bottom_y, left_w, bottom_h], axis=0)    # [4,]
        right_bottom = np.concatenate([right_x, bottom_y, right_w, bottom_h], axis=0) # [4,]
        
        return left_top, right_top, left_bottom, right_bottom

    def add_node_conditional(self, feature, xywh, conditional_tree, direction):
        # LANE node is defined by a tree
        if conditional_tree is not None:
            alpha = conditional_tree['ratio']
            # if we only have one lane
            if alpha is None:
                self.add_node_conditional(feature, xywh, None, direction)
            else:
                child_0_f, child_1_f, _ = self.decoder.lane_decoder(feature)
                separate_line_x = xywh[0:1] + alpha*xywh[2:3]/2
                separate_line_y_1 = xywh[1:2]-xywh[3:4]/2
                separate_line_y_2 = xywh[1:2]+xywh[3:4]/2
                separate_line = np.array([
                    [separate_line_x, separate_line_y_1],
                    [separate_line_x, separate_line_y_2]
                ])
                plt.plot(separate_line[:, 0], separate_line[:, 1], 'k-', linewidth=1)
                left, right = self.divide_1D(xywh[0:1], xywh[1:2], xywh[2:3], xywh[3:4], alpha)
                self.add_node_conditional(child_0_f, left, conditional_tree['left'], direction)
                self.add_node_conditional(child_1_f, right, conditional_tree['right'], direction)
        else:
            # only choose from QUAD, VEHICLE, EMPTY
            label_prob = self.decoder.nodeClassifier(feature) 
            label_prob[:, LANE] = -9999.9
            label_prob[:, ROAD] = -9999.9
            label_prob = torch.softmax(label_prob, dim=-1)
            label = torch.argmax(label_prob)

            if label == VEHICLE:
                vehicle_prop, alpha = self.decoder.vehicle_decoder(feature)
                vehicle_prop = CPU(vehicle_prop)[0]
                alpha = CPU(alpha)[0]
                theta = vehicle_prop[0:1]*np.pi
                x = xywh[0:1] + alpha[0]*xywh[2:3]/2
                y = xywh[1:2] + alpha[1]*xywh[3:4]/2
                pose = np.concatenate([x, y, theta], axis=0)[None]
                color = 'g' if direction else 'r'
                self.plot_box(self.get_box(pose[0], 2, 4), color)
                self.plot_arrow(pose[0], color, direction)
                #pose_collector.append(pose)
            elif label == QUAD:
                child_0_f, child_1_f, child_2_f, child_3_f, vehicle_prop, alpha = self.decoder.quad_decoder(feature)
                vehicle_prop = CPU(vehicle_prop)[0]
                alpha = CPU(alpha)[0]
                theta = vehicle_prop[0:1]*np.pi
                x = xywh[0:1] + alpha[0]*xywh[2:3]/2
                y = xywh[1:2] + alpha[1]*xywh[3:4]/2
                pose = np.concatenate([x, y, theta], axis=0)[None]
                #pose_collector.append(pose)
                color = 'g' if direction else 'r'
                self.plot_box(self.get_box(pose[0], 2, 4), color)
                self.plot_arrow(pose[0], color, direction)
                left_top, right_top, left_bottom, right_bottom = self.divide_2D(xywh[0:1], xywh[1:2], xywh[2:3], xywh[3:4], alpha[0], alpha[1])
                # plot horizental line
                separate_line_x = xywh[0:1] + alpha[0]*xywh[2:3]/2
                separate_line_y_1 = xywh[1:2] - xywh[3:4]/2
                separate_line_y_2 = xywh[1:2] + xywh[3:4]/2
                separate_line_hor = np.array([[separate_line_x, separate_line_y_1], [separate_line_x, separate_line_y_2]])
                plt.plot(separate_line_hor[:, 0], separate_line_hor[:, 1], c='silver', zorder=0)
                # plot verticle line
                separate_line_y = xywh[1:2] + alpha[1]*xywh[3:4]/2
                separate_line_x_1 = xywh[0:1] - xywh[2:3]/2
                separate_line_x_2 = xywh[0:1] + xywh[2:3]/2
                separate_line_ver = np.array([[separate_line_x_1, separate_line_y], [separate_line_x_2, separate_line_y]])
                plt.plot(separate_line_ver[:, 0], separate_line_ver[:, 1], c='silver', zorder=0)
                self.add_node_conditional(child_0_f, left_top, None, direction)
                self.add_node_conditional(child_1_f, right_top, None, direction)
                self.add_node_conditional(child_2_f, left_bottom, None, direction)
                self.add_node_conditional(child_3_f, right_bottom, None, direction)  
            elif label == EMPTY:
                return 'E'

    def decode(self, idx, feature, position_scale, condition=None):
        # the first node will always be a ROAD, then we apply the scale to get the real position
        #pose_collector = []
        feature = self.decoder.samplerDecoder(feature)
        road_feature, xywh = self.decoder.road_decoder(feature)

        # use conditional road segment information or sampled segments
        if condition is not None:
            num_lane = condition['num_lane']
            road_xywh = condition['xywh']
            direction = condition['direction']
            conditional_tree = self.lane_tree[num_lane]
            self.add_node_conditional(road_feature, road_xywh, conditional_tree, direction)
        else:
            road_xywh = CPU(xywh)[0]
            road_xywh[0:2] = road_xywh[0:2] * position_scale
            road_xywh[2:4] = road_xywh[2:4] * position_scale
            self.add_node(road_feature, road_xywh)

        self.plot_xywh(road_xywh)
        #pose_collector = np.concatenate(pose_collector, axis=0)
        #pose_collector = CUDA(torch.from_numpy(pose_collector))
        #return pose_collector

    def save_model(self, filename):
        state = {'encoder': self.encoder.state_dict(), 'decoder': self.decoder.state_dict()}
        torch.save(state, filename)

    def load_model(self, filename):
        if os.path.isfile(filename):
            checkpoint = torch.load(filename)
            self.encoder.load_state_dict(checkpoint['encoder'])
            self.decoder.load_state_dict(checkpoint['decoder'])
        else:
            raise FileNotFoundError

    @staticmethod
    def get_box(pose, width, height):
        xy = pose[0:2]
        o = pose[2]
        position = np.array([[-width/2, -height/2], [-width/2, height/2], [width/2, height/2], [width/2, -height/2], [-width/2, -height/2]])
        rotation_matrix = np.array([[np.cos(o), -np.sin(o)], [np.sin(o), np.cos(o)]])
        position = rotation_matrix.dot(position.T).T + xy
        return position

    @staticmethod
    def plot_box(points, color):
        plt.plot(points[:, 0], points[:, 1], color+'-')
    
    @staticmethod
    def plot_arrow(pose, color, direction):
        scale = 0.5
        xy = pose[0:2]
        o = pose[2]
        if direction:
            position = scale*np.array([[0.0, 1.0]]).T
        else:
            position = scale*np.array([[0.0, -1.0]]).T
        rotation_matrix = np.array([[np.cos(o), -np.sin(o)], [np.sin(o), np.cos(o)]])
        position = rotation_matrix.dot(position)[:, 0] 
        plt.arrow(xy[0], xy[1], position[0], position[1], color=color, width=0.3)

    @staticmethod
    def plot_xywh(xywh):
        # plot road
        road_x, road_y, road_w, road_h = xywh[0:1], xywh[1:2], xywh[2:3], xywh[3:4]
        road_show = np.array([
            [road_x-road_w/2, road_y-road_h/2],
            [road_x-road_w/2, road_y+road_h/2],
            [road_x+road_w/2, road_y+road_h/2],
            [road_x+road_w/2, road_y-road_h/2],
            [road_x-road_w/2, road_y-road_h/2]
        ])
        # table and line should also be rotated
        plt.plot(road_show[:, 0], road_show[:, 1], 'k-', linewidth=2)
