from utils import fflow as flw
import numpy as np
import torch
import os
import multiprocessing
import matplotlib.pyplot as plt
import pickle


class MyLogger(flw.Logger):
    def __init__(self):
        super(MyLogger, self).__init__()
        self.output = {}

    def get_output(self):
        return self.output

    def log(self, server=None):
        if server is None:
            return

        if self.output == {}:
            self.output = {
                "meta": server.option,
                "mean_curve": [],
                "var_curve": [],
                "train_losses": [],
                "test_accs": [],
                "backdoor_accs": [],
                "test_losses": [],
                "valid_accs": [],
                "client_accs": {},
                "mean_valid_accs": [],
                "all_selected_clients": []
            }

        if "mp_" in server.name:
            test_metric, test_loss, test_backdoor = server.test(device=torch.device('cuda'))
        else:
            test_metric, test_loss, test_backdoor = server.test()

        valid_metrics, valid_losses = server.test_on_clients(self.current_round, 'valid')
        train_metrics, train_losses = server.test_on_clients(self.current_round, 'train')

        self.output['train_losses'].append(
            1.0 * sum([ck * closs for ck, closs in zip(server.client_vols, train_losses)]) / server.data_vol)
        self.output['valid_accs'].append(valid_metrics)
        self.output['test_accs'].append(test_metric)
        self.output['backdoor_accs'].append(test_backdoor)
        self.output['test_losses'].append(test_loss)
        self.output['mean_valid_accs'].append(
            1.0 * sum([ck * acc for ck, acc in zip(server.client_vols, valid_metrics)]) / server.data_vol)
        self.output['mean_curve'].append(np.mean(valid_metrics))
        self.output['var_curve'].append(np.std(valid_metrics))
        self.output['all_selected_clients'].append([int(id) for id in server.selected_clients])

        for cid in range(server.num_clients):
            self.output['client_accs'][server.clients[cid].name] = [
                self.output['valid_accs'][i][cid] for i in range(len(self.output['valid_accs']))]

        log_data = [
            self.temp.format("Training Loss:", self.output['train_losses'][-1]),
            self.temp.format("Testing Loss:", self.output['test_losses'][-1]),
            self.temp.format("Testing Accuracy:", self.output['test_accs'][-1]),
            self.temp.format("Backdoor Accuracy:", self.output['backdoor_accs'][-1]),
            self.temp.format("Validating Accuracy:", self.output['mean_valid_accs'][-1]),
            self.temp.format("Mean of Client Accuracy:", self.output['mean_curve'][-1]),
            self.temp.format("Std of Client Accuracy:", self.output['var_curve'][-1]),
            "Selected clients in this round:",
            str(self.output['all_selected_clients'][-1])
        ]

        for line in log_data:
            print(line)
        option = flw.read_option()
        with open(f"{option['task']}_training_log_{option['num_of_code_words']}Words_{option['K_value']}K.txt", "a") as log_file:
            log_file.write("\n".join(log_data) + "\n")

logger = MyLogger()


def main():
    multiprocessing.set_start_method('spawn')
    # read options
    option = flw.read_option()
    os.environ['MASTER_ADDR'] = "localhost"
    os.environ['MASTER_PORT'] = '8888'
    os.environ['WORLD_SIZE'] = str(3)
    # set random seed
    flw.setup_seed(option['seed'])
    server = flw.initialize(option)
    # start federated optimization
    server.run()
    print(server.round_selected)
    print("Training Done!")


if __name__ == '__main__':
    main()