from typing import List, Tuple, Union, Dict
from models import CNN
from cocofl_client import cocofl_client
import torch
import flwr as fl
import random
from flwr.common import Metrics
from flwr.common import FitIns, FitRes
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.common import ndarrays_to_parameters
from flwr.server.strategy.aggregate import weighted_loss_avg, aggregate
from flwr.common.logger import log
from logging import WARNING
from dataset import EmnistDataset
from util import get_filters, get_parameters, parameters_to_ndarrays
CHANNEL = 1
Batch = 16
CLASSES = 62

def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
  # Multiply accuracy of each client by number of examples used
  accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
  examples = [num_examples for num_examples, _ in metrics]
  # Aggregate and return custom metric (weighted average)
  return {"accuracy": sum(accuracies) / sum(examples)}

class cocoFL_strategy(fl.server.strategy.FedAvg):
    def __init__(self, ff, fe, mfc, mec, mac, ACC=[], ClientsSelection=[]):
        super().__init__(fraction_fit=ff, fraction_evaluate=fe, min_fit_clients=mfc, min_evaluate_clients=mec, min_available_clients=mac, evaluate_metrics_aggregation_fn=weighted_average)
        self.fraction_fit_=ff,
        self.fraction_evaluate_=fe,
        self.min_fit_clients_=mfc,
        self.min_evaluate_clients_=mec,
        self.min_available_clients_=mac
        self.global_model = CNN(CHANNEL, outputs=CLASSES)
        self.accuracy_record = ACC
        '''
        self.personal_models = {}
        initial_parameters = get_filters(CNN(CHANNEL, outputs=CLASSES))
        for i in range(mac):
            self.personal_models[i] = initial_parameters
        '''
    def record_test_accuracy(self, acc):
        self.accuracy_record.append(acc)

    """override"""
    def initialize_parameters(self, client_manager: ClientManager):
        return ndarrays_to_parameters(get_parameters(self.global_model))
    
    """override"""
    def configure_fit(self, server_round: int, parameters, client_manager: ClientManager):
        random.seed(server_round)
        sample_size, min_num_clients = super().num_fit_clients(client_manager.num_available()) 
        clients = client_manager.sample(num_clients=sample_size, min_num_clients=min_num_clients)
        config_fit_list = []
        for client in clients:
            cid = int(client.cid)
            config = {}
            sub_parameters = get_filters(self.global_model)
            fit_ins = FitIns(ndarrays_to_parameters(sub_parameters), config)
            config_fit_list.append((client, fit_ins))
        return config_fit_list
    
    def aggregate_fit(self, server_round: int, results: List[Tuple[ClientProxy, FitRes]], failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]]):
      """override"""
      """Aggregate fit results using weighted average."""
      if not results:
        return None, {}
      # Do not aggregate if there are failures and failures are not accepted
      if not self.accept_failures and failures:
        return None, {}
      # Convert results
      Fit_res = []
      for _, fit_res in results:
        _, num = parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples
        #self.localmodels[cid] = param
        updated_layer = fit_res.metrics["updated layer"]  # This is a Dict
        Fit_res.append((updated_layer, num))
      #for params, size, rate in weights_results:
        new_model_dict = self.aggregate_updated_layer(Fit_res)
      # Aggregate custom metrics if aggregation fn was provided
      metrics_aggregated = {}
      if self.fit_metrics_aggregation_fn:
          fit_metrics = [(res.num_examples, res.metrics) for _, res in results]
          metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)
      elif server_round == 1:  # Only log this warning once
          log(WARNING, "No fit_metrics_aggregation_fn provided")
      self.global_model.load_state_dict(new_model_dict, strict=False)
      return get_parameters(self.global_model), metrics_aggregated
    
    def configure_evaluate(self, server_round: int, parameters, client_manager: ClientManager):
        """override"""
        if self.fraction_evaluate_ == 0.0:
            return []
        # Sample clients
        sample_size, min_num_clients = super().num_evaluation_clients(client_manager.num_available())
        clients = client_manager.sample(num_clients=sample_size, min_num_clients=min_num_clients)
        config_evaluate_list = []
        parameters = get_filters(self.global_model)
        for client in clients:
            config = {}
            #parameters = self.personal_models[int(client.cid)]
            fit_ins = FitIns(ndarrays_to_parameters(parameters), config)
            config_evaluate_list.append((client, fit_ins))
        return config_evaluate_list
    
    def aggregate_evaluate(self, server_round: int, results, failures):
        """Aggregate evaluation losses using weighted average."""
        if not results:
            return None, {}
        # Do not aggregate if there are failures and failures are not accepted
        if not self.accept_failures and failures:
            return None, {}
        # Aggregate loss
        loss_aggregated = weighted_loss_avg(
            [
                (evaluate_res.num_examples, evaluate_res.loss)
                for _, evaluate_res in results
            ]
        )
        metrics_aggregated = {}
        if self.evaluate_metrics_aggregation_fn:
            eval_metrics = [(1, res.metrics) for _, res in results]
            metrics_aggregated = self.evaluate_metrics_aggregation_fn(eval_metrics)
            self.record_test_accuracy(metrics_aggregated['accuracy'])
            print(f"COCOFL: Round {server_round}, test accuracy = {metrics_aggregated['accuracy']}")
        elif server_round == 1:  # Only log this warning once
            log(WARNING, "No evaluate_metrics_aggregation_fn provided")
        return loss_aggregated, metrics_aggregated
    
    def aggregate_updated_layer(self, results:Tuple[Dict,int]) -> Dict:
       # For frozen layers that do not get updated, use the old value:
       global_dict = {}
       for k, v in self.global_model.state_dict().items():
          global_dict[k] = v
    
       # Aggregate layers and update the global model:
       received_layer_dict = {} # key:str, value:list[(np.ndarray],num)]
       for local_dict, num in results:
          for k,v in local_dict.items():
             if k in received_layer_dict.keys():
                received_layer_dict[k].append((v,num))
             else:
                received_layer_dict[k] = [(v,num)]
       for k in received_layer_dict.keys():
          global_dict[k] = torch.tensor(aggregate(received_layer_dict[k]))
          
       # Return the updated global model (in the format of dict):   
       return global_dict

def get_lf_number(cid):
   lfs = [0, 1]
   z = int(cid) % 2
   return lfs[z]

def get_drop_rate(cid):
   if int(cid) < 50:
        return 0.5
   return 1.0

def cocofl_client_fn(cid) -> cocofl_client:
  Epoch = 5
  lf = get_lf_number(cid)
  rate = get_drop_rate(cid)
  dataset = EmnistDataset("clientdata/femnist_client_"+ str(cid) + "_ALPHA_0.1.csv")
  return cocofl_client(cid, dataset, Epoch, Batch, lf, rate)