import copy
from cv2 import log
import numpy as np

import torch
from torch.optim import SGD

from core.function import gather_flat_hyper_params ,get_trainable_hyper_params 
from utils.Comm import ComTopo, communication, communication1
from core.function import *
from core.SGDClient import SGDClient
from core.SVRGClient import SVRGClient
from core.NewThetaClient import NewThetaClient
from core.Client import Client
from core.ClientManage import ClientManage

class NewClientManage(ClientManage):
    def __init__(self, args, net_glob_sum, client_idx, dataset, dict_users, hyper_params, params, v, ma_terms) -> None:
        super().__init__(args, net_glob_sum, client_idx, dataset, dict_users)
        
        self.client_idx = client_idx
        self.args = args
        self.dataset = dataset
        self.dict_users = dict_users
        self.net_glob_sum = net_glob_sum
        self.gamma = args.gamma
        self.params = params
        self.v = v
        self.ma_terms = ma_terms


        self.hyper_params = hyper_params

     


    def client_job(self, args):
        h_y = []
        h_v = []
        hyper_params= []
        #client_locals = []
        
        for idx in self.client_idx:
            client = SVRGClient(self.args, idx, self.net_glob_sum[idx], self.dataset, self.dict_users, self.params[idx],
                                self.hyper_params[idx], self.v[idx])
            hyper_param = self.hyper_params[idx]
            # update y
            for k in range(self.args.inner_ep):
                # update y
                inner_grad = client.grad_d_in_d_y()
                #count = 0


            # update v
                v_grad = client.grad_v_R()
            #for i in range(len(self.param)):
            #    client.v[i] = client.v[i] - eta[1] * v_grad[i]

            # update x
                x_update = client.grad_f_bar()

                if args.alg in ['D-SOBA', 'SPARKLE', 'SPARKLE-E']:
                    self.ma_terms[idx] = [
                            (1 - args.momentum) * self.ma_terms[idx][i] + args.momentum * x_update[i]
                                for i in range(len(self.ma_terms[idx]))  ]    
                    for count in range(len(hyper_param)):     
                            hyper_param.detach()[count] = hyper_param.detach()[count] - self.gamma[0] * self.ma_terms[0][count]
                elif args.alg == 'SLDBO':
                    for count in range(len(hyper_param)):     
                            hyper_param.detach()[count] = hyper_param.detach()[count] - self.gamma[0] * x_update[count]
               
            #h_xi = []

            #h_x.append(h_xi)
            h_y.append(inner_grad)
            h_v.append(v_grad)
            hyper_params.append(hyper_param)

        return h_y, h_v, hyper_params, self.ma_terms
            


    
