#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import yaml
import time
from core.test import test_img
from utils.Comm import ComTopo, communication, communication1
from models.SvrgUpdate import LocalUpdate
from utils.options import args_parser
from utils.dataset_normal import load_data
from models.ModelBuilder import build_model
from core.NewClientManage import NewClientManage
from utils.my_logging import Logger
from core.function import assign_hyper_gradient
from torch.optim import SGD
import torch

import numpy as np
import copy


 
start_time = int(time.time())
if __name__ == '__main__':
    num = 0
    args = args_parser()
    torch.manual_seed(num)
    np.random.seed(num)
  
    dataset_train, dataset_test, dict_users, args.img_size, dataset_train_real = load_data(args)

    net_glob = build_model(args)
  


    w_glob = net_glob.state_dict()


    w_glob = net_glob.state_dict()
    if args.output == None:
        logs = Logger(f'./save/dc_dec_{args.alg}_{args.topoModel}_{args.dataset}_{args.model}_{args.epochs}_C{args.frac}_iid{args.iid}_'
                      f'tau{args.inner_ep}_blo{not args.no_blo}_'
                      f'gamma{[args.gamma[0],args.gamma[1],args.gamma[2]]}_{start_time}.yaml')
    else:
        logs = Logger(args.output)                                                          
    

    lambda_x = torch.zeros(int(len(dataset_train)/(args.num_users)), requires_grad=True, device = args.device)
    lambdas_x = [copy.deepcopy(lambda_x) for i in range(args.num_users)] 
    ma_term = torch.zeros(int(len(dataset_train)/(args.num_users)), requires_grad=True, device = args.device)  
    ma_terms = [copy.deepcopy(ma_term) for i in range(args.num_users)]  
    #param = list(net_glob.parameters())
    comm_round=0
    client_idx = list(range(args.num_users))
    net_glob_sum = []
    v = []
    h_y = []
    h_y_old = []
    h_v = []
    h_v_old = []
    for idx in range(args.num_users):
        v_user = []
        for n, k in net_glob.named_parameters():
            v_temp = torch.rand(k.shape)
            v_user.append(v_temp.to(args.device))
        h_v_user = [torch.zeros_like(h_v_user) for h_v_user in v_user]
        h_v.append(h_v_user)
        h_v_old_user = [torch.zeros_like(h_v_user) for h_v_user in v_user]
        h_v_old.append(h_v_old_user)
        v.append(v_user)    
    for idx in client_idx:
        net_glob_sum.append(copy.deepcopy(net_glob))
    params = []
    for idx in client_idx:
        param = list(net_glob_sum[idx].parameters())
        params.append(param)  
        h_y_user = [torch.zeros_like(h_y_user) for h_y_user in param]
        h_y.append(h_y_user)
        h_y_user_old = [torch.zeros_like(h_y_user) for h_y_user in param]
        h_y_old.append(h_y_user)
    test_accuracies = []

    
    while comm_round < args.epochs:

        start_time = time.time()

        w_globs = []

        

        client_manage=NewClientManageclient_manage = NewClientManage(args, net_glob_sum, client_idx, dataset_train, dict_users, 
                                                                     lambdas_x, params, v, ma_terms)
    
        h_y_new, h_v_new, lambdas_x, ma_terms = client_manage.client_job(args)
        

        topo = ComTopo(args.num_users, args.topoModel)
        topo_i_w = (np.eye(args.num_users)+topo)/2
        
        if args.alg == 'SLDBO':
            h_y = communication1(args, h_y, topo_i_w)
            h_v = communication1(args, h_v, topo_i_w)
   
            
            for idx in range(args.num_users):
                for i in range(len(h_y[0])):
                    h_y[idx][i] = h_y[idx][i] + h_y_new[idx][i]-h_y_old[idx][i]
                    h_v[idx][i] = h_v[idx][i] + h_v_new[idx][i]-h_v_old[idx][i]
                    h_y_old[idx][i] = h_y_new[idx][i]
                    h_v_old[idx][i] = h_v_new[idx][i]      
                    
        if args.alg == 'SPARKLE-E':
            h_y = communication1(args, h_y, topo_i_w)
            h_v = communication1(args, h_v, topo_i_w)
               
            for idx in range(args.num_users):
                for i in range(len(h_y[0])):
                    h_y[idx][i] = h_y[idx][i] + h_y_new[idx][i]-h_y_old[idx][i]
                    h_v[idx][i] = h_v[idx][i] + h_v_new[idx][i]-h_v_old[idx][i]
                    h_y_old[idx][i] = h_y_new[idx][i]
                    h_v_old[idx][i] = h_v_new[idx][i]   
        
        elif args.alg == 'D-SOBA':
            h_y = communication1(args, h_y_new, topo_i_w)
            h_v = communication1(args, h_v_new, topo_i_w)
            #ma_term = communication1(args, ma_term, topo_i_w)
            
        elif args.alg == 'SPARKLE':
            
            for idx in range(args.num_users):
                for i in range(len(h_y[0])):
                    h_y[idx][i] = h_y[idx][i] + h_y_new[idx][i]-h_y_old[idx][i]
                    h_y_old[idx][i] = h_y_new[idx][i]
                    h_v[idx][i] = h_v[idx][i] + h_v_new[idx][i]-h_v_old[idx][i]
                    h_v_old[idx][i] = h_v_new[idx][i]

            h_y = communication1(args, h_y, topo_i_w)
            h_v = communication1(args, h_v, topo_i_w)
           
        comm_round += 1  

        for idx in range(args.num_users):
            count = 0
            for param in net_glob_sum[idx].parameters():
                param.data = param.data - (args.gamma[1] * h_y[idx][count]) 
                count += 1


        for idx in range(args.num_users):
            for i in range(len(params[idx])):
                v[idx][i] = v[idx][i] - (args.gamma[2] * h_v[idx][i]) 



        for idx in client_idx:
            parameter = list(net_glob_sum[idx].parameters())
            params[idx] = parameter
            
        params = communication1(args, params, topo_i_w)
    
        v = communication1(args, v, topo_i_w)

        for idx in range(args.num_users):
            count = 0
            for p in net_glob_sum[idx].parameters():
                p.data = params[idx][count]
                count += 1

        end_time = time.time()
        roundtime = end_time - start_time
 
        net_glob_avg = copy.deepcopy(net_glob)


        with torch.no_grad():
            for param in net_glob_avg.parameters():
                param.data.zero_()


        with torch.no_grad():

            avg_params = list(net_glob_avg.parameters())

            for idx in range(args.num_users):
                user_params = params[idx]  

                for layer_idx in range(len(avg_params)):
                    avg_params[layer_idx].data += user_params[layer_idx].data / args.num_users

        acc_train, loss_train = test_img(net_glob_avg, dataset_train_real, args)
        acc_test, loss_test = test_img(net_glob_avg, dataset_test, args)
        print("Test acc/loss: {:.2f} {:.6f}".format(acc_test, loss_test),
            "Train acc/loss: {:.2f} {:.6f}".format(acc_train, loss_train),
            f"Comm round: {comm_round}", "time: {:.2f}s".format(roundtime)) 
        print(f"Memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
        print(f"Memory cached: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
        # draw plot
        test_accuracies.append(acc_test)

        logs.logging(acc_test, acc_train, loss_test, loss_train, comm_round, roundtime)
        logs.save()

        if args.round > 0 and comm_round > args.round:
            break

