'''
Author: Wenhao Ding
Email: wenhaod@andrew.cmu.edu
Date: 2020-10-12 22:14:36
LastEditTime: 2021-05-30 16:34:06
Description: 
    This file contains the code that describe a scenario tree.
    Modified from GRAINS: https://github.com/ManyiLi12345/GRAINS
'''

import numpy as np
from enum import Enum

import torch
from torch.autograd import Variable

from utils import CUDA
from torch_fold import Fold


class NodeType(Enum):
    BOX = 0
    QUAD = 1
    EMPTY = 2
    PLANE = 3
    WORLD = 4


class Node(object):
    """ Univesial Node. There are only three types of node.
        NOTE: this class will be called inside of dataloader, therefore we cannot convert data to torch.cuda.Variable.
        NOTE: however, we must wrap the tensor into a variable, otherwise the gradient will be broke down in the parser.
    """
    def __init__(self, child_0=None, child_1=None, child_2=None, child_3=None, prop=None, xywh=None, alpha=None, node_type=None):
        # three kinds of properties
        # we dont have sanity check to make sure the node has correct properties
        if prop is not None:
            self.prop = Variable(torch.FloatTensor(prop[None])) # [x, y, theta, r, g, b]
        if xywh is not None:
            self.xywh = Variable(torch.FloatTensor(xywh[None])) # [x, y, w, h]
        if alpha is not None:
            self.alpha = Variable(torch.FloatTensor(alpha[None])) # [alpha_x, alpha_y] or [alpha_x] 

        # children nodes
        self.child_0 = child_0
        self.child_1 = child_1
        self.child_2 = child_2
        self.child_3 = child_3

        # store the node type
        self.node_type = node_type
        # CELoss expects an index target
        self.label = torch.LongTensor([[self.node_type.value]])

    def is_world(self):
        return self.node_type == NodeType.WORLD

    def is_plane(self):
        return self.node_type == NodeType.PLANE

    def is_quad(self):
        return self.node_type == NodeType.QUAD

    def is_box(self):
        return self.node_type == NodeType.BOX

    def is_empty(self):
        return self.node_type == NodeType.EMPTY


class Tree(object):
    def __init__(self, data, position_scale): 
        """ Input data should be a nesting dict type. Then we will convert the dict to a tree of Node.
            The structure of the tree has been defined in the data, this class is just to wrap the data with Node.
            Note that, for the quadtree, we already store the relative position
        Input:
            data: the nesting dict data
            position_scale: a scale used for pose (box) and center (quad)
        """

        # scale tensor
        self.pose_scale = np.array([position_scale, position_scale, 2*np.pi])

        def add_node(data):
            type = data['type']
            if type == NodeType.BOX.value:
                prop = data['prop']
                alpha = data['alpha']
                node_count[0] += 1
                prop[2] = prop[2]/self.pose_scale[2] # normalize rotation
                return Node(prop=prop[2:], alpha=alpha, node_type=NodeType.BOX)                        
            elif type == NodeType.QUAD.value:
                child_0 = add_node(data[0])
                child_1 = add_node(data[1])
                child_2 = add_node(data[2])
                child_3 = add_node(data[3])
                prop = data['prop']
                alpha = data['alpha']
                node_count[1] += 1
                prop[2] = prop[2]/self.pose_scale[2] # normalize rotation
                return Node(child_0=child_0, child_1=child_1, child_2=child_2, child_3=child_3, prop=prop[2:], alpha=alpha, node_type=NodeType.QUAD)
            elif type == NodeType.EMPTY.value:
                node_count[2] += 1
                return Node(node_type=NodeType.EMPTY)
            elif type == NodeType.PLANE.value:
                child_0 = add_node(data[0])
                child_1 = add_node(data[1])
                child_2 = add_node(data[2])
                child_3 = add_node(data[3])
                prop = data['prop']
                alpha = data['alpha']
                node_count[3] += 1
                return Node(child_0=child_0, child_1=child_1, child_2=child_2, child_3=child_3, prop=prop[2:], alpha=alpha, node_type=NodeType.PLANE)
            elif type == NodeType.WORLD.value:
                node_count[4] += 1
                child_0 = add_node(data[0])
                child_1 = add_node(data[1])
                child_2 = add_node(data[2])
                child_3 = add_node(data[3])
                xywh = data['xywh']
                xywh[0:2] = xywh[0:2]/self.pose_scale[0:2]
                xywh[2:4] = xywh[2:4]/self.pose_scale[0:2]
                return Node(child_0=child_0, child_1=child_1, child_2=child_2, child_3=child_3, xywh=xywh, node_type=NodeType.WORLD)
            else:
                raise ValueError('Wrong node type.')

        node_count = [0, 0, 0, 0, 0]
        self.root = add_node(data)
        node_num = np.sum(node_count)
        self.node_count = torch.FloatTensor(np.array(node_count))
        self.node_num = torch.FloatTensor([node_num])


class TreeParser(object):
    """ Parse the scenario tree. The parser can only be encode parser or decode parser.
        This parser is different for different tree structure
    """
    def __init__(self, type):
        self.type = type
        self.fold = Fold()

    def apply(self, nn, nodes):
        """ Apply current parsed tree to a given neural module with dynamic batching.
        """
        return self.fold.apply(nn, nodes)

    def parse(self, tree, feature=None):
        """ parser the tree structure
        """
        if self.type == 'encode':
            return self.encode_structure(tree)
        elif self.type == 'decode':
            return self.decode_structure(tree, feature)
        else:
            raise ValueError('Wrong parser type.')

    def encode_structure(self, tree):
        """ This function encode the data (tree) to a parsed structure for pytorch in a recursive manner
            Input:
                tree: a tree structure contains the data of one sample
            Return: 
                encoder_ouput: the nodes that has been added to Fold
        """
        def encode_node(node):       
            if node.is_box():
                return self.fold.add('boxEncoder', CUDA(node.prop), CUDA(node.alpha))
            elif node.is_quad():
                child_0 = encode_node(node.child_0)
                child_1 = encode_node(node.child_1)
                child_2 = encode_node(node.child_2)
                child_3 = encode_node(node.child_3)
                return self.fold.add('quadEncoder', child_0, child_1, child_2, child_3, CUDA(node.prop), CUDA(node.alpha))
            elif node.is_empty():
                # we must have an empty encoder to return a tensor
                # this encoder will always return a zero tensor
                return self.fold.add('emptyEncoder')
            elif node.is_plane():
                child_0 = encode_node(node.child_0)
                child_1 = encode_node(node.child_1)
                child_2 = encode_node(node.child_2)
                child_3 = encode_node(node.child_3)
                return self.fold.add('planeEncoder', child_0, child_1, child_2, child_3, CUDA(node.prop), CUDA(node.alpha))
            elif node.is_world():
                child_0 = encode_node(node.child_0)
                child_1 = encode_node(node.child_1)
                child_2 = encode_node(node.child_2)
                child_3 = encode_node(node.child_3)
                return self.fold.add('worldEncoder', child_0, child_1, child_2, child_3, CUDA(node.xywh))
            else:
                raise ValueError('Wrong node type!')

        # start from the root node
        encoding = encode_node(tree.root)
        # the sample encoder does not belong to the tree, therefore we should add it lastly
        # the output of encoder contains two parts: [z, kld_elements]
        encoder_ouput = self.fold.add('samplerEncoder', encoding)
        return encoder_ouput

    def decode_structure(self, tree, feature):
        """ During the training stage, the tree structure is fixed, otherwise it's impossiable to calculate the loss.
            Therefore, we need to input the tree data. This function decode the feature to a parsed structure for pytorch in a recursive manner.
            Input:
                tree: a tree structure contains the ground truth of each node
                feature: feature from the encoder
            Return: 
                total_loss: the total loss
        """

        def decode_node(node, feature):
            if node.is_box():
                box_prop, alpha = self.fold.add('boxDecoder', feature).split(2)
                label = self.fold.add('nodeClassifier', feature)
                prop_loss = self.fold.add('boxLossEstimator', box_prop, CUDA(node.prop), CUDA(tree.node_count[0][None]))
                alpha_loss = self.fold.add('alpha2DLossEstimator', alpha, CUDA(node.alpha), CUDA(tree.node_count[0][None]))
                label_loss = self.fold.add('ceLossEstimator', label, CUDA(node.label), CUDA(tree.node_num[None])) 
                loss = self.fold.add('vectorAdder', prop_loss, alpha_loss)
                loss = self.fold.add('vectorAdder', loss, label_loss)
                return loss
            elif node.is_quad() or node.is_plane() or node.is_world():
                if node.is_quad():
                    child_0, child_1, child_2, child_3, box_prop, alpha = self.fold.add('quadDecoder', feature).split(6)
                    prop_loss = self.fold.add('boxLossEstimator', box_prop, CUDA(node.prop), CUDA(tree.node_count[1][None]))
                    alpha_loss = self.fold.add('alpha2DLossEstimator', alpha, CUDA(node.alpha), CUDA(tree.node_count[1][None]))
                    loss = self.fold.add('vectorAdder', prop_loss, alpha_loss)
                elif node.is_plane():
                    child_0, child_1, child_2, child_3, box_prop, alpha = self.fold.add('planeDecoder', feature).split(6)
                    prop_loss = self.fold.add('boxLossEstimator', box_prop, CUDA(node.prop), CUDA(tree.node_count[3][None]))
                    alpha_loss = self.fold.add('alpha2DLossEstimator', alpha, CUDA(node.alpha), CUDA(tree.node_count[3][None]))
                    loss = self.fold.add('vectorAdder', prop_loss, alpha_loss)
                else:
                    child_0, child_1, child_2, child_3, xywh = self.fold.add('worldDecoder', feature).split(5)
                    loss = self.fold.add('xywhLossEstimator', xywh, CUDA(node.xywh), CUDA(tree.node_count[4][None])) # normalize quad node

                label = self.fold.add('nodeClassifier', feature)
                label_loss = self.fold.add('ceLossEstimator', label, CUDA(node.label), CUDA(tree.node_num[None])) 

                # four children
                child_0_loss = decode_node(node.child_0, child_0)
                child_1_loss = decode_node(node.child_1, child_1)
                child_2_loss = decode_node(node.child_2, child_2)
                child_3_loss = decode_node(node.child_3, child_3)

                loss = self.fold.add('vectorAdder', loss, label_loss)
                loss = self.fold.add('vectorAdder', loss, child_0_loss)
                loss = self.fold.add('vectorAdder', loss, child_1_loss)
                loss = self.fold.add('vectorAdder', loss, child_2_loss)
                loss = self.fold.add('vectorAdder', loss, child_3_loss)
                return loss
            elif node.is_empty():
                # we dont need a empty decoder
                label = self.fold.add('nodeClassifier', feature)
                loss = self.fold.add('ceLossEstimator', label, CUDA(node.label), CUDA(tree.node_num[None])) 
                return loss

        # the sample decoder does not belong to the tree, therefore we should add it at the beginning
        feature = self.fold.add('samplerDecoder', feature) 

        # add nodes to Fold according to the tree structure
        recon_loss = decode_node(tree.root, feature)
        return recon_loss
