import copy

from torch.utils.data import ConcatDataset

from .base import *
import federated


class StandaloneController(FederatedController):
    def train(self, num_rounds):
        self.config.federated.num_rounds = num_rounds
        self.prepare_evaluation()
        self.evaluate(self.server_model)
      
        if self.is_llm:
            # TODO: wrong after a client is removed
            merged_loader = {-1: self.eval_loader}
        else:
            merged_dataset = {"train": [], "test": []}
            for c_loader in self.client_loaders.values():
                merged_dataset["train"].append(c_loader["train"].dataset)
                merged_dataset["test"].append(c_loader["test"].dataset)
            merged_loader = {
                "train": utils.get_dataloader(ConcatDataset(merged_dataset["train"]), 
                                            batch_size=self.config.local.train_batch_size, 
                                            shuffle=True),
                "test": utils.get_dataloader(ConcatDataset(merged_dataset["test"]), 
                                            batch_size=self.config.local.eval_batch_size, 
                                            shuffle=False)
            }
            merged_loader = {-1: merged_loader}

        federated.train_fed(self.server_model,
                            merged_loader,
                            self.config,
                            eval_fn=self.evaluate,
                            device=self.device,
                            is_llm=self.is_llm,
                            tokenizer=self.tokenizer
                            )
        
    
 
    def leave(self, unlearned_client):
        super().leave(unlearned_client)
        self.client_loaders.pop(unlearned_client)
        self.server_model = copy.deepcopy(self.init_server_model)   # can skip
        
    def join(self, client_id, dataloader, attach_to=None):
        super().join(client_id, dataloader, attach_to)