
from .mmt import *

class SuboptMMTController(MMTController):
    def __init__(self, *args, N=3, **kwargs):
        self.N = N  # number of children in a branch
        super(SuboptMMTController, self).__init__(*args, **kwargs)
    
    def construct_tree(self):
        num_clients = self.config.federated.num_clients
        tree_height = int(math.ceil(math.log(num_clients, self.N))) + 1
        parent_tree = {}    # {node: immediate parent node}
        data_tree = {}      # {node_id: total_data}
        for level in range(tree_height):    # tree is built bottom-up
            if level == 0:
                no_child_nodes = self.N ** (tree_height - 1)
                for i in range(no_child_nodes):
                    if i in self.client_loaders.keys():
                        data_tree[i] = [self.client_loaders[i]]
                    else:   # add dummy nodes to make 2**k clients
                        data_tree[i] = []
                child_offset_idx = no_child_nodes
                curr_node_idx = no_child_nodes
            else:
                for i in range(no_child_nodes // self.N):
                    data_tree[curr_node_idx] = []
                    leftmost_c_idx = curr_node_idx - child_offset_idx
                    for child_count in range(0, self.N):
                        c_idx = leftmost_c_idx + child_count
                        parent_tree[c_idx] = curr_node_idx
                        data_tree[curr_node_idx].extend(data_tree[c_idx])
                    curr_node_idx += 1
                    child_offset_idx -= (self.N - 1)    

                no_child_nodes //= self.N

        return parent_tree, data_tree