import copy
# from cv2 import log
import numpy as np
import torch
import copy
import time
from core.function import gather_flat_hyper_params
from utils.Comm import communication, communication1, ComTopo
from core.SGDClient_hr import SGDClient
from core.SVRGClient_hr import SVRGClient
from core.Client_hr import Client
from core.ClientManage import ClientManage


class ClientManageHR(ClientManage):
    def __init__(self, args, net_glob_sum, net_glob_theta_sum, net_glob_sum_old, net_glob_theta_sum_old, client_idx, dataset, dict_users, 
                 hyper_params, params, thetas, hyper_params_old, params_old, thetas_old, ck, ck_1, gama) -> None:
        super().__init__(args, net_glob_sum, client_idx, dataset, dict_users)

        self.client_idx = client_idx
        self.args = args
        self.dataset = dataset
        self.dict_users = dict_users
       
        self.net_glob_theta_sum = net_glob_theta_sum
        self.net_glob_sum = net_glob_sum
        self.net_glob_theta_sum_old = net_glob_theta_sum_old
        self.net_glob_sum_old = net_glob_sum_old
        self.params = params
        self.thetas = thetas
        self.hyper_params = hyper_params
        self.params_old = params_old
        self.thetas_old = thetas_old
        self.hyper_params_old = hyper_params_old

        #self.param_old = [x.clone().detach() for x in param]
        #self.theta_old = [x.clone().detach() for x in theta]
        #self.hyper_param_old = [x.clone().detach() for x in hyper_param]
        self.ck = ck
        self.ck_1 = ck_1

        self.gama = gama
        #self.state_dict_net = net_glob.state_dict()
        #self.state_dict_net_theta = net_glob_theta.state_dict()
        
    def client_job(self,args):
        h_y = []
        h_theta = []
        h_x = []
        client_locals = []
        ck = self.ck
        gama = self.gama
        # k=0, 1, 2,...,tau
        
        for idx in self.client_idx:
            
            #self.net_glob_theta.load_state_dict(self.state_dict_net_theta)
            #self.net_glob.load_state_dict(self.state_dict_net)
            
            # update theta
            for k in range(self.args.inner_ep):

                client = SVRGClient(self.args, idx, self.net_glob_sum[idx], self.dataset, self.dict_users, self.params[idx],
                                self.hyper_params[idx], self.thetas[idx])
                
                theta = [k for n, k in self.net_glob_theta_sum[idx].named_parameters() if "header" in n]
                param = [k for n, k in self.net_glob_sum[idx].named_parameters() if "header" in n]
                self.net_glob_sum[idx].zero_grad()
                self.net_glob_theta_sum[idx].zero_grad()  
                theta_grad = client.grad_d_in_d_y(self.net_glob_theta_sum[idx])  
                direct_grad_x = client.grad_d_in_d_x(self.net_glob_sum[idx])
                indirect_grad1_x = client.grad_d_out_d_x(self.net_glob_sum[idx])
                indirect_grad2_x = client.grad_d_in_d_x(self.net_glob_theta_sum[idx])        
                direct_grad_y = client.grad_d_in_d_y(self.net_glob_sum[idx])
                indirect_grad1_y = client.grad_d_out_d_y(self.net_glob_sum[idx])
                
                if args.alg == 'DSGDA-GT':
                    h_thetai = [torch.zeros_like(user) for user in param]  
                    for k in range(len(theta)):
                        h_thetai[k] = ck* (-theta_grad[k]) 
    
                # update x and y
                    #theta = list(self.net_glob_theta.header.parameters())
                    
                    # update x
                    h_xi =  [torch.zeros_like(user) for user in self.hyper_params[idx]]
                    for k in range(len(self.hyper_params[idx])):
                        h_xi[k] =  indirect_grad1_x[k]+ ck* (direct_grad_x[k] - indirect_grad2_x[k])

                    h_yi = [torch.zeros_like(user) for user in param]                  
                    # update y
                    for k in range(len(param)):
                        h_yi[k] =  indirect_grad1_y[k] + ck * direct_grad_y[k]
                                        
                
                else:    
                    h_thetai = [torch.zeros_like(user) for user in param]  
                    for k in range(len(theta)):
                        h_thetai[k] = theta_grad[k] + (1/gama)*(theta[k] - param[k])
    
                # update x and y
                    #theta = list(self.net_glob_theta.header.parameters())
                    
                    # update x
                    h_xi =  [torch.zeros_like(user) for user in self.hyper_params[idx]]
                    for k in range(len(self.hyper_params[idx])):
                        h_xi[k] = direct_grad_x[k] + ck * indirect_grad1_x[k] - indirect_grad2_x[k]

                    h_yi = [torch.zeros_like(user) for user in param]                  
                    # update y
                    for k in range(len(param)):
                        h_yi[k] = direct_grad_y[k] + ck * indirect_grad1_y[k] + (1/gama)* (theta[k] - param[k])
                           
            #client_locals.append(client)
            h_x.append(h_xi)
            h_y.append(h_yi)
            h_theta.append(h_thetai)

        return h_y, h_theta, h_x


    def client_job_HR(self, args):
        h_y = []
        h_theta = []
        h_x = []
        h_y_old = []
        h_theta_old = []
        h_x_old = []
        ck = self.ck
        ck_1 = self.ck_1
        gama = self.gama
        # k=0, 1, 2,...,tau
        
        for idx in self.client_idx:
            
            #self.net_glob_theta.load_state_dict(self.state_dict_net_theta)
            #self.net_glob.load_state_dict(self.state_dict_net)
            
            # update theta
            for k in range(self.args.inner_ep):


                client = SVRGClient(self.args, idx, self.net_glob_sum[idx], self.dataset, self.dict_users, self.params[idx],
                                self.hyper_params[idx], self.thetas[idx])
                
                theta = [k for n, k in self.net_glob_theta_sum[idx].named_parameters() if "header" in n]
                param = [k for n, k in self.net_glob_sum[idx].named_parameters() if "header" in n]
                self.net_glob_sum[idx].zero_grad()
                self.net_glob_theta_sum[idx].zero_grad()  
                theta_grad = client.grad_d_in_d_y(self.net_glob_theta_sum[idx])  
                direct_grad_x = client.grad_d_in_d_x()
                indirect_grad1_x = client.grad_d_out_d_x()
                indirect_grad2_x = client.grad_d_in_d_x(self.net_glob_theta_sum[idx])        
                direct_grad_y = client.grad_d_in_d_y()
                indirect_grad1_y = client.grad_d_out_d_y()
                
                h_thetai = [torch.zeros_like(user) for user in param]  
                for k in range(len(theta)):
                    h_thetai[k] = theta_grad[k] + (1/gama)*(theta[k] - param[k])
  
            # update x and y
                #theta = list(self.net_glob_theta.header.parameters())
                   
                # update x
                h_xi =  [torch.zeros_like(user) for user in self.hyper_params[idx]]
                for k in range(len(self.hyper_params[idx])):
                    h_xi[k] = direct_grad_x[k] + ck * indirect_grad1_x[k] - indirect_grad2_x[k]

                h_yi = [torch.zeros_like(user) for user in param]                  
                # update y
                for k in range(len(param)):
                    h_yi[k] = direct_grad_y[k] + ck * indirect_grad1_y[k] + (1/gama)* (theta[k] - param[k])
                           

                
                client = SVRGClient(self.args, idx, self.net_glob_sum[idx], self.dataset, self.dict_users, self.params[idx],
                                self.hyper_params[idx], self.thetas[idx])
                
                theta = [k for n, k in self.net_glob_theta_sum[idx].named_parameters() if "header" in n]
                param = [k for n, k in self.net_glob_sum[idx].named_parameters() if "header" in n]
                self.net_glob_sum[idx].zero_grad()
                self.net_glob_theta_sum[idx].zero_grad()  
                theta_grad = client.grad_d_in_d_y(self.net_glob_theta_sum[idx])  
                direct_grad_x = client.grad_d_in_d_x()
                indirect_grad1_x = client.grad_d_out_d_x()
                indirect_grad2_x = client.grad_d_in_d_x(self.net_glob_theta_sum[idx])        
                direct_grad_y = client.grad_d_in_d_y()
                indirect_grad1_y = client.grad_d_out_d_y()
                
                h_thetai = [torch.zeros_like(user) for user in param]  
                for k in range(len(theta)):
                    h_thetai[k] = theta_grad[k] + (1/gama)*(theta[k] - param[k])
  
            # update x and y
                #theta = list(self.net_glob_theta.header.parameters())
                   
                # update x
                h_xi =  [torch.zeros_like(user) for user in self.hyper_params[idx]]
                for k in range(len(self.hyper_params[idx])):
                    h_xi[k] = direct_grad_x[k] + ck * indirect_grad1_x[k] - indirect_grad2_x[k]

                h_yi = [torch.zeros_like(user) for user in param]                  
                # update y
                for k in range(len(param)):
                    h_yi[k] = direct_grad_y[k] + ck * indirect_grad1_y[k] + (1/gama)* (theta[k] - param[k])


                
                client = SVRGClient(self.args, idx, self.net_glob_sum_old[idx], self.dataset, self.dict_users, self.params_old[idx],
                                self.hyper_params_old[idx], self.thetas_old[idx])
                
                theta_old = [k for n, k in self.net_glob_theta_sum[idx].named_parameters() if "header" in n]
                param_old = [k for n, k in self.net_glob_sum[idx].named_parameters() if "header" in n]
                self.net_glob_sum_old[idx].zero_grad()
                self.net_glob_theta_sum_old[idx].zero_grad()  
                theta_grad_old = client.grad_d_in_d_y(self.net_glob_theta_sum_old[idx])  
                direct_grad_x_old = client.grad_d_in_d_x()
                indirect_grad1_x_old = client.grad_d_out_d_x()
                indirect_grad2_x_old = client.grad_d_in_d_x(self.net_glob_theta_sum_old[idx])        
                direct_grad_y_old = client.grad_d_in_d_y()
                indirect_grad1_y_old = client.grad_d_out_d_y()
                
                h_thetai_old = [torch.zeros_like(user) for user in param_old]  
                for k in range(len(theta_old)):
                    h_thetai_old[k] = theta_grad_old[k] + (1/gama)*(theta_old[k] - param_old[k])
  
            # update x and y
                #theta = list(self.net_glob_theta.header.parameters())
                   
                # update x
                h_xi_old =  [torch.zeros_like(user) for user in self.hyper_params_old[idx]]
                for k in range(len(self.hyper_params_old[idx])):
                    h_xi_old[k] = direct_grad_x_old[k] + ck_1 * indirect_grad1_x_old[k] - indirect_grad2_x_old[k]

                h_yi_old = [torch.zeros_like(user) for user in param_old]                  
                # update y
                for k in range(len(param)):
                    h_yi_old[k] = direct_grad_y_old[k] + ck_1 * indirect_grad1_y_old[k] + (1/gama)* (theta_old[k] - param_old[k])
                           

            h_x.append(h_xi)
            h_y.append(h_yi)
            h_theta.append(h_thetai)
            h_x_old.append(h_xi_old)
            h_y_old.append(h_yi_old)
            h_theta_old.append(h_thetai_old)

        return h_y, h_theta, h_x, h_y_old, h_theta_old, h_x_old