'''
Author: 
Email: 
Date: 2020-10-12 22:14:36
LastEditTime: 2021-05-30 19:46:14
Description: 
    The file implement the decoder part of T-VAE since we only need to decode the latent code.
'''

import os
import numpy as np
import copy

import torch
from torch import nn

from matplotlib import pyplot as plt
from utils import CPU


VEHICLE = 0
QUAD = 1
EMPTY = 2
LANE = 3
ROAD = 4


class SamplerDecoder(nn.Module):
    """ Decode a randomly sampled noise into a feature vector 
    """
    def __init__(self, z_size, feature_size, hidden_size):
        super(SamplerDecoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.Linear(z_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, feature_size),
            nn.Tanh(),
        )

    def forward(self, input_feature):
        output = self.decoder(input_feature)
        return output


class VehicleDecoder(nn.Module):
    """ Decode the vehicle pose from feature vector
    """
    def __init__(self, feature_size, input_size, hidden_size):
        super(VehicleDecoder, self).__init__()
        self.vehicle_decoder = nn.Sequential(
            nn.Linear(feature_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, input_size[0]),
        )
        self.alpha_decoder = nn.Sequential(
            nn.Linear(feature_size, input_size[1]),
            nn.Tanh(), # alpha is in [-1, 1]
        )

    def forward(self, parent_feature):
        vehicle_prop = self.vehicle_decoder(parent_feature)
        alpha = self.alpha_decoder(parent_feature)
        return vehicle_prop, alpha


class QuadDecoder(nn.Module):
    """ Decode the feature of child from feature vector
    """
    def __init__(self, feature_size, hidden_size, input_size):
        super(QuadDecoder, self).__init__()
        self.common = nn.Sequential(
            nn.Linear(feature_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
        )
        self.children_list = nn.ModuleList()
        for c_i in range(4):
            one_child = nn.Sequential(
                nn.Linear(hidden_size, feature_size),
                nn.Tanh(),
            )
            self.children_list.append(one_child)
        self.vehicle_decoder = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, input_size[0]),
        )
        self.alpha_decoder = nn.Sequential(
            nn.Linear(hidden_size, input_size[1]),
            nn.Tanh(), # alpha is in [-1, 1]
        )

    def forward(self, parent_feature):
        feature = self.common(parent_feature)
        child_0 = self.children_list[0](feature)
        child_1 = self.children_list[1](feature)
        child_2 = self.children_list[2](feature)
        child_3 = self.children_list[3](feature)
        vehicle_prop = self.vehicle_decoder(feature)
        alpha = self.alpha_decoder(feature)
        return child_0, child_1, child_2, child_3, vehicle_prop, alpha


class LaneDecoder(nn.Module):
    """ Decode the feature of child from feature vector
    """
    def __init__(self, feature_size, hidden_size, input_size):
        super(LaneDecoder, self).__init__()
        self.common = nn.Sequential(
            nn.Linear(feature_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
        )
        self.children_list = nn.ModuleList()
        for c_i in range(2):
            one_child = nn.Sequential(
                nn.Linear(hidden_size, feature_size),
                nn.Tanh(),
            )
            self.children_list.append(one_child)
        self.alpha_net = nn.Sequential(
            nn.Linear(hidden_size, input_size),
            nn.Tanh(),
        )

    def forward(self, parent_feature):
        feature = self.common(parent_feature)
        child_0 = self.children_list[0](feature)
        child_1 = self.children_list[1](feature)
        alpha = self.alpha_net(feature)
        return child_0, child_1, alpha


class RoadDecoder(nn.Module):
    """ Decode the feature of child from feature vector
    """
    def __init__(self, feature_size, hidden_size, input_size):
        super(RoadDecoder, self).__init__()
        self.common = nn.Sequential(
            nn.Linear(feature_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
        )
        self.child_net = nn.Sequential(
            nn.Linear(hidden_size, feature_size),
            nn.Tanh(),
        )
        self.xywh_net = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, input_size),
        )

    def forward(self, parent_feature):
        feature = self.common(parent_feature)
        child_0 = self.child_net(feature)
        xywh = self.xywh_net(feature)
        return child_0, xywh


class NodeClassifier(nn.Module):
    def __init__(self, feature_size, hidden_size, class_size=1):
        super(NodeClassifier, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(feature_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, class_size),
        )

    def forward(self, input_feature):
        x = self.classifier(input_feature)
        return x


class ScenarioDecoder(nn.Module):
    def __init__(self, z_size, feature_size, hidden_size, class_size):
        super(ScenarioDecoder, self).__init__()
        vehicle_prop_size = 1
        alpha_xy_size = 2
        alpha_x_size = 1
        xywh_size = 4

        self.vehicle_decoder = VehicleDecoder(feature_size=feature_size, hidden_size=hidden_size, input_size=[vehicle_prop_size, alpha_xy_size])
        self.quad_decoder = QuadDecoder(feature_size=feature_size, hidden_size=hidden_size, input_size=[vehicle_prop_size, alpha_xy_size])
        self.lane_decoder = LaneDecoder(feature_size=feature_size, hidden_size=hidden_size, input_size=alpha_x_size)
        self.road_decoder = RoadDecoder(feature_size=feature_size, hidden_size=hidden_size, input_size=xywh_size)

        self.sampler_decoder = SamplerDecoder(z_size=z_size, feature_size=feature_size, hidden_size=hidden_size)
        self.node_classifier = NodeClassifier(feature_size=feature_size, hidden_size=hidden_size, class_size=class_size)
        
        self.mse_loss = nn.MSELoss(reduction='none')  
        self.bce_loss = nn.BCEWithLogitsLoss(reduction='none') 
        self.ce_loss = nn.CrossEntropyLoss(reduction='none')
        
    def vehicleDecoder(self, parent):
        return self.vehicle_decoder(parent)

    def quadDecoder(self, parent):
        return self.quad_decoder(parent)

    def laneDecoder(self, parent):
        return self.lane_decoder(parent)

    def roadDecoder(self, parent):
        return self.road_decoder(parent)

    def samplerDecoder(self, parent):
        return self.sampler_decoder(parent)

    def nodeClassifier(self, parent):
        return self.node_classifier(parent)

    # NOTE: even if they all use mse loss, we need to separete them
    # this loss is used to calculate property loss for VEHICLE
    def vehicleLossEstimator(self, predict, gt, normalizer):
        # we dont calculate the loss of x and y
        loss = self.mse_loss(predict, gt[:, 2:3]).mean(dim=1)/normalizer[0]
        return loss

    # this loss is used to calculate 1D alpha for LANE
    def alpha1DLossEstimator(self, predict, gt, normalizer):
        loss = self.mse_loss(predict, gt).mean(dim=1)/normalizer[0]
        return loss
    
    # this loss is used to calculate 2D alpha for VEHICLE and QUAD
    def alpha2DLossEstimator(self, predict, gt, normalizer):
        loss = self.mse_loss(predict, gt).mean(dim=1)/normalizer[0]
        return loss

    # this loss is used to calculate xywh ROAD
    def xywhLossEstimator(self, predict, gt, normalizer):
        loss = self.mse_loss(predict, gt).mean(dim=1)/normalizer[0]
        return loss

    # this loss is for classifier node
    def ceLossEstimator(self, predict, gt, normalizer):
        loss = self.ce_loss(predict, gt[:, 0])/normalizer[0]
        return loss

    def vectorAdder(self, v1, v2):
        return v1.add_(v2)

    def vectorDivider(self, nom, den):
        return nom/den


class TreeVAE(nn.Module):
    """ Only the Road node contains absolute information, other nodes only have relative information
    """
    def __init__(self, z_dim):
        super(TreeVAE, self).__init__()
        self.feature_size = 64
        self.hidden_size = 64
        self.z_dim = z_dim
        self.class_size = 5

        self.decoder = ScenarioDecoder(z_size=self.z_dim, feature_size=self.feature_size, hidden_size=self.hidden_size, class_size=self.class_size)

        # definition of conditions
        self.lane_tree = {
            1: {'ratio': None},
            2: {'ratio': 0.0, 'left': None, 'right': None},
            3: {'ratio': 1.0/3.0, 'left': {'ratio': 0.0, 'left': None, 'right': None}, 'right': None},
            4: {'ratio': 0.0, 'left': {'ratio': 0.0, 'left': None, 'right': None}, 'right': {'ratio': 0.0, 'left': None, 'right': None}},
            5: {'ratio': 1.0/5.0, 
                    'left': {'ratio': 1.0/3.0, 'left': {'ratio': 0.0, 'left': None, 'right': None}, 'right': None}, 
                    'right': {'ratio': 0.0, 'left': None, 'right': None}}, 
            6: {'ratio': 0.0, 
                    'left': {'ratio': 1.0/3.0, 'left': {'ratio': 0.0, 'left': None, 'right': None}, 'right': None},
                    'right': {'ratio': 1.0/3.0, 'left': {'ratio': 0.0, 'left': None, 'right': None}, 'right': None}}}

    # equation to calculate the left center and right part center given (alpha, x, w)
    @staticmethod
    def left_or_bottom(alpha, x_or_y, w_or_h):
        new_x_or_y = x_or_y - w_or_h*(1-alpha)/4
        new_w_or_h = w_or_h*(1+alpha)/2
        return new_x_or_y, new_w_or_h
    
    @staticmethod
    def right_or_top(alpha, x_or_y, w_or_h):
        new_x_or_y = x_or_y + w_or_h*(1+alpha)/4
        new_w_or_h = w_or_h*(1-alpha)/2
        return new_x_or_y, new_w_or_h 

    # divide function for x (1D), for y we just copy the parent's value
    def divide_1D(self, road_x, road_y, road_w, road_h, alpha_x=None):
        # x and w
        left_x, left_w = self.left_or_bottom(alpha_x, road_x, road_w)
        right_x, right_w = self.right_or_top(alpha_x, road_x, road_w)
        # y and h
        left_y = copy.deepcopy(road_y)
        left_h = copy.deepcopy(road_h)
        right_y = copy.deepcopy(road_y)
        right_h = copy.deepcopy(road_h)
        # collect the new coordinate
        left = np.concatenate([left_x, left_y, left_w, left_h], axis=0)       # [4,]
        right = np.concatenate([right_x, right_y, right_w, right_h], axis=0)   # [4,]
        return left, right

    # divide function for x and y, if the ratio is None, just copy the parent node
    def divide_2D(self, road_x, road_y, road_w, road_h, alpha_x=None, alpha_y=None):
        # x and w
        left_x, left_w = self.left_or_bottom(alpha_x, road_x, road_w)
        right_x, right_w = self.right_or_top(alpha_x, road_x, road_w)

        # y and h
        bottom_y, bottom_h = self.left_or_bottom(alpha_y, road_y, road_h)
        top_y, top_h = self.right_or_top(alpha_y, road_y, road_h)

        # collect the new coordinate
        left_top = np.concatenate([left_x, top_y, left_w, top_h], axis=0)             # [4,]
        right_top = np.concatenate([right_x, top_y, right_w, top_h], axis=0)          # [4,]
        left_bottom = np.concatenate([left_x, bottom_y, left_w, bottom_h], axis=0)    # [4,]
        right_bottom = np.concatenate([right_x, bottom_y, right_w, bottom_h], axis=0) # [4,]
        
        return left_top, right_top, left_bottom, right_bottom

    def add_node_conditional(self, feature, xywh, conditional_tree, direction):
        # LANE node is defined by a tree
        if conditional_tree is not None:
            alpha = conditional_tree['ratio']
            # if we only have one lane
            if alpha is None:
                self.add_node_conditional(feature, xywh, None, direction)
            else:
                child_0_f, child_1_f, _ = self.decoder.lane_decoder(feature)
                if self.plot:
                    separate_line_x = xywh[0:1] + alpha*xywh[2:3]/2
                    separate_line_y_1 = xywh[1:2]-xywh[3:4]/2
                    separate_line_y_2 = xywh[1:2]+xywh[3:4]/2
                    separate_line = np.array([
                        [separate_line_x, separate_line_y_1],
                        [separate_line_x, separate_line_y_2]
                    ])
                    plt.plot(separate_line[:, 0], separate_line[:, 1], 'k-', linewidth=1)
                left, right = self.divide_1D(xywh[0:1], xywh[1:2], xywh[2:3], xywh[3:4], alpha)
                self.add_node_conditional(child_0_f, left, conditional_tree['left'], direction)
                self.add_node_conditional(child_1_f, right, conditional_tree['right'], direction)
        else:
            # only choose from QUAD, VEHICLE, EMPTY
            label_prob = self.decoder.nodeClassifier(feature) 
            label_prob[:, LANE] = -9999.9
            label_prob[:, ROAD] = -9999.9
            label_prob = torch.softmax(label_prob, dim=-1)
            label = torch.argmax(label_prob)

            if label == VEHICLE:
                vehicle_prop, alpha = self.decoder.vehicle_decoder(feature)
                vehicle_prop = CPU(vehicle_prop)[0]
                alpha = CPU(alpha)[0]
                theta = vehicle_prop[0:1]*np.pi
                x = xywh[0:1] + alpha[0]*xywh[2:3]/2
                y = xywh[1:2] + alpha[1]*xywh[3:4]/2
                pose = np.concatenate([x, y, theta], axis=0)[None]
                if self.plot:
                    color = 'g' if direction else 'r'
                    self.plot_box(self.get_box(pose[0], 2, 4), color)
                    self.plot_arrow(pose[0], color, direction)
                self.pose_collector.append(pose)
            elif label == QUAD:
                child_0_f, child_1_f, child_2_f, child_3_f, vehicle_prop, alpha = self.decoder.quad_decoder(feature)
                vehicle_prop = CPU(vehicle_prop)[0]
                alpha = CPU(alpha)[0]
                theta = vehicle_prop[0:1]*np.pi
                x = xywh[0:1] + alpha[0]*xywh[2:3]/2
                y = xywh[1:2] + alpha[1]*xywh[3:4]/2
                pose = np.concatenate([x, y, theta], axis=0)[None]
                self.pose_collector.append(pose)
                if self.plot:
                    color = 'g' if direction else 'r'
                    self.plot_box(self.get_box(pose[0], 2, 4), color)
                    self.plot_arrow(pose[0], color, direction)
                left_top, right_top, left_bottom, right_bottom = self.divide_2D(xywh[0:1], xywh[1:2], xywh[2:3], xywh[3:4], alpha[0], alpha[1])
                if self.plot:
                    # plot horizental line
                    separate_line_x = xywh[0:1] + alpha[0]*xywh[2:3]/2
                    separate_line_y_1 = xywh[1:2] - xywh[3:4]/2
                    separate_line_y_2 = xywh[1:2] + xywh[3:4]/2
                    separate_line_hor = np.array([[separate_line_x, separate_line_y_1], [separate_line_x, separate_line_y_2]])
                    plt.plot(separate_line_hor[:, 0], separate_line_hor[:, 1], c='silver', zorder=0)
                    # plot verticle line
                    separate_line_y = xywh[1:2] + alpha[1]*xywh[3:4]/2
                    separate_line_x_1 = xywh[0:1] - xywh[2:3]/2
                    separate_line_x_2 = xywh[0:1] + xywh[2:3]/2
                    separate_line_ver = np.array([[separate_line_x_1, separate_line_y], [separate_line_x_2, separate_line_y]])
                    plt.plot(separate_line_ver[:, 0], separate_line_ver[:, 1], c='silver', zorder=0)
                self.add_node_conditional(child_0_f, left_top, None, direction)
                self.add_node_conditional(child_1_f, right_top, None, direction)
                self.add_node_conditional(child_2_f, left_bottom, None, direction)
                self.add_node_conditional(child_3_f, right_bottom, None, direction)  
            elif label == EMPTY:
                return 'E'

    def decode(self, feature, position_scale, condition=None, plot=False):
        self.plot = plot
        # the first node will always be a ROAD, then we apply the scale to get the real position
        self.pose_collector = []
        feature = self.decoder.samplerDecoder(feature)
        road_feature, xywh = self.decoder.road_decoder(feature)

        # use conditional road segment information or sampled segments
        if condition is not None:
            num_lane = condition['num_lane']
            road_xywh = condition['xywh']
            direction = condition['direction']
            conditional_tree = self.lane_tree[num_lane]
            self.add_node_conditional(road_feature, road_xywh, conditional_tree, direction)
        else:
            road_xywh = CPU(xywh)[0]
            road_xywh[0:2] = road_xywh[0:2] * position_scale
            road_xywh[2:4] = road_xywh[2:4] * position_scale
            self.add_node(road_feature, road_xywh)
        if self.plot:
            self.plot_xywh(road_xywh)
        return self.pose_collector

    def load_model(self, filename):
        if os.path.isfile(filename):
            checkpoint = torch.load(filename)
            self.decoder.load_state_dict(checkpoint['decoder'])
        else:
            raise FileNotFoundError

    @staticmethod
    def get_box(pose, width, height):
        xy = pose[0:2]
        o = pose[2]
        position = np.array([[-width/2, -height/2], [-width/2, height/2], [width/2, height/2], [width/2, -height/2], [-width/2, -height/2]])
        rotation_matrix = np.array([[np.cos(o), -np.sin(o)], [np.sin(o), np.cos(o)]])
        position = rotation_matrix.dot(position.T).T + xy
        return position

    @staticmethod
    def plot_box(points, color):
        plt.plot(points[:, 0], points[:, 1], color+'-')
    
    @staticmethod
    def plot_arrow(pose, color, direction):
        scale = 0.5
        xy = pose[0:2]
        o = pose[2]
        if direction:
            position = scale*np.array([[0.0, 1.0]]).T
        else:
            position = scale*np.array([[0.0, -1.0]]).T
        rotation_matrix = np.array([[np.cos(o), -np.sin(o)], [np.sin(o), np.cos(o)]])
        position = rotation_matrix.dot(position)[:, 0] 
        plt.arrow(xy[0], xy[1], position[0], position[1], color=color, width=0.3)

    @staticmethod
    def plot_xywh(xywh):
        # plot road
        road_x, road_y, road_w, road_h = xywh[0:1], xywh[1:2], xywh[2:3], xywh[3:4]
        road_show = np.array([
            [road_x-road_w/2, road_y-road_h/2],
            [road_x-road_w/2, road_y+road_h/2],
            [road_x+road_w/2, road_y+road_h/2],
            [road_x+road_w/2, road_y-road_h/2],
            [road_x-road_w/2, road_y-road_h/2]
        ])
        # table and line should also be rotated
        plt.plot(road_show[:, 0], road_show[:, 1], 'k-', linewidth=2)
