from src.algorithms.base.server_base import BaseServer


class FOOGDServer(BaseServer):
    def __init__(self, server_args):
        super().__init__(server_args)

    def load_checkpoint(self, checkpoint):
        if "clients" in checkpoint:
            for client_checkpoint, client in zip(checkpoint["clients"], self.clients):
                client.backbone.load_state_dict(client_checkpoint["backbone"])
                client.score_model.load_state_dict(client_checkpoint["score_model"])
        else:
            for client in self.clients:
                client.backbone.load_state_dict(checkpoint[0])
                client.score_model.load_state_dict(checkpoint[1])
