import copy
from cv2 import log
import numpy as np

import torch
from torch.optim import SGD

from core.function import gather_flat_hyper_params ,get_trainable_hyper_params 
from utils.Comm import ComTopo, communication, communication1
from core.function import *
from core.SGDClient import SGDClient
from core.SVRGClient import SVRGClient
from core.NewThetaClient import NewThetaClient
from core.Client import Client
from core.ClientManage import ClientManage

class NewClientManage(ClientManage):
    def __init__(self, args, net_glob_sum, client_idx, dataset, dict_users, hyper_params, params, v, ma_terms) -> None:
        super().__init__(args, net_glob_sum, client_idx, dataset, dict_users)
        
        self.client_idx = client_idx
        self.args = args
        self.dataset = dataset
        self.dict_users = dict_users
        self.net_glob_sum = net_glob_sum
        self.gamma = args.gamma
        self.params = params
        self.v = v
        self.ma_terms = ma_terms

        #self.hyper_param = {
        #    'dy': hyper_param['dy'].clone().detach(),  
        #    'ly': hyper_param['ly'].clone().detach(),
        #    'wy': hyper_param['wy'].clone().detach()  
        #}
        #self.hyper_param['dy'].requires_grad_()
        #self.hyper_param['ly'].requires_grad_()
        self.hyper_params = hyper_params
        #self.param_old = [x.clone().detach() for x in param]
        #self.v_old = [x.clone().detach() for x in v]
        #self.hyper_param_old = {
        #    'dy': hyper_param['dy'].clone().detach(),  
        #    'ly': hyper_param['ly'].clone().detach(),
        #    'wy': hyper_param['wy'].clone().detach()  
        #}
        #self.hyper_param_old['dy'].requires_grad_()
        #self.hyper_param_old['ly'].requires_grad_()

        #self.hyper_optimizer= SGD([self.hyper_param[k] for k in self.hyper_param], lr=alpha)
     


    def client_job(self, args):
        h_y = []
        h_v = []
        hyper_params= []
        #client_locals = []
        
        for idx in self.client_idx:
            client = SVRGClient(self.args, idx, self.net_glob_sum[idx], self.dataset, self.dict_users, self.params[idx],
                                self.hyper_params[idx], self.v[idx])
            hyper_param = self.hyper_params[idx]
            # update y
            for k in range(self.args.inner_ep):
                # update y
                inner_grad = client.grad_d_in_d_y()
                #count = 0

                #for param in client.net.parameters():
                #    param.data = param.data - eta[0] * inner_grad[count]
                #    count += 1
            
            # update v
                v_grad = client.grad_v_R()
            #for i in range(len(self.param)):
            #    client.v[i] = client.v[i] - eta[1] * v_grad[i]

            # update x
                x_update = client.grad_f_bar()

                if args.alg in ['D-SOBA', 'SPARKLE', 'SPARKLE-E']:
                    self.ma_terms[idx] = [
                            (1 - args.momentum) * self.ma_terms[idx][i] + args.momentum * x_update[i]
                                for i in range(len(self.ma_terms[idx]))  ]    
                    for count in range(len(hyper_param)):     
                            hyper_param.detach()[count] = hyper_param.detach()[count] - self.gamma[0] * self.ma_terms[0][count]
                elif args.alg == 'SLDBO':
                    for count in range(len(hyper_param)):     
                            hyper_param.detach()[count] = hyper_param.detach()[count] - self.gamma[0] * x_update[count]
               
            #h_xi = []


            #count = 0
            #for params in client.net.parameters():
            #    h_yi.append((-params.data + self.param_old[count]) / (eta[0] * self.args.inner_ep)) 
            #    h_vi.append((-client.v[count] + self.v_old[count]) /(eta[1] * self.args.inner_ep))
            #    count += 1

            #count = 0
            #for params_name, params in self.hyper_param.items():
            #    h_xi.append((-params.data + self.hyper_param_old[params_name]) / (eta[2] * self.args.inner_ep))    
            #    count += 1
            
               
         
            #client_locals.append(client)
            #h_x.append(h_xi)
            h_y.append(inner_grad)
            h_v.append(v_grad)
            hyper_params.append(hyper_param)

        return h_y, h_v, hyper_params, self.ma_terms
            


    
