import copy
from cv2 import log
import numpy as np

import torch
from torch.optim import SGD

from core.function import gather_flat_hyper_params ,get_trainable_hyper_params 
from utils.Fed import FedAvg, FedAvgGradient, FedAvgP
from utils.Fed import *
from core.function import *
from core.SGDClient import SGDClient
from core.SVRGClient import SVRGClient
from core.NewThetaClient import NewThetaClient
from core.Client import Client
from core.ClientManage import ClientManage

class NewClientManage(ClientManage):
    def __init__(self, args, net_glob, net_glob_theta, client_idx, dataset, dict_users, hyper_param, param, theta, gamma, ck) -> None:
        super().__init__(args, net_glob, client_idx, dataset, dict_users)
        self.client_idx = client_idx
        self.args = args
        self.dataset = dataset
        self.dict_users = dict_users
        self.param = param
        self.theta = theta
        self.net_glob_theta = copy.deepcopy(net_glob_theta)
        self.net_glob = copy.deepcopy(net_glob)
        self.param = list(net_glob.parameters())
        self.theta = list(net_glob_theta.parameters())
        self.hyper_param = {
            'dy': hyper_param['dy'].clone().detach(),  
            'ly': hyper_param['ly'].clone().detach(),
            'wy': hyper_param['wy'].clone().detach()  
        }
        self.hyper_param['dy'].requires_grad_()
        self.hyper_param['ly'].requires_grad_()
        self.param_old = [x.clone().detach() for x in param]
        self.theta_old = [x.clone().detach() for x in theta]
        self.hyper_param_old = {
            'dy': hyper_param['dy'].clone().detach(),  
            'ly': hyper_param['ly'].clone().detach(),
            'wy': hyper_param['wy'].clone().detach()  
        }
        self.hyper_param['dy'].requires_grad_()
        self.hyper_param['ly'].requires_grad_()
        self.ck = ck
        self.gamma = gamma
        self.state_dict_net = net_glob.state_dict()
        self.state_dict_net_theta = net_glob_theta.state_dict()

        #self.hyper_optimizer= SGD([self.hyper_param[k] for k in self.hyper_param], lr=alpha)
     
        self.ck = ck
        self.gamma = gamma

    def client_job(self, eta):
        h_y = []
        h_theta = []
        h_x = []
        client_locals = []
        ck = self.ck
        gamma = self.gamma
        
        for idx in self.client_idx:
            
            self.net_glob_theta.load_state_dict(self.state_dict_net_theta)
            self.net_glob.load_state_dict(self.state_dict_net)
            
           
            for k in range(self.args.inner_ep):
                client = SVRGClient(self.args, idx, self.net_glob, self.dataset, self.dict_users, self.param,
                                self.hyper_param, self.theta)
                
                theta = list(self.net_glob_theta.parameters())
                param = list(self.net_glob.parameters())
                self.net_glob.zero_grad()
                self.net_glob_theta.zero_grad()   
                theta_grad = client.grad_d_in_d_y(self.net_glob_theta)     
                direct_grad_x = client.grad_d_in_d_x()
                #indirect_grad1_x = client.grad_d_out_d_x()
                indirect_grad2_x = client.grad_d_in_d_x(self.net_glob_theta) 
                direct_grad_y = client.grad_d_in_d_y()
                indirect_grad1_y = client.grad_d_out_d_y()
                
                # update theta
                count = 0
                for params in self.net_glob_theta.parameters():
                        params.data = params.data - eta[1] * (theta_grad[count] + gamma*(params.data- param[count]))
                        count += 1

             
            # update x and y
                #theta = list(self.net_glob_theta.parameters())

                # update x
                count = 0  
                      
                self.hyper_param['dy'] = self.hyper_param['dy'] - eta[2] * (direct_grad_x[0] - indirect_grad2_x[0])
                self.hyper_param['ly'] = self.hyper_param['ly'] - eta[2] * (direct_grad_x[1] - indirect_grad2_x[1])

                # update y
                       
                
                count = 0 
                #print('=============================================,netglob1',list(self.net_glob.header.parameters()))
                for params in client.net.parameters():
                        params.data = params.data - eta[0] * (direct_grad_y[count] + 
                                                            ck * indirect_grad1_y[count] + gamma* (theta[count] - params.data))
                        count += 1
                #print('=============================================,netglob2',list(self.net_glob.header.parameters()))        
            h_yi = []
            h_thetai = []
            h_xi = []

            count = 0
            for params in self.net_glob_theta.parameters():
                h_thetai.append((-params.data + self.theta_old[count]) / (eta[1] * self.args.inner_ep))
                count += 1

            count = 0
            for params in client.net.parameters():
                h_yi.append((-params.data + self.param_old[count]) / (eta[0] * self.args.inner_ep))   
                count += 1

            count = 0
            for params_name, params in self.hyper_param.items():
                h_xi.append((-params.data + self.hyper_param_old[params_name]) / (eta[2] * self.args.inner_ep))    
                count += 1
            
               
         
            client_locals.append(client)
            h_x.append(h_xi)
            h_y.append(h_yi)
            h_theta.append(h_thetai)

        return h_y, h_theta, h_x
            


    
