#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import random

import yaml
import time
from core.test import test_img
from utils.Comm import communication, communication1, ComTopo
from models.SvrgUpdate import LocalUpdate
from utils.options import args_parser
from utils.dataset_normal import load_data
from models.ModelBuilder import build_model
from core.ClientManage_hr import ClientManageHR
from utils.my_logging import Logger
from core.function import assign_hyper_gradient
from torch.optim import SGD
import matplotlib.pyplot as plt
import torch

import numpy as np
import copy
import os

np.random.seed(0)
nums = np.random.choice(range(0, 100), 10 , replace=False)
print(nums)
#print(nums)
nums = [55]
# Adjust algorithm-related coefficients
for gamma0 in [0.1]:
    for gamma1 in [0.1]:
        for gamma2 in [0.05]: 
            for beta0 in [0.01]:
                for beta1 in [0.01]:
                    for beta2 in [0.01]:
                        for ck_bar in [2]:
                            for num in nums:
         # General momentum term
                                start_time = int(time.time())
                                    #for num in nums:    
                                if __name__ == '__main__':
                                    #num = 0
                                    torch.manual_seed(num)
                                    np.random.seed(num)
                                    test_accuracies = []
                                    # parse param
                                    args = args_parser()
                                    dataset_train, dataset_test, dict_users, args.img_size, dataset_train_real = load_data(args)
                                    net_glob = build_model(args)
                                    net_glob_theta = copy.deepcopy(net_glob)
                                    args.gamma = [gamma0, gamma1, gamma2]
                                    args.beta = [beta0, beta1, beta2]
                                    gama = 50
                                    # Use ck_bar and gama in the algorithm
                                    print(f"Testing with gamma: {args.gamma}, beta: {args.beta}, ck_bar: {ck_bar}, gama: {gama}")   
                                    #set hyper_params    
                                    #ck_bar = 2.7
                                    #gama = 0.1cd /home/yinhaian/anaconda3/envs/pytorch/bin/python /home/mays/2024_DSBO/dec_hyper-representation/MeFBO/sun_hr.py
                                    exp = 0.001
                                    w_glob = net_glob.state_dict()
                                    if args.output == None:
                                        logs = Logger(f'./save/hr_dec_{args.alg}_{args.topoModel}_{args.dataset}_{args.model}_{args.epochs}_C{args.frac}_iid{args.iid}_alg_{args.alg}_topo_{args.topoModel}_'
                                                        f'ck_bar{ck_bar}_gama{gama}_exp{exp}_gamma{[args.gamma[0],args.gamma[1],args.gamma[2]]}_'
                                                        f'beta{[args.beta[0],args.beta[1],args.beta[2]]}_{start_time}.yaml')
                                    else:
                                        logs = Logger(args.output)

                                    # Set inner and outer parameters, and outer optimizer

                                    comm_round = 0


                                    net_glob_sum = []
                                    v = []
                                    params = []
                        
                                    hyper_params = []
                                    #ma_term_new = []
                                    h_y = []
                                    h_y_old = []
                                    h_theta = []
                                    h_theta_old = []
                                    h_x = []
                                    h_x_old = []
                                    v_x = []
                                    v_y = []    
                                    v_theta = []
                                    v_x_new = []
                                    v_y_new = []
                                    v_theta_new = []    
                                    net_glob_theta_sum= []
                                    thetas = []
                                    net_glob_sum_old = []
                                    net_glob_theta_sum_old = []
                                    hyper_params_old = []   
                                    thetas_old = []
                                    params_old = []
                                    for idx in range(args.num_users):
                                        net_glob_sum.append(copy.deepcopy(net_glob))
                                        net_glob_sum_old.append(copy.deepcopy(net_glob))
                                        net_glob_theta_sum.append(copy.deepcopy(net_glob))
                                        net_glob_theta_sum_old.append(copy.deepcopy(net_glob))
                                        hyper_param = [k for n, k in net_glob_sum[idx].named_parameters() if not "header" in n]
                                        #ma_term_user = [torch.zeros_like(h_x_user) for h_x_user in hyper_param]
                                        #ma_term.append(ma_term_user)
                                        #ma_term_new.append(ma_term_user)
                                        #ma_term_old.append(ma_term_user)

                                        hyper_params.append(hyper_param)
                                        hyper_param_old = [k for n, k in net_glob_sum_old[idx].named_parameters() if not "header" in n]
                                        hyper_params_old.append(hyper_param_old)
                                        param = [k for n, k in net_glob_sum[idx].named_parameters() if "header" in n]
                                        params.append(param) 
                                        theta = [k for n, k in net_glob_theta_sum[idx].named_parameters() if "header" in n]
                                        thetas.append(theta) 
                                        param_old = [k for n, k in net_glob_sum_old[idx].named_parameters() if "header" in n]
                                        params_old.append(param_old)
                                        theta_old = [k for n, k in net_glob_theta_sum_old[idx].named_parameters() if "header" in n]
                                        thetas_old.append(theta_old)
                                        v_user = [torch.zeros_like(v_user) for v_user in param]
                                        v_y.append(v_user)
                                        v_user = [torch.zeros_like(v_user) for v_user in param]
                                        v_y_new.append(v_user)
                                        v_user = [torch.zeros_like(v_user) for v_user in param]
                                        v_theta.append(v_user)
                                        v_user = [torch.zeros_like(v_user) for v_user in param]
                                        v_theta_new.append(v_user)
                                        h_y_user = [torch.zeros_like(h_y_user) for h_y_user in param]
                                        h_y.append(h_y_user)
                                        h_y_user = [torch.zeros_like(h_y_user) for h_y_user in param]
                                        h_y_old.append(h_y_user)
                                        h_theta_user = [torch.zeros_like(h_theta_user) for h_theta_user in param]
                                        h_theta.append(h_theta_user)
                                        h_theta_user = [torch.zeros_like(h_theta_user) for h_theta_user in param]
                                        h_theta_old.append(h_theta_user)
                                        h_x_user = [torch.zeros_like(h_x_user) for h_x_user in hyper_param]
                                        h_x.append(h_x_user)
                                        h_x_user = [torch.zeros_like(h_x_user) for h_x_user in hyper_param]
                                        h_x_old.append(h_x_user)
                                        v_x_user = [torch.zeros_like(v_x_user) for v_x_user in hyper_param]
                                        v_x.append(v_x_user)
                                        v_x_user = [torch.zeros_like(v_x_user) for v_x_user in hyper_param]
                                        v_x_new.append(v_x_user)
                                        
                                        #hyper_param = [k for n, k in net_glob.named_parameters() if not "header" in n]
                                        #param = [k for n, k in net_glob.named_parameters() if "header" in n]
                                        theta = []
                                        #for n, k in net_glob.named_parameters():
                                        #   if "header" in n:
                                        #       theta_temp = torch.rand(k.shape)
                                        #       theta.append(theta_temp.to(args.device))
                                        
                                        client_idx = range(args.num_users)
                                        
                                        # number of clients
                                        #m = max(int(args.frac * args.num_users), 1)
                                        w_globs = []
                                        ck_1 = 0
                                        # generate m clients for update and they cannot be selected by multiple times
                                    while comm_round < args.epochs:
                                        start_time = time.time()
                                        ck = ck_bar * 1 / ((1+comm_round)**exp)
                                        if comm_round > 0:
                                            ck_1 = ck_bar * 1 / ((1+comm_round-1)**exp)
                                        
                                        #state_dict_net = net_glob.state_dict()
                                        #state_dict_net_theta = net_glob_theta.state_dict()

                                        client_manage = ClientManageHR(args, net_glob_sum, net_glob_theta_sum, net_glob_sum_old, net_glob_theta_sum_old ,client_idx,  
                                                                        dataset_train, dict_users, hyper_params, params, thetas, hyper_params_old, params_old, 
                                                                        params_old, ck, ck_1, gama)
                                        #net_glob.load_state_dict(state_dict_net)
                                        #net_glob_theta.load_state_dict(state_dict_net_theta)

                                        # client do
                                        if args.alg == 'SUN-HR':
                                            if comm_round == 0:
                                                h_y_new, h_theta_new, h_x_new = client_manage.client_job(args)
                                            else:
                                                h_y_new, h_theta_new, h_x_new, h_y_old, h_theta_old, h_x_old = client_manage.client_job_HR(args)
                                        
                                        else:
                                            h_y_new, h_theta_new, h_x_new = client_manage.client_job(args)
                                        
                                        topo = ComTopo(args.num_users, args.topoModel)
                                        topo_i_w = (np.eye(args.num_users)+topo)/2
                                        
                                        if args.alg == 'SUN-SE':
                                            #for idx in range(len(h_x_new)):
                                            #    for layer_idx in range(len(h_x_new[idx])):
                                            #        ma_term[idx][layer_idx] = (1-args.momentum) * ma_term[idx][layer_idx] + (args.momentum) * h_x_new[idx][layer_idx]
                                            h_x = h_x_new
                                            h_y = h_y_new       
                                            h_theta = h_theta_new

                                        
                                        elif args.alg == 'SUN-GT':
                                            #for idx in range(len(h_x_new)):
                                            #    for layer_idx in range(len(h_x_new[idx])):
                                            #        ma_term[idx][layer_idx] = (1-args.momentum) * ma_term[idx][layer_idx] + (args.momentum) * h_x_new[idx][layer_idx]
                                            for idx in range(args.num_users):
                                                for i in range(len(h_y[0])):
                                                    h_y[idx][i] = h_y[idx][i] + h_y_new[idx][i]-h_y_old[idx][i]
                                                    h_theta[idx][i] = h_theta[idx][i] + h_theta_new[idx][i]-h_theta_old[idx][i]
                                                    h_y_old[idx][i] = h_y_new[idx][i]
                                                    h_theta_old[idx][i] = h_theta_new[idx][i]
                                            for idx in range(args.num_users):
                                                for i in range(len(hyper_param)):
                                                    h_x[idx][i] = h_x[idx][i] + h_x_new[idx][i]-h_x_old[idx][i]
                                                    h_x_old[idx][i] = h_x_new[idx][i]
                                            h_y = communication1(args, h_y_new, topo_i_w)
                                            h_theta = communication1(args, h_theta_new, topo_i_w)
                                            h_x = communication1(args, h_x_new, topo_i_w)

                                        elif args.alg == 'DSGDA-GT':
                                            #for idx in range(len(h_x_new)):
                                            #    for layer_idx in range(len(h_x_new[idx])):
                                            #        ma_term[idx][layer_idx] = (1-args.momentum) * ma_term[idx][layer_idx] + (args.momentum) * h_x_new[idx][layer_idx]
                                            h_y = communication1(args, h_y_new, topo_i_w)
                                            h_theta = communication1(args, h_theta_new, topo_i_w)
                                            h_x = communication1(args, h_x_new, topo_i_w)
                                            for idx in range(args.num_users):
                                                for i in range(len(h_y[0])):
                                                    h_y[idx][i] = h_y[idx][i] + h_y_new[idx][i]-h_y_old[idx][i]
                                                    h_theta[idx][i] = h_theta[idx][i] + h_theta_new[idx][i]-h_theta_old[idx][i]
                                                    h_y_old[idx][i] = h_y_new[idx][i]
                                                    h_theta_old[idx][i] = h_theta_new[idx][i]
                                            for idx in range(args.num_users):
                                                for i in range(len(hyper_param)):
                                                    h_x[idx][i] = h_x[idx][i] + h_x_new[idx][i]-h_x_old[idx][i]
                                                    h_x_old[idx][i] = h_x_new[idx][i]
                                            #h_y = communication1(args, h_y_new, topo_i_w)
                                            #h_theta = communication1(args, h_theta_new, topo_i_w)
                                            #h_x = communication1(args, h_x_new, topo_i_w)
                                            
                                        elif args.alg == 'SUN-HR':
                                            #for idx in range(len(h_x_new)):
                                            #    for layer_idx in range(len(h_x_new[idx])):
                                            #        ma_term[idx][layer_idx] = (1-args.momentum) * ma_term[idx][layer_idx] + (args.momentum) * h_x_new[idx][layer_idx]
                                            for idx in range(args.num_users):
                                                for i in range(len(hyper_param)):
                                                    v_x_new[idx][i] = h_x_new[idx][i] + (1-args.beta[0]) * (v_x[idx][i]-h_x_old[idx][i])
                                                for i in range(len(param)):    
                                                    v_y_new[idx][i] = h_y_new[idx][i] + (1-args.beta[1]) * (v_y[idx][i]-h_y_old[idx][i])
                                                    v_theta_new[idx][i] = h_theta_new[idx][i] + (1-args.beta[2]) * (v_theta[idx][i]-h_theta_old[idx][i])
                                            #print('v_x_new:',v_x_new[idx][1],'v_x:', v_x[idx][1])
                                            #print('v_y_new:',v_y_new[idx][1],'v_y:', v_y[idx][1])       
                                            #print('v_theta_new:',v_theta_new[idx][1],'v_theta:', v_theta[idx][1])
                                            for idx in range(args.num_users):
                                                for i in range(len(hyper_param)):
                                                    h_x[idx][i] = h_x[idx][i] + v_x_new[idx][i]-v_x[idx][i]
                                                for i in range(len(param)):
                                                    h_y[idx][i] = h_y[idx][i] + v_y_new[idx][i]-v_y[idx][i]
                                                    h_theta[idx][i] = h_theta[idx][i] + v_theta_new[idx][i]-v_theta[idx][i]



                                            h_y = communication1(args, h_y, topo_i_w)
                                            h_theta = communication1(args, h_theta, topo_i_w)
                                            h_x = communication1(args, h_x, topo_i_w)
                                            #v_x = v_x_new
                                            #v_theta = v_theta_new
                                            #v_y = v_y_new   
                                        
                                        comm_round += 1
                                        
                                        current_memory = torch.cuda.max_memory_allocated(2)/ (1024**2)
                                        print('===================================current_memor:',current_memory) 
                                        # server do



                                        #update
                                        for idx in range(args.num_users):
                                            for i in range(len(param)):
                                                params[idx][i] = params[idx][i] - (args.gamma[1] * h_y[idx][i]) 
                                                thetas[idx][i] = thetas[idx][i] - (args.gamma[2] * h_theta[idx][i]) 

                                        for idx in range(args.num_users):
                                            for i in range(len(hyper_param)):
                                                hyper_params[idx][i] = hyper_params[idx][i] - (args.gamma[0] * h_x[idx][i]) 
                                        
                                        #communication
                                        
                                        params = communication1(args, params, topo_i_w)
                                        hyper_params = communication1(args, hyper_params, topo_i_w)
                                        thetas = communication1(args, thetas, topo_i_w)
                                        
                                        for idx in range(args.num_users):
                                            count = 0
                                            for p in net_glob_sum[idx].parameters():
                                                if count < len(hyper_param):
                                                    p.data = hyper_params[idx][count]
                                                    count += 1
                                                else:
                                                    p.data = params[idx][count - len(hyper_param)]
                                                    count += 1
                                            count = 0
                                            for p in net_glob_theta_sum[idx].parameters():
                                                if count < len(hyper_param):
                                                    p.data = hyper_params[idx][count]
                                                    count += 1
                                                else:
                                                    p.data = thetas[idx][count - len(hyper_param)]
                                                    count += 1  
                                        
                                        end_time = time.time()
                                        roundtime = end_time - start_time


                                        # testing
                                        net_glob_avg = copy.deepcopy(net_glob)        
                                        with torch.no_grad():
                                            for user_param in net_glob_avg.parameters():
                                                user_param.data.zero_()

                                        with torch.no_grad():
                                            avg_params = list(net_glob_avg.parameters())
                                                
                                        for idx in range(args.num_users):
                                            user_params = list(net_glob_sum[idx].parameters()) 

                                            for layer_idx in range(len(avg_params)):
                                                avg_params[layer_idx].data += user_params[layer_idx].data / args.num_users
                                        
                                        net_glob_avg.eval()
                                        acc_train, loss_train,F1_score_train = test_img(net_glob_avg, dataset_train_real, args)
                                        acc_test, loss_test, F1_score_test = test_img(net_glob_avg, dataset_test, args)
                                        print("Test acc/loss: {:.2f} {:.6f}".format(acc_test, loss_test),
                                                "Train acc/loss: {:.2f} {:.6f}".format(acc_train, loss_train),
                                                "Train F1 score:{:2f}".format(F1_score_train),
                                                "Test F1 score:{:2f}".format(F1_score_test),
                                                f"Comm round: {comm_round}", "time: {:.2f}s".format(roundtime))
                                        # draw plot
                                        test_accuracies.append(acc_test)

                                        logs.logging(acc_test, acc_train, loss_test, loss_train, F1_score_test, F1_score_train, comm_round, roundtime)
                                        logs.save()

                                        if args.round > 0 and comm_round > args.round:
                                            break

                                    #plt.figure()
                                    #plt.plot(range(len(test_accuracies)), test_accuracies)
                                    #plt.ylabel('test_accuracies')

                                    #os.makedirs('./save_SE', exist_ok=True)
                                    #plt.savefig(
                                    #    './save_SE/dec_{}_{}_{}_{}_{}_C{}_iid{}_{}_{}_{}_{}_{}_{}.png'.
                                    #    format(args.alg, args.topoModel, args.dataset, args.model, comm_round, args.frac, args.iid, 
                                    #            args.gamma[0],args.gamma[1],args.gamma[2], args.beta[0],args.beta[1],args.beta[2]))
                                    #plt.close()