'''
Author: 
Email: 
Date: 2020-10-12 22:14:36
LastEditTime: 2021-05-30 18:38:13
Description: 
    Modified from GRAINS: https://github.com/ManyiLi12345/GRAINS
    We define the class of scenario tree to parse the scenario data.
'''

import collections
import torch
from torch.autograd import Variable


class Node(object):
    """ This class is modified from https://github.com/nearai/torchfold. Original project is out of date.
    Input:
        step: the depth index, the operater in each step is called independently
        index: the index of each operation in each step
    """
    def __init__(self, op, step, index, *args):
        self.op = op
        self.step = step
        self.index = index
        self.args = args
        self.split_idx = -1
        self.batch = True

    def split(self, num):
        """ Split resulting node, if function returns multiple values.
        """
        nodes = []
        for idx in range(num):
            nodes.append(Node(self.op, self.step, self.index, *self.args))
            nodes[-1].split_idx = idx
        return tuple(nodes)

    def nobatch(self):
        self.batch = False
        return self

    def get(self, values):
        if self.split_idx >= 0:
            return values[self.step][self.op][self.split_idx][self.index]
        else:
            return values[self.step][self.op][self.index]

    def __repr__(self):
        return "<Step: %d, Index: %d, Node op: %s>" % (self.step, self.index, self.op)


class Fold(object):
    """ This class is modified from https://github.com/nearai/torchfold. Original project is out of date.
        When we apply the NN to the batched data, the whole batch is calculated together and divided by depth.
    """
    def __init__(self):
        self.steps = collections.defaultdict(lambda: collections.defaultdict(list))
        self.cached_nodes = collections.defaultdict(dict)
        self.total_nodes = 0

    def add(self, op, *args):
        self.total_nodes += 1
        if not all([isinstance(arg, (Node, int, Variable)) for arg in args]):
            raise ValueError("All args should be Variable, int or Node, but got: %s" % str(args))
        if args not in self.cached_nodes[op]:
            step = max([0] + [arg.step + 1 for arg in args if isinstance(arg, Node)])
            node = Node(op, step, len(self.steps[step][op]), *args)
            self.steps[step][op].append(args)
            self.cached_nodes[op][args] = node
        return self.cached_nodes[op][args]
        
    def _batch_args(self, arg_lists, values):
        res = []
        for arg in arg_lists:
            r = []
            if isinstance(arg, Node):
                res.append(arg.get(values))
                continue

            if isinstance(arg[0], Node):
                if arg[0].batch:
                    for x in arg:
                        r.append(x.get(values))
                    res.append(torch.cat(r, 0))
                else:
                    for i in range(2, len(arg)):
                        if arg[i] != arg[0]:
                            raise ValueError("Can not use more then one of nobatch argument, got: %s." % str(arg))
                    x = arg[0]
                    res.append(x.get(values))
            # if the arg is not Node or list of Node, just concate
            else:
                res.append(torch.cat(arg, 0))
        return res

    def apply(self, nn, nodes):
        values = {}
        for step in sorted(self.steps.keys()):
            values[step] = {}
            for op in self.steps[step]:
                func = getattr(nn, op)
                try:
                    batched_args = self._batch_args(zip(*self.steps[step][op]), values)
                except Exception:
                    raise ValueError("Error while executing node %s[%d] with args: %s" % (op, step, self.steps[step][op]))
                if batched_args:
                    arg_size = batched_args[0].size()[0]
                else:
                    arg_size = 1
                res = func(*batched_args)
                if isinstance(res, (tuple, list)):
                    values[step][op] = []
                    for x in res:
                        values[step][op].append(torch.chunk(x, arg_size))
                else:
                    values[step][op] = torch.chunk(res, arg_size)
        return self._batch_args(nodes, values)
