import copy
from .base import *


class RetrainingController(FederatedController):
    def train(self, num_rounds):
        super().train(num_rounds)

    def leave(self, unlearned_client):
        super().leave(unlearned_client)
        self.client_loaders.pop(unlearned_client)
        self.server_model = copy.deepcopy(self.init_server_model)
    
    def join(self, client_id, dataloader, attach_to=None):
        super().join(client_id, dataloader, attach_to)