"""
Licensed Materials - Property of IBM
Restricted Materials of IBM
20190891
© Copyright IBM Corp. 2021 All Rights Reserved.
"""
"""
Module to where fusion algorithms are implemented.
"""
import copy
import logging
import numpy as np
import operator
import sys
from ibmfl.model.model_update import ModelUpdate
from ibmfl.aggregator.fusion.iter_avg_fusion_handler import IterAvgFusionHandler
import ibmfl.util.fl_metrics as fl_metrics
from itertools import groupby
import torch as th
import pandas as pd
from sklearn.cluster import KMeans
from scipy.spatial import distance
from ibmfl.message.message import Message
from ibmfl.message.message_type import MessageType
from ibmfl.exceptions import ModelUpdateException, FusionException
import random

logger = logging.getLogger(__name__)


class TierFedAvgFusionHandler(IterAvgFusionHandler):
    """
    Class for iterative averaging based fusion algorithms.
    An iterative fusion algorithm here referred to a fusion algorithm that
    sends out queries at each global round to registered parties for
    information, and use the collected information from parties to update
    the global model.
    The type of queries sent out at each round is the same. For example,
    at each round, the aggregator send out a query to request local model's
    weights after parties local training ends.
    The iterative algorithms can be terminated at any global rounds.

    In this class, the aggregator requests local model's weights from all
    parties at each round, and the averaging aggregation is performed over
    collected model weights. The global model's weights then are updated by
    the mean of all collected local models' weights.
    """

    def __init__(self, hyperparams,
                 protocol_handler,
                 data_handler=None,
                 data_handlers=[],
                 fl_models=None,
                 shapley_value_test_model=None,
                 **kwargs):
        """
        Initializes an IterAvgFusionHandler object with provided information,
        such as protocol handler, fl_model, data_handler and hyperparams.

        :param hyperparams: Hyperparameters used for training.
        :type hyperparams: `dict`
        :param protocol_handler: Protocol handler used for handling learning \
        algorithm's request for communication.
        :type protocol_handler: `ProtoHandler`
        :param data_handler: data handler that will be used to obtain data
        :type data_handler: `DataHandler`
        :param fl_model: model to be trained
        :type fl_model: `model.FLModel`
        :param kwargs: Additional arguments to initialize a fusion handler.
        :type kwargs: `Dict`
        :return: None
        """
        super().__init__(hyperparams,
                         protocol_handler,
                         data_handler,
                         data_handlers,
                         fl_models[0],
                         **kwargs)
        self.name = "Tier-FedAvg"
        self._eps = 1e-6
        self.params_global = hyperparams.get('global') or {}
        self.params_local = hyperparams.get('local') or None
        self.rounds = self.params_global.get('rounds') or 1
        self.curr_round = 0
        self.data_handlers = data_handlers
        self.global_accuracy = -1
        self.global_acc_threshold = 0.5
        self.tier_update_frequency = 0
        self.personalized_model_collection_frequency = 25
        self.utility_improvement = 0.0
        self.select_random = self.params_global.get('select_random') or True
        # self.token_to_pay = self.params_global.get('token_to_pay') or 1
        self.token_to_pay = 1
        self.random_parties_selected_per_tier = self.params_global.get('random_parties_selected_per_tier')
        self.parties_selected_per_tier = self.params_global.get('parties_selected_per_tier')
        self.previous_model_updates = []
        self.tier_client_idx = {}
        self.pre_training_rounds = self.params_global.get('pre_training_rounds') or 0
        self.termination_accuracy = self.params_global.get(
            'termination_accuracy')
        self.current_tiers = self.params_global.get(
            'tiers')
        self.shapley_value_history= {tier_id: dict() for tier_id in range(self.current_tiers)}
        self.tier_client_idx= {str(tier_id): [] for tier_id in range(self.current_tiers)}
        self.shapley_value_per_tier_per_party = {str(tier_id): dict() for tier_id in range(self.current_tiers)}
        self.shapley_value_test_model = shapley_value_test_model
        self.model_updates = []
        # self.tokens = self.params_global.get('tokens') or 0
        self.tokens = self.params_global.get('tokens') or 100
        self.tokens_party = {}
        self.max_global_acc_per_tier = {}
        self.cur_global_acc_per_tier = {}
        self.participated_rounds_per_tier_per_party = {}
        self.available_free_tokens_per_tier = {}
            
        self.fl_models = fl_models
        # print(f'fl_models: {fl_models}')
        for fl_model in fl_models:
            if fl_model and fl_model.is_fitted():
                self.model_updates.append(fl_model.get_model_update())
            else:
                self.model_updates = None

        self.current_model_weights_per_tier = []
        for model_update in self.model_updates:
            self.current_model_weights_per_tier.append(model_update.get('weights') if model_update else None)
        logger.info('[FARAZ] len(self.current_model_weights_per_tier): ' + str(len(self.current_model_weights_per_tier)))
        self.current_model_weights = \
            self.model_updates[0].get('weights') if self.model_updates[0] else None

        if self.evidencia:
            from ibmfl.evidencia.util.hashing import hash_model_update

    # def get_cosine_scores():
    def get_model_logits(self, model):
        '''Returns the softmax layer logits of the model'''
        if isinstance(model, ModelUpdate):
            model = model.get('weights')

        return model[-1]
    
    def get_cosine_similarities(self):
        '''Returns the cosine similarities between the client model and the server model'''

        logger.info('[FARAZ] Calculating cosine similarities')
        
        similarities = {}
        for i in range(len(self.previous_model_updates)):
            models = self.previous_model_updates[i][0]
            parties = self.previous_model_updates[i][1]
            logger.info('[FARAZ] clients: ' + str(parties) + 'in tier: ' + str(i))
            tier_global_model_logits = self.get_model_logits(self.current_model_weights_per_tier[i])
            
            for j in range(len(parties)):
                client_id = parties[j]
                
                client_model_logits = self.get_model_logits(models[j])

                if client_id not in similarities.keys():
                    similarities[client_id] = [(1. - distance.cosine(client_model_logits, tier_global_model_logits), i)]
                else:
                    # print('client_model_logits: ', client_model_logits, 'Global tier model logits: ', tier_global_model_logits)
                    # print('cosine similarities: ', 1. - distance.cosine(client_model_logits, tier_global_model_logits))
                    cosine_similarity = 1. - distance.cosine(client_model_logits, tier_global_model_logits)
                    # if math.isnan(cosine_similarity):
                    #     cosine_similarity = 1.0
                    similarities[client_id].append((cosine_similarity, i))
            
        logger.info('[FARAZ] similarities: ' + str(similarities))
        return similarities
    
    def get_personalized_model_acc(self):
        ''''Returns the personalized model accuracy of each client
        returns: a dictionary of the form {client_id: personalized_model_acc}
        '''
        logger.info('[FARAZ] Collecting personalized model accuracies from clients: ' + str(self.get_registered_parties()))
        return self.ph.get_personalized_model_acc(self.get_registered_parties())
    
    def measure_average_accuracies_from_clients(self):
        '''Measures the average accuracy of tier models from clients
        :return: average accuracy per tier
        '''
        average_accuracy = {}
        model_updates = []
        if self.current_model_weights_per_tier:
            for current_model_weights in self.current_model_weights_per_tier:
                model_updates.append(ModelUpdate(weights=current_model_weights))
            for tier_id in self.tier_client_idx.keys():
                if self.tier_client_idx[tier_id] not in [[], None]:
                    average_accuracies = self.ph.measure_average_accuracies_from_clients(self.tier_client_idx[tier_id], {'model_updates': model_updates})
                    logger.info('Local accuracies of tier ' + str(tier_id) + ' are: ' + str(average_accuracies))
                    average_accuracy[tier_id] = np.mean([acc[int(tier_id)] for acc in average_accuracies[0]])
                    logger.info('Average accuracy of tier ' + str(tier_id) + ' is: ' + str(average_accuracy[tier_id]))
                else:
                    logger.error('No clients available for tier: ' + str(tier_id))
        else:
            logger.error('No model updates available')
        
        return average_accuracy
    
    def measure_global_accuracy(self):
        '''Measures the global accuracy of the model
        :return: max global accuracy per tier
        '''
        global_accuracy = {}
        # accuracy_DA = 0.0
        # accuracy_DB = 0.0
        accuracies = []
        files_list = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F', 6: 'G', 7: 'H', 8: 'I', 9: 'J', 10: 'K', 11: 'L', 12: 'M', 13: 'N', 14: 'O', 15: 'P', 16: 'Q', 17: 'R', 18: 'S', 19: 'T', 20: 'U', 21: 'V', 22: 'W', 23: 'X', 24: 'Y', 25: 'Z'}
        tier_no = 0
        for fl_model in self.fl_models:
            if self.data_handler and fl_model:
                (_, _), test_data = self.data_handler.get_data()
                eval_results = fl_model.evaluate(test_data)
                logger.info('[FARAZ] global accuracy on IID dataset: ' + str(eval_results))
            
            for tier_id in range(0, self.current_tiers):
                if self.data_handlers[tier_id] and self.fl_models[tier_id]:
                    (_, _), test_data = self.data_handlers[tier_id].get_data()
                    eval_results = fl_model.evaluate(test_data)
                    accuracies.append(eval_results['accuracy_score'])
                    logger.info('[FARAZ] tier {} global accuracy on D{}'.format(tier_no, files_list[tier_id]) + ' dataset: ' + str(eval_results))
            tier_no += 1
                    
            # if self.data_handler_DA and fl_model:
            #     (_, _), test_data = self.data_handler_DA.get_data()
            #     eval_results = fl_model.evaluate(test_data)
            #     accuracy_DA = eval_results['accuracy_score']
            #     logger.info('[FARAZ] tier ' + str(tier_id) + ' global accuracy on DA dataset: ' + str(eval_results))

            # if self.data_handler_DB and fl_model:
            #     (_, _), test_data = self.data_handler_DB.get_data()
            #     eval_results = fl_model.evaluate(test_data)
            #     accuracy_DB = eval_results['accuracy_score']
            #     logger.info('[FARAZ] tier ' + str(tier_id) + ' global accuracy on DB dataset: ' + str(eval_results))
                
            # global_accuracy[str(tier_id)] = max(accuracy_DA, accuracy_DB)
        for tier_id in range(0, self.current_tiers):
            global_accuracy[str(tier_id)] = np.mean(accuracies)
            
        return global_accuracy
    
    def give_tokens_by_marginal_contribution(self):
        
        for tier_no, avaialble_tokens in self.available_free_tokens_per_tier.items():
            if avaialble_tokens <= 0:
                logger.info('[FARAZ] Free Tokens are not enough for Accuracy Contribution: ' + str(avaialble_tokens))
                return

            # logger.info('[FARAZ] self.accuracies_party: ' + str(self.accuracies_party))

            shapley_values_per_party = {}
            shapley_values_per_party = self.shapley_value_per_tier_per_party[tier_no]
            logger.info('[FARAZ] shapley_values_per_party: ' + str(shapley_values_per_party))

            logger.info('[DEBUG] acc_parties: ' + str(shapley_values_per_party))
            sorted_party = sorted(shapley_values_per_party.items(), key=operator.itemgetter(1), reverse=True)
            logger.info('[FARAZ] sorted_party: ' + str(sorted_party))

            # Total number of selected parties
            n = len(shapley_values_per_party)
            # Ranking Denominator
            d = n * (n + 1)/2

            logger.info('[DEBUG] self.tokens_party, BEFORE giving free tokens for Acc Contribution: ' + str(self.tokens_party))
            logger.info('[DEBUG] available_free_tokens, BEFORE giving free tokens for Acc Contribution: ' + str(avaialble_tokens))

            #logger.info('[DEBUG] n: ' + str(n))
            #logger.info('[DEBUG] d: ' + str(d))

            given_free_tokens = 0
            for idx, party in enumerate(sorted_party):
                #free_token = int((n - idx) / d * avaialble_tokens)
                free_token = int((n - idx) * (avaialble_tokens/d) * 0.5)
                #logger.info('[DEBUG] party: ' + str(party[0]))
                # logger.info('[FARAZ] free_token: ' + str(free_token))
                self.tokens_party[party[0]] += free_token
                given_free_tokens += free_token
            avaialble_tokens -= given_free_tokens
            self.available_free_tokens_per_tier.update({tier_no: avaialble_tokens})
            
            logger.info('[DEBUG] available_free_tokens, AFTER giving free tokens for Acc Contribution: ' + str(avaialble_tokens))
            logger.info('[DEBUG] self.tokens_party, AFTER giving free tokens for Acc Contribution: ' + str(self.tokens_party))

    def give_tokens_by_participation_record(self):
        for tier_no, avaialble_tokens in self.available_free_tokens_per_tier.items():
            if avaialble_tokens <= 0:
                logger.info('[FARAZ] Free Tokens are not enough for Algorithm 1: ' + str(avaialble_tokens))
                return
            participated_rounds_per_party = self.participated_rounds_per_tier_per_party[tier_no]
            if self.utility_improvement > 0:
                sorted_party = sorted(participated_rounds_per_party.items(), key=operator.itemgetter(1), reverse=True)
            else:
                sorted_party = sorted(participated_rounds_per_party.items(), key=operator.itemgetter(1), reverse=False)
                
            logger.info('[FARAZ] sorted_party: ' + str(sorted_party))

            # Total number of parties
            n = len(participated_rounds_per_party)
            # Ranking Denominator
            d = n * (n + 1)/2

            logger.info('[DEBUG] self.tokens_party, BEFORE giving free tokens for Participation: ' + str(self.tokens_party))
            logger.info('[DEBUG] available_free_tokens, BEFORE giving free tokens for Participation: ' + str(avaialble_tokens))

            #logger.info('[DEBUG] n: ' + str(n))
            #logger.info('[DEBUG] d: ' + str(d))

            given_free_tokens = 0
            for idx, party in enumerate(sorted_party):
                free_token = int((n - idx) / d * avaialble_tokens)
                #free_token = int((n - idx) / d * avaialble_tokens * 0.5)
                #logger.info('[DEBUG] party: ' + str(party[0]))
                # logger.info('[FARAZ] free_token: ' + str(free_token))
                self.tokens_party[party[0]] += free_token
                given_free_tokens += free_token
            avaialble_tokens -= given_free_tokens
            self.available_free_tokens_per_tier.update({tier_no: avaialble_tokens})
            
            logger.info('[DEBUG] available_free_tokens, AFTER giving free tokens: ' + str(avaialble_tokens))
            logger.info('[DEBUG] self.tokens_party, AFTER giving free tokens: ' + str(self.tokens_party))
        
    def reimburse_tokens_by_utility(self):
        no_of_parties = len(self.get_available_parties())
        for tier_no, avaialble_tokens in self.available_free_tokens_per_tier.items():
            
            # Skip at first.
            if self.max_global_acc_per_tier[tier_no] == 0 and self.cur_global_acc_per_tier.get(tier_no) is not None:
                self.max_global_acc_per_tier[tier_no] = self.cur_global_acc_per_tier[tier_no]
                continue
            else:
                t_max = 1.0
                i_max = 1.0
                utility_improvement = 0
                reduction_ratio = 0
                reimbursed_token = 0

                # utility_improvement: 0.0~1.0
                if self.cur_global_acc_per_tier.get(tier_no) is not None:
                    utility_improvement = max(0.0, (self.cur_global_acc_per_tier[tier_no] - self.max_global_acc_per_tier[tier_no])/self.max_global_acc_per_tier[tier_no])
                    utility_improvement = min(1.0, utility_improvement)
                    logger.info('[FARAZ] utility_improvement: ' + str(utility_improvement))

                    logger.info('[DEBUG] self.tokens_party, BEFORE reimbursing free tokens: ' + str(self.tokens_party))
                    logger.info('[DEBUG] available_free_tokens, BEFORE reimbursing free token: ' + str(avaialble_tokens))

                    # Update Max Accuracy achieved until current round.
                    if self.cur_global_acc_per_tier[tier_no] > self.max_global_acc_per_tier[tier_no]:
                        self.max_global_acc_per_tier[tier_no] = self.cur_global_acc_per_tier[tier_no]
                    else:
                        self.max_global_acc_per_tier[tier_no] = 0

                # Calculate Reimbursed Token based on Utility improvement.
                utility_improvement = min(i_max, utility_improvement)
                self.utility_improvement = utility_improvement
                if utility_improvement > 0:
                    reduction_ratio = 0
                else:
                    reduction_ratio = t_max * (1 - (utility_improvement/i_max))
                logger.info('[FARAZ] reduction_ratio: ' + str(reduction_ratio))
                if avaialble_tokens > no_of_parties:
                    reduction_ratio = no_of_parties/avaialble_tokens
                reimbursed_token = int(avaialble_tokens * reduction_ratio)
                logger.info('[FARAZ] reimbursed_token: ' + str(reimbursed_token))

                # Reimburse Token to ALL Consumers.
                logger.info('tier_client_idx: ' + str(self.tier_client_idx))
                if self.tier_client_idx.get(str(tier_no)):
                    lst_parties = self.tier_client_idx[str(tier_no)]
                    total_parties = len(lst_parties)
                    #logger.info('[DEBUG] total_parties = len(lst_parties): ' + str(total_parties))
                    for party in lst_parties:
                        self.tokens_party[party] += int(reimbursed_token/total_parties)
                        avaialble_tokens -= int(reimbursed_token/total_parties)
                    self.available_free_tokens_per_tier.update({tier_no: avaialble_tokens})
                
                logger.info('[DEBUG] available_free_tokens, AFTER reimbursing free tokens: ' + str(avaialble_tokens))
                logger.info('[DEBUG] self.tokens_party, AFTER reimbursing free tokesn: ' + str(self.tokens_party))
        
    def update_tier_client_idx(self, selected_clients):
        """
        Update tier_client_idx
        """
        # self.tier_client_idx = {}
        #delete if exists in self.tier_client_idx
        for tier_no in range(0, self.current_tiers):
            if self.tier_client_idx.get(str(tier_no)):
                for client in selected_clients[str(tier_no)]:
                    if client in self.tier_client_idx[str(tier_no)]:
                        # logger.info('delete client: ' + str(client) + ' from tier: ' + str(tier_no) + ' in self.tier_client_idx')
                        self.tier_client_idx[str(tier_no)].remove(client)

            for client in selected_clients[str(tier_no)]:
                if self.tier_client_idx.get(str(tier_no)):
                    # logger.info('[DEBUG] self.tier_client_idx: ' + str(self.tier_client_idx))
                    if client not in self.tier_client_idx[str(tier_no)]:
                        self.tier_client_idx[str(tier_no)].append(client)
                else:
                    self.tier_client_idx.update({str(tier_no): [client]})

        logger.info('[DEBUG] self.tier_client_idx: ' + str(self.tier_client_idx))
        #delete all clients from selected_clients in self.tier_client_idx
        
    def get_f1_scores_on_IID_data(self):
        """
        Returns f1 scores on IID data

        :return: f1 scores
        :rtype: `Dcit`
        """
        f1_scores = {}
        
        test_dataset = []
        x_test = self.data_handler.x_test
        y_test = self.data_handler.y_test
        for data_id in range(1, len(y_test)):
            test_dataset.append((x_test[data_id], y_test[data_id]))
        
        res = dict()
        # logger.info('[FARAZ] test_dataset: ' + str(test_dataset))
        # forming equal groups
        res = {key: [v[0] for v in val] for key, val in groupby(
            sorted(test_dataset, key=lambda ele: ele[1]), key=lambda ele: ele[1])}
        for party_no in range(0, len(self.previous_model_updates[0][1])):
            self.fl_models[1].update_model(self.previous_model_updates[0][0][party_no])
            f1_scores_per_class = []
            for key, value in res.items():
                
                y_preds = np.argmax(self.fl_models[1].predict(th.tensor(np.array(value))), axis=1)
                y_test_per_class = [key for i in range(len(y_preds))]
                f1_scores_per_class.append(fl_metrics.get_multi_label_classification_metrics(y_preds, y_test_per_class)['f1 weighted'])
            f1_scores[self.previous_model_updates[0][1][party_no]] = f1_scores_per_class
        # logger.info('[FARAZ] f1_scores: {}'.format(f1_scores))

        return f1_scores
    
    def select_clients_on_basis_of_f1_Scores(self, f1_scores):
        """
        Selects clients on basis of f1 scores

        :param f1_scores: f1 scores
        :type f1_scores: `dict`
        :return: selected clients
        :rtype: `list`
        """
        # [FARAZ] Doing: Select clients on basis of f1 scores
        df  = pd.DataFrame.from_dict(f1_scores, orient='index')
        variance_df = df.var(axis=0).nlargest(10)
        top_variant_features = df.loc[:, variance_df.keys().tolist()].values
        tier_idx = KMeans(n_clusters=2, random_state=0).fit_predict(top_variant_features)
        # logger.info('[FARAZ] tier_idx: ' + str(tier_idx))
        
        k = 0
        for client_id in f1_scores.keys():
            
            if self.tier_client_idx.get(str(tier_idx[k])) is not None:
                self.tier_client_idx[str(tier_idx[k])].append(client_id)
            else:
                self.tier_client_idx[str(tier_idx[k])] = [client_id]
            
            # self.tier_client_idx[str(k%2)] = [client_id]
            k += 1
        
        #[FARAZ] Doing: Updating the tiers based on their client averages
        
        for tier_id, party_ids in self.tier_client_idx.items():
            #[FARAZ] Doing: get corresponding model updates from previous_model_updates using party ids
            idx = []
            for party_id in party_ids:
                idx.append(self.previous_model_updates[0][1].index(party_id))
            party_lst_per_tier = []
            for id in idx:
                party_lst_per_tier.append(self.previous_model_updates[0][0][id].get('weights'))
            #[FARAZ] Doing: get the average of the model updates for parties in a tier
            if len(party_lst_per_tier) > 1:
                self.current_model_weights_per_tier[int(tier_id)] = party_lst_per_tier[0]
            else:
                self.current_model_weights_per_tier[int(tier_id)] = np.mean(party_lst_per_tier, axis=0)
            
    
    def get_party_preferences(self):
        """
        Returns party preferences

        :return: party preferences
        :rtype: `dict`
        """
        registered_parties = self.get_registered_parties()
        # tier_global_model_logits = []
        logger.info('[FARAZ] Sending requests to parties to get their preferences')
        
        if self.current_model_weights_per_tier:
            model_updates = []
            for current_model_weights in self.current_model_weights_per_tier:
                model_updates.append(ModelUpdate(weights=current_model_weights))
                    
        # for i in range(self.current_tiers):
        #     tier_global_model_logits.append(self.get_model_logits(self.current_model_weights_per_tier[i]))
        # party_preferences = self.ph.get_party_preferences(registered_parties, {'tier_model_logits': tier_global_model_logits})
        party_preferences = self.ph.get_party_preferences(registered_parties, {'model_updates': model_updates})
        logger.info('[FARAZ] party_preferences: ' + str(party_preferences))
        
        return party_preferences
    
    def get_shapley_value_for_party_in_tier(self, party_id, tier_id):
        """
        Returns shapley value for party in tier

        :param party_id: party id
        :type party_id: `int`
        :param tier_id: tier id
        :type tier_id: `int`
        :return: shapley value
        :rtype: `float`
        """
        if self.shapley_value_history.get(tier_id) is None:
            #if tier does not exist return MAX
            return sys.float_info.min
        elif self.shapley_value_history[tier_id].get(party_id) is None:
            #if party record does not exist in this tier return MAX
            return sys.float_info.min
        elif self.tokens_party[party_id] <= self.token_to_pay:
            #if party does not have tokens return MAX
            return sys.float_info.min
        else:
            return self.shapley_value_history[tier_id][party_id]
    
    def select_clients_from_preferences_contributions(self, party_preferences):
        """
        Selects clients for the current round on the basis of their
        preferences and contributions to the tier-level global model

        :param tier_parties: partiy preferences for tier
        :type tier_parties: `dict()`
        :return: selected clients
        :rtype: `dict`
        """
        logger.info('[FARAZ] Selecting clients from preferences and contributions')
        selected_parties = {}
        sorted_clients = {}
        
        tiers = party_preferences[0]
        # logger.info('[FARAZ] tiers: ' + str(tiers))
        parties = party_preferences[1]
        for i in range(0, len(tiers)):
            for j in range(0, len(tiers[i])):
                # logger.info('tiers[i][j]: ' + tiers[i][j])
                if selected_parties.get(tiers[i][j]) is not None:
                    selected_parties[tiers[i][j]].append(parties[i])
                    # logger.info('appending selected_parties[tiers[i][j]]: ' + str(selected_parties[tiers[i][j]]))
                else:
                    selected_parties[tiers[i][j]] = [parties[i]]
        #         logger.info('creating selected_parties[tiers[i][j]]: ' + str(selected_parties[tiers[i][j]]))
        # logger.info('creating selected_parties[tiers[i][j]]: ' + str(selected_parties))
        
        #Sort clients by their Shapley values and tokens then select top k clients and r random clients

        sorted_clients_by_shapley_and_tokens = []
        for tier_no, parties in selected_parties.items():
            if self.utility_improvement > 0:
                sorted_clients_by_shapley_and_tokens = sorted(parties, key=lambda x: self.get_shapley_value_for_party_in_tier(x, tier_no), reverse=True)
            else:
                 sorted_clients_by_shapley_and_tokens = sorted(parties, key=lambda x: self.get_shapley_value_for_party_in_tier(x, tier_no), reverse=False)
                 
            if self.parties_selected_per_tier + self.random_parties_selected_per_tier < len(sorted_clients_by_shapley_and_tokens):
                selected_clients = sorted_clients_by_shapley_and_tokens[:self.parties_selected_per_tier]
                remaining = list(set(sorted_clients_by_shapley_and_tokens) - set(selected_clients))
                sorted_clients[tier_no] = selected_clients + random.sample(remaining, self.random_parties_selected_per_tier)
            else:
                sorted_clients[tier_no] = sorted_clients_by_shapley_and_tokens
        #Make sure each tier has clients first selected by Shapley value and then randomly
        logger.info('[FARAZ] cur_global_acc_per_tier: ' + str(self.cur_global_acc_per_tier))
        for tier_no in range(0, self.current_tiers):
            if self.cur_global_acc_per_tier[str(tier_no)] < self.global_acc_threshold:
                if sorted_clients.get(str(tier_no)) is None:
                    if self.utility_improvement > 0:
                        sorted_clients_by_shapley_and_tokens = sorted(parties, key=lambda x: self.get_shapley_value_for_party_in_tier(x, tier_no), reverse=False)
                    else:
                        sorted_clients_by_shapley_and_tokens = sorted(parties, key=lambda x: self.get_shapley_value_for_party_in_tier(x, tier_no), reverse=True)
                    sorted_clients[str(tier_no)] = sorted_clients_by_shapley_and_tokens[-self.parties_selected_per_tier:]
                    remaining = list(set(sorted_clients_by_shapley_and_tokens) - set(selected_clients))
                    if len(remaining) <= self.random_parties_selected_per_tier:
                        sorted_clients[str(tier_no)] = sorted_clients[str(tier_no)] + remaining
                    else:
                        sorted_clients[str(tier_no)] = selected_clients + random.sample(remaining, self.random_parties_selected_per_tier)
                # logger.info('tier_no: ' + str(tier_no) + 'sorted_clients[tier_no]: ' + str(sorted_clients[tier_no]))
            if sorted_clients.get(str(tier_no)) is None:
                sorted_clients[str(tier_no)] = random.sample(self.get_available_parties(), self.parties_selected_per_tier + self.random_parties_selected_per_tier)
        logger.info('sorted_clients: ' + str(sorted_clients))
        self.update_tier_client_idx(sorted_clients)
        return sorted_clients
    
    def select_random_clients(self):
        """
        Selects random clients for the current round

        :return: selected clients
        :rtype: `dict`
        """
        selected_parties = {}
        for tier_no in range(0, self.current_tiers):
            selected_parties[str(tier_no)] = random.sample(self.get_available_parties(), self.parties_selected_per_tier + self.random_parties_selected_per_tier)
        self.update_tier_client_idx(selected_parties)
        return selected_parties
    
    def select_clients_per_tier(self):
        """
        Selects clients for each tier

        :return: selected clients
        :rtype: `list`
        """
        if self.tier_update_frequency!=0 and self.curr_round % self.tier_update_frequency == 0:
            self.parties_selected_per_tier = 48
            self.random_parties_selected_per_tier = 2
            
        else:
            self.random_parties_selected_per_tier = self.params_global.get('random_parties_selected_per_tier')
            self.parties_selected_per_tier = self.params_global.get('parties_selected_per_tier')
            
        registered_parties = self.get_registered_parties()
        
        selected_parties = {}
        
        if self.curr_round == self.pre_training_rounds and not self.select_random:
            f1_scores = self.get_f1_scores_on_IID_data()
            self.select_clients_on_basis_of_f1_Scores(f1_scores)
            
            selected_parties = self.tier_client_idx
            logger.info('tier_client_idx 1 : ' + str(self.tier_client_idx))
        elif self.select_random:
            selected_parties = self.select_random_clients()
            self.select_random = False
            logger.info('tier_client_idx 3 : ' + str(self.tier_client_idx))    

        elif self.curr_round > self.pre_training_rounds:
            
            party_preferences = self.get_party_preferences()

            selected_parties = self.select_clients_from_preferences_contributions(party_preferences)
            self.tier_client_idx = selected_parties
            logger.info('tier_client_idx 2 : ' + str(self.tier_client_idx))
            
            
            
        else:
            selected_parties[0] = registered_parties
            logger.info('tier_client_idx 4 : ' + str(self.tier_client_idx))
            
                       
        return selected_parties
    
    def get_selected_parties(self):
        """
        Returns selected clients

        :return: selected clients
        :rtype: `list`
        """
        selected_parties = self.select_clients_per_tier()
        return selected_parties
    
    def start_global_training_by_tier(self):
        """
        Starts an iterative global federated learning training process.
        """
        if self.tokens != 0:
            lst_parties = self.get_registered_parties()
            for party in lst_parties:
                self.tokens_party[party] = self.tokens
        for i in range(self.current_tiers):
            self.shapley_value_history[str(i)] = {}
            self.participated_rounds_per_tier_per_party[str(i)] = {}
            self.available_free_tokens_per_tier[str(i)] = 0
            self.max_global_acc_per_tier[str(i)] = 0.0
            self.cur_global_acc_per_tier[str(i)] = 0.0
            for party_id in self.get_registered_parties():
                self.shapley_value_history[str(i)][party_id] = 0
                self.participated_rounds_per_tier_per_party[str(i)][party_id] = 0
            
        self.curr_round = 0
        while not self.reach_termination_criteria(self.curr_round):
            if self.curr_round == 100:
                self.random_parties_selected_per_tier = self.random_parties_selected_per_tier - 1
                self.parties_selected_per_tier = self.parties_selected_per_tier + 1
            
            logger.info('[FARAZ] Starting round: ' + str(self.curr_round))
            # logger.info('[FARAZ] tokens_party: ' + str(self.tokens_party))
            # construct ModelUpdate
            if self.current_model_weights_per_tier:
                self.model_updates = []
                for current_model_weights in self.current_model_weights_per_tier:
                    self.model_updates.append(ModelUpdate(weights=current_model_weights))
            else:
                self.model_updates = None


            if self.model_updates:
                # log to Evidentia
                if self.evidencia:
                    self.evidencia.add_claim("sent_global_model",
                                            "{}, '\"{}\"'".format(self.curr_round + 1,
                                            hash_model_update_by_tier(self.model_updates)))

            lst_replies = {}
            
            selected_parties = self.get_selected_parties()
            #[FARAZ] query parties for each tier separately
            logger.info('[FARAZ] Initiating training requests')
            
            logger.info('[FARAZ] selected_parties: ' + str(selected_parties))
            for tier_id in selected_parties.keys():
                
                for party_id in selected_parties[tier_id]:
                    if self.curr_round > self.pre_training_rounds:
                        self.participated_rounds_per_tier_per_party[tier_id][party_id] += 1

                if self.curr_round > self.pre_training_rounds:
                    for party in selected_parties[tier_id]:
                        self.tokens_party[party_id] -= self.token_to_pay
                        self.available_free_tokens_per_tier[tier_id] += self.token_to_pay
                        
                payload = {'hyperparams': {'local': self.params_local, 'tier': tier_id},
                        'model_updates': self.model_updates
                        }
                
                reply = self.query_parties(payload, selected_parties[tier_id], True)
                lst_replies[tier_id] = reply

            self.previous_model_updates = lst_replies
            
            
            # log to Evidentia
            if self.evidencia:
                updates_hashes = []
                for tier_id in lst_replies.keys():
                    for model_updates in lst_replies[tier_id]:
                        for update in model_updates:
                            updates_hashes.append(hash_model_update(update))
                            self.evidencia.add_claim("received_model_update_hashes",
                                                "{}, '{}'".format(self.curr_round + 1,
                                                str(updates_hashes).replace('\'', '"')))
                            

            self.update_weights(lst_replies)
            # Update model if we are maintaining one
            for i in range(0, self.current_tiers):
                self.fl_models[i].update_model(ModelUpdate(weights=self.current_model_weights_per_tier[i]))
                # if self.fl_models[i] is not None:
            #send updated tier models to parties
            # self.send_global_models()
            # self.cur_global_acc_per_tier = self.measure_global_accuracy()
            self.measure_global_accuracy()
            self.cur_global_acc_per_tier = self.measure_average_accuracies_from_clients()
            
            if self.curr_round > self.pre_training_rounds or self.select_random:
                self.cal_shapley_value(lst_replies)
                
                self.reimburse_tokens_by_utility()
                
                #Give tokens by marginal contributions
                self.give_tokens_by_marginal_contribution()
                # Give Free Tokens to all Producers by previous participation
                self.give_tokens_by_participation_record()
            
                logger.info(f'[FARAZ] shapley_value_per_tier_per_party: {self.shapley_value_per_tier_per_party}')
                
                logger.info('[FARAZ] self.tokens_party: ' + str(self.tokens_party))
                logger.info('[FARAZ] self.participated_rounds_party: ' + str(self.participated_rounds_per_tier_per_party))
                logger.info('[FARAZ] available_free_tokens: ' + str(self.available_free_tokens_per_tier))
                logger.info('[FARAZ] average_accuracies: ' + str(self.cur_global_acc_per_tier))
                if self.curr_round % self.personalized_model_collection_frequency == 0:
                    personalized_model_accuracies = self.get_personalized_model_acc()
                    logger.info('[FARAZ] personalized_model_accuracies: ' + str(personalized_model_accuracies))

            self.curr_round += 1
            self.save_current_state()
    
    def cal_aggregation_weight(self, lst_parties):
        '''
        get aggregation weight for each tier
        '''
        return 1 / len(lst_parties)

    def flatten_list_of_numpy(self, numpy_list):
        """
        flatten list of numpy to 1-dimensional numpy
        """
        temp_list = copy.deepcopy(numpy_list)
        res = []

        for item in temp_list:
            res.append(item.ravel())
        
        return np.concatenate(res)

    def get_gradient_on_test_data(self, aggregated_model_paramter):
        """
        Evaluate the local model based on the local test data.

        :param aggregated_model_paramter: parameter of aggregated_model. list of array
        :type aggregated_model_paramter: `list`
        :return: gradient of aggregated_model_paramter running on the aggregator test dataset. list of array
        :rtype: `list`
        """
        self.shapley_value_test_model.update_model(
            ModelUpdate(weights=aggregated_model_paramter)
        )

        (_), test_dataset = self.data_handler.get_data()
        gradients = self.shapley_value_test_model.get_gradient(
            train_data=test_dataset
        )
        return gradients
        

    def cal_shapley_value(self, lst_replies):
        """
        caculate shapley value for each party
        :param: lst_replies: info of parties update
        :type lst_replies: `dict[tier_id, tuple(lst_model_updates, lst_parties)]`
        :return: calculated shapley value for each party
        :rtype: `dict[tier_id, dict[party_id, shapley value]]`
        """
        # 2-layer dict 
        # For the first layer: key is tier_id, value is a dict
        # For the second layer: key is party_id, value is the shapley value
        self.shapley_value_per_tier_per_party= {str(tier_id): dict() for tier_id in range(self.current_tiers)}

        (_), test_dataset = self.data_handler.get_data()
        test_data_points_num = len(test_dataset[0])
        logger.info('Calculating shapley value for each party')
        # calculate shapley value for each tier
        for tier_id in range(self.current_tiers):
            # get aggregated model parameter for each tier
            if tier_id >= len(self.current_model_weights_per_tier):
                continue
            aggregated_model_paramter = self.current_model_weights_per_tier[tier_id]
            # get gradient of tier aggregated model on aggregator test dataset
            gradients = self.get_gradient_on_test_data(copy.deepcopy(aggregated_model_paramter))

            # calculate shapley value for each client
            # Note: using normalized aggregation since we delete 2 clients in the base case
            if str(tier_id) not in lst_replies:
                continue
            lst_model_updates = lst_replies[str(tier_id)][0]
            lst_parties = lst_replies[str(tier_id)][1]
            for i in range(len(lst_model_updates)):
                party_parameter = lst_model_updates[i]
                party_parameter = self.fusion_collected_responses(modelUpdates=copy.deepcopy(party_parameter))
                party_id = lst_parties[i]
                normalized_parameter = party_parameter - aggregated_model_paramter

                aggregation_weight = self.cal_aggregation_weight(lst_parties=lst_parties)
                gradients_flatten = self.flatten_list_of_numpy(gradients)
                normalized_parameter_flatten = self.flatten_list_of_numpy(normalized_parameter)
                if len(gradients_flatten) != len(normalized_parameter_flatten):
                    raise ValueError('parameter size is not the same for calculating shapley value')

                shapley_value = - (1/test_data_points_num) * aggregation_weight * np.dot(gradients_flatten, normalized_parameter_flatten)
                self.shapley_value_per_tier_per_party[str(tier_id)][party_id] = shapley_value
                #Updating Shapley value histories for better client selection
                for tier_no in self.shapley_value_per_tier_per_party:
                    for party in self.shapley_value_per_tier_per_party[tier_no]:
                        self.shapley_value_history[tier_no][party] = self.shapley_value_per_tier_per_party[tier_no][party]
        # return shapley_value_per_tier_per_party

    def fusion_collected_responses(self, modelUpdates, key='weights'):
        """
        Receives a model updates, where a model update is of the type
        `ModelUpdate`, using the values (indicating by the key)
        included in each model_update, it finds the mean.

        :param modelUpdates: A model updates of type `ModelUpdate` \
        to be averaged.
        :type modelUpdates:  `ModelUpdate`
        :param key: A key indicating what values the method will aggregate over.
        :type key: `str`
        :return: results after aggregation
        :rtype: `list`
        """        
        results = None
        try:
            results = np.array(modelUpdates.get(key))
        except Exception as ex:
            results = IterAvgFusionHandler.transform_update_to_np_array(modelUpdates.get(key))

        return results

    def update_weights(self, lst_model_updates):
        """
        Update the global model's weights with the list of collected
        model_updates from parties.
        In this method, it calls the self.fusion_collected_response to average
        the local model weights collected from parties and update the current
        global model weights by the results from self.fusion_collected_response.

        :param lst_model_updates: list of model updates of type `ModelUpdate` to be averaged.
        :type lst_model_updates: `list`
        :return: None
        """
        for tier_id in lst_model_updates.keys():
            if self.curr_round > self.pre_training_rounds:
                self.current_model_weights_per_tier[int(tier_id)] = self.fusion_collected_responses_by_tier(lst_model_updates[tier_id])
            else:
                updated_weights = self.fusion_collected_responses_by_tier(lst_model_updates[tier_id])
                for tier_id in range(self.current_tiers):
                    self.current_model_weights_per_tier[int(tier_id)] = updated_weights


    def fusion_collected_responses_by_tier(self, modelUpdates, key='weights'):
        """
        Receives a list of model updates, where a model update is of the type
        `ModelUpdate`, using the values (indicating by the key)
        included in each model_update, it finds the mean.

        :param lst_model_updates: List of model updates of type `ModelUpdate` \
        to be averaged.
        :type lst_model_updates:  `list`
        :param key: A key indicating what values the method will aggregate over.
        :type key: `str`
        :return: results after aggregation
        :rtype: `list`
        """
        # v = []
        w = []
        n_k = []
        if type(modelUpdates) is tuple:
            modelUpdates = modelUpdates[0]
        for update in modelUpdates:
            try:
                # update = np.array(update.get(key))
                w.append(np.array(update.get('weights')))
                n_k.append(update.get('train_counts'))
            except ModelUpdateException as ex:
                logger.exception(ex)
                raise FusionException("Model updates are not appropriate for this fusion method.  Check local training.")
        n_norm = n_k / (np.sum(n_k) + self._eps)
        weights = np.sum([w[i] * n_norm[i] for i in range(len(n_k))], axis=0)

        return weights

    def reach_termination_criteria(self, curr_round):
        """
        Returns True when termination criteria has been reached, otherwise
        returns False.
        Termination criteria is reached when the number of rounds run reaches
        the one provided as global rounds hyperparameter.
        If a `DataHandler` has been provided and a targeted accuracy has been
        given in the list of hyperparameters, early termination is verified.

        :param curr_round: Number of global rounds that already run
        :type curr_round: `int`
        :return: boolean
        :rtype: `boolean`
        """
      
        if curr_round >= self.rounds:
            logger.info('Reached maximum global rounds. Finish training :) ')
            return True

        return self.terminate_with_metrics(curr_round)
    
    def send_global_models(self):
        """
        Send global models to all the parties
        """
        # Select data parties
        lst_parties = self.ph.get_available_parties()

        model_updates = self.get_global_models()
        payload = {'model_updates': model_updates
                   }

        logger.info('Sync Global Models' + str(model_updates))
        self.ph.sync_model_parties(lst_parties, payload)
        
    def get_global_model(self):
        """
        Returns last model_update

        :return: model_update
        :rtype: `ModelUpdate`
        """
        return ModelUpdate(weights=self.current_model_weights)
    
    def get_global_models(self):
        """
        Returns last model_updates

        :return: model_updates
        :rtype: `List of ModelUpdate`
        """
        global_models = []
        for i in range(0, self.current_tiers):
            global_models.append(ModelUpdate(weights=self.current_model_weights_per_tier[i]))
        return global_models

    def get_current_metrics(self):
        """Returns metrics pertaining to current state of fusion handler

        :return: metrics
        :rtype: `dict`
        """
        fh_metrics = {}
        fh_metrics['rounds'] = self.rounds
        fh_metrics['curr_round'] = self.curr_round
        fh_metrics['acc'] = self.global_accuracy
        #fh_metrics['model_update'] = self.model_update
        return fh_metrics

    @staticmethod
    def transform_update_to_np_array(update):
        """
        Transform a update of type list of numpy.ndarray to a numpy.ndarray 
        of numpy.ndarray.
        This method is a way to resolve the ValueError raised by numpy when 
        all the numpy.ndarray inside the provided list have the same 
        first dimension.

        A example of the possible case:
        a = [b, c], where a is of type list, b and c is of type numpy.ndarray.
        When b.shape[0] == c.shape[0] and b.shape[1] != c.shape[1], 
        the following line of code will cause numpy to raise a ValueError: 
        Could not broadcast input array from shape XXX(b.shape) into shape XX (c.shape).

        np.array(a)

        :param update: The input list of numpy.ndarray.
        :type update: `list`
        :return: the resulting update of type numpy.ndarray
        :rtype: `np.ndarray`
        """
        if update[0].shape[0]!= 2:
            update.append(np.zeros((2,)))
            update = np.array(update)
        else:
            update.append(np.zeros((3,)))
            update = np.array(update)
        return update[:-1]
