from typing import *
from .mmt import MMTController

class GreedyController(MMTController):
    def __init__(self, *args, subscription_seq: List[int], **kwargs):
        super(MMTController, self).__init__(*args, **kwargs) 
        self.subscription_seq = subscription_seq
        self.parent_tree, self.root_node_idx = self.construct_tree(self.subscription_seq)
        self.client_mapping = list(self.client_loaders.keys())  # no dummy nodes added hence client_mapping is also keys of client_loaders
        self.subscription_seq = subscription_seq    # assume distinct probabilities
        self.sub_models = {i: None for i in range(self.root_node_idx + 1)}
        self.sub_accu_rounds = {i: 0 for i in range(self.root_node_idx + 1)}
        print('subscription seq:', self.subscription_seq)
        print('parent_tree:', self.parent_tree)
    
    def construct_tree(self, subscription_seq, idx: int = 0):
        """
        Given clients' subscription order of clients, the tree is built from the client with furthest expiry 
        and sequentially merged to the sub-FL node with the next expired client.
        """
        parent_tree = {}
        node_idx = subscription_seq[idx]
        if idx == len(subscription_seq) - 1:
            root_node_idx = self.config.federated.num_clients
            parent_tree[node_idx] = root_node_idx
        else:
            remaining_parent_tree, remaining_root_node_idx = self.construct_tree(subscription_seq, idx + 1) 
            parent_tree.update(remaining_parent_tree)
            root_node_idx = remaining_root_node_idx + 1
            parent_tree[remaining_root_node_idx] = root_node_idx
            parent_tree[node_idx] = root_node_idx
        return parent_tree, root_node_idx