import random
import numpy as np

import flwr as fl
from flwr.server.strategy import FedAvg
from flwr.common import parameters_to_ndarrays, FitIns
from typing import Dict, Tuple, Optional

from src.core.utils import set_seed


class BaseStrategy(FedAvg):
    def __init__(self, server, **kwargs):
        super().__init__(**kwargs)
        self.server = server
        self.seed = self.server.config['seed']


    def aggregate_fit(self, rnd, results, failures):
        aggregated_parameters, metrics = super().aggregate_fit(rnd, results, failures)

        # update server parameters
        if aggregated_parameters:
            self.server.load_parameters(parameters_to_ndarrays(aggregated_parameters))

        return aggregated_parameters, metrics


    def configure_fit(self, server_round, parameters, client_manager):
        
        set_seed(self.seed + server_round)
        
        """Configure the next round of training."""
        config = {}
        if self.on_fit_config_fn is not None:
            # Custom fit config function provided
            config = self.on_fit_config_fn(server_round)
            
        if self.server.config['Training']['use_scheduler']:
            config['current_lr'] = self.server.optimizer.param_groups[0]['lr']
        
        config['server_round'] = server_round
        
        fit_ins = FitIns(parameters, config)
        
        # Sample clients
        sample_size, min_num_clients = self.num_fit_clients(
            client_manager.num_available()
        )
        clients = client_manager.sample(
            num_clients=sample_size, min_num_clients=min_num_clients
        )

        return [(client, fit_ins) for client in clients]


    def evaluate(
        self,
        rnd, parameters, config=None):

        self.server.round = rnd
        self.server.load_parameters(parameters_to_ndarrays(parameters))
        eval_dict = self.server.evaluate(mode="agg")
        
        loss = eval_dict.get(f"agg/loss", 0.0)
        acc = eval_dict.get(f"agg/top-1-acc", 0.0)

        return loss, {"accuracy": acc}