
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import yaml
import time
from core.test import test_img
from utils.Fed import FedAvg, FedAvgGradient
from models.SvrgUpdate import LocalUpdate
from utils.options import args_parser
from utils.dataset_normal import load_data
from models.ModelBuilder import build_model
from core.NewClientManage import NewClientManage
from utils.my_logging import Logger
from core.function import assign_hyper_gradient
from torch.optim import SGD
import torch
import torch.nn.init as init
import numpy as np
import copy

np.random.seed(0)
nums = np.random.choice(range(0, 100), 10 , replace=False)
start_time = int(time.time())
for num in nums:
 if __name__ == '__main__':
    args = args_parser()
    torch.manual_seed(num)
    np.random.seed(num)
  
    dataset_train, dataset_test, dict_users, args.img_size, dataset_train_real = load_data(args)

    net_glob = build_model(args)
    net_glob_theta = copy.deepcopy(net_glob)
 

    w_glob = net_glob.state_dict()
    w_glob_theta = net_glob_theta.state_dict()
    
    ck_bar = 2
    gamma = 0.05
    expc = 0.01
    #akbar = 2.1
    #expa = 0.3
    w_glob = net_glob.state_dict()
    if args.output == None:
        logs = Logger(f'./save/dc_cr_0.05_fed_{args.dataset}_{args.model}_{args.epochs}_C{args.frac}_iid{args.iid}_'
                      f'tau{args.inner_ep}_blo{not args.no_blo}_ck_bar{ck_bar}_gamma{gamma}_exp{expc}_eta{[args.eta[0],args.eta[1],args.eta[2]]}_'
                      f'gamma{[args.gamma[0],args.gamma[1],args.gamma[2]]}_{start_time}.yaml')
    else:
        logs = Logger(args.output)                                                          
    

   
    hyper_param = torch.zeros(int(0.5*len(dataset_train)* args.frac), requires_grad=True, device = args.device)
    hyper_params = [copy.deepcopy(hyper_param) for i in range(args.num_users)]     

    comm_round=0

    test_accuracies = []

    
    while comm_round < 500:

        param = list(net_glob.parameters())
   
        theta = list(net_glob_theta.parameters())
    

        start_time = time.time()
        # number of clients
        m = max(int(args.frac * args.num_users), 1)
        w_globs = []
     
        
        ck = ck_bar * 1 / ((1+comm_round)**expc)
        #ak = akbar *  1 / ((1+comm_round)**expa)
        
        client_idx = np.random.choice(range(args.num_users), m, replace=False)

        state_dict_net = net_glob.state_dict()
        state_dict_net_theta = net_glob_theta.state_dict()

        m = max(int(args.frac * args.num_users), 1)
        client_idx = np.random.choice(range(args.num_users), m, replace=False)
        
        client_manage=NewClientManage(args, net_glob, net_glob_theta, client_idx, dataset_train, dict_users,  hyper_params, param, theta, gamma, ck)
    
        h_y, h_theta = client_manage.client_job(args.eta)

        comm_round += 1
  
        # server do
        h_y_fianl = []
        h_theta_fianl = []
        #h_x_fianl = []
        for i in range(len(param)):
            for j in range(1, m):
                h_y[0][i] += h_y[j][i]
                h_theta[0][i] += h_theta[j][i]
            h_y_fianl.append(h_y[0][i] / m)
            param[i] = param[i] - (args.gamma[0] * h_y_fianl[i]) / args.inner_ep
            h_theta_fianl.append(h_theta[0][i] / m)
            theta[i] = theta[i] - (args.gamma[1] * h_theta_fianl[i]) / args.inner_ep
            
    
        
        count = 0
        for p in net_glob.parameters():
            p.data = param[count]
            count += 1
        count = 0
        for p in net_glob_theta.parameters():
            p.data = theta[count]
            count += 1
        
        end_time = time.time()
        roundtime = end_time - start_time
        # testing
        net_glob.eval()
        acc_train, loss_train = test_img(net_glob, dataset_train_real, args)
        acc_test, loss_test = test_img(net_glob, dataset_test, args)
        print("Test acc/loss: {:.2f} {:.6f}".format(acc_test, loss_test),
              "Train acc/loss: {:.2f} {:.6f}".format(acc_train, loss_train),
              f"Comm round: {comm_round}", "time: {:.2f}s".format(roundtime))
        # draw plot
        test_accuracies.append(acc_test)

        logs.logging(client_idx, acc_test, acc_train, loss_test, loss_train, comm_round, roundtime)
        logs.save()

        if args.round > 0 and comm_round > args.round:
            break
        
