from __future__ import annotations

import copy
import datetime
import numpy as np
import sys
import math
import time
import torch
import torch.nn.functional as F
import models
from itertools import compress
from config import cfg
from collections import defaultdict
from jenkspy import JenksNaturalBreaks

from _typing import (
    DatasetType,
    OptimizerType,
    DataLoaderType,
    ModelType,
    MetricType,
    LoggerType,
    ClientType,
    ServerType
)

from models.api import create_simp_objective

from optimizer.api import create_optimizer
from .serverBase import ServerBase

from utils.api import (
    to_device,  
    collate
)

from data import (
    fetch_dataset, 
    split_dataset, 
    make_data_loader, 
    separate_dataset, 
    make_batchnorm_dataset, 
    make_batchnorm_stats
)


class ServerSimpFedIncen(ServerBase):

    def __init__(
        self, 
        model: ModelType,
        clients: dict[int, ClientType],
        dataset: DatasetType,
        # test_dataset: DatasetType,
    ) -> None:

        super().__init__(dataset=dataset)
        self.server_model_state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
        server_optimizer = create_optimizer(model, 'server')
        self.server_optimizer_state_dict = server_optimizer.state_dict()
        self.clients = clients

        self.theta_0 = cfg['theta_0']
        self.theta_1 = cfg['theta_1']
        objective_model = create_simp_objective(self.theta_0, self.theta_1)
        self.fedincen_objective_model_state_dict = {k: v.cpu() for k, v in objective_model.state_dict().items()}
        fedincen_objective_model_optimizer = create_optimizer(objective_model, 'fedincen_objective')
        self.fedincen_objective_model_optimizer_state_dict = fedincen_objective_model_optimizer.state_dict()
        
        # following data structures are all for incentivization
        # cg is the short for collaboration gain
        # ig is the short for individual gain
        '''
        Logger will record the following things about incentiviation:
            1. List: participation_client_ids (Done)
                participation client ids from round 1 to t
            2. List: active_client_ids (Done)
                active client ids from round 1 to t
            3. List: server_collaboration_gains (Done)
                server collaboration gains from round 1 to t
            4. List: server_aggregated_participation_clients_participation_costs (Done)
                all participation clients' participation cost from round 1 to t
            5. List: server_aggregated_active_clients_participation_costs (Done)
                all active clients' participation cost from round 1 to t
            6. List[List]: participation_clients_participation_costs (Done)
                participation_clients_participation_costs from round 1 to t
            7. List[List]: active_clients_participation_costs (Done)
                active_clients_participation_costs from round 1 to t
            8. List[List]: active_clients_individual_gains (Done)
                active_clients_individual_gains from round 1 to t
        ''' 

        # list, records how many clients are willing to participate after seeing 
        # the pricing plan from round 1 to t
        self.participation_client_ids = [[-1],]
        # list, records how many active clients will be selected from the participated
        # clients from round 1 to t
        # For active clients, we mean the clients that will train their local models on their local dataset
        self.active_client_ids = [[-1],]
        # list. collaboration gain aggregated from current round's active clients' individual gain
        # from round 1 to t
        # Currently, we use negative Loss Value as the collaboration gain
        self.server_collaboration_gains = [0]
        self.server_aggregated_participation_clients_participation_costs = [0]
        self.server_aggregated_active_clients_participation_costs = [0]

        self.active_client_losses = []
        self.malicious_client_ids_all = []
        # list. Dimension of num_clients. Initialization as 0. Updating the last active round for all clients
        # self.clients_last_active_rounds = [0 for _ in range(cfg['num_clients'])]
        self.clients_participated_active_rounds = [[0] for _ in range(cfg['num_clients'])]

        # list. Dimension of num_clients. Initialization as -50. 
        # Updating the clients_last_active_round_vidual_gain for all clients
        # self.clients_last_active_round_individual_gains = [0 for _ in range(cfg['num_clients'])]
        self.clients_participated_active_round_individual_gains = [[0] for _ in range(cfg['num_clients'])]

        self.malicious_clients_participation_costs = []
        self.benign_clients_participation_costs = []
        # list. Record the value (gradient of client model * (server model - client model))
        # of client's last active round.
        # To avoid store the whole models
        self.clients_last_active_round_marginal_gain_second_terms = [0 for _ in range(cfg['num_clients'])]
        # self.theta = torch.ones(2, 1, requires_grad=True)
        # self.theta = torch.tensor([[0.5], [5]], requires_grad=True)
        self.overall_realized_profit = 0
        self.temp_round = 1
        self.theta_1_diff = []
        

    def initiate_server_collaboration_gains(
        self,
        dataset,
        logger,
        metric
    ):
        server_loss_initilization = super().initiate_server_collaboration_gains(
            dataset=dataset,
            logger=logger,
            metric=metric,
            server_model_state_dict=self.server_model_state_dict,
        )
        server_cg_initilization = self.convert_loss_to_gain(server_loss_initilization)
        self.server_collaboration_gains[0] = server_cg_initilization

        # self.clients_last_active_round_individual_gains = [server_cg_initilization for _ in range(cfg['num_clients'])]
        self.clients_participated_active_round_individual_gains = [[server_cg_initilization] for _ in range(cfg['num_clients'])]
        print(f'Initial server collaboration gain: {server_cg_initilization}')
        return


    def server_pricing_strategy(
            self, 
            lr, 
            malicious_client_ids, 
            global_epoch,
            client_participated_active_round_client_igs,
            client_participated_active_round_server_cgs,
            # client_last_active_round_client_igs,
            theta_1
        ):
        '''
        update theta = (theta_0, theta_1)
        '''
        objective_model = create_simp_objective(self.theta_0, self.theta_1)
        self.fedincen_objective_model_state_dict['theta_1'] = torch.tensor(theta_1)
        objective_model.load_state_dict(self.fedincen_objective_model_state_dict, strict=False)
        # maximize the objective function, change lr to negative
        # TODO
        self.fedincen_objective_model_optimizer_state_dict['param_groups'][0]['lr'] = lr
        # self.fedincen_objective_model_optimizer_state_dict['param_groups'][0]['lr'] = -lr
        objective_optimizer = create_optimizer(objective_model, 'fedincen_objective')
        objective_optimizer.load_state_dict(self.fedincen_objective_model_optimizer_state_dict)

        input = {
            'active_client_ids': self.active_client_ids,
            # 'clients_last_active_round_individual_gains': self.clients_last_active_round_individual_gains,
            'client_participated_active_round_client_igs': client_participated_active_round_client_igs,
            'client_participated_active_round_server_cgs': client_participated_active_round_server_cgs,
            # 'clients_last_active_rounds': self.clients_last_active_rounds,
            'clients_participated_active_rounds': self.clients_participated_active_rounds,
            'server_collaboration_gains': self.server_collaboration_gains,
            'clients_last_active_round_marginal_gain_second_terms': self.clients_last_active_round_marginal_gain_second_terms,
            'malicious_client_ids': malicious_client_ids,
            'active_client_losses': self.active_client_losses,
            'malicious_client_ids_all': self.malicious_client_ids_all,
            'global_epoch': global_epoch,
            'theta_1_diff': self.theta_1_diff,
            'theta_1': theta_1
        }

        # input = to_device(input, cfg['device'])
        objective_optimizer.zero_grad()
        output = objective_model(input)
        output.backward()
        torch.nn.utils.clip_grad_norm_(objective_model.parameters(), max_norm=1)
        objective_optimizer.step()

        self.fedincen_objective_model_optimizer_state_dict = objective_optimizer.state_dict()
        self.fedincen_objective_model_state_dict = {k: v.cpu() for k, v in objective_model.state_dict().items()}
        self.theta_0 = self.fedincen_objective_model_state_dict['theta_0'].item()
        self.theta_1 = self.fedincen_objective_model_state_dict['theta_1'].item()
        return

    def average(self, list_a):
        return sum(list_a) / len(list_a)
    
    def get_client_strategy_info_from_all_clients(self):
        '''
        Get the following infos:
            1. client_last_active_round_client_igs
                individual gain list consisting of all clients' last active round client ig
            2. client_last_active_round_server_cgs
                collaboration gain list consisting of all clients' last active round server cg
        '''
        # participation_cost_list = []
        client_last_active_round_client_igs = []
        # client_last_active_round_server_cgs = []
        # cur_client_last_last_active_round_server_cgs = []

        client_participated_active_round_client_igs = []
        client_participated_active_round_server_cgs = []
        # cur_client_last_last_active_round_server_cgs = []
        for client_id in range(cfg['num_clients']):
            cur_client_participated_active_round_client_igs = self.clients_participated_active_round_individual_gains[client_id][-cfg['pricing_interval']:]
            cur_client_last_active_round_client_ig = cur_client_participated_active_round_client_igs[-1]
            # cur_client_last_active_round = self.clients_last_active_rounds[client_id]
            cur_client_participated_active_rounds = self.clients_participated_active_rounds[client_id][-cfg['pricing_interval']:]
            # cur_client_last_active_round = cur_client_participated_active_rounds[-1]

            cur_client_participated_active_round_server_cgs = []
            for participated_active_round in cur_client_participated_active_rounds:
                cur_client_participated_active_round_server_cgs.append(self.server_collaboration_gains[participated_active_round])

            client_last_active_round_client_igs.append(cur_client_last_active_round_client_ig)

            client_participated_active_round_client_igs.append(self.average(cur_client_participated_active_round_client_igs))
            client_participated_active_round_server_cgs.append(self.average(cur_client_participated_active_round_server_cgs))

        return client_participated_active_round_client_igs, client_participated_active_round_server_cgs, client_last_active_round_client_igs
    
    def train(
        self,
        dataset: DatasetType,  
        optimizer: OptimizerType, 
        fedincen_objective_model_optimizer,
        metric: MetricType, 
        logger: LoggerType, 
        global_epoch: int,
        malicious_client_ids
    ):  
        print(f'simpleFedIncen', flush=True)
        
        client_participated_active_round_client_igs, client_participated_active_round_server_cgs, client_last_active_round_client_igs = \
            self.get_client_strategy_info_from_all_clients()
        
        z_diff = []
        malicious_client_z_diff = []
        all_z_diff = []
        malicious_client_ids.sort()
        for i in range(cfg['num_clients']):
            if i not in malicious_client_ids:
                z_diff.append(client_participated_active_round_server_cgs[i] - client_participated_active_round_client_igs[i])
            else:
                malicious_client_z_diff.append(client_participated_active_round_server_cgs[i] - client_participated_active_round_client_igs[i])

            all_z_diff.append(client_participated_active_round_server_cgs[i] - client_participated_active_round_client_igs[i])
        
        temp = np.array(z_diff)
        temp = temp[temp != 0]
        if len(temp) == 0:
            z_diff_mean = 0
            z_diff_std = 0
            z_diff_var = 0
        else:
            z_diff_mean = np.mean(temp)
            z_diff_std = np.std(temp)
            z_diff_var = np.var(temp)
        
        temp = np.array(malicious_client_z_diff)
        temp = temp[temp != 0]
        if len(temp) == 0:
            malicious_client_z_diff_mean = 0
            malicious_client_z_diff_std = 0
            malicious_client_z_diff_var = 0
        else:
            malicious_client_z_diff_mean = np.mean(temp)
            malicious_client_z_diff_std = np.std(temp)
            malicious_client_z_diff_var = np.var(temp)

        logger.append(
            {   
                f'z_diff': z_diff,
            }, 
            'train', 
            n=len(z_diff)
        )

        logger.append(
            {   
                f'malicious_client_z_diff': malicious_client_z_diff,
            }, 
            'train', 
            n=len(malicious_client_z_diff)
        )

        logger.append(
            {   
                f'all_z_diff': all_z_diff,
            }, 
            'train', 
            n=len(all_z_diff)
        )

        logger.append(
            {   
                f'z_diff_mean': z_diff_mean,
                f'z_diff_std': z_diff_std,
                f'z_diff_var': z_diff_var,
                f'malicious_client_z_diff_mean': malicious_client_z_diff_mean,
                f'malicious_client_z_diff_std': malicious_client_z_diff_std,
                f'malicious_client_z_diff_var': malicious_client_z_diff_var,
            }, 
            'train', 
        )


        jnb = JenksNaturalBreaks(2)
        if global_epoch > 5 and len(malicious_client_ids) > 0:

            jnb.fit(all_z_diff)
            self.theta_1 = jnb.breaks_[1]

        print(f'jnb_split: {self.theta_1}', flush=True)
        # update theta
        if global_epoch > 0 and len(malicious_client_ids) > 0:
            self.server_pricing_strategy(
                lr=fedincen_objective_model_optimizer.param_groups[0]['lr'],
                malicious_client_ids=malicious_client_ids,
                global_epoch=global_epoch,
                client_participated_active_round_client_igs=client_participated_active_round_client_igs,
                client_participated_active_round_server_cgs=client_participated_active_round_server_cgs,
                # client_last_active_round_client_igs=client_last_active_round_client_igs,
                theta_1=self.theta_1 
            )


        last_round_server_cg = self.server_collaboration_gains[-1]
        objective_model = create_simp_objective(self.theta_0, self.theta_1)
        clients_indicator_list = objective_model.calculate_client_participation_decision(
            self.server_collaboration_gains,
            client_participated_active_round_client_igs, 
            client_participated_active_round_server_cgs, 
            malicious_client_ids
        )       


        
        selected_client_ids, num_active_clients, cur_round_participation_client_ids = super().select_clients(
            clients=self.clients, 
            clients_indicator_list=clients_indicator_list,
            global_epoch=global_epoch
        )


        benign_clients = list(set(range(cfg['num_clients'])) - set(malicious_client_ids))
        intersect_num = len(np.intersect1d( np.array(benign_clients), np.array(cur_round_participation_client_ids)))
        benign_identify_ratio = intersect_num / (len(cur_round_participation_client_ids) + 1e-5)

        intersect_num = len(np.intersect1d(np.array(malicious_client_ids), np.array(list(set(range(cfg['num_clients'])) - set(cur_round_participation_client_ids)))))
        malicious_identify_ratio = intersect_num / (len(malicious_client_ids) + 1e-5)

        logger.append(
            {   
                f'benign_identify_ratio': benign_identify_ratio,
                f'malicious_identify_ratio': malicious_identify_ratio,
            }, 
            'train', 
        )

        self.participation_client_ids.append(copy.deepcopy(cur_round_participation_client_ids))

        self.active_client_ids.append(copy.deepcopy(selected_client_ids))
        
        logger.append(
            {
                f'participation_client_ids': cur_round_participation_client_ids,
            }, 
            'train', 
            n=len(cur_round_participation_client_ids)
        )

        logger.append(
            {
                f'active_client_ids': selected_client_ids,
            }, 
            'train', 
            n=len(selected_client_ids)
        )

        super().distribute_server_model_to_clients(
            server_model_state_dict=self.server_model_state_dict,
            clients=self.clients
        )
        start_time = time.time()
        lr = optimizer.param_groups[0]['lr']

        dataset_list = []
        for i in range(cfg['num_clients']):
            dataset_list.append(separate_dataset(dataset, self.clients[i].data_split['train']))

        processed_client_count = 0
        for i in range(num_active_clients):
            m = selected_client_ids[i]
            # dataset_m = separate_dataset(dataset, self.clients[m].data_split['train'])
            dataset_m = copy.deepcopy(dataset_list[m])
            if dataset_m is None:
                self.clients[m].active = False
            else:
                self.clients[m].active = True
                self.clients[m].train(
                    dataset=dataset_m, 
                    lr=lr, 
                    metric=metric, 
                    logger=logger,
                    client_id=m,
                    malicious_client_ids=malicious_client_ids
                )
            processed_client_count += 1
            super().add_log(
                i=processed_client_count,
                num_active_clients=len(selected_client_ids),
                start_time=start_time,
                global_epoch=global_epoch,
                lr=lr,
                selected_client_ids=selected_client_ids,
                metric=metric,
                logger=logger,
            )
        super().update_server_model(clients=self.clients) 
        return
    
    def convert_loss_to_gain(self, x):
        if isinstance(x, list) or isinstance(x, (np.ndarray, np.generic)):
            gain = [-i for i in x]
        else:
            gain = -x
        return gain
    

    def check_torch(self, x):
        if torch.is_tensor(x):
            print ("The input object is a Tensor.")
            return x.item()
        else:
            print ("The input object is not a Tensor.")
            return x



    def evaluate_trained_model(
        self,
        dataset,
        batchnorm_dataset,
        logger,
        metric,
        global_epoch,
        malicious_client_ids
    ):  

        better_model, best_server_loss, cur_round_server_loss, cur_round_active_clients_losses, cur_round_marginal_gain_second_terms = super().evaluate_trained_model(
            dataset=dataset,
            batchnorm_dataset=batchnorm_dataset,
            logger=logger,
            metric=metric,
            global_epoch=global_epoch,
            server_model_state_dict=self.server_model_state_dict,
            clients=self.clients,
            cur_round_active_client_ids=self.active_client_ids[-1]
        )
        best_server_gain = self.convert_loss_to_gain(best_server_loss)
        cur_round_server_cg = self.convert_loss_to_gain(copy.deepcopy(cur_round_server_loss))
        cur_round_active_clients_igs = self.convert_loss_to_gain(copy.deepcopy(cur_round_active_clients_losses))

        self.active_client_losses.append(copy.deepcopy(cur_round_active_clients_losses))
        
        is_malicious_client = []
        malicious_z_diff = []
        non_malicious_z_diff = []
        for i in range(len(self.active_client_ids[-1])):
            active_client_id = self.active_client_ids[-1][i]
            if active_client_id in malicious_client_ids:
                is_malicious_client.append(1)
                malicious_z_diff.append(cur_round_server_cg - cur_round_active_clients_igs[i])
            else:
                is_malicious_client.append(0)
                non_malicious_z_diff.append(cur_round_server_cg - cur_round_active_clients_igs[i])


        temp = np.array(non_malicious_z_diff)
        temp = temp[temp != 0]
        if len(temp) == 0:
            z_diff_mean = 0
            z_diff_std = 0
            z_diff_var = 0
        else:
            z_diff_mean = np.mean(temp)
            z_diff_std = np.std(temp)
            z_diff_var = np.var(temp)
        
        temp = np.array(malicious_z_diff)
        temp = temp[temp != 0]
        if len(temp) == 0:
            malicious_client_z_diff_mean = 0
            malicious_client_z_diff_std = 0
            malicious_client_z_diff_var = 0
        else:
            malicious_client_z_diff_mean = np.mean(temp)
            malicious_client_z_diff_std = np.std(temp)
            malicious_client_z_diff_var = np.var(temp)

        logger.append(
            {   
                f'cur_round_z_diff': non_malicious_z_diff,
            }, 
            'train', 
            n=len(non_malicious_z_diff)
        )

        logger.append(
            {   
                f'cur_round_malicious_client_z_diff': malicious_z_diff,
            }, 
            'train', 
            n=len(malicious_z_diff)
        )

        logger.append(
            {   
                f'cur_round_z_diff_mean': z_diff_mean,
                f'cur_round_z_diff_std': z_diff_std,
                f'cur_round_z_diff_var': z_diff_var,
                f'cur_round_malicious_client_z_diff_mean': malicious_client_z_diff_mean,
                f'cur_round_malicious_client_z_diff_std': malicious_client_z_diff_std,
                f'mcur_round_alicious_client_z_diff_var': malicious_client_z_diff_var,
            }, 
            'train', 
        )


        self.malicious_client_ids_all.append(copy.deepcopy(is_malicious_client))
        self.server_collaboration_gains.append(best_server_gain)
        for i in range(len(self.active_client_ids[-1])):
            active_client_id = self.active_client_ids[-1][i]
            self.clients_participated_active_rounds[active_client_id].append(global_epoch)
            self.clients_participated_active_round_individual_gains[active_client_id].append(cur_round_active_clients_igs[i])
            self.clients_last_active_round_marginal_gain_second_terms[active_client_id] = cur_round_marginal_gain_second_terms[i]
        
        cur_round_participation_clients_id = self.participation_client_ids[-1]
        cur_round_active_client_ids=self.active_client_ids[-1]


        objective_model = create_simp_objective(self.theta_0, self.theta_1)
        cur_round_participation_clients_participation_costs, cur_round_active_clients_participation_costs, cur_round_malicious_clients_participation_costs, cur_round_benign_clients_participation_costs = objective_model.calculate_client_participation_cost(
            cur_round_server_cg,
            cur_round_active_clients_igs,
            malicious_client_ids,
            cur_round_participation_clients_id = cur_round_participation_clients_id,
            cur_round_active_client_ids=cur_round_active_client_ids
        )

        self.malicious_clients_participation_costs.append(sum(cur_round_malicious_clients_participation_costs))
        self.benign_clients_participation_costs.append(sum(cur_round_benign_clients_participation_costs))
        self.server_aggregated_participation_clients_participation_costs.append(sum(cur_round_participation_clients_participation_costs))
        self.server_aggregated_active_clients_participation_costs.append(sum(cur_round_active_clients_participation_costs))
        
        last_round_server_cg = self.server_collaboration_gains[-2]
        if cfg['lambda'] == 0:
            self.overall_realized_profit += (cur_round_server_cg - last_round_server_cg)
        elif cfg['lambda'] == 999999:
            self.overall_realized_profit += cfg['lambda_for_infinite_situation'] * sum(cur_round_participation_clients_participation_costs)
        else:
            self.overall_realized_profit += cfg['lambda'] * sum(cur_round_participation_clients_participation_costs) + (cur_round_server_cg - last_round_server_cg)

        for i in range(len(self.participation_client_ids[-1])):
            client_id = self.participation_client_ids[-1][i]
            participation_cost = cur_round_participation_clients_participation_costs[i]
            logger.append(
            {   
                f'participation_cost': participation_cost,
            }, 
            f'train_{client_id}_participation_cost', 
        )

        logger.append(
            {   
                f'overall_realized_profit': self.overall_realized_profit,
                f'server_collaboration_gains': cur_round_server_cg,
                f'server_aggregated_participation_clients_participation_costs': sum(cur_round_participation_clients_participation_costs),
                f'server_aggregated_active_clients_participation_costs': sum(cur_round_active_clients_participation_costs),
                f'malicious_clients_participation_costs_sum': sum(cur_round_malicious_clients_participation_costs), 
                f'malicious_clients_participation_costs_num': len(cur_round_malicious_clients_participation_costs),
                f'benign_clients_participation_costs_sum': sum(cur_round_benign_clients_participation_costs),
                f'benign_clients_participation_costs_num': len(cur_round_benign_clients_participation_costs),
            }, 
            'train', 
        )

        logger.append(
            {   
                f'participation_clients_participation_costs': cur_round_participation_clients_participation_costs,                
            }, 
            'train', 
            n=len(cur_round_participation_clients_participation_costs)
        )

        logger.append(
            {   
                f'malicious_client_ids': malicious_client_ids,            
            }, 
            'train', 
            n=len(malicious_client_ids)
        )

        if len(cur_round_malicious_clients_participation_costs) < 0:
            cur_round_malicious_clients_participation_costs = [0]
        logger.append(
            {   
                f'malicious_clients_participation_costs': cur_round_malicious_clients_participation_costs,                
            }, 
            'train', 
            n=len(cur_round_malicious_clients_participation_costs)
        )

        logger.append(
            {   
                f'benign_clients_participation_costs': cur_round_benign_clients_participation_costs,                
            }, 
            'train', 
            n=len(cur_round_benign_clients_participation_costs)
        )
        logger.append(
            {   
                f'active_clients_participation_costs': cur_round_active_clients_igs,
                f'active_clients_individual_gains': cur_round_active_clients_igs,
            }, 
            'train', 
            n=len(cur_round_active_clients_igs)
        )
        return 
    










    # def train(
    #     self,
    #     dataset: DatasetType,  
    #     optimizer: OptimizerType, 
    #     fedincen_objective_model_optimizer,
    #     metric: MetricType, 
    #     logger: LoggerType, 
    #     global_epoch: int,
    #     malicious_client_ids
    # ):  
    #     print(f'simpleFedIncen', flush=True)
        
    #     client_participated_active_round_client_igs, client_participated_active_round_server_cgs, client_last_active_round_client_igs = \
    #         self.get_client_strategy_info_from_all_clients()
        
    #     z_diff = []
    #     malicious_client_z_diff = []
    #     all_z_diff = []
    #     malicious_client_ids.sort()
    #     for i in range(cfg['num_clients']):
    #         if i not in malicious_client_ids:
    #             z_diff.append(client_participated_active_round_server_cgs[i] - client_participated_active_round_client_igs[i])
    #         else:
    #             malicious_client_z_diff.append(client_participated_active_round_server_cgs[i] - client_participated_active_round_client_igs[i])

    #         all_z_diff.append(client_participated_active_round_server_cgs[i] - client_participated_active_round_client_igs[i])
        
    #     # z_diff = [client_participated_active_round_server_cgs[i] - client_participated_active_round_client_igs[i] for i in range(cfg['num_clients'])]
        
    #     temp = np.array(z_diff)
    #     temp = temp[temp != 0]
    #     if len(temp) == 0:
    #         z_diff_mean = 0
    #         z_diff_std = 0
    #         z_diff_var = 0
    #     else:
    #         z_diff_mean = np.mean(temp)
    #         z_diff_std = np.std(temp)
    #         z_diff_var = np.var(temp)
    #     print(f"{cfg['model_tag']} z_diff: {z_diff}", flush=True)
    #     print(f"{cfg['model_tag']} z_diff mean: {z_diff_mean}", flush=True)
    #     print(f"{cfg['model_tag']} z_diff std: {z_diff_std}", flush=True)
    #     print(f"{cfg['model_tag']} z_diff var: {z_diff_var}", flush=True)
        
    #     temp = np.array(malicious_client_z_diff)
    #     temp = temp[temp != 0]
    #     if len(temp) == 0:
    #         malicious_client_z_diff_mean = 0
    #         malicious_client_z_diff_std = 0
    #         malicious_client_z_diff_var = 0
    #     else:
    #         malicious_client_z_diff_mean = np.mean(temp)
    #         malicious_client_z_diff_std = np.std(temp)
    #         malicious_client_z_diff_var = np.var(temp)
    #     print(f"{cfg['model_tag']} malicious_client_z_diff: {malicious_client_z_diff}", flush=True)
    #     print(f"{cfg['model_tag']} malicious_client_z_diff mean: {malicious_client_z_diff_mean}", flush=True)
    #     print(f"{cfg['model_tag']} malicious_client_z_diff std: {malicious_client_z_diff_std}", flush=True)
    #     print(f"{cfg['model_tag']} malicious_client_z_diff var: {malicious_client_z_diff_var}", flush=True)


    #     logger.append(
    #         {   
    #             f'z_diff': z_diff,
    #         }, 
    #         'train', 
    #         n=len(z_diff)
    #     )

    #     logger.append(
    #         {   
    #             f'malicious_client_z_diff': malicious_client_z_diff,
    #         }, 
    #         'train', 
    #         n=len(malicious_client_z_diff)
    #     )

    #     logger.append(
    #         {   
    #             f'all_z_diff': all_z_diff,
    #         }, 
    #         'train', 
    #         n=len(all_z_diff)
    #     )

    #     logger.append(
    #         {   
    #             f'z_diff_mean': z_diff_mean,
    #             f'z_diff_std': z_diff_std,
    #             f'z_diff_var': z_diff_var,
    #             f'malicious_client_z_diff_mean': malicious_client_z_diff_mean,
    #             f'malicious_client_z_diff_std': malicious_client_z_diff_std,
    #             f'malicious_client_z_diff_var': malicious_client_z_diff_var,
    #         }, 
    #         'train', 
    #     )

    #     # update theta
    #     # self.server_pricing_strategy(
    #     #     lr=fedincen_objective_model_optimizer.param_groups[0]['lr'],
    #     #     malicious_client_ids=malicious_client_ids,
    #     #     global_epoch=global_epoch,

    #     # )

    #     last_round_server_cg = self.server_collaboration_gains[-1]
    #     cur_round_participation_client_ids = super().distribute_client_strategy_info_to_all_clients(
    #         global_epoch=global_epoch,
    #         client_last_active_round_client_igs=client_last_active_round_client_igs,
    #         client_last_acive_round_server_cgs=None,
    #         cur_client_last_last_active_round_server_cgs=None,
    #         last_round_server_cg=last_round_server_cg,
    #         theta_0=0,
    #         # theta_1=self.fedincen_objective_model_state_dict['theta'][1].item()
    #         theta_1=self.theta_1,
    #         malicious_client_ids=malicious_client_ids,
    #         client_participated_active_round_client_igs=client_participated_active_round_client_igs,
    #         client_participated_active_round_server_cgs=client_participated_active_round_server_cgs
    #     )

    #     # cur_round_participation_client_ids = []
    #     if global_epoch > 25 and len(malicious_client_ids) > 0:
    #         cur_round_participation_client_ids = []
    #         jnb = JenksNaturalBreaks(2)
    #         # cur_combine_list = z_diff + malicious_client_z_diff
    #         jnb.fit(all_z_diff)
    #         for i in range(cfg['num_clients']):
    #             # benign_client
    #             if all_z_diff[i] <= jnb.breaks_[1]:
    #                 cur_round_participation_client_ids.append(i)


    #     benign_clients = list(set(range(cfg['num_clients'])) - set(malicious_client_ids))
    #     intersect_num = len(np.intersect1d( np.array(benign_clients), np.array(cur_round_participation_client_ids)))
    #     # print('benign', intersect_num / benign_client_num, "\n")
    #     benign_identify_ratio = intersect_num / len(cur_round_participation_client_ids)

    #     intersect_num = len(np.intersect1d(np.array(malicious_client_ids), np.array(list(set(range(cfg['num_clients'])) - set(cur_round_participation_client_ids)))))
    #     # print('benign', intersect_num / benign_client_num, "\n")
    #     malicious_identify_ratio = intersect_num / len(malicious_client_ids)

    #     logger.append(
    #         {   
    #             f'benign_identify_ratio': benign_identify_ratio,
    #             f'malicious_identify_ratio': malicious_identify_ratio,
    #         }, 
    #         'train', 
    #         # n=len(all_z_diff)
    #     )
    #     # print(f"{cfg['model_tag']} server_collaboration_gains: {self.server_collaboration_gains}")
    #     self.participation_client_ids.append(copy.deepcopy(cur_round_participation_client_ids))
    #     print(f"{cfg['model_tag']} cur_round_participation_client_ids_length: {len(cur_round_participation_client_ids)}", flush=True)
    #     selected_client_ids, num_active_clients = super().select_clients(
    #         clients=self.clients, 
    #         cur_round_participation_client_ids=cur_round_participation_client_ids,
    #     )
    #     # print(f"{cfg['model_tag']} selected_client_ids_length: {len(selected_client_ids)}", flush=True)
    #     # print(f"{cfg['model_tag']} selected_client_ids: {selected_client_ids}", flush=True)
    #     self.active_client_ids.append(copy.deepcopy(selected_client_ids))
        
    #     logger.append(
    #         {
    #             f'participation_client_ids': cur_round_participation_client_ids,
    #         }, 
    #         'train', 
    #         n=len(cur_round_participation_client_ids)
    #     )

    #     logger.append(
    #         {
    #             f'active_client_ids': selected_client_ids,
    #         }, 
    #         'train', 
    #         n=len(selected_client_ids)
    #     )

    #     super().distribute_server_model_to_clients(
    #         server_model_state_dict=self.server_model_state_dict,
    #         clients=self.clients
    #     )
    #     start_time = time.time()
    #     lr = optimizer.param_groups[0]['lr']

    #     dataset_list = []
    #     for i in range(cfg['num_clients']):
    #         dataset_list.append(separate_dataset(dataset, self.clients[i].data_split['train']))

    #     processed_client_count = 0
    #     for i in range(num_active_clients):
    #         m = selected_client_ids[i]
    #         # dataset_m = separate_dataset(dataset, self.clients[m].data_split['train'])
    #         dataset_m = copy.deepcopy(dataset_list[m])
    #         if dataset_m is None:
    #             self.clients[m].active = False
    #         else:
    #             self.clients[m].active = True
    #             self.clients[m].train(
    #                 dataset=dataset_m, 
    #                 lr=lr, 
    #                 metric=metric, 
    #                 logger=logger,
    #                 client_id=m,
    #                 malicious_client_ids=malicious_client_ids
    #             )
    #         processed_client_count += 1
    #         super().add_log(
    #             i=processed_client_count,
    #             num_active_clients=len(selected_client_ids),
    #             start_time=start_time,
    #             global_epoch=global_epoch,
    #             lr=lr,
    #             selected_client_ids=selected_client_ids,
    #             metric=metric,
    #             logger=logger,
    #         )
        
    #     # super().distribute_server_model_to_rest_clients(
    #     #         server_model_state_dict=self.server_model_state_dict,
    #     #         clients=self.clients,
    #     #         rest_client_ids=rest_client_ids
    #     #     )
    #     # for i in range(len(rest_client_ids)):
    #     #     client_id = rest_client_ids[i]  
    #     #     # dataset_client_id = separate_dataset(dataset, self.clients[client_id].data_split['train'])
    #     #     dataset_client_id = copy.deepcopy(dataset_list[client_id])
    #     #     # if dataset_m is None:
    #     #     #     self.clients[m].active = False
    #     #     # else:
    #     #     #     self.clients[m].active = True
    #     #     self.clients[client_id].train(
    #     #         dataset=dataset_client_id, 
    #     #         lr=lr, 
    #     #         metric=metric, 
    #     #         logger=logger,
    #     #     )
    #     #     processed_client_count += 1
    #     #     super().add_log(
    #     #         i=processed_client_count,
    #     #         num_active_clients=cfg['num_clients'],
    #     #         start_time=start_time,
    #     #         global_epoch=global_epoch,
    #     #         lr=lr,
    #     #         selected_client_ids=selected_client_ids,
    #     #         metric=metric,
    #     #         logger=logger,
    #     #     )

    #     super().update_server_model(clients=self.clients) 
    #     return