import torch
import copy
from transformers import RobertaTokenizer

class Server:
    def __init__(self, global_model, args):
        self.global_model = global_model
        self.args = args

    def aggregate(self, client_states):
        global_state = copy.deepcopy(client_states[0])
        for key in global_state.keys():
            for i in range(1, len(client_states)):
                global_state[key] += client_states[i][key]
            global_state[key] = global_state[key] / len(client_states)
        self.global_model.load_state_dict(global_state)
        return self.global_model.state_dict()

    def get_global_logits(self, public_loader, device):
        self.global_model.eval()
        logits = []
        tokenizer = RobertaTokenizer.from_pretrained('roberta-large')
        with torch.no_grad():
            for batch in public_loader:
                inputs = tokenizer(batch['text'], padding=True, truncation=True, return_tensors='pt').to(device)
                output = self.global_model(inputs['input_ids'])
                logits.append(output.cpu())
        return logits
