from utils import *
from datetime import datetime
import time, random, os
import tensorflow as tf
from tensorflow import keras



def PIDF_main(Application, configs, ST_node, SV_Data, Models_dict:callable, train_local_model:callable, evaluate_models:callable):
    
    N_NODES = configs['N_NODES']
    EPOCHS = configs['EPOCHS']
    COMMUNICATION_ROUNDS = configs['COMMUNICATION_ROUNDS']
    physics_weights = configs['physics_weights']
    num_iterations = configs['num_iterations']
    from_iteration = configs['from_iteration']
    to_iteration = configs['to_iteration']
    non_IID = configs['non_IID']
    save_log = configs['save_log']
    shuffle = configs['shuffle']
    Dirichlet_alpha = configs['Dirichlet_alpha']
    noise = configs['noise']
    noise_on_input = configs['noise_on_input']
    nst = configs['nst']
    num_true_values = configs['num_true_values']
    A = configs['A']
    root_path = configs['root_path']

    pw_time_consumption = dict()
    pw_name_map = dict()
    for pw in physics_weights:
        pw_time_consumption[pw] = list()
        pw_name_map[pw] = 'PW'+str(pw)

    results = [] # Initialize an empty list to store results
    CR_results = []
    CR_nodes = []
    gap_results = []
    

    # Creating timestamp for the directory name
    timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
    output_dir = f"{root_path}/PIDFL_output_{Application}_{N_NODES}_{num_iterations}_{COMMUNICATION_ROUNDS}_{noise}_{timestamp}"
    if shuffle:
        output_dir = f"{root_path}/Shuffle_PIDFL_output_{Application}_{N_NODES}_{num_iterations}_{COMMUNICATION_ROUNDS}_{noise}_{timestamp}"
     
    if save_log:
        os.makedirs(output_dir, exist_ok=True)

    print("Adjacency Matrix is as follows:")
    print(A)

    start_time = time.time()
    for iteration in range(num_iterations)[from_iteration: to_iteration]:
        print(f"Iteration {iteration+1}/{num_iterations}")
        if iteration == 0:
            res_dict = {'Application': Application, 'N_Nodes': N_NODES, 
                        'N_ITs': num_iterations, 'N_CRs': COMMUNICATION_ROUNDS, 'Noise': noise, 
                        'It': iteration}
            for pw in physics_weights:
                res_dict[pw_name_map[pw]] = 0
        else:
            res_dict = {'Application': "", 'N_Nodes': "", 
                        'N_ITs': "", 'N_CRs': "", 'Noise': "", 
                        'It': iteration}
            for pw in physics_weights:
                res_dict[pw_name_map[pw]] = 0

        #### Defining Physical Models
        physics_models = dict()
        for physics_weight in physics_weights:
                physics_models[physics_weight] = [Models_dict() for _ in range(N_NODES)]


        It_data = ST_node[iteration]
        if not non_IID:
            node_samples = nst // N_NODES

            elements = [item for item in It_data]
            local_datasets = list()
            for i in range(N_NODES):
                data = [tf.convert_to_tensor(list(el)[i*node_samples: (i+1)*node_samples], dtype=tf.float32) for el in elements]

                if noise > 0 and Application not in ['AirQ', 'DrugDiffusion']:
                    if noise_on_input:
                        noise_l = noise*np.std(data[0])*np.random.randn(data[0].shape[0], data[0].shape[1]) ## Adding noise to the data
                        data[0] = data[0] + noise_l
                    else:
                        for i in range(1, 1+num_true_values):
                            noise_l = noise*np.std(data[-i])*np.random.randn(data[-i].shape[-i], data[-i].shape[1]) ## Adding noise to the data
                            data_list[-i] = data_list[-i] + noise_l
                
                data = tuple(data)
                local_datasets.append(data)

        elif non_IID:
            node_alpha = [Dirichlet_alpha for _ in range(N_NODES)]
            proportions = list(np.random.dirichlet(node_alpha))
            print(f"Iteration {iteration} Data, Non-IID Proportions: ", proportions)
            
            local_datasets = list()
            elements = [item for item in It_data]
            taken_index = 0
            for ind in range(N_NODES):
                proportion = proportions[ind]
                n_samples=int(proportion*nst) 

                data_list = [tf.convert_to_tensor(list(el)[taken_index: taken_index + n_samples], dtype=tf.float32) for el in elements]
                if noise > 0 and n_samples >1 and Application not in ['AirQ', 'DrugDiffusion']:
                    noise_l = noise*np.std(data_list[0])*np.random.randn(data_list[0].shape[0], data_list[0].shape[1]) ## Adding noise to the data
                    if noise_on_input:
                        data_list[0] = data_list[0] + noise_l
                    else:
                        for i in range(1, 1+num_true_values):
                            data_list[-i] = data_list[-i] + noise_l

                
                data = tuple(data_list)
                local_datasets.append(data)

                taken_index = taken_index + n_samples
                print(f"Data Length for Node {ind}:", len(data[0]))

        # Training Loop
        for physics_weight in physics_weights:
            print(f"\nPhysics Weight: {physics_weight}")

            weight_models = physics_models[physics_weight]

            print(weight_models[0].summary())

            for communication_round in range(COMMUNICATION_ROUNDS):
                print(f"\nCommunication Round: {communication_round}")
                pw_start_time = time.time()
                for i, model in enumerate(weight_models):
                    # Local training
                    if len(local_datasets[i][0]):
                        train_local_model(model, local_datasets[i], EPOCHS, use_physics=physics_weight > 0, physics_weight=physics_weight)
                
                adapted_params = [model.get_weights() for model in weight_models]
                params_before = adapted_params
                losses_before = evaluate_models(weight_models, SV_Data)

                print("Losses Before Aggregation:")
                for index in range(len(losses_before)):
                    print(f"{index}: {losses_before[index]}")
                print("===================================")
                
                for i, model in enumerate(weight_models):
                    averaged_params = [np.mean([adapted_params[j][k] for j in range(N_NODES) if A[i,j]], axis=0) for k in range(len(weight_models[0].get_weights()))]
                    model.set_weights(averaged_params)

                params_adapted = [model.get_weights() for model in weight_models]
                losses_after = evaluate_models(weight_models, SV_Data)

                print("Losses After Aggregation:")
                for index in range(len(losses_after)):
                    print(f"{index}: {losses_after[index]}")
                print("===================================")

                pw_time_consumption[physics_weight].append(time.time() - pw_start_time)
                print("\nTime tanken for this CR:", (time.time() - pw_start_time))

                if save_log:
                    save_parameters(f"{output_dir}", communication_round, iteration, params_before, params_adapted, losses_before, losses_after, physics_weight)
                
                cr_losses = evaluate_models(weight_models, SV_Data)
                cr_average_loss = np.mean(cr_losses)
                CR_results.append({'Iteration': iteration+1, 'CR':communication_round ,'Physics Weight': physics_weight, 'Average Loss': cr_average_loss}) # Append results for this iteration to the list

                cr_nodes_dict = dict({'Iteration': iteration+1, 'CR':communication_round,'Physics Weight': physics_weight})
                for i in range(N_NODES):
                    cr_nodes_dict[i] = cr_losses[i]
                    
                CR_nodes.append(cr_nodes_dict)
                print(f"CR {communication_round} loss values:", cr_nodes_dict)
                print("==========================================")
        
            # Evaluate
            average_loss = np.mean(evaluate_models(weight_models, SV_Data))
            print(f"Average loss with lambda_p= {physics_weight}: {average_loss}")

            results.append({'Iteration': iteration+1, 'Physics Weight': physics_weight, 'Average Loss': average_loss, 'Time': (time.time() - pw_start_time)}) # Append results for this iteration to the list

            res_dict[pw_name_map[physics_weight]] = average_loss

            results_df = pd.DataFrame(results) # Convert results list to a DataFrame
            CR_results_df = pd.DataFrame(CR_results)
            CR_nodes_df = pd.DataFrame(CR_nodes)
            print(f"\nFinal Results:\n{results_df}") # Display the DataFrame
            print("---------------------------")
            # print(f"\nFinal CR Nodes:\n{CR_nodes_df}") # Display the DataFrame
            # print("---------------------------")
            status = "IID" if not non_IID else "NonIID"
            if save_log:
                results_df.to_csv(f'{output_dir}/PIDFL_{Application}_{status}_{N_NODES}_{num_iterations}_{COMMUNICATION_ROUNDS}_{noise}_{timestamp}.csv')
                CR_results_df.to_csv(f'{output_dir}/PIDFL_CRresults_{Application}_{status}_{N_NODES}_{num_iterations}_{COMMUNICATION_ROUNDS}_{noise}_{timestamp}.csv')
                CR_nodes_df.to_csv(f'{output_dir}/PIDFL_CRNodes_{Application}_{status}_{N_NODES}_{num_iterations}_{COMMUNICATION_ROUNDS}_{noise}_{timestamp}.csv')

        gap_results.append(res_dict)
        print("Gap Results:", gap_results)

        print("\nPhysics Weights Time Consumption in each iteration:")
        for pw in pw_time_consumption:
            print(pw, "=>", pw_time_consumption[pw])

        print("\nAverage Total Time Consumption:")
        for pw in pw_time_consumption:
            avg = np.mean(pw_time_consumption[pw])
            print(f"Average for PW {pw} => {avg/60} Minutes")


    gap_df = pd.DataFrame(gap_results)

    ## Mean Values
    res_dict_mean = {'Application': "Mean_Loss", 'N_Nodes': "", 'N_ITs': "", 'N_CRs': "", 'Noise': "", 'It': ""}
    for pw in physics_weights:
        res_dict_mean[pw_name_map[pw]] = np.mean(gap_df[pw_name_map[pw]])
    gap_results.append(res_dict_mean)
    
    ## Standard Deviation
    res_dict_STD = {'Application': "STD_Loss", 'N_Nodes': "", 'N_ITs': "", 'N_CRs': "", 'Noise': "", 'It': ""}
    for pw in physics_weights:
        res_dict_STD[pw_name_map[pw]] = np.std(gap_df[pw_name_map[pw]])
    gap_results.append(res_dict_STD)

    ## Delta
    res_dict_Delta = {'Application': "Delta_Mean_Loss", 'N_Nodes': "", 'N_ITs': "", 'N_CRs': "", 'Noise': "", 'It': ""}
    for pw in physics_weights:
        res_dict_Delta[pw_name_map[pw]] = np.mean(gap_df[pw_name_map[pw]]) - np.mean(gap_df[pw_name_map[0.0]])
    gap_results.append(res_dict_Delta)

    ## ILS Gap Values: gap(%) = (z_LNS - z_ILS) / z_ILS * 100, where z denotes the solution value
    res_dict_Gap = {'Application': "Gap_Mean_Loss", 'N_Nodes': "", 'N_ITs': "", 'N_CRs': "", 'Noise': "", 'It': ""}
    for pw in physics_weights:
        if pw == 0.0:
            res_dict_Gap[pw_name_map[pw]] = 0.0
        else:
            res_dict_Gap[pw_name_map[pw]] = 100 * (np.mean(gap_df[pw_name_map[pw]]) - np.mean(gap_df[pw_name_map[0.0]])) / np.mean(gap_df[pw_name_map[0.0]])
    gap_results.append(res_dict_Gap)

    gap_df = pd.DataFrame(gap_results)
    print("Gap Values:\n", gap_df)
    if save_log:
        gap_df.to_csv(f'{output_dir}/PIDFL_GapResults_{Application}_{status}_{N_NODES}_{num_iterations}_{COMMUNICATION_ROUNDS}_{noise}_{timestamp}.csv')

    finish_time = time.time()
    print(f"Time taken by the PIDFL Algorithm: {(finish_time-start_time)/60} Minutes")

    
def Fed_Avg(Application, configs, ST_node, SV_Data, Models_dict:callable, train_local_model:callable, evaluate_models:callable):
    N_NODES = configs['N_NODES']
    EPOCHS = configs['EPOCHS']
    COMMUNICATION_ROUNDS = configs['COMMUNICATION_ROUNDS']
    num_iterations = configs['num_iterations']
    from_iteration = configs['from_iteration']
    to_iteration = configs['to_iteration']
    non_IID = configs['non_IID']
    save_log = configs['save_log']
    shuffle = configs['shuffle']
    Dirichlet_alpha = configs['Dirichlet_alpha']
    noise = configs['noise']
    noise_on_input = configs['noise_on_input']
    nst = configs['nst']
    num_true_values = configs['num_true_values']
    A = configs['A']
    C = configs['C']
    root_path = configs['root_path']

    results = [] # Initialize an empty list to store results
    CR_results = []
    CR_nodes = []

    # Creating timestamp for the directory name
    timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
    output_dir = f"{root_path}/FedAvg_output_{Application}_{N_NODES}_{COMMUNICATION_ROUNDS}_{noise}_{timestamp}"
    if shuffle:
        output_dir = f"{root_path}/Shuffle_FedAvg_output_{Application}_{N_NODES}_{COMMUNICATION_ROUNDS}_{noise}_{timestamp}"

    if save_log:
        os.makedirs(output_dir, exist_ok=True)

    start_time = time.time()
    for iteration in range(num_iterations)[from_iteration: to_iteration]:
        print(f"Iteration {iteration+1}/{num_iterations}")

        client_models = dict()
        for i in range(N_NODES):
            client_models[i] = Models_dict()
        server_model = Models_dict()

        print(server_model.summary())

        It_data = ST_node[iteration]
        if not non_IID:
            node_samples = nst // N_NODES

            elements = [item for item in It_data]
            local_datasets = list()
            for i in range(N_NODES):
                data = [tf.convert_to_tensor(list(el)[i*node_samples: (i+1)*node_samples], dtype=tf.float32) for el in elements]

                if noise > 0 and Application not in ['AirQ', 'DrugDiffusion']:
                    if noise_on_input:
                        noise_l = noise*np.std(data[0])*np.random.randn(data[0].shape[0], data[0].shape[1]) ## Adding noise to the data
                        data[0] = data[0] + noise_l
                    else:
                        for i in range(1, 1+num_true_values):
                            noise_l = noise*np.std(data[-i])*np.random.randn(data[-i].shape[-i], data[-i].shape[1]) ## Adding noise to the data
                            data_list[-i] = data_list[-i] + noise_l

                data = tuple(data)
                local_datasets.append(data)

        elif non_IID:
            node_alpha = [Dirichlet_alpha for _ in range(N_NODES)]
            proportions = list(np.random.dirichlet(node_alpha))
            print(f"Iteration {iteration} Data, Non-IID Proportions: ", proportions)
            
            local_datasets = list()
            elements = [item for item in It_data]
            taken_index = 0
            for ind in range(N_NODES):
                proportion = proportions[ind]
                n_samples=int(proportion*nst) 

                data_list = [tf.convert_to_tensor(list(el)[taken_index: taken_index + n_samples], dtype=tf.float32) for el in elements]
                if noise > 0 and n_samples >1 and Application not in ['AirQ', 'DrugDiffusion']:
                    noise_l = noise*np.std(data_list[0])*np.random.randn(data_list[0].shape[0], data_list[0].shape[1]) ## Adding noise to the data
                    if noise_on_input:
                        data_list[0] = data_list[0] + noise_l
                    else:
                        for i in range(1, 1+num_true_values):
                            data_list[-i] = data_list[-i] + noise_l

                
                data = tuple(data_list)

                local_datasets.append(data)
                taken_index = taken_index + n_samples
                print(f"Data Length for Node {ind}:", len(data[0]))

        for communication_round in range(COMMUNICATION_ROUNDS):
            print(f"\nCommunication Round: {communication_round}")

            m = max(int(C*N_NODES), 2)

            selected_keys = random.sample(list(client_models.keys()), k=m)
            print(selected_keys)
            for key in selected_keys:
                # Local training
                if len(local_datasets[key][0]):
                    train_local_model(client_models[key], local_datasets[key], EPOCHS)
                
            # Aggregation (ATC algorithm)
            adapted_params = [model.get_weights() for model in client_models.values()]
            params_before = [model.get_weights() for model in client_models.values()]
            losses_before = evaluate_models(client_models.values(), SV_Data)

            print("Losses Before Aggregation:")
            for index in range(len(losses_before)):
                print(f"{index}: {losses_before[index]}")
            print("===================================")

            ## weighted average on the server side;
            averaged_params = [np.mean([(len(local_datasets[j][0])/nst)*adapted_params[j][k] for j in range(N_NODES) if A[i,j]], axis=0) for k in range(len(list(client_models.values())[0].get_weights()))]
            
            server_model.set_weights(averaged_params)

            for model in client_models:
                client_models[model].set_weights(server_model.get_weights())

            params_adapted = [model.get_weights() for model in client_models.values()]

            losses_list = evaluate_models(client_models.values(), SV_Data)
            cr_loss = np.mean(losses_list)
            CR_results.append({'Iteration': iteration+1, 'CR':communication_round , 'Average Loss': cr_loss}) # Append results for this iteration to the list

            cr_nodes_dict = dict({'Iteration': iteration+1, 'CR':communication_round})
            for i in range(N_NODES):
                cr_nodes_dict[i] = losses_list[i]
                
            CR_nodes.append(cr_nodes_dict)
            print(f"CR {communication_round} loss values:", cr_nodes_dict)

            # Save parameters and weights to file
            if save_log:
                save_parameters(f"{output_dir}", communication_round, iteration, params_before, params_adapted, losses_before, losses_list, None)


            print("Losses After Aggregation:")
            for index in range(len(losses_list)):
                print(f"{index}: {losses_list[index]}")
            print("==========================================")

        # Evaluate the models
        average_loss = evaluate_models(client_models.values(), SV_Data)
        average_loss = np.mean(average_loss)
        print(f"Average loss in iteration {iteration+1}: {average_loss}")
        results.append({'Iteration': iteration+1, 'Average Loss': average_loss}) # Append results for this iteration to the list

        results_df = pd.DataFrame(results) # Convert results list to a DataFrame
        print(f"\nFinal Results:\n{results_df}") # Display the DataFrame

        CR_results_df = pd.DataFrame(CR_results)
        print(f"\nFinal CR Results:\n{CR_results_df}")

        CR_nodes_df = pd.DataFrame(CR_nodes)
        
        status = "IID" if not non_IID else "NonIID"
        if save_log:
            results_df.to_csv(f'{output_dir}/FedAvg_{Application}_{status}_{N_NODES}_{timestamp}.csv')
            CR_results_df.to_csv(f'{output_dir}/FedAvg_CRResults_{Application}_{status}_{N_NODES}_{timestamp}.csv')
            CR_nodes_df.to_csv(f'{output_dir}/FedAvg_CRNodes_{Application}_{status}_{N_NODES}_{timestamp}.csv')

    finish_time = time.time()
    print(f"Time taken by the FedAvg Algorithm: {(finish_time-start_time)/60} Minutes")


def CL_Base_model(Application, configs, ST_node, SV_Data, Models_dict:callable, train_local_model:callable, evaluate_models:callable):
    N_NODES = configs['N_NODES']
    EPOCHS = configs['EPOCHS']
    COMMUNICATION_ROUNDS = configs['COMMUNICATION_ROUNDS']
    physics_weights = configs['physics_weights']
    num_iterations = configs['num_iterations']
    from_iteration = configs['from_iteration']
    to_iteration = configs['to_iteration']
    save_log = configs['save_log']
    shuffle = configs['shuffle']
    noise = configs['noise']
    noise_on_input = configs['noise_on_input']
    num_true_values = configs['num_true_values']
    root_path = configs['root_path']

    pw_name_map = dict()
    for pw in physics_weights:
        pw_name_map[pw] = 'PW'+str(pw)

    results = [] # Initialize an empty list to store results
    gap_results = []

    # Creating timestamp for the directory name
    timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
    output_dir = f"{root_path}/CL_Base_output_{Application}_{num_iterations}_{noise}_{timestamp}"
    if shuffle: 
        output_dir = f"{root_path}/Shuffle_CL_Base_output_{Application}_{num_iterations}_{noise}_{timestamp}"
    
    if save_log:
        os.makedirs(output_dir, exist_ok=True)

    start_time = time.time()
    for iteration in range(num_iterations)[from_iteration: to_iteration]:
        print(f"Iteration {iteration+1}/{num_iterations}")

        if iteration == 0:
            res_dict = {'Application': Application, 'N_Nodes': N_NODES, 
                        'N_ITs': num_iterations, 'N_CRs': COMMUNICATION_ROUNDS, 'Noise': noise, 
                        'It': iteration}
            for pw in physics_weights:
                res_dict[pw_name_map[pw]] = 0
        else:
            res_dict = {'Application': "", 'N_Nodes': "", 
                        'N_ITs': "", 'N_CRs': "", 'Noise': "", 
                        'It': iteration}
            for pw in physics_weights:
                res_dict[pw_name_map[pw]] = 0
        
        ## Adding Noise to the data
        elements = [item for item in ST_node[iteration]]
        data = [tf.convert_to_tensor(el, dtype=tf.float32) for el in elements]

        if noise > 0 and Application not in ['AirQ', 'DrugDiffusion']:
            if noise_on_input:
                noise_l = noise*np.std(data[0])*np.random.randn(data[0].shape[0], data[0].shape[1]) ## Adding noise to the data
                data[0] = data[0] + noise_l
            else:
                for i in range(1, 1+num_true_values):
                    noise_l = noise*np.std(data[-i])*np.random.randn(data[-i].shape[-i], data[-i].shape[1]) ## Adding noise to the data
                    data[-i] = data[-i] + noise_l

        ST_node[iteration] = tuple(data)

        physics_models = dict() ## Model Definition
        for physics_weight in physics_weights:
                physics_models[physics_weight] = Models_dict()

        for physics_weight in physics_weights:
            print(f"\nPhysics Weight: {physics_weight}")
            print(physics_models[physics_weight].summary())

            train_local_model(physics_models[physics_weight], ST_node[iteration] , EPOCHS)
            
            # Evaluate the models
            average_loss = evaluate_models([physics_models[physics_weight]], SV_Data)[0]
            print(f"Average loss in iteration {iteration} for physical weight {physics_weight}: {average_loss}")

            res_dict[pw_name_map[physics_weight]] = average_loss

            results.append({'Iteration': iteration+1, 'Physical Weight': physics_weight, 'Average Loss': average_loss}) # Append results for this iteration to the list

            results_df = pd.DataFrame(results) # Convert results list to a DataFrame
            print(f"\nFinal Results:\n{results_df}") # Display the DataFrame
            if save_log:
                results_df.to_csv(f'{output_dir}/CL_BaseModel_{Application}_{num_iterations}_{noise}_{timestamp}.csv')

        gap_results.append(res_dict)
        print("Gap Results:", gap_results)

    gap_df = pd.DataFrame(gap_results)
    ## Mean Values
    res_dict_mean = {'Application': "Mean_Loss", 'N_Nodes': "", 'N_ITs': "", 'N_CRs': "", 'Noise': "", 'It': ""}
    for pw in physics_weights:
        res_dict_mean[pw_name_map[pw]] = np.mean(gap_df[pw_name_map[pw]])
    gap_results.append(res_dict_mean)
    
    ## Standard Deviation
    res_dict_STD = {'Application': "STD_Loss", 'N_Nodes': "", 'N_ITs': "", 'N_CRs': "", 'Noise': "", 'It': ""}
    for pw in physics_weights:
        res_dict_STD[pw_name_map[pw]] = np.std(gap_df[pw_name_map[pw]])
    gap_results.append(res_dict_STD)

    ## Delta
    res_dict_Delta = {'Application': "Detal_Mean_Loss", 'N_Nodes': "", 'N_ITs': "", 'N_CRs': "", 'Noise': "", 'It': ""}
    for pw in physics_weights:
        res_dict_Delta[pw_name_map[pw]] = np.mean(gap_df[pw_name_map[pw]]) - np.mean(gap_df[pw_name_map[0.0]])
    gap_results.append(res_dict_Delta)

    ## ILS Gap Values: gap(%) = (z_LNS - z_ILS) / z_ILS * 100, where z denotes the solution value
    res_dict_Gap = {'Application': "Gap_Mean_Loss", 'N_Nodes': "", 'N_ITs': "", 'N_CRs': "", 'Noise': "", 'It': ""}
    for pw in physics_weights:
        if pw == 0.0:
            res_dict_Gap[pw_name_map[pw]] = 0.0
        else:
            res_dict_Gap[pw_name_map[pw]] = 100 * (np.mean(gap_df[pw_name_map[pw]]) - np.mean(gap_df[pw_name_map[0.0]])) / np.mean(gap_df[pw_name_map[0.0]])
    gap_results.append(res_dict_Gap)

    gap_df = pd.DataFrame(gap_results)
    print("Gap Values:\n", gap_df)
    if save_log:
        gap_df.to_csv(f'{output_dir}/CL_BaseModel_GapResults_{Application}_{num_iterations}_{noise}_{timestamp}.csv')
 
    finish_time = time.time()
    print(f"Time taken by the CL_Base Algorithm: {(finish_time-start_time)/60} Minutes")


def Segmented_Gossip(Application, configs, ST_node, SV_Data, Models_dict:callable, train_local_model:callable, evaluate_models:callable):
    N_NODES = configs['N_NODES']
    LEARNING_RATE = configs['LEARNING_RATE']
    EPOCHS = configs['EPOCHS']
    COMMUNICATION_ROUNDS = configs['COMMUNICATION_ROUNDS']
    physics_weights = configs['physics_weights']
    num_iterations = configs['num_iterations']
    from_iteration = configs['from_iteration']
    to_iteration = configs['to_iteration']
    non_IID = configs['non_IID']
    save_log = configs['save_log']
    shuffle = configs['shuffle']
    Dirichlet_alpha = configs['Dirichlet_alpha']
    noise = configs['noise']
    noise_on_input = configs['noise_on_input']
    nst = configs['nst']
    num_true_values = configs['num_true_values']
    A = configs['A']
    R = configs['R']
    S = configs['S']
    root_path = configs['root_path']

    pw_time_consumption = dict()
    for pw in physics_weights:
        pw_time_consumption[pw] = list()

    results = [] # Initialize an empty list to store results
    CR_results = []
    CR_nodes = []

    # Creating timestamp for the directory name
    timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
    output_dir = f"{root_path}/SGossip_output_{Application}_{N_NODES}_{COMMUNICATION_ROUNDS}_{noise}_{timestamp}"
    if shuffle:
        output_dir = f"{root_path}/Shuffle_SGossip_output_{Application}_{N_NODES}_{COMMUNICATION_ROUNDS}_{noise}_{timestamp}"

    if save_log:
        os.makedirs(output_dir, exist_ok=True)

    start_time = time.time()
    for iteration in range(num_iterations):
        print(f"Iteration {iteration+1}/{num_iterations}")

        client_models = dict()
        for i in range(N_NODES):
            client_models[i] = Models_dict()

        print(client_models[0].summary())

        It_data = ST_node[iteration]
        if not non_IID:
            node_samples = nst // N_NODES

            elements = [item for item in It_data]
            local_datasets = list()
            for i in range(N_NODES):
                data = [tf.convert_to_tensor(list(el)[i*node_samples: (i+1)*node_samples], dtype=tf.float32) for el in elements]

                if noise > 0 and Application not in ['AirQ', 'DrugDiffusion']:
                    if noise_on_input:
                        noise_l = noise*np.std(data[0])*np.random.randn(data[0].shape[0], data[0].shape[1]) ## Adding noise to the data
                        data[0] = data[0] + noise_l
                    else:
                        for i in range(1, 1+num_true_values):
                            noise_l = noise*np.std(data[-i])*np.random.randn(data[-i].shape[-i], data[-i].shape[1]) ## Adding noise to the data
                            data_list[-i] = data_list[-i] + noise_l

                data = tuple(data)
                local_datasets.append(data)

        elif non_IID:
            node_alpha = [Dirichlet_alpha for _ in range(N_NODES)]
            proportions = list(np.random.dirichlet(node_alpha))
            print(f"Iteration {iteration} Data, Non-IID Proportions: ", proportions)
            
            local_datasets = list()
            elements = [item for item in It_data]
            taken_index = 0
            for ind in range(N_NODES):
                proportion = proportions[ind]
                n_samples=int(proportion*nst) 

                data_list = [tf.convert_to_tensor(list(el)[taken_index: taken_index + n_samples], dtype=tf.float32) for el in elements]
                if noise > 0 and n_samples >1 and Application not in ['AirQ', 'DrugDiffusion']:
                    noise_l = noise*np.std(data_list[0])*np.random.randn(data_list[0].shape[0], data_list[0].shape[1]) ## Adding noise to the data
                    if noise_on_input:
                        data_list[0] = data_list[0] + noise_l
                    else:
                        for i in range(1, 1+num_true_values):
                            data_list[-i] = data_list[-i] + noise_l

                
                data = tuple(data_list)
                local_datasets.append(data)
                taken_index = taken_index + n_samples
                print(f"Data Length for Node {ind}:", len(data[0]))

        for communication_round in range(COMMUNICATION_ROUNDS):
            print(f"\nCommunication Round: {communication_round}")
 
            for key in client_models:
                # Local training
                if len(local_datasets[key][0]):
                    train_local_model(client_models[key], local_datasets[key], EPOCHS)

            params_before = [model.get_weights() for model in client_models.values()]
            losses_before = evaluate_models(client_models.values(), SV_Data)

            print("Losses Before Aggregation:")
            for index in range(len(losses_before)):
                print(f"{index}: {losses_before[index]}")
            print("===================================")

            ## Aggregating the Segmented Weights when all the models are locally trained
            for key in client_models:
                model_weights = client_models[key].get_weights()
                # print("Model Weights:", len(model_weights))
                segment_size = len(model_weights) // S

                # segments = [model_weights[seg_ind*segment_size : (seg_ind+1)*segment_size] for seg_ind in range(S)]
                segments_indices = [(seg_ind*segment_size, (seg_ind+1)*segment_size) for seg_ind in range(S)]
                # print("Segment Indices:", segments_indices)
                
                model_segments = list()
                for segment_ind in segments_indices:
                    # print("Segment:", segment_ind)
                    peer_weights = dict()
                    peer_datasize = dict()

                    peer_weights[key] = model_weights[segment_ind[0]:segment_ind[1]]
                    peer_datasize[key] = len(local_datasets[key][0])

                    peers = list(client_models.keys())
                    peers.remove(key)

                    selected_peers = random.sample(peers, k=R) ## This should be according to A, in case the topology of the network is not a complete graph
                    # print("Selected Peers:", selected_peers)
                    for peer_ind in selected_peers:
                        peer_weights[peer_ind] = client_models[peer_ind].get_weights()[segment_ind[0]:segment_ind[1]]
                        peer_datasize[peer_ind] = len(local_datasets[peer_ind][0]) 

                    # print(peer_datasize)

                    sum_datasizes = sum(peer_datasize.values())
                    weighted_peerweights = [np.mean([(peer_datasize[peer_ind]/sum_datasizes)*peer_weights[peer_ind][k] for peer_ind in peer_weights.keys() if A[key, peer_ind]], axis=0) for k in range(len(model_weights[segment_ind[0]:segment_ind[1]]))]
                    
                    model_segments.extend(weighted_peerweights)

                # print(model_segments)
                client_models[key].set_weights(model_segments)
                # print("=============================")
        
            
            ## Recording the parameters after Aggregation
            params_adapted = [model.get_weights() for model in client_models.values()]

            losses_after = evaluate_models(client_models.values(), SV_Data)
            cr_loss = np.mean(losses_after)
            CR_results.append({'Iteration': iteration+1, 'CR':communication_round , 'Average Loss': cr_loss}) # Append results for this iteration to the list

            cr_nodes_dict = dict({'Iteration': iteration+1, 'CR':communication_round})
            for i in range(N_NODES):
                cr_nodes_dict[i] = losses_after[i]
            CR_nodes.append(cr_nodes_dict)
            print(f"CR {communication_round} loss values:", cr_nodes_dict)

            if save_log:
                # Save parameters and weights to file
                save_parameters(f"{output_dir}", communication_round, iteration, params_before, params_adapted, losses_before, losses_after, None)

            print("Losses After Aggregation:")
            for index in range(len(losses_after)):
                print(f"{index}: {losses_after[index]}")
            print("==========================================")


        # Evaluate the models
        average_loss = evaluate_models(client_models.values(), SV_Data)
        average_loss = np.mean(average_loss)
        print(f"Average loss in iteration {iteration+1}: {average_loss}")
        results.append({'Iteration': iteration+1, 'Average Loss': average_loss}) # Append results for this iteration to the list

        results_df = pd.DataFrame(results) # Convert results list to a DataFrame
        print(f"\nFinal Results:\n{results_df}") # Display the DataFrame
        status = "IID" if not non_IID else "NonIID"

        CR_results_df = pd.DataFrame(CR_results)
        CR_nodes_df = pd.DataFrame(CR_nodes)
        print(f"\nFinal CR Results:\n{CR_results_df}")
        if save_log:
            results_df.to_csv(f'{output_dir}/SGossip_{Application}_{status}_{N_NODES}_{timestamp}.csv')
            CR_results_df.to_csv(f'{output_dir}/SGossip_CRResults_{Application}_{status}_{N_NODES}_{timestamp}.csv')
            CR_nodes_df.to_csv(f'{output_dir}/SGossip_CRNodes_{Application}_{status}_{N_NODES}_{timestamp}.csv')

    finish_time = time.time()
    print(f"Time taken by Segmented Gossip process: {(finish_time-start_time)/60} Minutes")

