#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import random

import yaml
import time
from core.test import test_img
from utils.Fed import FedAvg, FedAvgGradient
from models.SvrgUpdate import LocalUpdate
from utils.options import args_parser
from utils.dataset_normal import load_data
from models.ModelBuilder import build_model
from core.ClientManage_hr import ClientManageHR
from utils.my_logging import Logger
from core.function import assign_hyper_gradient
from torch.optim import SGD
import matplotlib.pyplot as plt
import torch

import numpy as np
import copy

start_time = int(time.time())

if __name__ == '__main__':
    torch.manual_seed(0)
    np.random.seed(0)
    test_accuracies = []
    # parse param
    args = args_parser()
    dataset_train, dataset_test, dict_users, args.img_size, dataset_train_real = load_data(args)
    net_glob = build_model(args)

    # copy weights
    w_glob = net_glob.state_dict()
    if args.output == None:
        logs = Logger(f'./save/hr_fed_{args.dataset}_{args.model}_{args.epochs}_C{args.frac}_iid{args.iid}_'
                      f'tau{args.inner_ep}_blo{not args.no_blo}_eta{[args.eta[0],args.eta[1],args.eta[2]]}_'
                      f'gamma{[args.gamma[0],args.gamma[1],args.gamma[2]]}_{start_time}.yaml')
    else:
        logs = Logger(args.output)

    # Set inner and outer parameters, and outer optimizer
    hyper_param = [k for n, k in net_glob.named_parameters() if not "header" in n]
    param = [k for n, k in net_glob.named_parameters() if "header" in n]
    v = []
    for n, k in net_glob.named_parameters():
        if "header" in n:
            v_temp = torch.rand(k.shape)
            v.append(v_temp.to(args.device))
    comm_round = 0

    hyper_optimizer = SGD(hyper_param, lr=1)

    # Global epoch (Fed+Fedout)
    # for k = 0-K, epoch=100, frac = 0.1, num_user = 100
    # for iter in range(args.epochs):
    while comm_round <= 2000:
        # number of clients
        m = max(int(args.frac * args.num_users), 1)
        w_globs = []
        # clients end

        # generate m clients for update and they cannot be selected by multiple times
        client_idx = np.random.choice(range(args.num_users), m, replace=False)
            # net_glob
        client_manage = ClientManageHR(args, net_glob, client_idx, dataset_train, dict_users, hyper_param, param, v)

        # client do
        h_y, h_v, h_x = client_manage.client_job(args.eta)
        comm_round += 1

        # server do
        h_y_fianl = []
        h_v_fianl = []
        h_x_fianl = []
        for i in range(len(param)):
            for j in range(1, m):
                h_y[0][i] += h_y[j][i]
                h_v[0][i] += h_v[j][i]
            h_y_fianl.append(h_y[0][i] / m)
            param[i] = param[i] - (args.gamma[0] * h_y_fianl[i]) / args.inner_ep
            h_v_fianl.append(h_v[0][i] / m)
            v[i] = v[i] - (args.gamma[1] * h_v_fianl[i]) / args.inner_ep

        for i in range(len(hyper_param)):
            for j in range(1, m):
                h_x[0][i] += h_x[j][i]
            h_x_fianl.append(h_x[0][i] / m)
            hyper_param[i] = hyper_param[i] - (args.gamma[2] * h_x_fianl[i]) / args.inner_ep

        # print loss
        # print('Round {}, Average loss {:.3f}'.format(comm_round, loss_avg))
        count = 0
        for p in net_glob.parameters():
            if count < len(hyper_param):
                p.data = hyper_param[count]
                count += 1
            else:
                p.data = param[count - len(hyper_param)]
                count += 1

        # testing
        net_glob.eval()
        acc_train, loss_train = test_img(net_glob, dataset_train_real, args)
        acc_test, loss_test = test_img(net_glob, dataset_test, args)
        print("Test acc/loss: {:.2f} {:.2f}".format(acc_test, loss_test),
              "Train acc/loss: {:.2f} {:.2f}".format(acc_train, loss_train),
              f"Comm round: {comm_round}")
        # draw plot
        test_accuracies.append(acc_test)

        logs.logging(client_idx, acc_test, acc_train, loss_test, loss_train, comm_round)
        logs.save()

        if args.round > 0 and comm_round > args.round:
            break

    plt.figure()
    plt.plot(range(len(test_accuracies)), test_accuracies)
    plt.ylabel('test_accuracies')
    plt.savefig(
        './save/fed_{}_{}_{}_C{}_iid{}_{}_{}_{}.png'.format(args.dataset, args.model, comm_round, args.frac, args.iid, args.gamma[0],args.gamma[1],args.gamma[2]))
