#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import random

import yaml
import time
from core.test import test_img
from utils.Comm import communication, communication1, ComTopo
from models.SvrgUpdate import LocalUpdate
from utils.options import args_parser
from utils.dataset_normal import load_data
from models.ModelBuilder import build_model
from core.ClientManage_hr import ClientManageHR
from utils.my_logging import Logger
from core.function import assign_hyper_gradient
from torch.optim import SGD
import matplotlib.pyplot as plt
import torch

import numpy as np
import copy



start_time = int(time.time())
if __name__ == '__main__':
    num=0
    torch.manual_seed(num)
    np.random.seed(num)
    test_accuracies = []
    # parse param
    args = args_parser()
    dataset_train, dataset_test, dict_users, args.img_size, dataset_train_real = load_data(args)
    net_glob = build_model(args)

    net_glob_sum = []
    v = []
    params = []
    ma_term = []
    hyper_params = []
    ma_term_new = []
    h_y = []
    h_y_old = []
    h_v = []
    h_x = []
    h_v_old = []
    ma_term_old = []
    for idx in range(args.num_users):
        v_user = []
        for n, k in net_glob.named_parameters():
            if "header" in n:
                v_temp = torch.rand(k.shape)
                v_user.append(v_temp.to(args.device))
        v.append(v_user) 
           
        net_glob_sum.append(copy.deepcopy(net_glob))
        hyper_param = [k for n, k in net_glob_sum[idx].named_parameters() if not "header" in n]
        ma_term_user = [torch.zeros_like(h_x_user) for h_x_user in hyper_param]
        ma_term.append(ma_term_user)
        ma_term_new.append(ma_term_user)
        ma_term_old.append(ma_term_user)
        hyper_params.append(hyper_param)
        param = [k for n, k in net_glob.named_parameters() if "header" in n]
        params.append(param)    
        h_y_user = [torch.zeros_like(h_y_user) for h_y_user in param]
        h_v_user = [torch.zeros_like(h_v_user) for h_v_user in v_user]
        h_y.append(h_y_user)
        h_v.append(h_v_user)
        h_y_user = [torch.zeros_like(h_y_user) for h_y_user in param]
        h_v_user = [torch.zeros_like(h_v_user) for h_v_user in v_user]
        h_y_old.append(h_y_user)
        h_v_old.append(h_v_user)
    #print(f"Shape of h_x for user 0: {[tensor.shape for tensor in hyper_params[0]]}")
    w_glob = net_glob.state_dict()
    if args.output == None:
        logs = Logger(f'./save_cir_mm_sl/hr_{args.alg}_{args.topoModel}_{args.dataset}_{args.model}_{args.epochs}_iid{args.iid}_'
                      f'tau{args.inner_ep}_blo{not args.no_blo}_'
                      f'gamma{[args.gamma[0],args.gamma[1],args.gamma[2]]}_{start_time}.yaml')
    else:
        logs = Logger(args.output)

    comm_round = 0

    while comm_round < args.epochs:
        # number of clients
        #m = max(int(args.frac * args.num_users), 1)
        w_globs = []
        start_time = time.time()
        # generate m clients for update and they cannot be selected by multiple times
        client_idx = range(args.num_users)
        client_manage = ClientManageHR(args, net_glob_sum, client_idx, dataset_train, dict_users, hyper_params, params, v)
        
        # client do


        h_y_new, h_v_new, h_x_new = client_manage.client_job()
        
        
        topo = ComTopo(args.num_users, args.topoModel)
        topo_i_w = (np.eye(args.num_users)+topo)/2
        
        if args.alg == 'SLDBO':
            args.momentum = 1
            for idx in range(len(h_x_new)):
                for layer_idx in range(len(h_x_new[idx])):
                    ma_term[idx][layer_idx] = (1-args.momentum) * ma_term[idx][layer_idx] + (args.momentum) * h_x_new[idx][layer_idx]
       
            h_y = communication1(args, h_y, topo_i_w)
            h_v = communication1(args, h_v, topo_i_w)
            ma_term = communication1(args, ma_term, topo_i_w)  

            for idx in range(args.num_users):
                for i in range(len(hyper_param)):
                    ma_term[idx][i] = ma_term[idx][i] + ma_term_new[idx][i]-ma_term_old[idx][i]
                    ma_term_old[idx][i] = ma_term_new[idx][i]
            
            for idx in range(args.num_users):
                for i in range(len(h_y[0])):
                    h_y[idx][i] = h_y[idx][i] + h_y_new[idx][i]-h_y_old[idx][i]
                    h_v[idx][i] = h_v[idx][i] + h_v_new[idx][i]-h_v_old[idx][i]
                    h_y_old[idx][i] = h_y_new[idx][i]
                    h_v_old[idx][i] = h_v_new[idx][i]       
        
        elif args.alg == 'D-SOBA':
            for idx in range(len(h_x_new)):
                for layer_idx in range(len(h_x_new[idx])):
                    ma_term[idx][layer_idx] = (1-args.momentum) * ma_term[idx][layer_idx] + (args.momentum) * h_x_new[idx][layer_idx]
            h_y = communication1(args, h_y_new, topo_i_w)
            h_v = communication1(args, h_v_new, topo_i_w)
            ma_term = communication1(args, ma_term, topo_i_w)
            
        elif args.alg == 'SPARKLE':
            for idx in range(len(h_x_new)):
                for layer_idx in range(len(h_x_new[idx])):
                    ma_term[idx][layer_idx] = (1-args.momentum) * ma_term[idx][layer_idx] + (args.momentum) * h_x_new[idx][layer_idx]
            for idx in range(args.num_users):
                for i in range(len(hyper_param)):
                    ma_term[idx][i] = ma_term[idx][i] + ma_term_new[idx][i]-ma_term_old[idx][i]
                    ma_term_old[idx][i] = ma_term_new[idx][i]
            
            for idx in range(args.num_users):
                for i in range(len(h_y[0])):
                    h_y[idx][i] = h_y[idx][i] + h_y_new[idx][i]-h_y_old[idx][i]
                    h_v[idx][i] = h_v[idx][i] + h_v_new[idx][i]-h_v_old[idx][i]
                    h_y_old[idx][i] = h_y_new[idx][i]
                    h_v_old[idx][i] = h_v_new[idx][i]

            h_y = communication1(args, h_y, topo_i_w)
            h_v = communication1(args, h_v, topo_i_w)
            ma_term = communication1(args, ma_term, topo_i_w)
            
        elif args.alg == 'SPARKLE-E':
            for idx in range(len(h_x_new)):
                for layer_idx in range(len(h_x_new[idx])):
                    ma_term[idx][layer_idx] = (1-args.momentum) * ma_term[idx][layer_idx] + (args.momentum) * h_x_new[idx][layer_idx]
            h_y = communication1(args, h_y, topo_i_w)
            h_v = communication1(args, h_v, topo_i_w)
            ma_term = communication1(args, ma_term, topo_i_w)
            
            for idx in range(len(h_x_new)):
                for layer_idx in range(len(h_x_new[idx])):
                    ma_term[idx][layer_idx] = (1-args.momentum) * ma_term[idx][layer_idx] + (args.momentum) * h_x_new[idx][layer_idx]
            for idx in range(args.num_users):
                for i in range(len(hyper_param)):
                    ma_term[idx][i] = ma_term[idx][i] + ma_term_new[idx][i]-ma_term_old[idx][i]
                    ma_term_old[idx][i] = ma_term_new[idx][i]
            
            for idx in range(args.num_users):
                for i in range(len(h_y[0])):
                    h_y[idx][i] = h_y[idx][i] + h_y_new[idx][i]-h_y_old[idx][i]
                    h_v[idx][i] = h_v[idx][i] + h_v_new[idx][i]-h_v_old[idx][i]
                    h_y_old[idx][i] = h_y_new[idx][i]
                    h_v_old[idx][i] = h_v_new[idx][i]
    
            
            
        comm_round += 1
        


        #update
        for idx in range(args.num_users):
            for i in range(len(params[0])):
                params[idx][i] = params[idx][i] - (args.gamma[0] * h_y[idx][i]) 
                v[idx][i] = v[idx][i] - (args.gamma[1] * h_v[idx][i]) 

        for idx in range(args.num_users):
            for i in range(len(hyper_params[0])):
                hyper_params[idx][i] = hyper_params[idx][i] - (args.gamma[2] * ma_term[idx][i]) 
        
        #communication


        
        params = communication1(args, params, topo_i_w)
        hyper_params = communication1(args, hyper_params, topo_i_w)
        v = communication1(args, v, topo_i_w)

        for idx in range(args.num_users):
            count = 0
            for p in net_glob_sum[idx].parameters():
                if count < len(hyper_param):
                    p.data = hyper_params[idx][count]
                    count += 1
                else:
                    p.data = params[idx][count - len(hyper_param)]
                    count += 1
            
        end_time = time.time()

        round_time = end_time - start_time   
            
        # testing            
        net_glob_avg = copy.deepcopy(net_glob)        
        with torch.no_grad():
            for param_user in net_glob_avg.parameters():
                param_user.data.zero_()

        with torch.no_grad():
            avg_params = list(net_glob_avg.parameters())
               
        for idx in range(args.num_users):
            user_params = list(net_glob_sum[idx].parameters()) 

            for layer_idx in range(len(avg_params)):
                avg_params[layer_idx].data += user_params[layer_idx].data / args.num_users
       
        net_glob_avg.eval()
        acc_train, loss_train = test_img(net_glob_avg, dataset_train_real, args)
        acc_test, loss_test = test_img(net_glob_avg, dataset_test, args)
        print("Test acc/loss: {:.2f} {:.2f}".format(acc_test, loss_test),
              "Train acc/loss: {:.2f} {:.2f}".format(acc_train, loss_train),
              f"Comm round: {comm_round}","time: {:.2f}".format(round_time))
        
        test_accuracies.append(acc_test)

        logs.logging(acc_test, acc_train, loss_test, loss_train, comm_round, round_time)
        logs.save()
  
        if args.round > 0 and comm_round > args.round:
            break


   