import os
import pickle
from collections import defaultdict
from flwr.common import parameters_to_ndarrays, ndarrays_to_parameters
from src.core.base import BaseStrategy


class FedAvgFinetuneStrategy(BaseStrategy):

    def __init__(self, server, **kwargs):
        super().__init__(server, **kwargs)
        self.client_metric_log = defaultdict(dict)  # round -> client_id -> metrics
        
    def aggregate_fit(self, rnd, results, failures):
        aggregated_parameters, metrics = super().aggregate_fit(rnd, results, failures)
        
        if aggregated_parameters:
            self.server.round = rnd
            self.server.load_parameters(parameters_to_ndarrays(aggregated_parameters))
            
            tot_samples_list = []
            tot_acc_list = []
            masked_samples_list = []
            masked_acc_list = []
            pmasked_samples_list = []
            pmasked_acc_list = []
            
            if self.server.config['client_pseudo']:
                for client_proxy, fit_res in results:
                    cid = int(client_proxy.cid)
                    client_metrics = fit_res.metrics
                    self.client_metric_log[rnd][cid] = client_metrics

                    # if 'tot_pseudo_acc' in client_metrics and 'masked_pseudo_acc' in client_metrics:
                    tot_samples_list.append(client_metrics['tot_samples'])
                    tot_acc_list.append(client_metrics['tot_pseudo_acc'])
                    masked_samples_list.append(client_metrics['masked_samples'])
                    masked_acc_list.append(client_metrics['masked_pseudo_acc'])
                    pmasked_samples_list.append(client_metrics['pmasked_samples'])
                    pmasked_acc_list.append(client_metrics['pmasked_pseudo_acc'])


                client_log_save_path = os.path.join(self.server.save_path, 'client_metric_log.pkl')
                with open(client_log_save_path, 'wb') as f:
                    pickle.dump(self.client_metric_log, f)
                    
                if self.server.use_wandb:
                    if len(tot_acc_list) > 0:
                        self.server.run.log({
                            'client/tot_samples': sum(tot_samples_list) / len(tot_samples_list),
                            'client/tot_pseudo_acc': sum(tot_acc_list) / len(tot_acc_list),
                            'client/masked_samples': sum(masked_samples_list) / len(masked_samples_list),
                            'client/masked_pseudo_acc': sum(masked_acc_list) / len(masked_acc_list),
                            'client/pmasked_samples': sum(pmasked_samples_list) / len(pmasked_samples_list),
                            'client/pmasked_pseudo_acc': sum(pmasked_acc_list) / len(pmasked_acc_list),
                            }, step=self.server.epoch+self.server.round)

            _ = self.server.evaluate(mode="agg")
            self.server.fine_tune(epochs=self.server.finetune_epochs)

            fine_tuned_parameters = ndarrays_to_parameters(self.server.get_model_parameters())
            return fine_tuned_parameters, metrics

        return aggregated_parameters, metrics


    def evaluate(
        self,
        rnd, parameters, config=None):

        print("*** eval..")
        self.server.round = rnd
        self.server.load_parameters(parameters_to_ndarrays(parameters))
        eval_dict = self.server.evaluate(mode="finetune")
        
        loss = eval_dict.get(f"finetune/loss", 0.0)
        acc = eval_dict.get(f"finetune/top-1-acc", 0.0)

        return loss, {"accuracy": acc}
    
