import os
import pickle
from collections import defaultdict
import torch
from flwr.common import parameters_to_ndarrays, ndarrays_to_parameters, FitIns
from src.core.base import BaseStrategy



class FedProSubFinetuneStrategy(BaseStrategy):

    def __init__(self, server, **kwargs):
        super().__init__(server, **kwargs)
        self.client_metric_log = defaultdict(dict)  # round -> client_id -> metrics
        
        
    def aggregate_fit(self, rnd, results, failures):
        aggregated_parameters, metrics = super().aggregate_fit(rnd, results, failures)
        
        if aggregated_parameters:
            self.server.round = rnd
            self.server.load_parameters(parameters_to_ndarrays(aggregated_parameters))

            # Only aggregate alpha/beta from clients after round 10

            alpha1_list, beta1_list = [], []
            alpha2_list, beta2_list = [], []
            for client_proxy, fit_res in results:
                if "alpha1" in fit_res.metrics and "beta1" in fit_res.metrics:
                    alpha1_list.append(fit_res.metrics["alpha1"])
                    beta1_list.append(fit_res.metrics["beta1"])
                if "alpha2" in fit_res.metrics and "beta2" in fit_res.metrics:
                    alpha2_list.append(fit_res.metrics["alpha2"])
                    beta2_list.append(fit_res.metrics["beta2"])
            ema_decay = self.server.config['Training']['beta_ema_decay'] if 'beta_ema_decay' in self.server.config['Training'] else 0.999
            if alpha1_list and beta1_list:
                mean_alpha1 = sum(alpha1_list) / len(alpha1_list)
                mean_beta1 = sum(beta1_list) / len(beta1_list)
                self.server.alpha1 = ema_decay * self.server.alpha1 + (1 - ema_decay) * mean_alpha1
                self.server.beta1  = ema_decay * self.server.beta1  + (1 - ema_decay) * mean_beta1
                print(f"[Server] Aggregated α1={self.server.alpha1:.4f}, β1={self.server.beta1:.4f}")
            if alpha2_list and beta2_list:
                mean_alpha2 = sum(alpha2_list) / len(alpha2_list)
                mean_beta2 = sum(beta2_list) / len(beta2_list)
                self.server.alpha2 = ema_decay * self.server.alpha2 + (1 - ema_decay) * mean_alpha2
                self.server.beta2  = ema_decay * self.server.beta2  + (1 - ema_decay) * mean_beta2
                print(f"[Server] Aggregated α2={self.server.alpha2:.4f}, β2={self.server.beta2:.4f}")



            
            tot_samples_list = []
            tot_acc_list = []
            masked_samples_list = []
            masked_acc_list = []
            
            # if self.server.config['client_pseudo']:
            #     for client_proxy, fit_res in results:
            #         cid = int(client_proxy.cid)
            #         client_metrics = fit_res.metrics
            #         self.client_metric_log[rnd][cid] = client_metrics

            #         # if 'tot_pseudo_acc' in client_metrics and 'masked_pseudo_acc' in client_metrics:
            #         tot_samples_list.append(client_metrics['tot_samples'])
            #         tot_acc_list.append(client_metrics['tot_pseudo_acc'])
            #         masked_samples_list.append(client_metrics['masked_samples'])
            #         masked_acc_list.append(client_metrics['masked_pseudo_acc'])


            #     client_log_save_path = os.path.join(self.server.save_path, 'client_metric_log.pkl')
            #     with open(client_log_save_path, 'wb') as f:
            #         pickle.dump(self.client_metric_log, f)
                    
            #     if self.server.use_wandb:
            #         if len(tot_acc_list) > 0:
            #             self.server.run.log({
            #                 'client/tot_samples': sum(tot_samples_list) / len(tot_samples_list),
            #                 'client/tot_pseudo_acc': sum(tot_acc_list) / len(tot_acc_list),
            #                 'client/masked_samples': sum(masked_samples_list) / len(masked_samples_list),
            #                 'client/masked_pseudo_acc': sum(masked_acc_list) / len(masked_acc_list),
            #                 }, step=self.server.epoch+self.server.round)

            _ = self.server.evaluate(mode="agg")
            self.server.fine_tune(epochs=self.server.finetune_epochs)

            fine_tuned_parameters = ndarrays_to_parameters(self.server.get_model_parameters())
            return fine_tuned_parameters, metrics

        return aggregated_parameters, metrics


    def configure_fit(self, server_round, parameters, client_manager):
        
        """Configure the next round of training."""
        config = {}
        if self.on_fit_config_fn is not None:
            # Custom fit config function provided
            config = self.on_fit_config_fn(server_round)
            
        if self.server.config['Training']['use_scheduler']:
            config['current_lr'] = self.server.optimizer.param_groups[0]['lr']
        
        # Check if prototypes are available
        if self.server.prototypes is not None:
            proto_path = os.path.join(self.server.save_path, f"round{server_round}_proto.pt")
            torch.save(self.server.prototypes.cpu(), proto_path)
            config['prototype_path'] = proto_path
        else:
            print(f"[Warning] Prototypes not initialized yet for round {server_round}")
            config['prototype_path'] = None
        
        # Send updated alpha/beta parameters to clients for synchronization
        config['alpha1'] = float(self.server.alpha1)
        config['alpha2'] = float(self.server.alpha2)
        config['beta1'] = float(self.server.beta1)
        config['beta2'] = float(self.server.beta2)
        
        print(f"[Round {server_round}] Sending to clients: α1={config['alpha1']:.4f}, β1={config['beta1']:.4f}, α2={config['alpha2']:.4f}, β2={config['beta2']:.4f}")
        
        fit_ins = FitIns(parameters, config)
        
        # Sample clients
        sample_size, min_num_clients = self.num_fit_clients(
            client_manager.num_available()
        )
        clients = client_manager.sample(
            num_clients=sample_size, min_num_clients=min_num_clients
        )

        return [(client, fit_ins) for client in clients]
    
    
    def evaluate(
        self,
        rnd, parameters, config=None):

        print("*** eval..")
        self.server.round = rnd
        self.server.load_parameters(parameters_to_ndarrays(parameters))
        eval_dict = self.server.evaluate(mode="finetune")
        
        loss = eval_dict.get(f"finetune/loss", 0.0)
        acc = eval_dict.get(f"finetune/top-1-acc", 0.0)

        return loss, {"accuracy": acc}
