import operator
from typing import *

import numpy as np

from .mmt import MMTController



class ShannonFanoMMTController(MMTController):
    def __init__(self, *args, unlearning_probs: List[float], **kwargs):
        super(MMTController, self).__init__(*args, **kwargs)
        self.unlearning_probs = unlearning_probs
        self.parent_tree, self.root_node_idx = self.construct_tree(self.unlearning_probs)
        print('parent_tree:', self.parent_tree)
        print('root node:', self.root_node_idx)
        self.client_mapping = list(self.client_loaders.keys())  # no dummy nodes added hence client_mapping is also keys of client_loaders
        self.sub_models = {i: None for i in range(self.root_node_idx + 1)}
        self.sub_accu_rounds = {i: 0 for i in range(self.root_node_idx + 1)}
    
    def construct_tree(self, probs: List[float]):
        """
        Shannon-Fano coding:
        1. Sort the probabilities in ascending order.
        2. Split where the difference between sum of lhs and sum of rhs is minimized.
        3. Continue until reach the leaf nodes.
        """
        client_n_prob = [(idx, prob) for idx, prob in enumerate(probs)]
        client_n_prob = sorted(client_n_prob, key=operator.itemgetter(1))
        prefix_sum = []
        for i, (_, prob) in enumerate(client_n_prob):
            if i == 0:
                prefix_sum.append(prob)
            else:
                prefix_sum.append(prefix_sum[i-1] + prob)
        
        print('prefix sum:', np.around(prefix_sum, 3))
        parent_tree, root_node_idx, _ = self.split_tree(client_n_prob,
                                                        prefix_sum, 
                                                        left=0, 
                                                        right=len(prefix_sum), 
                                                        highest_root_node_idx=self.config.federated.num_clients - 1,
                                                        )
        return parent_tree, root_node_idx

    def split_tree(self, client_n_prob, prefix_sum, left, right, highest_root_node_idx):
        parent_tree = {}
        if right - left == 1:   # only have one client
            client_idx = client_n_prob[left][0] 
            root_node_idx = client_idx
        else:
            best_split_i = None
            smallest_diff_sum = None
            all_diff = []
            for i in range(left, right):
                if left == 0:
                    left_sum = prefix_sum[i]
                    right_sum = prefix_sum[right - 1] - prefix_sum[i]
                else:
                    left_sum = prefix_sum[i] - prefix_sum[left - 1]
                    right_sum = prefix_sum[right - 1] - prefix_sum[i]
                diff_sum = abs(left_sum - right_sum)
                all_diff.append(diff_sum)
                if i == left:
                    smallest_diff_sum = abs(left_sum - right_sum)
                    best_split_i = i+1
                else:
                    diff_sum = abs(left_sum - right_sum)
                    if diff_sum < smallest_diff_sum:
                        best_split_i = i+1
                        smallest_diff_sum = diff_sum
            
            print("left:", left, "right:", right, "diff:", all_diff)
            left_parent_tree, left_root_node_idx, highest_root_node_idx = self.split_tree(client_n_prob,
                                                                                          prefix_sum, 
                                                                                          left=left, 
                                                                                          right=best_split_i,
                                                                                          highest_root_node_idx=highest_root_node_idx,
                                                                                          )
            right_parent_tree, right_root_node_idx, highest_root_node_idx = self.split_tree(client_n_prob,
                                                                                            prefix_sum, 
                                                                                            left=best_split_i, 
                                                                                            right=right,
                                                                                            highest_root_node_idx=highest_root_node_idx,
                                                                                            )
            root_node_idx = max(left_root_node_idx, right_root_node_idx, highest_root_node_idx) + 1
            parent_tree[left_root_node_idx] = root_node_idx
            parent_tree[right_root_node_idx] = root_node_idx
            parent_tree.update(left_parent_tree)
            parent_tree.update(right_parent_tree)
            highest_root_node_idx = root_node_idx

        return parent_tree, root_node_idx, highest_root_node_idx

                
                

