import json
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter

from ..compressors import BaseCompressor

class BaseClient:
    def __init__(self, id, data, config):
        self.id = id
        self.format_client_data(data)
        self.format_client_config(config)

    def format_client_data(self, data):
        raise NotImplementedError

    def format_client_config(self, *args, **kwargs):        
        raise NotImplementedError

class BaseServer:
    def __init__(self,all_client_data,d:int, compressor: BaseCompressor, tb_writer : SummaryWriter, results_dir: Path):
        self.d = d
        self.compressor = compressor
        self.tb_writer = tb_writer
        self.results_dir = results_dir
        self.create_clients(all_client_data)
        
    def create_clients(self, all_client_data):
        raise NotImplementedError
        

    def run(self):
        raise NotImplementedError
    
    def log_results(self, results : dict, idx: int):
        ## Save results dict with and index.
        for key, value in results.items():
            self.tb_writer.add_scalar(key, value[-1], idx)
        
        # with open(self.results_dir/Path(f"results_iter_{idx}.yaml"), "w") as f:
        #     json.dump(results, f) 

    def metric(self, model):
        ## Compute error metric with current model
        raise NotImplementedError