import logging
import copy
from collections import OrderedDict
from .server_federated import FedServer
from ..misc import *

logger = logging.getLogger(__name__)

class ServerFedAT(FedServer):
    def __init__(self, weights, model, loss_fn, num_clients, device, **kwargs):
        self.counter = 0 
        self.global_step = 0
        weights = [1.0 / num_clients for _ in range(num_clients)] if weights is None else weights
        super(ServerFedAT, self).__init__(weights, model, loss_fn, num_clients, device, **kwargs)
        self.list_named_parameters = []
        for name, _ in self.model.named_parameters():
            self.list_named_parameters.append(name)

        ## General buffer for storing the first update
        self.general_buffer = OrderedDict()
        for name in self.model.state_dict():
            self.general_buffer[name] = torch.zeros_like(self.model.state_dict()[name])
        self.general_buffer_size = 0

    def setup_tiers(self, num_tiers: int, tier_client_counts):
        """Initialize the tier buffers and tier models"""
        self.num_tiers = len(tier_client_counts)
        self.tier_client_counts = tier_client_counts
        self.tier_buffers = [OrderedDict() for _ in range(num_tiers)]
        self.tier_models = [OrderedDict() for _ in range(num_tiers)]
        self.tier_update_counters =[1 for _ in range(num_tiers)]
        for i in range(len(self.tier_buffers)):
            for name in self.model.state_dict():
                self.tier_buffers[i][name] = torch.zeros_like(self.model.state_dict()[name])
                self.tier_models[i][name] = self.model.state_dict()[name]
       
    def buffer(self, local_model: dict):
        """Buffer the first update"""
        for name in self.model.state_dict():
            self.general_buffer[name] += local_model[name]
        self.general_buffer_size += 1

    def buffer_update(self):
        """Update the global model with the first update"""
        self.global_state = copy.deepcopy(self.model.state_dict())
        for name in self.model.state_dict():
            self.global_state[name] = torch.div(self.general_buffer[name], self.general_buffer_size)
        self.model.load_state_dict(self.global_state)
        self.general_buffer = OrderedDict()

    def buffer_tier(self, local_model: dict, client_tier: int):
        """Buffer the updates for each tier"""
        for name in self.model.state_dict():
            self.tier_buffers[client_tier][name] += local_model[name]

    def tier_update(self, client_tier: int):
        """Update the global model with the updates for each tier"""
        self.global_state = copy.deepcopy(self.model.state_dict())
        self.tier_update_counters[client_tier] += 1
        ## Update the global model
        denominator = sum(self.tier_update_counters)
        weight_factor = self.tier_update_counters[self.num_tiers-1-client_tier]/denominator
        for name in self.model.state_dict():
            if name in self.list_named_parameters:
                self.global_state[name] += (torch.div(self.tier_buffers[client_tier][name], self.tier_client_counts[client_tier]) - self.tier_models[client_tier][name]) * (self.tier_client_counts[client_tier]/self.num_clients)
            else:
                self.global_state[name] = torch.div(self.tier_buffers[client_tier][name], self.tier_client_counts[client_tier])
        ## Clean up the tier buffer
        for name in self.model.state_dict():
            self.tier_buffers[client_tier][name] = torch.zeros_like(self.model.state_dict()[name])
            self.tier_models[client_tier][name] = self.global_state[name]
        self.model.load_state_dict(self.global_state)
  
    def logging_summary(self, cfg, logger):
        super(FedServer, self).log_summary(cfg, logger)
        logger.info("client_learning_rate=%s " % (cfg.fed.args.optim_args.lr))
        if cfg.summary_file != "":
            with open(cfg.summary_file, "a") as f:
                f.write(
                    cfg.logginginfo.DataSet_name
                    + " ServerFedAT ClientLR "
                    + str(cfg.fed.args.optim_args.lr)
                    + " TestAccuracy "
                    + str(cfg.logginginfo.accuracy)
                    + " BestAccuracy "
                    + str(cfg.logginginfo.BestAccuracy)
                    + " Time "
                    + str(round(cfg.logginginfo.Elapsed_time, 2))
                    + "\n"
                )
