from typing import List, Tuple, Union, Dict
from models import CNN
from STL_client import STL_client
import torch
import flwr as fl
import random
from flwr.common import Metrics
from flwr.common import FitIns, FitRes
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.common import ndarrays_to_parameters
from flwr.server.strategy.aggregate import weighted_loss_avg, aggregate
from flwr.common.logger import log
from logging import WARNING
from dataset import EmnistDataset
from util import get_filters, get_parameters, parameters_to_ndarrays, generate_subnet_ordered, Param_by_layers, set_filters
CHANNEL = 1
Batch = 16
CLASSES = 62
TOTAL_ROUND = 100
WIDTH = 20

def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
  # Multiply accuracy of each client by number of examples used
  accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
  examples = [num_examples for num_examples, _ in metrics]
  # Aggregate and return custom metric (weighted average)
  return {"accuracy": sum(accuracies) / sum(examples)}

def get_active_params(model:CNN, server_round:int, scaler:float, T=TOTAL_ROUND):
    if server_round <= T / 3:
        if scaler >= 0.99:
           return {}
        dropinfo, _ = generate_subnet_ordered(model, scaler)
        return dropinfo
    elif server_round <= T * 2 / 3:
        if scaler >= 0.99:
           return {}
        dropinfo, _ = generate_subnet_ordered(model, scaler)
        layer1_names = Param_by_layers[2]
        for name in layer1_names:
            dropinfo[name] = list(range(WIDTH))
        #subparams = get_subnet(model, dropinfo)
        return dropinfo
    else:
        dropinfo = None
        return dropinfo

class SLT_strategy(fl.server.strategy.FedAvg):
    def __init__(self, ff, fe, mfc, mec, mac, maxiter=100, ACC=[], ClientsSelection=[]):
        super().__init__(fraction_fit=ff, fraction_evaluate=fe, min_fit_clients=mfc, min_evaluate_clients=mec, min_available_clients=mac, evaluate_metrics_aggregation_fn=weighted_average)
        self.fraction_fit_=ff,
        self.fraction_evaluate_=fe,
        self.min_fit_clients_=mfc,
        self.min_evaluate_clients_=mec,
        self.min_available_clients_=mac
        self.global_model = CNN(CHANNEL, outputs=CLASSES)
        self.accuracy_record = ACC
        self.personal_models = {}
        self.Maxiteration = maxiter
        initial_parameters = get_filters(CNN(CHANNEL, outputs=CLASSES))
        for i in range(mac):
            self.personal_models[i] = initial_parameters

    def record_test_accuracy(self, acc):
        self.accuracy_record.append(acc)

    """override"""
    def initialize_parameters(self, client_manager: ClientManager):
        return ndarrays_to_parameters(get_parameters(self.global_model))
    
    """override"""
    def configure_fit(self, server_round: int, parameters, client_manager: ClientManager):
        random.seed(server_round)
        sample_size, min_num_clients = super().num_fit_clients(client_manager.num_available()) 
        clients = client_manager.sample(num_clients=sample_size, min_num_clients=min_num_clients)
        config_fit_list = []
        for client in clients:
            cid = int(client.cid)
            config = {}
            rate = get_rate(cid)
            drop_info = get_active_params(self.global_model, server_round, rate, T=self.Maxiteration)
            params = get_filters(self.global_model)
            config['drop_info'] = drop_info
            #config['active_params'] = sub_parameters
            config['personal model'] = self.personal_models[cid]
            fit_ins = FitIns(ndarrays_to_parameters(params), config)
            config_fit_list.append((client, fit_ins))
        return config_fit_list
    
    def aggregate_fit(self, server_round: int, results: List[Tuple[ClientProxy, FitRes]], failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]]):
      """override"""
      """Aggregate fit results using weighted average."""
      if not results:
        return None, {}
      # Do not aggregate if there are failures and failures are not accepted
      if not self.accept_failures and failures:
        return None, {}
      # Convert results
      Fit_res = []
      for client, fit_res in results:
        cid = client.cid
        local_param = parameters_to_ndarrays(fit_res.parameters)
        self.personal_models[int(cid)] = local_param
        Fit_res.append((local_param, fit_res.num_examples))
      # Aggregate custom metrics if aggregation fn was provided
      metrics_aggregated = {}
      if self.fit_metrics_aggregation_fn:
          fit_metrics = [(res.num_examples, res.metrics) for _, res in results]
          metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)
      elif server_round == 1:  # Only log this warning once
          log(WARNING, "No fit_metrics_aggregation_fn provided")
      aggregated_parameters = aggregate(Fit_res)
      set_filters(self.global_model, aggregated_parameters)
      if self.fit_metrics_aggregation_fn:
          fit_metrics = [(res.num_examples, res.metrics) for _, res in results]
          metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)
      elif server_round == 1:  # Only log this warning once
          log(WARNING, "No fit_metrics_aggregation_fn provided")
      return get_filters(self.global_model), metrics_aggregated
    
    def configure_evaluate(self, server_round: int, parameters, client_manager: ClientManager):
        """override"""
        if self.fraction_evaluate_ == 0.0:
            return []
        # Sample clients
        sample_size, min_num_clients = super().num_evaluation_clients(client_manager.num_available())
        clients = client_manager.sample(num_clients=sample_size, min_num_clients=min_num_clients)
        config_evaluate_list = []
        parameters = get_filters(self.global_model)
        for client in clients:
            config = {}
            #parameters = self.personal_models[int(client.cid)]
            fit_ins = FitIns(ndarrays_to_parameters(parameters), config)
            config_evaluate_list.append((client, fit_ins))
        return config_evaluate_list
    
    def aggregate_evaluate(self, server_round: int, results, failures):
        """Aggregate evaluation losses using weighted average."""
        if not results:
            return None, {}
        # Do not aggregate if there are failures and failures are not accepted
        if not self.accept_failures and failures:
            return None, {}
        # Aggregate loss
        loss_aggregated = weighted_loss_avg(
            [
                (evaluate_res.num_examples, evaluate_res.loss)
                for _, evaluate_res in results
            ]
        )
        metrics_aggregated = {}
        if self.evaluate_metrics_aggregation_fn:
            eval_metrics = [(1, res.metrics) for _, res in results]
            metrics_aggregated = self.evaluate_metrics_aggregation_fn(eval_metrics)
            self.record_test_accuracy(metrics_aggregated['accuracy'])
            print(f"SLT: Round {server_round}, test accuracy = {metrics_aggregated['accuracy']}")
        elif server_round == 1:  # Only log this warning once
            log(WARNING, "No evaluate_metrics_aggregation_fn provided")
        return loss_aggregated, metrics_aggregated

def stl_client_fn(cid) -> STL_client:
  Epoch = 5
  rate = get_rate(cid)
  dataset = EmnistDataset("clientdata/femnist_client_"+ str(cid) + "_ALPHA_0.1.csv")
  return STL_client(cid, dataset, Epoch, Batch, rate)

def get_rate(cid):
   return 0.5
