#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import random

import yaml
import time
from core.test import test_img
from utils.Comm import communication, communication1, ComTopo
from models.SvrgUpdate import LocalUpdate
from utils.options import args_parser
from utils.dataset_normal import load_data
from models.ModelBuilder import build_model
from core.ClientManage_hr import ClientManageHR
from utils.my_logging import Logger
from core.function import assign_hyper_gradient
from torch.optim import SGD
import matplotlib.pyplot as plt
import torch

import numpy as np
import copy
import os


# General momentum term
start_time = int(time.time())  
if __name__ == '__main__':
    num = 0
    torch.manual_seed(num)
    np.random.seed(num)
    test_accuracies = []
    ck_bar = 2
    # parse param
    args = args_parser()
    dataset_train, dataset_test, dict_users, args.img_size, dataset_train_real = load_data(args)
    net_glob = build_model(args)
    net_glob_theta = copy.deepcopy(net_glob)
    gama = 50
    exp = 0.001
    w_glob = net_glob.state_dict()
    if args.output == None:
        logs = Logger(f'./save/hr_dec_{args.alg}_{args.topoModel}_{args.dataset}_{args.model}_{args.epochs}_C{args.frac}_iid{args.iid}_alg_{args.alg}_topo_{args.topoModel}_'
                        f'ck_bar{ck_bar}_gama{gama}_exp{exp}_gamma{[args.gamma[0],args.gamma[1],args.gamma[2]]}_'
                        f'beta{[args.beta[0],args.beta[1],args.beta[2]]}_{start_time}.yaml')
    else:
        logs = Logger(args.output)

    # Set inner and outer parameters, and outer optimizer

    comm_round = 0


    net_glob_sum = []
    v = []
    params = []

    hyper_params = []
    #ma_term_new = []
    h_y = []
    h_y_old = []
    h_theta = []
    h_theta_old = []
    h_x = []
    h_x_old = []
    v_x = []
    v_y = []    
    v_theta = []
    v_x_new = []
    v_y_new = []
    v_theta_new = []    
    net_glob_theta_sum= []
    thetas = []
    net_glob_sum_old = []
    net_glob_theta_sum_old = []
    hyper_params_old = []   
    thetas_old = []
    params_old = []
    for idx in range(args.num_users):
        net_glob_sum.append(copy.deepcopy(net_glob))
        net_glob_sum_old.append(copy.deepcopy(net_glob))
        net_glob_theta_sum.append(copy.deepcopy(net_glob))
        net_glob_theta_sum_old.append(copy.deepcopy(net_glob))
        hyper_param = [k for n, k in net_glob_sum[idx].named_parameters() if not "header" in n]
        #ma_term_user = [torch.zeros_like(h_x_user) for h_x_user in hyper_param]
        #ma_term.append(ma_term_user)
        #ma_term_new.append(ma_term_user)
        #ma_term_old.append(ma_term_user)

        hyper_params.append(hyper_param)
        hyper_param_old = [k for n, k in net_glob_sum_old[idx].named_parameters() if not "header" in n]
        hyper_params_old.append(hyper_param_old)
        param = [k for n, k in net_glob_sum[idx].named_parameters() if "header" in n]
        params.append(param) 
        theta = [k for n, k in net_glob_theta_sum[idx].named_parameters() if "header" in n]
        thetas.append(theta) 
        param_old = [k for n, k in net_glob_sum_old[idx].named_parameters() if "header" in n]
        params_old.append(param_old)
        theta_old = [k for n, k in net_glob_theta_sum_old[idx].named_parameters() if "header" in n]
        thetas_old.append(theta_old)
        v_user = [torch.zeros_like(v_user) for v_user in param]
        v_y.append(v_user)
        v_user = [torch.zeros_like(v_user) for v_user in param]
        v_y_new.append(v_user)
        v_user = [torch.zeros_like(v_user) for v_user in param]
        v_theta.append(v_user)
        v_user = [torch.zeros_like(v_user) for v_user in param]
        v_theta_new.append(v_user)
        h_y_user = [torch.zeros_like(h_y_user) for h_y_user in param]
        h_y.append(h_y_user)
        h_y_user = [torch.zeros_like(h_y_user) for h_y_user in param]
        h_y_old.append(h_y_user)
        h_theta_user = [torch.zeros_like(h_theta_user) for h_theta_user in param]
        h_theta.append(h_theta_user)
        h_theta_user = [torch.zeros_like(h_theta_user) for h_theta_user in param]
        h_theta_old.append(h_theta_user)
        h_x_user = [torch.zeros_like(h_x_user) for h_x_user in hyper_param]
        h_x.append(h_x_user)
        h_x_user = [torch.zeros_like(h_x_user) for h_x_user in hyper_param]
        h_x_old.append(h_x_user)
        v_x_user = [torch.zeros_like(v_x_user) for v_x_user in hyper_param]
        v_x.append(v_x_user)
        v_x_user = [torch.zeros_like(v_x_user) for v_x_user in hyper_param]
        v_x_new.append(v_x_user)
        
        #hyper_param = [k for n, k in net_glob.named_parameters() if not "header" in n]
        #param = [k for n, k in net_glob.named_parameters() if "header" in n]
        theta = []
        #for n, k in net_glob.named_parameters():
        #   if "header" in n:
        #       theta_temp = torch.rand(k.shape)
        #       theta.append(theta_temp.to(args.device))
        
        client_idx = range(args.num_users)
        
        # number of clients
        #m = max(int(args.frac * args.num_users), 1)
        w_globs = []
        ck_1 = 0
        # generate m clients for update and they cannot be selected by multiple times
    while comm_round < args.epochs:
        start_time = time.time()
        ck = ck_bar * 1 / ((1+comm_round)**exp)
        if comm_round > 0:
            ck_1 = ck_bar * 1 / ((1+comm_round-1)**exp)
        
        #state_dict_net = net_glob.state_dict()
        #state_dict_net_theta = net_glob_theta.state_dict()

        client_manage = ClientManageHR(args, net_glob_sum, net_glob_theta_sum, net_glob_sum_old, net_glob_theta_sum_old ,client_idx,  
                                        dataset_train, dict_users, hyper_params, params, thetas, hyper_params_old, params_old, 
                                        params_old, ck, ck_1, gama)
        #net_glob.load_state_dict(state_dict_net)
        #net_glob_theta.load_state_dict(state_dict_net_theta)

        # client do
        if args.alg == 'SUN-HR':
            if comm_round == 0:
                h_y_new, h_theta_new, h_x_new = client_manage.client_job(args)
            else:
                h_y_new, h_theta_new, h_x_new, h_y_old, h_theta_old, h_x_old = client_manage.client_job_HR(args)
        
        else:
            h_y_new, h_theta_new, h_x_new = client_manage.client_job(args)
        
        topo = ComTopo(args.num_users, args.topoModel)
        topo_i_w = (np.eye(args.num_users)+topo)/2
        
        if args.alg == 'SUN-SE':
            #for idx in range(len(h_x_new)):
            #    for layer_idx in range(len(h_x_new[idx])):
            #        ma_term[idx][layer_idx] = (1-args.momentum) * ma_term[idx][layer_idx] + (args.momentum) * h_x_new[idx][layer_idx]
            h_x = h_x_new
            h_y = h_y_new       
            h_theta = h_theta_new

        
        elif args.alg == 'SUN-GT':
            #for idx in range(len(h_x_new)):
            #    for layer_idx in range(len(h_x_new[idx])):
            #        ma_term[idx][layer_idx] = (1-args.momentum) * ma_term[idx][layer_idx] + (args.momentum) * h_x_new[idx][layer_idx]
            for idx in range(args.num_users):
                for i in range(len(h_y[0])):
                    h_y[idx][i] = h_y[idx][i] + h_y_new[idx][i]-h_y_old[idx][i]
                    h_theta[idx][i] = h_theta[idx][i] + h_theta_new[idx][i]-h_theta_old[idx][i]
                    h_y_old[idx][i] = h_y_new[idx][i]
                    h_theta_old[idx][i] = h_theta_new[idx][i]
            for idx in range(args.num_users):
                for i in range(len(hyper_param)):
                    h_x[idx][i] = h_x[idx][i] + h_x_new[idx][i]-h_x_old[idx][i]
                    h_x_old[idx][i] = h_x_new[idx][i]
            h_y = communication1(args, h_y_new, topo_i_w)
            h_theta = communication1(args, h_theta_new, topo_i_w)
            h_x = communication1(args, h_x_new, topo_i_w)

        elif args.alg == 'DSGDA-GT':
            #for idx in range(len(h_x_new)):
            #    for layer_idx in range(len(h_x_new[idx])):
            #        ma_term[idx][layer_idx] = (1-args.momentum) * ma_term[idx][layer_idx] + (args.momentum) * h_x_new[idx][layer_idx]
            h_y = communication1(args, h_y_new, topo_i_w)
            h_theta = communication1(args, h_theta_new, topo_i_w)
            h_x = communication1(args, h_x_new, topo_i_w)
            for idx in range(args.num_users):
                for i in range(len(h_y[0])):
                    h_y[idx][i] = h_y[idx][i] + h_y_new[idx][i]-h_y_old[idx][i]
                    h_theta[idx][i] = h_theta[idx][i] + h_theta_new[idx][i]-h_theta_old[idx][i]
                    h_y_old[idx][i] = h_y_new[idx][i]
                    h_theta_old[idx][i] = h_theta_new[idx][i]
            for idx in range(args.num_users):
                for i in range(len(hyper_param)):
                    h_x[idx][i] = h_x[idx][i] + h_x_new[idx][i]-h_x_old[idx][i]
                    h_x_old[idx][i] = h_x_new[idx][i]
            #h_y = communication1(args, h_y_new, topo_i_w)
            #h_theta = communication1(args, h_theta_new, topo_i_w)
            #h_x = communication1(args, h_x_new, topo_i_w)
            
        elif args.alg == 'SUN-HR':
            #for idx in range(len(h_x_new)):
            #    for layer_idx in range(len(h_x_new[idx])):
            #        ma_term[idx][layer_idx] = (1-args.momentum) * ma_term[idx][layer_idx] + (args.momentum) * h_x_new[idx][layer_idx]
            for idx in range(args.num_users):
                for i in range(len(hyper_param)):
                    v_x_new[idx][i] = h_x_new[idx][i] + (1-args.beta[0]) * (v_x[idx][i]-h_x_old[idx][i])
                for i in range(len(param)):    
                    v_y_new[idx][i] = h_y_new[idx][i] + (1-args.beta[1]) * (v_y[idx][i]-h_y_old[idx][i])
                    v_theta_new[idx][i] = h_theta_new[idx][i] + (1-args.beta[2]) * (v_theta[idx][i]-h_theta_old[idx][i])
            #print('v_x_new:',v_x_new[idx][1],'v_x:', v_x[idx][1])
            #print('v_y_new:',v_y_new[idx][1],'v_y:', v_y[idx][1])       
            #print('v_theta_new:',v_theta_new[idx][1],'v_theta:', v_theta[idx][1])
            for idx in range(args.num_users):
                for i in range(len(hyper_param)):
                    h_x[idx][i] = h_x[idx][i] + v_x_new[idx][i]-v_x[idx][i]
                for i in range(len(param)):
                    h_y[idx][i] = h_y[idx][i] + v_y_new[idx][i]-v_y[idx][i]
                    h_theta[idx][i] = h_theta[idx][i] + v_theta_new[idx][i]-v_theta[idx][i]



            h_y = communication1(args, h_y, topo_i_w)
            h_theta = communication1(args, h_theta, topo_i_w)
            h_x = communication1(args, h_x, topo_i_w)
            #v_x = v_x_new
            #v_theta = v_theta_new
            #v_y = v_y_new   
        
        comm_round += 1
        
        current_memory = torch.cuda.max_memory_allocated(2)/ (1024**2)
        print('===================================current_memor:',current_memory) 
        # server do



        #update
        for idx in range(args.num_users):
            for i in range(len(param)):
                params[idx][i] = params[idx][i] - (args.gamma[1] * h_y[idx][i]) 
                thetas[idx][i] = thetas[idx][i] - (args.gamma[2] * h_theta[idx][i]) 

        for idx in range(args.num_users):
            for i in range(len(hyper_param)):
                hyper_params[idx][i] = hyper_params[idx][i] - (args.gamma[0] * h_x[idx][i]) 
        
        #communication
        
        params = communication1(args, params, topo_i_w)
        hyper_params = communication1(args, hyper_params, topo_i_w)
        thetas = communication1(args, thetas, topo_i_w)
        
        for idx in range(args.num_users):
            count = 0
            for p in net_glob_sum[idx].parameters():
                if count < len(hyper_param):
                    p.data = hyper_params[idx][count]
                    count += 1
                else:
                    p.data = params[idx][count - len(hyper_param)]
                    count += 1
            count = 0
            for p in net_glob_theta_sum[idx].parameters():
                if count < len(hyper_param):
                    p.data = hyper_params[idx][count]
                    count += 1
                else:
                    p.data = thetas[idx][count - len(hyper_param)]
                    count += 1  
        
        end_time = time.time()
        roundtime = end_time - start_time


        # testing
        net_glob_avg = copy.deepcopy(net_glob)        
        with torch.no_grad():
            for user_param in net_glob_avg.parameters():
                user_param.data.zero_()

        with torch.no_grad():
            avg_params = list(net_glob_avg.parameters())
                
        for idx in range(args.num_users):
            user_params = list(net_glob_sum[idx].parameters()) 

            for layer_idx in range(len(avg_params)):
                avg_params[layer_idx].data += user_params[layer_idx].data / args.num_users
        
        net_glob_avg.eval()
        acc_train, loss_train,F1_score_train = test_img(net_glob_avg, dataset_train_real, args)
        acc_test, loss_test, F1_score_test = test_img(net_glob_avg, dataset_test, args)
        print("Test acc/loss: {:.2f} {:.6f}".format(acc_test, loss_test),
                "Train acc/loss: {:.2f} {:.6f}".format(acc_train, loss_train),
                "Train F1 score:{:2f}".format(F1_score_train),
                "Test F1 score:{:2f}".format(F1_score_test),
                f"Comm round: {comm_round}", "time: {:.2f}s".format(roundtime))
        # draw plot
        test_accuracies.append(acc_test)

        logs.logging(acc_test, acc_train, loss_test, loss_train, F1_score_test, F1_score_train, comm_round, roundtime)
        logs.save()

        if args.round > 0 and comm_round > args.round:
            break

                                    #plt.figure()
                                    #plt.plot(range(len(test_accuracies)), test_accuracies)
                                    #plt.ylabel('test_accuracies')

                                    #os.makedirs('./save_SE', exist_ok=True)
                                    #plt.savefig(
                                    #    './save_SE/dec_{}_{}_{}_{}_{}_C{}_iid{}_{}_{}_{}_{}_{}_{}.png'.
                                    #    format(args.alg, args.topoModel, args.dataset, args.model, comm_round, args.frac, args.iid, 
                                    #            args.gamma[0],args.gamma[1],args.gamma[2], args.beta[0],args.beta[1],args.beta[2]))
                                    #plt.close()