import copy
# from cv2 import log
import numpy as np
import torch
import copy
import time
from core.function import gather_flat_hyper_params
from utils.Fed import FedAvg, FedAvgGradient, FedAvgP
from core.SGDClient_hr import SGDClient
from core.SVRGClient_hr import SVRGClient
from core.Client_hr import Client
from core.ClientManage import ClientManage


class ClientManageHR(ClientManage):
    def __init__(self, args, net_glob, net_glob_theta, client_idx, dataset, dict_users, hyper_param, param, theta, ck, gamma) -> None:
        super().__init__(args, net_glob, client_idx, dataset, dict_users)

        self.client_idx = client_idx
        self.args = args
        self.dataset = dataset
        self.dict_users = dict_users
       
        self.net_glob_theta = copy.deepcopy(net_glob_theta)
        self.net_glob = copy.deepcopy(net_glob)
        self.param = [x.clone().detach() for x in param]
        self.theta = [x.clone().detach() for x in theta]
        self.hyper_param = [x.clone().detach() for x in hyper_param]

        self.param_old = [x.clone().detach() for x in param]
        self.theta_old = [x.clone().detach() for x in theta]
        self.hyper_param_old = [x.clone().detach() for x in hyper_param]
        self.ck = ck
        self.gamma = gamma
        self.state_dict_net = net_glob.state_dict()
        self.state_dict_net_theta = net_glob_theta.state_dict()
        
    def client_job(self, eta):
        h_y = []
        h_theta = []
        h_x = []
        client_locals = []
        ck = self.ck
        gamma = self.gamma
        # k=0, 1, 2,...,tau
        
        for idx in self.client_idx:
            
            self.net_glob_theta.load_state_dict(self.state_dict_net_theta)
            self.net_glob.load_state_dict(self.state_dict_net)
            
            # update theta
            for k in range(self.args.inner_ep):

                client = SVRGClient(self.args, idx, self.net_glob, self.dataset, self.dict_users, self.param,
                                self.hyper_param, self.theta)
                
                theta = list(self.net_glob_theta.header.parameters())
                param = list(self.net_glob.header.parameters())
                self.net_glob.zero_grad()
                self.net_glob_theta.zero_grad()  
                theta_grad = client.grad_d_in_d_y(self.net_glob_theta)  
                direct_grad_x = client.grad_d_in_d_x()
                indirect_grad1_x = client.grad_d_out_d_x()
                indirect_grad2_x = client.grad_d_in_d_x(self.net_glob_theta)        
                direct_grad_y = client.grad_d_in_d_y()
                indirect_grad1_y = client.grad_d_out_d_y()
                
                count = 0
                for params in self.net_glob_theta.parameters():
                    if count < len(self.hyper_param):
                        count += 1
                        continue
                    else:
                        params.data = params.data - eta[1] * (theta_grad[count - len(self.hyper_param)] + gamma*(params.data - param[count - len(self.hyper_param)]))
                        count += 1
  
            # update x and y
                #theta = list(self.net_glob_theta.header.parameters())
                   
                # update x
                
        
                count = 0              
                for params in client.net.parameters():
                    if count == len(self.hyper_param):
                        break
                    else:
                        params.data = params.data - eta[2] * (direct_grad_x[count] + ck * indirect_grad1_x[count] - indirect_grad2_x[count])
                        count += 1
                
                # update y
                
                count = 0 
                for params in client.net.parameters():
                    if count < len(self.hyper_param):
                        count += 1
                        continue
                    else:
                        params.data = params.data - eta[0] * (direct_grad_y[count - len(self.hyper_param)] + 
                                                            ck * indirect_grad1_y[count - len(self.hyper_param)] + gamma* (theta[count-len(self.hyper_param)] - params.data))
                        count += 1
                        
            h_yi = []
            h_thetai = []
            h_xi = []

            count = 0
            for param in self.net_glob_theta.parameters():
                if count >= len(self.hyper_param):
                    h_thetai.append((-param.data + self.theta_old[count - len(self.hyper_param)]) / (eta[1] * self.args.inner_ep))
                count += 1

            count = 0
            for param in client.net.parameters():
                if count < len(self.hyper_param):
                    h_xi.append((-param.data + self.hyper_param_old[count]) / (eta[2] * self.args.inner_ep))
                else:
                    h_yi.append((-param.data + self.param_old[count - len(self.hyper_param)]) / (eta[0] * self.args.inner_ep))   
                count += 1
            
            
               
         
            client_locals.append(client)
            h_x.append(h_xi)
            h_y.append(h_yi)
            h_theta.append(h_thetai)

        return h_y, h_theta, h_x


