import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import json
import random
import time

import torch
import dataset
from server import Server
from client import Client
import pandas as pd
import matplotlib.pyplot as plt


if __name__ == '__main__':
    # 存储容器，用于绘制图像
    accs = []  # 存放准确率
    losses = []  # 存放损失

    # 载入配置文件
    with open('config.json', 'r') as f:
        conf = json.load(f)

    train_datasets, eval_datasets, client_idcs = dataset.get_dataset('../data/', conf['type'])  # 获取训练数据和测试数据
    server = Server(conf, eval_datasets)  # 创建服务器
    clients = []  # 客户端列表
    for c in range(conf['no_models']):  # 创建客户端
        clients.append(Client(conf, server.global_model, train_datasets, client_idcs[c], c))

    for e in range(conf['global_rounds']):  # 进行全局轮数
        candidates = random.sample(clients, conf['k'])  # 随机选取k个客户端
        weight_accumulator = {}  # 创建计算好的参数字典，其值为本地模型计算的变化量之和
        for name, params in server.global_model.state_dict().items():
            weight_accumulator[name] = torch.zeros_like(params)  # 初始化上面的参数字典，大小和全局模型相同

        for c in candidates:  # 逐一获取本地模型更新的变化量并进行累加
            diff = c.local_train(server.global_model)  # 进行本地训练并计算差值字典
            # if e >= 1:
            #     print(diff)
            for name, params in server.global_model.state_dict().items():
                weight_accumulator[name].add_(diff[name])

        server.model_aggregrate(weight_accumulator)  # 模型聚合
        acc, loss = server.model_eval()  # 进行全局模型测试
        accs.append(acc)
        losses.append(loss)
        print('全局模型：第{}轮完成！准确率：{:.2f} loss: {:.2f}'.format(e, acc, loss))

    # 将准确率信息存储在txt文件中用于绘图
    curTime = time.strftime('%Y%m%d_%H%M%S');
    plt.plot([i for i in range(len(accs))], accs, label='Acc')
    plt.legend()
    plt.xlabel('Global Rounds')
    plt.ylabel('Accuracy')
    plt.title("Full FedAvg_momentum 0.0001,Drichlet alpha=0.1 10/10")
    plt.savefig("./results/{}.jpg".format(curTime))
    plt.show()
    df = pd.DataFrame([accs, losses])  # 计入表格
    df.to_csv("./results/data_{}.csv".format(curTime))  # 存入文件并加上时间戳
