import os
import pickle
from collections import defaultdict
import torch
from flwr.common import parameters_to_ndarrays, ndarrays_to_parameters, FitIns
from src.core.base import BaseStrategy

from src.core.utils import set_seed

class FedOursFinetuneStrategy(BaseStrategy):

    def __init__(self, server, **kwargs):
        super().__init__(server, **kwargs)
        self.client_metric_log = defaultdict(dict)  # round -> client_id -> metrics
        self.client_log_save_path = os.path.join(self.server.save_path, 'client_metric_log.pkl')
        
        
    def aggregate_fit(self, server_round, results, failures):
        aggregated_parameters, metrics = super().aggregate_fit(server_round, results, failures)
        
        self.server.round = server_round
        
        if self.server.config.get("client_pseudo", False):
            self._log_client_metrics(server_round, results)
            
        self.server.load_parameters(parameters_to_ndarrays(aggregated_parameters))
        _ = self.server.evaluate(mode="agg")
        self.server.fine_tune(epochs=self.server.finetune_epochs)

        fine_tuned_parameters = ndarrays_to_parameters(self.server.get_model_parameters())
        return fine_tuned_parameters, metrics


    def _log_client_metrics(self, rnd, results):
        
        tot_samples_list = []
        tot_acc_list = []
        masked_samples_list = []
        masked_acc_list = []
        pmasked_samples_list = []
        pmasked_acc_list = []

        for client_proxy, fit_res in results:
            cid = int(client_proxy.cid)
            client_metrics = fit_res.metrics
            self.client_metric_log[rnd][cid] = client_metrics

            tot_samples_list.append(client_metrics['tot_samples'])
            tot_acc_list.append(client_metrics['tot_pseudo_acc'])
            masked_samples_list.append(client_metrics['masked_samples'])
            masked_acc_list.append(client_metrics['masked_pseudo_acc'])
            pmasked_samples_list.append(client_metrics['pmasked_samples'])
            pmasked_acc_list.append(client_metrics['pmasked_pseudo_acc'])

        with open(self.client_log_save_path, 'wb') as f:
            pickle.dump(self.client_metric_log, f)
            
        if self.server.use_wandb:
            if len(tot_acc_list) > 0:
                self.server.run.log({
                    'client/tot_samples': sum(tot_samples_list) / len(tot_samples_list),
                    'client/tot_pseudo_acc': sum(tot_acc_list) / len(tot_acc_list),
                    'client/masked_samples': sum(masked_samples_list) / len(masked_samples_list),
                    'client/masked_pseudo_acc': sum(masked_acc_list) / len(masked_acc_list),
                    'client/pmasked_samples': sum(pmasked_samples_list) / len(pmasked_samples_list),
                    'client/pmasked_pseudo_acc': sum(pmasked_acc_list) / len(pmasked_acc_list),
                    }, step=self.server.epoch+self.server.round)
                
                
    def configure_fit(self, server_round, parameters, client_manager):
        
        set_seed(self.seed + server_round)
        
        """Configure the next round of training."""
        config = {}
        if self.on_fit_config_fn is not None:
            # Custom fit config function provided
            config = self.on_fit_config_fn(server_round)
            
        if self.server.config['Training']['use_scheduler']:
            config['current_lr'] = self.server.optimizer.param_groups[0]['lr']

        proto_path = os.path.join(self.server.save_path, f"prototypes.pt")
        torch.save(self.server.prototypes.cpu(), proto_path)
        config['prototype_path'] = proto_path
        
        config['server_round'] = server_round
        
        fit_ins = FitIns(parameters, config)
        
        # Sample clients
        sample_size, min_num_clients = self.num_fit_clients(
            client_manager.num_available()
        )
        clients = client_manager.sample(
            num_clients=sample_size, min_num_clients=min_num_clients
        )

        return [(client, fit_ins) for client in clients]
    
    
    def evaluate(
        self,
        rnd, parameters, config=None):

        print("*** eval..")
        self.server.round = rnd
        self.server.load_parameters(parameters_to_ndarrays(parameters))
        eval_dict = self.server.evaluate(mode="finetune")
        
        loss = eval_dict.get(f"finetune/loss", 0.0)
        acc = eval_dict.get(f"finetune/top-1-acc", 0.0)

        return loss, {"accuracy": acc}