from .fedbase import BasicServer, BasicClient
import copy
import os.path

from .fedbase import BasicServer, BasicClient
import numpy as np
from utils import fmodule
import utils.system_simulator as ss
import torch
import collections
import json
class Server(BasicServer):
    def __init__(self, option, model, clients, test_data = None):
        super(Server, self).__init__(option, model, clients, test_data)
    def run(self):
        filename = 'FedAvgLog.json'
        if os.path.exists(filename):
            os.remove(filename)
        for round in range(1, self.num_rounds+1):
            self.current_round = round
            self.iterate()
            test_metric, save_metric = self.test_on_clients()
            global_acc = float(self.test()['accuracy'])
            accuracy = test_metric['accuracy']
            loss = test_metric['loss']
            self.outFunc(round, global_acc, accuracy, loss)
            self.save_log(self.stream_log)
            # decay learning rate
            self.global_lr_scheduler(round)
            selected = [int(i) for i in self.selected_clients]

            with open(filename, 'a') as f:
                f.write('Round_{}'.format(self.current_round))
                f.write('\n')
                json.dump(selected, f)
                f.write('\n')
                json.dump(save_metric, f)
                f.write('\n')
                f.write('\n')
        return
    def test_on_clients(self, dataflag='valid'):
        all_metrics = collections.defaultdict(list)
        save_metrics = collections.defaultdict(list)
        for cid, c in enumerate(self.clients):
            client_metrics = c.test(self.model, dataflag)
            for met_name, met_val in client_metrics.items():
                all_metrics[met_name].append(met_val)
            if cid in self.selected_clients:
                for met_name, met_val in client_metrics.items():
                    save_metrics[met_name].append(met_val)
        return all_metrics, save_metrics
    def sample(self):
        selected_clients = list(range(20))
        print(selected_clients)
        return selected_clients
class Client(BasicClient):
    def __init__(self, option, name='', train_data=None, valid_data=None):
        super(Client, self).__init__(option, name, train_data, valid_data)
