import socket
import time
import struct
import numpy as np

from control_algorithm.adaptive_tau import ControlAlgAdaptiveTauClient, ControlAlgAdaptiveTauServer
from data_reader.data_reader import get_data, get_data_train_samples
from models.get_model import get_model
from util.sampling import MinibatchSampling
from util.utils import send_msg, recv_msg

from scipy.stats import mode

# Configurations are in a separate config.py file
from config import SERVER_ADDR, SERVER_PORT, dataset_file_path, n_nodes

sock = socket.socket()
sock.connect((SERVER_ADDR, SERVER_PORT))

print('---------------------------------------------------------------------------')

batch_size_prev = None
total_data_prev = None
sim_prev = None

import csv

'''
def save_to_csv(case, mu_new, B, H, file_name='results.csv'):
    with open(file_name, mode='a', newline='') as file:
        writer = csv.writer(file)
        writer.writerow([case, mu_new, B, H])'''

def save_to_csv(case, mu_list, B, H, file_name='results.csv'):
    if len(mu_list) > 0:
        mu_mode = mode(mu_list)[0][0]  
    else:
        mu_mode = None  
    
    with open(file_name, mode='a', newline='') as file:
        writer = csv.writer(file)
        writer.writerow([case, mu_mode, B, H])


M = n_nodes  


def gradient_square_sum(gradients, M):
    return sum([np.linalg.norm(g)**2 for g in gradients]) / M

def global_gradient_square(global_grad):
    return np.linalg.norm(global_grad)**2

def calculate_H2(local_grads, global_grad, B_squared, M):
    H_squared = 0
    for grad in local_grads:
        H_squared += (np.linalg.norm(grad)**2 - B_squared * np.linalg.norm(global_grad)**2)
    return H_squared / M

def calculate_B_and_H(local_grads, global_grad, M):

    local_grad_square_sum = gradient_square_sum(local_grads, M)


    global_grad_square_sum = global_gradient_square(global_grad)


    B_squared = local_grad_square_sum / global_grad_square_sum


    H_squared = calculate_H2(local_grads, global_grad, B_squared, M)

    return np.sqrt(B_squared), np.sqrt(H_squared)



def loss_function(params, local_grads, global_grad):
    B_squared, H_squared = params[0], params[1]
    

    local_grad_square_sum = sum([np.linalg.norm(g)**2 for g in local_grads]) / M
    

    global_grad_square = np.linalg.norm(global_grad)**2
    right_side = B_squared * global_grad_square + H_squared
    

    loss = (local_grad_square_sum - right_side) ** 2
    return loss

initial_params = [1.0, 1.0]  

from scipy.optimize import minimize

'''
def update_mu(w_local, w_global, grad_local):
    
    norm_diff = np.linalg.norm(w_local - w_global)
    if norm_diff > 0:
        mu_new = - np.dot(grad_local, (w_local - w_global)) / norm_diff ** 2  
    else:
        mu_new = 0  
    return max(mu_new, 0)  # 确保 μ 是非负的

'''
'''
def update_mu(w_local, w_global, grad_local):

    diff = w_local - w_global
    norm_diff = np.linalg.norm(diff)
    if norm_diff > 0:
        mu_new = - grad_local / diff  
    else:
        mu_new = 0  
    return max(mu_new, 0) 


def update_mu(w_local, w_global, grad_local):
    diff = w_local - w_global  
    xi = diff  
    
    if np.all(w_local == w_global + xi) and np.all(grad_local <= 0):
        mu_new = 0  
    elif np.all(w_local == w_global - xi) and np.all(grad_local > 0):
        mu_new = 0  
    else:
        norm_diff = np.linalg.norm(diff)
        if norm_diff > 0:
            mu_new = - np.dot(grad_local, diff) / norm_diff ** 2  
        else:
            mu_new = 0  
    
    return max(mu_new, 0)  
'''

'''
def generate_xi_threshold_vector(grad_local, scale_factor=0.1):
    
    xi_threshold_vector = np.abs(grad_local) * scale_factor  
    return xi_threshold_vector

def generate_xi_threshold_vector_from_weights(w_local, w_global, scale_factor=0.1):

    diff = np.abs(w_local - w_global)  
    xi_threshold_vector = diff * scale_factor  
    return xi_threshold_vector

def generate_xi_threshold_vector_from_distribution(w_local, scale_factor=0.1):
    std_w = np.std(w_local)  
    xi_threshold_vector = np.random.normal(0, std_w * scale_factor, size=w_local.shape)  
    return np.abs(xi_threshold_vector)  

def dynamic_xi_threshold_update(xi_threshold_vector, loss_change, alpha=0.1):

    xi_threshold_vector += alpha * np.abs(loss_change)  
    return xi_threshold_vector
    
'''

'''
def generate_scalar_xi_threshold(grad_local, scale_factor=0.1):

    norm_grad = np.linalg.norm(grad_local)  
    xi_threshold = norm_grad * scale_factor  
    return xi_threshold

def generate_scalar_xi_from_loss(loss_change, base_value=0.01, scale_factor=0.1):
    xi_threshold = base_value + scale_factor * abs(loss_change)  
    return xi_threshold

xi_threshold = 0.01  
    
'''

def generate_xi_threshold_vector(grad_local, scale_factor=0.1):
    xi_threshold_vector = np.abs(grad_local) * scale_factor  
    return xi_threshold_vector

def generate_scalar_xi_threshold(grad_local, scale_factor=0.1):
    norm_grad = np.linalg.norm(grad_local)  
    xi_threshold = norm_grad * scale_factor  
    return xi_threshold

def update_mu(w_local, w_global, grad_local, sparsity_threshold=1e-3, tolerance=1e-1):

    diff = w_local - w_global  
    print(f"diff, diff = {diff}")
    norm_diff = np.linalg.norm(diff)  
    print(f"L2, norm_diff = {norm_diff}")

    mu_new = np.zeros_like(w_local)


    xi_threshold = generate_scalar_xi_threshold(grad_local, scale_factor=0.001)

    if norm_diff <= xi_threshold:
        print(f"No significant difference: norm_diff = {norm_diff}, xi_threshold = {xi_threshold}")
        return mu_new 
    

    xi_threshold_vector = generate_xi_threshold_vector(grad_local, scale_factor=0.1)

    updated_weights_count = 0


    for idx in range(len(w_local)):
        if abs(diff[idx]) > sparsity_threshold: 
            xi_threshold_idx = xi_threshold_vector[idx]  
            if (abs(w_local[idx] - (w_global[idx] + xi_threshold_idx)) < tolerance and grad_local[idx] <= 0) or \
               (abs(w_local[idx] - (w_global[idx] - xi_threshold_idx)) < tolerance and grad_local[idx] > 0):

                mu_new[idx] = - grad_local[idx] / diff[idx] if diff[idx] != 0 else 0
                updated_weights_count += 1  
            else:
                mu_new[idx] = 0


    non_sparse_indices = np.abs(diff) <= sparsity_threshold
    if np.any(non_sparse_indices):
        avg_mu = np.mean(mu_new[~non_sparse_indices])  
        mu_new[non_sparse_indices] = avg_mu
        print(f"Applied sparse update for {np.sum(non_sparse_indices)} weights with avg_mu = {avg_mu}")


    print(f"Total updated weights: {updated_weights_count} out of {len(w_local)}")

    return np.maximum(mu_new, 0) 


try:
    while True:
        msg = recv_msg(sock, 'MSG_INIT_SERVER_TO_CLIENT')
        # ['MSG_INIT_SERVER_TO_CLIENT', model_name, dataset, num_iterations_with_same_minibatch_for_tau_equals_one, step_size, batch_size,
        # total_data, use_control_alg, indices_this_node, read_all_data_for_stochastic, use_min_loss, sim, mu]

        model_name = msg[1]
        dataset = msg[2]
        num_iterations_with_same_minibatch_for_tau_equals_one = msg[3]
        step_size = msg[4]
        batch_size = msg[5]
        total_data = msg[6]
        control_alg_server_instance = msg[7]
        indices_this_node = msg[8]
        read_all_data_for_stochastic = msg[9]
        use_min_loss = msg[10]
        sim = msg[11]
        #mu = msg[12]  # mu is used for the smooth SVM model, added for FedProx
        #print(f"Received mu: {mu}")

        model = get_model(model_name)
        model2 = get_model(model_name)   # Used for computing loss_w_prev_min_loss for stochastic gradient descent,
                                         # so that the state of model can be still used by control algorithm later.

        if hasattr(model, 'create_graph'):
            model.create_graph(learning_rate=step_size)
        if hasattr(model2, 'create_graph'):
            model2.create_graph(learning_rate=step_size)

        # Assume the dataset does not change
        if read_all_data_for_stochastic or batch_size >= total_data:
            if batch_size_prev != batch_size or total_data_prev != total_data or (batch_size >= total_data and sim_prev != sim):
                print('Reading all data samples used in training...')
                train_image, train_label, _, _, _ = get_data(dataset, total_data, dataset_file_path, sim_round=sim)

        batch_size_prev = batch_size
        total_data_prev = total_data
        sim_prev = sim

        if batch_size >= total_data:
            sampler = None
            train_indices = indices_this_node
        else:
            sampler = MinibatchSampling(indices_this_node, batch_size, sim)
            train_indices = None  # To be defined later
        last_batch_read_count = None

        data_size_local = len(indices_this_node)

        if isinstance(control_alg_server_instance, ControlAlgAdaptiveTauServer):
            control_alg = ControlAlgAdaptiveTauClient()
        else:
            control_alg = None

        w_prev_min_loss = None
        w_last_global = None
        total_iterations = 0 

        msg = ['MSG_DATA_PREP_FINISHED_CLIENT_TO_SERVER']
        send_msg(sock, msg)

        while True:
            print('---------------------------------------------------------------------------')

            msg = recv_msg(sock, 'MSG_WEIGHT_TAU_SERVER_TO_CLIENT')
            # ['MSG_WEIGHT_TAU_SERVER_TO_CLIENT', w_global, tau, is_last_round, prev_loss_is_min]
            w = msg[1]
            tau_config = msg[2]
            is_last_round = msg[3]
            prev_loss_is_min = msg[4]

            if prev_loss_is_min or ((w_prev_min_loss is None) and (w_last_global is not None)):
                w_prev_min_loss = w_last_global

            if control_alg is not None:
                control_alg.init_new_round(w)

            time_local_start = time.time()  # Only count this part as time for local iteration because the remaining part does not increase with tau

            # Perform local iteration
            grad = None
            loss_last_global = None   # Only the loss at starting time is from global model parameter
            loss_w_prev_min_loss = None

            tau_actual = 0

            local_gradients = []  

            for i in range(0, tau_config):

                # When batch size is smaller than total data, read the data here; else read data during client init above
                if batch_size < total_data:
                    if (not isinstance(control_alg, ControlAlgAdaptiveTauClient)) or (i != 0) or (train_indices is None) \
                            or (tau_config <= 1 and
                                (last_batch_read_count is None or
                                 last_batch_read_count >= num_iterations_with_same_minibatch_for_tau_equals_one)):

                        sample_indices = sampler.get_next_batch()

                        if read_all_data_for_stochastic:
                            train_indices = sample_indices
                        else:
                            train_image, train_label = get_data_train_samples(dataset, sample_indices, dataset_file_path)
                            train_indices = range(0, min(batch_size, len(train_label)))

                        last_batch_read_count = 0

                    last_batch_read_count += 1

                grad = model.gradient(train_image, train_label, w, train_indices)
                local_gradients.append(grad)

                #mu_new = update_mu(w, w_last_global, grad)

                if i == 0:
                    try:
                        # Note: This has to follow the gradient computation line above
                        loss_last_global = model.loss_from_prev_gradient_computation()
                        print('*** Loss computed from previous gradient computation')
                    except:
                        # Will get an exception if the model does not support computing loss
                        # from previous gradient computation
                        loss_last_global = model.loss(train_image, train_label, w, train_indices)
                        print('*** Loss computed from data')

                    w_last_global = w

                    if use_min_loss:
                        if (batch_size < total_data) and (w_prev_min_loss is not None):
                            # Compute loss on w_prev_min_loss so that the batch remains the same
                            loss_w_prev_min_loss = model2.loss(train_image, train_label, w_prev_min_loss, train_indices)


                mu_new = update_mu(w, w_last_global, grad)
                print(f"update μ,mu_new = {mu_new}")

                '''if i == 1:
                    try:
                        # Note: This has to follow the gradient computation line above
                        loss_last_global = model.loss_from_prev_gradient_computation_1(w_t=w_last_global, mu=mu_new)
                        print('*** Loss computed from  gradient computation')
                    except:
                        # Will get an exception if the model does not support computing loss
                        # from previous gradient computation
                        loss_last_global = model.loss_1(train_image, train_label, w, train_indices, w_t=w_last_global, mu=mu_new)
                        print('*** Loss computed from data')

                    #w_last_global = w

                    if use_min_loss:
                        if (batch_size < total_data) and (w_prev_min_loss is not None):
                            # Compute loss on w_prev_min_loss so that the batch remains the same
                            loss_w_prev_min_loss = model2.loss_1(train_image, train_label, w_prev_min_loss, train_indices, w_t=w_last_global, mu=mu_new)'''



                prox_term = mu_new * (w - w_last_global)

                w = w - step_size * (grad + prox_term)
                #w = w - step_size * grad 

                tau_actual += 1
                total_iterations += 1

                #grad = grad + prox_term # Update grad to include proximal term

                if control_alg is not None:
                    is_last_local = control_alg.update_after_each_local(i, w, grad, total_iterations)

                    if is_last_local:
                        break

            # Local operation finished, global aggregation starts
            time_local_end = time.time()
            time_all_local = time_local_end - time_local_start
            print('time_all_local =', time_all_local)


            if len(local_gradients) > 0:
                global_gradient = np.mean(local_gradients, axis=0)  
                B, H = calculate_B_and_H(local_gradients, global_gradient, M)
                result = minimize(loss_function, initial_params, args=(local_gradients, global_gradient), method='L-BFGS-B', bounds=[(0, None), (0, None)])

                B_squared_opt, H_squared_opt = result.x
                B_opt = np.sqrt(B_squared_opt)
                H_opt = np.sqrt(H_squared_opt)
                print(f"Client B: {B}, H: {H}")
                print(f"Client B: {B_opt}, H: {H_opt}")

                case_name = f"Client_{M}_Round_{total_iterations}"
                save_to_csv(case_name, mu_new, B, H)

            if control_alg is not None:
                control_alg.update_after_all_local(model, train_image, train_label, train_indices,
                                                   w, w_last_global, loss_last_global,mu_new)

            msg = ['MSG_WEIGHT_TIME_SIZE_CLIENT_TO_SERVER', w, time_all_local, tau_actual, data_size_local,
                   loss_last_global, loss_w_prev_min_loss,mu_new]
            send_msg(sock, msg)

            if control_alg is not None:
                control_alg.send_to_server(sock)

            if is_last_round:
                break

except (struct.error, socket.error):
    print('Server has stopped')
    pass
