
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import yaml
import time
from core.test import test_img
from utils.Fed import FedAvg, FedAvgGradient
from models.SvrgUpdate import LocalUpdate
from utils.options import args_parser
from utils.dataset import load_data
from models.ModelBuilder import build_model
from core.NewClientManage import NewClientManage
from utils.my_logging import Logger
from core.function import assign_hyper_gradient
from torch.optim import SGD
import torch

import numpy as np
import copy

start_time = int(time.time())
np.random.seed(0)
nums = np.random.choice(range(0, 100), 10 , replace=False)

for num in nums: 
 if __name__ == '__main__':
    args = args_parser()
    torch.manual_seed(num)
    np.random.seed(num)
  
    dataset_train, dataset_test, dict_users, args.img_size, dataset_train_real = load_data(args)

    net_glob = build_model(args)
    net_glob_theta = copy.deepcopy(net_glob)


    w_glob = net_glob.state_dict()
    w_glob_theta = net_glob_theta.state_dict()
    
    ck_bar = 4
    gamma = 0.015
    exp = 0.1

    w_glob = net_glob.state_dict()
    if args.output == None:
        logs = Logger(f'./save/im_inner3_fed_{args.dataset}_{args.model}_{args.epochs}_C{args.frac}_iid{args.iid}_'
                      f'tau{args.inner_ep}_blo{not args.no_blo}_ck_bar{ck_bar}_gamma{gamma}_exp{exp}_eta{[args.eta[0],args.eta[1],args.eta[2]]}_'
                      f'gamma{[args.gamma[0],args.gamma[1],args.gamma[2]]}_{start_time}.yaml')
    else:
        logs = Logger(args.output)                                                          
    

    mu=0.1**(1/9)
    probability=np.array([mu**-i for i in range(0,10)])
    wy=probability/np.linalg.norm(probability)
    ly= np.log(1./probability)
    hyper_param={
            'dy':torch.zeros(args.num_classes, requires_grad=True, device = args.device),
            'ly':torch.zeros(args.num_classes, requires_grad=True, device = args.device),
            'wy':torch.tensor(wy, device = args.device, dtype = torch.float32)
            }
    
    comm_round=0

    test_accuracies = []
    #hyper_optimizer=SGD([hyper_param[k] for k in hyper_param], lr=alpha)
    
    while comm_round <= 500:

        param = list(net_glob.parameters())
   
        theta = list(net_glob_theta.parameters())
    

        start_time = time.time()
        # number of clients
        m = max(int(args.frac * args.num_users), 1)
        w_globs = []
        # generate m clients for update and they cannot be selected by multiple times
        
        ck = ck_bar * 1 / ((1+comm_round)**exp)
        
        client_idx = np.random.choice(range(args.num_users), m, replace=False)

        state_dict_net = net_glob.state_dict()
        state_dict_net_theta = net_glob_theta.state_dict()

        m = max(int(args.frac * args.num_users), 1)
        client_idx = np.random.choice(range(args.num_users), m, replace=False)
        hyper_param_copy = {
            'dy': hyper_param['dy'].clone().detach(),  
            'ly': hyper_param['ly'].clone().detach(),
            'wy': hyper_param['wy'].clone().detach()  
        }
        hyper_param_copy['dy'].requires_grad_()
        hyper_param_copy['ly'].requires_grad_()
        client_manage=NewClientManage(args, net_glob, net_glob_theta, client_idx, dataset_train, dict_users, hyper_param, param, theta, gamma, ck)
    
        h_y, h_theta, h_x = client_manage.client_job(args.eta)
    
        comm_round += 1
  
        # server do
        h_y_fianl = []
        h_theta_fianl = []
        h_x_fianl = []
        for i in range(len(param)):
            for j in range(1, m):
                h_y[0][i] += h_y[j][i]
                h_theta[0][i] += h_theta[j][i]
            h_y_fianl.append(h_y[0][i] / m)
            param[i] = param[i] - (args.gamma[0] * h_y_fianl[i]) / args.inner_ep
            h_theta_fianl.append(h_theta[0][i] / m)
            theta[i] = theta[i] - (args.gamma[1] * h_theta_fianl[i]) / args.inner_ep


        for i in range(len(hyper_param)):
            for j in range(1, m):
                h_x[0][i] += h_x[j][i]
            h_x_fianl.append(h_x[0][i] / m)
        hyper_param['dy'] = hyper_param_copy['dy'] - (args.gamma[2] * h_x_fianl[0]) / args.inner_ep
        hyper_param['ly'] = hyper_param_copy['ly'] - (args.gamma[2] * h_x_fianl[1]) / args.inner_ep
        
        count = 0
        for p in net_glob.parameters():
            p.data = param[count]
            count += 1
        count = 0
        for p in net_glob_theta.parameters():
            p.data = theta[count]
            count += 1
        
        end_time = time.time()
        roundtime = end_time - start_time
        # testing
        net_glob.eval()
        acc_train, loss_train = test_img(net_glob, dataset_train_real, args)
        acc_test, loss_test = test_img(net_glob, dataset_test, args)
        print("Test acc/loss: {:.2f} {:.6f}".format(acc_test, loss_test),
              "Train acc/loss: {:.2f} {:.6f}".format(acc_train, loss_train),
              f"Comm round: {comm_round}", "time: {:.2f}s".format(roundtime))
        # draw plot
        test_accuracies.append(acc_test)

        logs.logging(client_idx, acc_test, acc_train, loss_test, loss_train, comm_round, roundtime)
        logs.save()

        if args.round > 0 and comm_round > args.round:
            break
        
