import copy
import torch

from coopt.server_client.model_training import optimize_model

class Server:
    def __init__(self, args, dataset):
        self.args = args
        self.dataset = dataset
        self.global_x = dataset
        self.global_y = torch.tensor(dataset.targets)
        self.global_optimal_data = torch.randn(len(self.global_y), 512)
        self.global_optimal_data_index = torch.zeros(len(self.global_y), dtype=torch.int)
        self.global_model = None
        self.all_round_client_avg_acc = []
        self.all_round_global_acc = []
        self.all_round_client = {0: [], 1: [], 2: []}

    def optimize_global_model(self,):
        self.global_model, global_best_val_acc = optimize_model(
            self.args, 
            copy.deepcopy(self.global_x), 
            copy.deepcopy(self.global_optimal_data), 
            copy.deepcopy(self.global_y), 
            self.global_model,
            'Global'
        )
        return global_best_val_acc

    def save_model(self, e_round, save_name='align_to_best'):
        global_model_data = {
            0: self.global_model,
            1: self.global_x,
            2: self.global_optimal_data,
            3: self.global_y,
            4: self.all_round_global_acc,
            5: self.all_round_client_avg_acc
        }
        save_path = f"save/global_client_model_save/global_{save_name}_round_{e_round}_{self.args.global_ipc}.pt"
        torch.save(global_model_data, save_path)