#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import random

import yaml
import time
from core.test import test_img
from utils.Fed import FedAvg, FedAvgGradient
from models.SvrgUpdate import LocalUpdate
from utils.options import args_parser
from utils.dataset_normal import load_data
from models.ModelBuilder import build_model
from core.ClientManage_hr import ClientManageHR
from utils.my_logging import Logger
from core.function import assign_hyper_gradient
from torch.optim import SGD
import matplotlib.pyplot as plt
import torch

import numpy as np
import copy


np.random.seed(0)
nums = np.random.choice(range(0, 100), 10 , replace=False)

#print(nums)
start_time = int(time.time())
for num in nums:    
 if __name__ == '__main__':
    
    torch.manual_seed(num)
    np.random.seed(num)
    test_accuracies = []
    # parse param
    args = args_parser()
    dataset_train, dataset_test, dict_users, args.img_size, dataset_train_real = load_data(args)
    net_glob = build_model(args)
    net_glob_theta = copy.deepcopy(net_glob)

    #set hyper_params    
    ck_bar = 2.7
    gama = 0.09
    exp = 0.00001
    w_glob = net_glob.state_dict()
    if args.output == None:
        logs = Logger(f'./savegama0.09/hr_fed_{args.dataset}_{args.model}_{args.epochs}_C{args.frac}_iid{args.iid}_'
                      f'tau{args.inner_ep}_blo{not args.no_blo}_ck_bar{ck_bar}_gamma{gama}_exp{exp}_eta{[args.eta[0],args.eta[1],args.eta[2]]}_'
                      f'gamma{[args.gamma[0],args.gamma[1],args.gamma[2]]}_{start_time}.yaml')
    else:
        logs = Logger(args.output)
    
    # Set inner and outer parameters, and outer optimizer
    
    comm_round = 0
   
    # hyper_optimizer = SGD(hyper_param, lr=1)

    # Global epoch (Fed+Fedout)
    # for k = 0-K, epoch=100, frac = 0.1, num_user = 100
    # for iter in range(args.epochs):
    while comm_round < 400:

        start_time = time.time()
      
        hyper_param = [k for n, k in net_glob.named_parameters() if not "header" in n]
        param = [k for n, k in net_glob.named_parameters() if "header" in n]
        theta = []
        for n, k in net_glob.named_parameters():
           if "header" in n:
               theta_temp = torch.rand(k.shape)
               theta.append(theta_temp.to(args.device))
        
        
        # number of clients
        m = max(int(args.frac * args.num_users), 1)
        w_globs = []
        # generate m clients for update and they cannot be selected by multiple times
        
        ck = ck_bar * 1 / ((1+comm_round)**exp)
        
        client_idx = np.random.choice(range(args.num_users), m, replace=False)

        state_dict_net = net_glob.state_dict()
        state_dict_net_theta = net_glob_theta.state_dict()

        client_manage = ClientManageHR(args, net_glob, net_glob_theta, client_idx, dataset_train, dict_users, hyper_param, param, theta, ck, gama)
        net_glob.load_state_dict(state_dict_net)
        net_glob_theta.load_state_dict(state_dict_net_theta)

        # client do
        
        h_y, h_theta, h_x = client_manage.client_job(args.eta)
        comm_round += 1
        
       
        
        # server do
        h_y_fianl = []
        h_theta_fianl = []
        h_x_fianl = []
        for i in range(len(param)):
            for j in range(1, m):
                h_y[0][i] += h_y[j][i]
                h_theta[0][i] += h_theta[j][i]
            h_y_fianl.append(h_y[0][i] / m)
            param[i] = param[i] - (args.gamma[0] * h_y_fianl[i]) / args.inner_ep
            h_theta_fianl.append(h_theta[0][i] / m)
            theta[i] = theta[i] - (args.gamma[1] * h_theta_fianl[i]) / args.inner_ep


        for i in range(len(hyper_param)):
            for j in range(1, m):
                h_x[0][i] += h_x[j][i]
            h_x_fianl.append(h_x[0][i] / m)
            hyper_param[i] = hyper_param[i] - (args.gamma[2] * h_x_fianl[i]) / args.inner_ep

        count = 0
        for p in net_glob.parameters():
            if count < len(hyper_param):
                p.data = hyper_param[count]
                count += 1
            else:
                p.data = param[count - len(hyper_param)]
                count += 1
        count = 0
        for p in net_glob_theta.parameters():
            if count < len(hyper_param):
                p.data = hyper_param[count]
                count += 1
            else:
                p.data = theta[count - len(hyper_param)]
                count += 1
        
        end_time = time.time()
        roundtime = end_time - start_time
        # testing
        
        net_glob.eval()
        acc_train, loss_train,F1_score_train = test_img(net_glob, dataset_train_real, args)
        acc_test, loss_test, F1_score_test = test_img(net_glob, dataset_test, args)
        print("Test acc/loss: {:.2f} {:.6f}".format(acc_test, loss_test),
              "Train acc/loss: {:.2f} {:.6f}".format(acc_train, loss_train),
              "Train F1 score:{:2f}".format(F1_score_train),
              "Test F1 score:{:2f}".format(F1_score_test),
              f"Comm round: {comm_round}", "time: {:.2f}s".format(roundtime))
        # draw plot
        test_accuracies.append(acc_test)

        logs.logging(client_idx, acc_test, acc_train, loss_test, loss_train, F1_score_test, F1_score_train, comm_round, roundtime)
        logs.save()

        if args.round > 0 and comm_round > args.round:
            break

    plt.figure()
    plt.plot(range(len(test_accuracies)), test_accuracies)
    plt.ylabel('test_accuracies')
    plt.savefig(
        './save/fed_{}_{}_{}_C{}_iid{}_{}_{}_{}.png'.format(args.dataset, args.model, comm_round, args.frac, args.iid, args.gamma[0],args.gamma[1],args.gamma[2]))
