import copy
# from cv2 import log
import numpy as np
import time
import torch

from core.function import gather_flat_hyper_params
from utils.Fed import FedAvg, FedAvgGradient, FedAvgP
from core.SGDClient_hr import SGDClient
from core.SVRGClient_hr import SVRGClient
from core.Client_hr import Client
from core.ClientManage import ClientManage


class ClientManageHR(ClientManage):
    def __init__(self, args, net_glob_sum, client_idx, dataset, dict_users, hyper_params, params, v) -> None:
        super().__init__(args, net_glob_sum, client_idx, dataset, dict_users)

        self.client_idx = client_idx
        self.args = args
        self.dataset = dataset
        self.dict_users = dict_users
        self.net_glob_sum = net_glob_sum
        self.params = params
        self.v = v
        self.hyper_params = hyper_params

        #self.param_old = [x.clone().detach() for x in param]
        #self.v_old = [x.clone().detach() for x in v]
        #self.hyper_param_old = [x.clone().detach() for x in hyper_param]


    def client_job(self):
        h_y = []
        h_v = []
        h_x = []
        client_locals = []
        # k=0, 1, 2,...,tau
        for idx in self.client_idx:
            client = SVRGClient(self.args, idx, self.net_glob_sum[idx], self.dataset, self.dict_users, self.params[idx],
                                self.hyper_params[idx], self.v[idx])
            for k in range(self.args.inner_ep):
                # update y
                start_time1 = time.time()
                inner_grad = client.grad_d_in_d_y()
                #count = 0
                end_time1 = time.time()
                #print('=========================================================running time:', end_time1 - start_time1)

                #for param in client.net.parameters():
                #    if count < len(self.hyper_param):
                #        count += 1
                #        continue
                #    else:
                #        param.data = param.data - eta[0] * inner_grad[count - len(self.hyper_param)]
                #        count += 1


                # update v
                v_grad = client.grad_v_R()
                #for i in range(len(self.param)):
                #    client.v[i] = client.v[i] - eta[1] * v_grad[i]

                # update x
                x_update = client.grad_f_bar()
                #count = 0
                #for param in client.net.parameters():
                #    if count == len(self.hyper_param):
                #        break
                #    else:
                #        param.data = param.data - eta[2] * x_update[count]
                #    count += 1
                    
            #current_memory = torch.cuda.max_memory_allocated(3)/ (1024**2)
            #print('===================================current_memor:',current_memory)      
        
            #h_yi = []
            #h_vi = []
            #h_xi = []

            #count = 0
            #for param in client.net.parameters():
            #    if count < len(self.hyper_param):
            #        h_xi.append((-param.data + self.hyper_param_old[count]) / (eta[2] * self.args.inner_ep))
            #    else:
            #        h_yi.append((-param.data + self.param_old[count - len(self.hyper_param)]) / (eta[0] * self.args.inner_ep))
            #        h_vi.append((-client.v[count - len(self.hyper_param)] + self.v_old[count - len(self.hyper_param)]) /
            #                    (eta[1] * self.args.inner_ep))
            #    count += 1
                
        
            client_locals.append(client)
            h_x.append(x_update)
            h_y.append(inner_grad)
            h_v.append(v_grad)
            
        return h_y, h_v, h_x


