import copy
from cv2 import log
import numpy as np

import torch
from torch.optim import SGD

from core.function import gather_flat_hyper_params ,get_trainable_hyper_params 
from utils.Fed import FedAvg, FedAvgGradient, FedAvgP
from utils.Fed import *
from core.function import *
from core.SGDClient import SGDClient
from core.SVRGClient import SVRGClient
from core.NewThetaClient import NewThetaClient
from core.Client import Client
from core.ClientManage import ClientManage

class NewClientManage(ClientManage):
    def __init__(self, args, net_glob, net_glob_theta, client_idx, dataset, dict_users, hyper_params, param, theta, gamma, ck) -> None:
        super().__init__(args, net_glob, client_idx, dataset, dict_users)
        self.client_idx = client_idx
        self.args = args
        self.dataset = dataset
        self.dict_users = dict_users
        self.param = [x.clone().detach() for x in param]
        self.theta = [x.clone().detach() for x in theta]
        self.net_glob_theta = copy.deepcopy(net_glob_theta)
        self.net_glob = copy.deepcopy(net_glob)
        self.param = list(net_glob.parameters())
        self.theta = list(net_glob_theta.parameters())
       
        self.param_old = [x.clone().detach() for x in param]
        self.theta_old = [x.clone().detach() for x in theta]
        
        self.ck = ck
        self.gamma = gamma
        self.state_dict_net = net_glob.state_dict()
        self.state_dict_net_theta = net_glob_theta.state_dict()
        self.hyper_params = hyper_params
        #self.hyper_optimizer= SGD([self.hyper_param[k] for k in self.hyper_param], lr=alpha)
    
        self.ck = ck
        self.gamma = gamma

    def client_job(self, eta):
        h_y = []
        h_theta = []
        h_x = []
        client_locals = []
        ck = self.ck
        gamma = self.gamma
        
        for idx in self.client_idx:
            
            self.net_glob_theta.load_state_dict(self.state_dict_net_theta)
            self.net_glob.load_state_dict(self.state_dict_net)
            
            
            for k in range(self.args.inner_ep):
                client = SVRGClient(self.args, idx, self.net_glob, self.dataset, self.dict_users, self.param,
                                self.hyper_params[idx], self.theta)
                count = 0
                hyper_param = self.hyper_params[idx]
                theta = list(self.net_glob_theta.parameters())
                param = list(self.net_glob.parameters())
                self.net_glob_theta.zero_grad()   
                theta_grad = client.grad_d_in_d_y(self.net_glob_theta)
                self.net_glob.zero_grad()
                self.net_glob_theta.zero_grad()        
                direct_grad_x = client.grad_d_in_d_x()
                #indirect_grad1_x = client.grad_d_out_d_x()
                indirect_grad2_x = client.grad_d_in_d_x(self.net_glob_theta)
                self.net_glob.zero_grad()
                self.net_glob_theta.zero_grad()        
                direct_grad_y = client.grad_d_in_d_y()
                indirect_grad1_y = client.grad_d_out_d_y()
               
                # update theta
                for params in self.net_glob_theta.parameters():
                        params.data = params.data - eta[1] * (theta_grad[count] + gamma*(params.data- param[count]))
                        count += 1

            # update x and y
                #theta = list(self.net_glob_theta.parameters())

                # update x
                
                for count in range(len(hyper_param)):     
                    hyper_param.detach()[count] = hyper_param.detach()[count] - eta[2] * (direct_grad_x[count] - indirect_grad2_x[count])
               

                # update y
                
                count = 0 
                
                for params in client.net.parameters():
                        params.data = params.data - eta[0] * (direct_grad_y[count] + 
                                                            ck * indirect_grad1_y[count] + gamma* (theta[count] - params.data))
                        count += 1
               
            h_yi = []
            h_thetai = []
            h_xi = []

            count = 0
            for params in self.net_glob_theta.parameters():
                h_thetai.append((-params.data + self.theta_old[count]) / (eta[1] * self.args.inner_ep))
                count += 1

            count = 0
            for params in client.net.parameters():
                h_yi.append((-params.data + self.param_old[count]) / (eta[0] * self.args.inner_ep))   
                count += 1

            #count = 0
            #for params_name, params in self.hyper_param.items():
            #    h_xi.append((-params.data + self.hyper_param_old[params_name]) / (eta[2] * self.args.inner_ep))    
            #    count += 1
            
               
         
            client_locals.append(client)
            #h_x.append(h_xi)
            h_y.append(h_yi)
            h_theta.append(h_thetai)

        return h_y, h_theta #h_x
            


    
