import copy
from cv2 import log
import numpy as np

import torch
from torch.optim import SGD

from core.function import gather_flat_hyper_params ,get_trainable_hyper_params 
from utils.Comm import ComTopo, communication, communication1

from core.function import *
from core.SGDClient import SGDClient
from core.SVRGClient import SVRGClient
from core.NewThetaClient import NewThetaClient
from core.Client import Client
from core.ClientManage import ClientManage

class NewClientManage(ClientManage):
    def __init__(self, args, net_glob_sum, net_glob_theta_sum, net_glob_sum_old, net_glob_theta_sum_old, client_idx, dataset, dict_users, 
                 hyper_params, params, thetas, hyper_params_old, params_old, thetas_old, ck, ck_1, gama, v_x) -> None:
        super().__init__(args, net_glob_sum, client_idx, dataset, dict_users)
        self.client_idx = client_idx
        self.args = args
        self.dataset = dataset
        self.dict_users = dict_users

        self.net_glob_theta_sum = net_glob_theta_sum
        self.net_glob_sum = net_glob_sum
        self.net_glob_theta_sum_old = net_glob_theta_sum_old
        self.net_glob_sum_old = net_glob_sum_old
        self.params = params
        self.thetas = thetas
        self.hyper_params = hyper_params
        self.params_old = params_old
        self.thetas_old = thetas_old
        self.hyper_params_old = hyper_params_old
        #self.param_old = [x.clone().detach() for x in param]
        #self.theta_old = [x.clone().detach() for x in theta]
        self.ck_1 = ck_1
        self.ck = ck
        self.gama = gama
        #self.state_dict_net = net_glob.state_dict()
        #self.state_dict_net_theta = net_glob_theta.state_dict()
       # self.hyper_params = hyper_params
        
    

        self.v_x = v_x
        self.beta = args.beta
        
    def client_job_sl(self, gamma):
        h_y = []
        # k=0, 1, 2,...,tau
        
        for idx in self.client_idx:
            
            #self.net_glob_theta.load_state_dict(self.state_dict_net_theta)
            #self.net_glob.load_state_dict(self.state_dict_net)
            
            # update theta
            for k in range(self.args.inner_ep):
                
                client = SVRGClient(self.args, idx, self.net_glob_sum[idx], self.dataset, self.dict_users, self.params[idx],
                                self.hyper_params[idx], self.thetas[idx])
                if self.args.alg in ['DGD-T', 'DGT-T']:
                    h_yi = client.grad_d_t(self.net_glob_sum[idx])         
                h_y.append(h_yi)
                    
        return h_y   

            

    def client_job(self, gamma):
        h_y = []
        h_theta = []
        hyper_params_record = []
        #client_locals = []
        ck = self.ck
        gama = self.gama
        # k=0, 1, 2,...,tau
        
        for idx in self.client_idx:
            
            #self.net_glob_theta.load_state_dict(self.state_dict_net_theta)
            #self.net_glob.load_state_dict(self.state_dict_net)
            
            # update theta
            for k in range(self.args.inner_ep):

                client = SVRGClient(self.args, idx, self.net_glob_sum[idx], self.dataset, self.dict_users, self.params[idx],
                                self.hyper_params[idx], self.thetas[idx])
                
                hyper_param = self.hyper_params[idx]
                count = 0
                theta = list(self.net_glob_theta_sum[idx].parameters())
                param = list(self.net_glob_sum[idx].parameters())
                self.net_glob_theta_sum[idx].zero_grad()   
                theta_grad = client.grad_d_in_d_y(self.net_glob_theta_sum[idx],hyper_param)
                self.net_glob_sum[idx].zero_grad()
                self.net_glob_theta_sum[idx].zero_grad()        
                direct_grad_x = client.grad_d_in_d_x(self.net_glob_sum[idx], hyper_param)
                #indirect_grad1_x = client.grad_d_out_d_x()
                indirect_grad2_x = client.grad_d_in_d_x(self.net_glob_theta_sum[idx],hyper_param)
                self.net_glob_sum[idx].zero_grad()
                self.net_glob_theta_sum[idx].zero_grad()        
                direct_grad_y = client.grad_d_in_d_y(self.net_glob_sum[idx],hyper_param)
                indirect_grad1_y = client.grad_d_out_d_y(self.net_glob_sum[idx])
                
                if self.args.alg == 'DSGDA-GT':
                    h_thetai = [torch.zeros_like(user) for user in param]  
                    for k in range(len(theta)):
                        h_thetai[k] = -ck * theta_grad[k]
    
                # update x and y
                    #theta = list(self.net_glob_theta.header.parameters())
                    
                    # update x
                    h_xi =  [torch.zeros_like(user) for user in self.hyper_params[idx]]
                    for k in range(len(self.hyper_params[idx])):
                        h_xi[k] = ck*(direct_grad_x[k] -indirect_grad2_x[k]) 

                    for count in range(len(hyper_param)):     
                        hyper_param.detach()[count] = hyper_param.detach()[count] - gamma[0] * h_xi[count]

                    h_yi = [torch.zeros_like(user) for user in param]                  
                    # update y
                    for k in range(len(param)):
                        h_yi[k] = indirect_grad1_y[k] -  ck*direct_grad_y[k]
                
                else:
                    h_thetai = [torch.zeros_like(user) for user in param]  
                    for k in range(len(theta)):
                        h_thetai[k] =  theta_grad[k] + gama* (theta[k] - param[k])
    
                # update x and y
                    #theta = list(self.net_glob_theta.header.parameters())
                    
                    # update x
                    h_xi =  [torch.zeros_like(user) for user in self.hyper_params[idx]]
                    for k in range(len(self.hyper_params[idx])):
                        h_xi[k] = direct_grad_x[k]  - indirect_grad2_x[k]

                    for count in range(len(hyper_param)):     
                        hyper_param.detach()[count] = hyper_param.detach()[count] - gamma[0] * h_xi[count]

                    h_yi = [torch.zeros_like(user) for user in param]                  
                    # update y
                    for k in range(len(param)):
                        h_yi[k] = direct_grad_y[k] + ck * indirect_grad1_y[k] + gama* (theta[k] - param[k])
                           
            #client_locals.append(client)
            hyper_params_record.append(hyper_param)
            h_y.append(h_yi)
            h_theta.append(h_thetai)

        return h_y, h_theta, hyper_params_record

    def client_job_HR(self, gamma):
        h_y = []
        h_theta = []
        hyper_params = []
        h_y_old = []
        h_theta_old = []
        h_x_old = []
        ck = self.ck
        ck_1 = self.ck_1
        gama = self.gama
        v_x = self.v_x
        beta = self.beta
        # k=0, 1, 2,...,tau
        
        for idx in self.client_idx:
            
            #self.net_glob_theta.load_state_dict(self.state_dict_net_theta)
            #self.net_glob.load_state_dict(self.state_dict_net)
            
            # update theta
            for k in range(self.args.inner_ep):


                client = SVRGClient(self.args, idx, self.net_glob_sum[idx], self.dataset, self.dict_users, self.params[idx],
                                self.hyper_params[idx], self.thetas[idx])
                
                theta = list(self.net_glob_theta_sum[idx].parameters())
                param = list(self.net_glob_sum[idx].parameters())
                self.net_glob_sum[idx].zero_grad()
                self.net_glob_theta_sum[idx].zero_grad()  
                theta_grad = client.grad_d_in_d_y(self.net_glob_theta_sum[idx])  
                direct_grad_x = client.grad_d_in_d_x()
                #indirect_grad1_x = client.grad_d_out_d_x()
                indirect_grad2_x = client.grad_d_in_d_x(self.net_glob_theta_sum[idx])        
                direct_grad_y = client.grad_d_in_d_y()
                indirect_grad1_y = client.grad_d_out_d_y()
                
                h_thetai = [torch.zeros_like(user) for user in param]  
                for k in range(len(theta)):
                    h_thetai[k] = theta_grad[k] + gama*(theta[k] - param[k])
  
            # update x and y
                #theta = list(self.net_glob_theta.header.parameters())
                   
                # update x
                h_xi =  [torch.zeros_like(user) for user in self.hyper_params[idx]]
                for k in range(len(self.hyper_params[idx])):
                    h_xi[k] = direct_grad_x[k] - indirect_grad2_x[k]

                h_yi = [torch.zeros_like(user) for user in param]                  
                # update y
                for k in range(len(param)):
                    h_yi[k] = direct_grad_y[k] + ck * indirect_grad1_y[k] + gama* (theta[k] - param[k])
                           

                
                client = SVRGClient(self.args, idx, self.net_glob_sum[idx], self.dataset, self.dict_users, self.params[idx],
                                self.hyper_params[idx], self.thetas[idx])
                
                hyper_param = self.hyper_params[idx]
                count = 0
                theta = list(self.net_glob_theta_sum[idx].parameters())
                param = list(self.net_glob_sum[idx].parameters())
                self.net_glob_theta_sum_old[idx].zero_grad()   
                theta_grad = client.grad_d_in_d_y(self.net_glob_theta_sum[idx])
                self.net_glob_sum_old[idx].zero_grad()
                self.net_glob_theta_sum_old[idx].zero_grad()        
                direct_grad_x = client.grad_d_in_d_x()
                #indirect_grad1_x = client.grad_d_out_d_x()
                indirect_grad2_x = client.grad_d_in_d_x(self.net_glob_theta_sum[idx])
                self.net_glob_sum[idx].zero_grad()
                self.net_glob_theta_sum_old[idx].zero_grad()        
                direct_grad_y = client.grad_d_in_d_y()
                indirect_grad1_y = client.grad_d_out_d_y()
                
                h_thetai = [torch.zeros_like(user) for user in param]  
                for k in range(len(theta)):
                    h_thetai[k] = theta_grad[k] + gama*(theta[k] - param[k])
                
  
            # update x and y
                #theta = list(self.net_glob_theta.header.parameters())
                   
                # update x
                h_xi =  [torch.zeros_like(user) for user in self.hyper_params[idx]]
                for k in range(len(self.hyper_params[idx])):
                    h_xi[k] = direct_grad_x[k]  - indirect_grad2_x[k]

                h_yi = [torch.zeros_like(user) for user in param]                  
                # update y
                for k in range(len(param)):
                    h_yi[k] = direct_grad_y[k] + ck * indirect_grad1_y[k] + gama* (theta[k] - param[k])


                
                client = SVRGClient(self.args, idx, self.net_glob_sum_old[idx], self.dataset, self.dict_users, self.params_old[idx],
                                self.hyper_params_old[idx], self.thetas_old[idx])
                #hyper_param_old = self.hyper_params_old[idx]
                theta_old = list(self.net_glob_theta_sum_old[idx].parameters())
                param_old = list(self.net_glob_sum_old[idx].parameters())
                self.net_glob_sum_old[idx].zero_grad()
                self.net_glob_theta_sum_old[idx].zero_grad()  
                theta_grad_old = client.grad_d_in_d_y(self.net_glob_theta_sum_old[idx], self.hyper_params_old[idx])  
                direct_grad_x_old = client.grad_d_in_d_x(self.net_glob_theta_sum_old[idx], self.hyper_params_old[idx])
                #indirect_grad1_x_old = client.grad_d_out_d_x()
                indirect_grad2_x_old = client.grad_d_in_d_x(self.net_glob_theta_sum_old[idx], self.hyper_params_old[idx])        
                direct_grad_y_old = client.grad_d_in_d_y(self.net_glob_theta_sum_old[idx], self.hyper_params_old[idx])
                indirect_grad1_y_old = client.grad_d_out_d_y(self.net_glob_theta_sum_old[idx])
                
                h_thetai_old = [torch.zeros_like(user) for user in param_old]  
                for k in range(len(theta_old)):
                    h_thetai_old[k] = theta_grad_old[k] + gama*(theta_old[k] - param_old[k])
  
            # update x and y
                #theta = list(self.net_glob_theta.header.parameters())
                   
                # update x
                h_xi_old =  [torch.zeros_like(user) for user in self.hyper_params_old[idx]]
                for k in range(len(self.hyper_params_old[idx])):
                    h_xi_old[k] = direct_grad_x_old[k]  - indirect_grad2_x_old[k]

                h_yi_old = [torch.zeros_like(user) for user in param_old]                  
                # update y
                for k in range(len(param_old)):
                    h_yi_old[k] = direct_grad_y_old[k] + ck_1 * indirect_grad1_y_old[k] + gama* (theta_old[k] - param_old[k])

                for k in range(len(hyper_param)):
                    v_x[idx][k] = h_xi[k] + (1-beta[2])*(v_x[idx][k] - h_xi_old[k])

                for count in range(len(hyper_param)):     
                    hyper_param.detach()[count] = hyper_param.detach()[count] - gamma[0] * v_x[idx][count]
                           
         
            #h_x.append(h_xi)
            h_y.append(h_yi)
            h_theta.append(h_thetai)
            hyper_params.append(hyper_param)
            h_y_old.append(h_yi_old)
            h_theta_old.append(h_thetai_old)

        return h_y, h_theta, hyper_params, h_y_old, h_theta_old, v_x
    
