import operator
from heapq import heapify, heappush, heappop
from typing import *

import numpy as np

from .mmt import MMTController

class HuffmanMMTController(MMTController):
    def __init__(self, *args, unlearning_probs: List[float], **kwargs):
        super(MMTController, self).__init__(*args, **kwargs)
        self.unlearning_probs = unlearning_probs
        self.parent_tree, self.root_node_idx = self.construct_tree(self.unlearning_probs)
        print('parent_tree:', self.parent_tree)
        print('root node:', self.root_node_idx)
        self.client_mapping = list(self.client_loaders.keys())  # no dummy nodes added hence client_mapping is also keys of client_loaders
        self.sub_models = {i: None for i in range(self.root_node_idx + 1)}
        self.sub_accu_rounds = {i: 0 for i in range(self.root_node_idx + 1)}
    
    def construct_tree(self, probs: List[float]):
        """
        Huffman coding:
        Start by picking the two least probabilities, merge and pick the next two least probabilities, 
        including the merged node whose probability is sum of child nodes' probabilities.
        """

        hq = [(prob, idx) for idx, prob in enumerate(probs)]    # prob as first element to enable sorting on tuples
        heapify(hq)
        root_node_idx = self.config.federated.num_clients
        parent_tree = {}
        while len(hq) >= 2:
            left_prob, left_node_idx = heappop(hq)
            right_prob, right_node_idx = heappop(hq)
            parent_tree[left_node_idx] = root_node_idx
            parent_tree[right_node_idx] = root_node_idx
            heappush(hq, (left_prob + right_prob, root_node_idx))
            root_node_idx += 1

        root_node_idx -= 1
        return parent_tree, root_node_idx