
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import yaml
import time
from core.test import test_img
from utils.Comm import ComTopo, communication, communication1
from models.SvrgUpdate import LocalUpdate
from utils.options import args_parser
from utils.dataset_normal import load_data
from models.ModelBuilder import build_model
from core.NewClientManage import NewClientManage
from utils.my_logging import Logger
from core.function import assign_hyper_gradient
from torch.optim import SGD
import torch
import torch.nn.init as init
import numpy as np
import copy
import matplotlib.pyplot as plt
import os


num=0
start_time = int(time.time())
if __name__ == '__main__':
    args = args_parser()
    torch.manual_seed(num)
    np.random.seed(num)
  
    dataset_train, dataset_test, dict_users, args.img_size, dataset_train_real = load_data(args)

    net_glob = build_model(args)
    net_glob_theta = copy.deepcopy(net_glob)

    w_glob = net_glob.state_dict()
    w_glob_theta = net_glob_theta.state_dict()
    
    ck_bar = args.ck_bar
    gama = args.gama
    exp = args.exp

    w_glob = net_glob.state_dict()
    if args.output == None:
        logs = Logger(f'./save/dc_dec_{args.alg}_{args.topoModel}_{args.dataset}_{args.model}_{args.epochs}_C{args.frac}_iid{args.iid}_'
                      f'tau{args.inner_ep}_blo{not args.no_blo}_ck_bar{ck_bar}_gamma{gama}_exp{exp}_'
                      f'gamma{[args.gamma[0],args.gamma[1],args.gamma[2]]}_{start_time}.yaml')
    else:
        logs = Logger(args.output)                                                          
    

    
    hyper_params=[]
    hyper_params_old=[]
    for idx in range(args.num_users):
        hyper_param = torch.zeros(int(len(dataset_train)/args.num_users), requires_grad=True, device = args.device)
        hyper_params.append(hyper_param)    
        hyper_param_old = torch.zeros(int(len(dataset_train)/args.num_users), requires_grad=True, device = args.device)
        hyper_params_old.append(hyper_param_old)
    comm_round=0
    net_glob_sum = []
    v = []
    params = []

    #hyper_params = []
    #ma_term_new = []
    h_y = []
    h_y_old = []
    h_theta = []
    h_theta_old = []
    h_x = []
    h_x_old = []
    v_x = []
    v_y = []    
    v_theta = []
    v_x_new = []
    v_y_new = []
    v_theta_new = []    
    net_glob_theta_sum= []
    thetas = []
    net_glob_sum_old = []
    net_glob_theta_sum_old = []
    thetas_old = []
    params_old = []
    for idx in range(args.num_users):
        net_glob_sum.append(copy.deepcopy(net_glob))
        net_glob_sum_old.append(copy.deepcopy(net_glob))
        net_glob_theta_sum.append(copy.deepcopy(net_glob))
        net_glob_theta_sum_old.append(copy.deepcopy(net_glob))

        param = list(net_glob_sum[idx].parameters())
        params.append(param) 
        theta = list(net_glob_theta_sum[idx].parameters())
        thetas.append(theta) 
        param_old = list(net_glob_sum_old[idx].parameters())
        params_old.append(param_old)
        theta_old = list(net_glob_theta_sum_old[idx].parameters())
        thetas_old.append(theta_old)
        v_user = [torch.zeros_like(v_user) for v_user in param]
        v_y.append(v_user)
        v_user = [torch.zeros_like(v_user) for v_user in param]
        v_y_new.append(v_user)
        v_user = [torch.zeros_like(v_user) for v_user in param]
        v_theta.append(v_user)
        v_user = [torch.zeros_like(v_user) for v_user in param]
        v_theta_new.append(v_user)
        h_y_user = [torch.zeros_like(h_y_user) for h_y_user in param]
        h_y.append(h_y_user)
        h_y_user = [torch.zeros_like(h_y_user) for h_y_user in param]
        h_y_old.append(h_y_user)
        h_theta_user = [torch.zeros_like(h_theta_user) for h_theta_user in param]
        h_theta.append(h_theta_user)
        h_theta_user = [torch.zeros_like(h_theta_user) for h_theta_user in param]
        h_theta_old.append(h_theta_user)
        h_x_user = [torch.zeros_like(h_x_user) for h_x_user in hyper_param]
        h_x.append(h_x_user)
        h_x_user = [torch.zeros_like(h_x_user) for h_x_user in hyper_param]
        h_x_old.append(h_x_user)
        v_x_user = [torch.zeros_like(v_x_user) for v_x_user in hyper_param]
        v_x.append(v_x_user)
        v_x_user = [torch.zeros_like(v_x_user) for v_x_user in hyper_param]
        v_x_new.append(v_x_user)

    test_accuracies = []
    #hyper_optimizer=SGD([hyper_param[k] for k in hyper_param], lr=alpha)
    ck_1 = 1
    while comm_round < args.epochs:
        #gama = 0.1
        start_time = time.time()
        # number of clients
        #m = max(int(args.frac * args.num_users), 1)
        w_globs = []
        # generate m clients for update and they cannot be selected by multiple times
        
        ck = ck_bar * 1 / ((1+comm_round)**exp)
        
        client_idx = range(args.num_users)

        state_dict_net = net_glob.state_dict()
        state_dict_net_theta = net_glob_theta.state_dict()
        ck = ck_bar * 1 / ((1+comm_round)**exp)
        if comm_round > 0:
            ck_1 = ck_bar * 1 / ((1+comm_round-1)**exp)



        client_manage=NewClientManage(args, net_glob_sum, net_glob_theta_sum, net_glob_sum_old, net_glob_theta_sum_old ,client_idx,  
                                                                        dataset_train, dict_users, hyper_params, params, thetas, hyper_params_old, params_old, 
                                                                        params_old, ck, ck_1, gama, v_x)
    
        if args.alg == 'SUN-HR':
            if comm_round == 0:
                h_y_new, h_theta_new, hyper_params = client_manage.client_job(args.gamma)
            else:
                h_y_new, h_theta_new, hyper_params, h_y_old, h_theta_old, v_x = client_manage.client_job_HR(args.gamma)
        

        elif args.alg in ['DGD-T','DGT-T']:
            h_y_new = client_manage.client_job_sl(args.gamma)   
                     
        else:
            h_y_new, h_theta_new, hyper_params = client_manage.client_job(args.gamma)
                
        topo = ComTopo(args.num_users, args.topoModel)
        topo_i_w = (np.eye(args.num_users)+topo)/2


                
        if args.alg == 'SUN-SE':
               
            #h_x = h_x_new
            h_y = h_y_new       
            h_theta = h_theta_new
        
        elif args.alg == 'DGD-T':
            h_y = h_y_new
         
                
        elif args.alg == 'SUN-GT':
          
            for idx in range(args.num_users):
                for i in range(len(h_y[0])):
                    h_y[idx][i] = h_y[idx][i] + h_y_new[idx][i]-h_y_old[idx][i]
                    h_theta[idx][i] = h_theta[idx][i] + h_theta_new[idx][i]-h_theta_old[idx][i]
                    h_y_old[idx][i] = h_y_new[idx][i]
                    h_theta_old[idx][i] = h_theta_new[idx][i]
        
            h_y = communication1(args, h_y_new, topo_i_w)
            h_theta = communication1(args, h_theta_new, topo_i_w)
            #h_x = communication1(args, h_x_new, topo_i_w)
            
        elif args.alg == 'DGT-T':
            for idx in range(args.num_users):
                for i in range(len(h_y[0])):
                    h_y[idx][i] = h_y[idx][i] + h_y_new[idx][i]-h_y_old[idx][i]
                    h_y_old[idx][i] = h_y_new[idx][i]
          
            h_y = communication1(args, h_y_new, topo_i_w)
            #h_x = communication1(args, h_x_new, topo_i_w)

        elif args.alg == 'DSGDA-GT':
            
            h_y = communication1(args, h_y_new, topo_i_w)
            h_theta = communication1(args, h_theta_new, topo_i_w)
            
            for idx in range(args.num_users):
                for i in range(len(h_y[0])):
                    h_y[idx][i] = h_y[idx][i] + h_y_new[idx][i]-h_y_old[idx][i]
                    h_theta[idx][i] = h_theta[idx][i] + h_theta_new[idx][i]-h_theta_old[idx][i]
                    h_y_old[idx][i] = h_y_new[idx][i]
                    h_theta_old[idx][i] = h_theta_new[idx][i]
                    

                 
                       
        comm_round += 1
                
        

        #update
        for idx in range(args.num_users):
            for i in range(len(param)):
                params[idx][i] = params[idx][i] - (args.gamma[1] * h_y[idx][i]) 
                thetas[idx][i] = thetas[idx][i] - (args.gamma[2] * h_theta[idx][i]) 


        #communication
        
        params = communication1(args, params, topo_i_w)
        #hyper_params = communication1(args, hyper_params, topo_i_w)
        thetas = communication1(args, thetas, topo_i_w)
        
        for idx in range(args.num_users):
            count = 0
            for p in net_glob_sum[idx].parameters():
                p.data = params[idx][count]
                count += 1
    
        
        end_time = time.time()
        roundtime = end_time - start_time


        # testing
        net_glob_avg = copy.deepcopy(net_glob)        
        with torch.no_grad():
            for user_param in net_glob_avg.parameters():
                user_param.data.zero_()

        with torch.no_grad():
            avg_params = list(net_glob_avg.parameters())
                
        for idx in range(args.num_users):
            user_params = list(net_glob_sum[idx].parameters()) 

            for layer_idx in range(len(avg_params)):
                avg_params[layer_idx].data += user_params[layer_idx].data / args.num_users
        
        net_glob_avg.eval()
        acc_train, loss_train = test_img(net_glob_avg, dataset_train_real, args)
        acc_test, loss_test = test_img(net_glob_avg, dataset_test, args)
        print("Test acc/loss: {:.2f} {:.6f}".format(acc_test, loss_test),
                "Train acc/loss: {:.2f} {:.6f}".format(acc_train, loss_train),
                f"Comm round: {comm_round}", "time: {:.2f}s".format(roundtime))
        # draw plot
        test_accuracies.append(acc_test)

        logs.logging(acc_test, acc_train, loss_test, loss_train, comm_round, roundtime)
        logs.save()

        if args.round > 0 and comm_round > args.round:
            break

       