# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at

#   http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.


import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7"

import math
import torch
import numpy
import argparse
from scipy.io import arff
# import weka.core.jvm
# import weka.core.converters
import re
import copy
from collections import Counter
from collections import defaultdict
from sklearn.cluster import KMeans, AgglomerativeClustering
from sklearn import metrics
from scipy.spatial.distance import cdist
from numpy import dot
from numpy.linalg import norm
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt


import pickle

from sklearn.cluster import kmeans_plusplus

from sklearn.metrics import rand_score
from Common_functions import *
from multiprocessing import Pool
import subprocess

def run_client(client_id, gpu=0, batching=True, batch_size=64, cuda=True, save_path_encoder='Encoder_weights_saved', C=float('inf'), num_labeled_examples_available=1944,perform_baseline_exp=True, num_clusters=3, t=0, num_clusters_for_overlap=2):
    # Command to run your client.py script with specified client ID and GPU
    cmd = f"python3 Output_SVM_train_client.py --client_id {client_id} --gpu {gpu} --batching {batching} --batch_size {batch_size} --save_path_encoder {save_path_encoder} --cuda {cuda} --C {C} --num_labeled_examples_available {num_labeled_examples_available} --perform_baseline_exp {perform_baseline_exp} --num_clusters {num_clusters} --t {t} --num_clusters_for_overlap {num_clusters_for_overlap}"
    # Run the command
    subprocess.run(cmd, shell=True)


import argparse
def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('true'):
        return True
    elif v.lower() in ('false'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')





def main(args):
    num_clients=args.num_clients
    batching=args.batching
    batch_size=args.batch_size
    num_rounds=args.num_rounds
    min_round_before_agg=args.min_round_before_agg
    gpu=args.gpu
    cuda=args.cuda
    save_path_encoder=args.save_path_encoder
    perform_baseline_exp = args.perform_baseline_exp;
    num_clusters=args.num_clusters
    num_clusters_for_overlap=args.num_clusters_for_overlap
    num_labeled_examples_available=args.num_labeled_examples_available
    save_path=args.save_path
    hyper= "default_hyperparameters.json"
    seq_length=8

    if cuda and not torch.cuda.is_available():
        print("CUDA is not available, proceeding without it...")
        cuda = False

    print("CUDA check:", torch.cuda.is_available())

    clusters_list_main = None
    cluster_aggs_main_dict = None
    cluster_aggs_main_dict_intercept = None

    clusters_list_main_2 = None
    cluster_aggs_main_dict_2 = None
    cluster_aggs_main_dict_intercept_2 = None

    clusters_list_main_3 = None
    cluster_aggs_main_dict_3 = None
    cluster_aggs_main_dict_intercept_3 = None

    clusters_list_main_inst = None
    cluster_aggs_main_dict_inst = None
    cluster_aggs_main_dict_intercept_inst = None

    clusters_list_main_IFCA = None
    cluster_aggs_main_dict_IFCA = None
    cluster_aggs_main_dict_intercept_IFCA = None

    clusters_list_main_FLSC = None
    cluster_aggs_main_dict_FLSC = None
    cluster_aggs_main_dict_intercept_FLSC = None

    count_1 = 0;
    count_2 = 0;
    os.environ["PYTHONWARNINGS"] = "ignore"

    # num_clients = 100
    # batch_size = 64;
    # batching= True
    # seq_length = 8
    # num_rounds = 200;
    # min_round_before_agg=10;
    # gpu=0
    # cuda=True
    # save_path_encoder = 'Encoder_weights_saved'

    C_t=[0]*num_clients;

    C=float('inf')
    # num_labeled_examples_available=1944;
    # perform_baseline_exp=True;
    # num_clusters=3;
    # num_clusters_for_overlap=2;


    with Pool() as pool:
        tasks = [(client_id, gpu, batching, batch_size, cuda, save_path_encoder, C, num_labeled_examples_available) for client_id in range(num_clients)]
        pool.starmap(run_client, tasks)
    C = []

    for i in range(num_clients):
        C += [torch.load(save_path + 'SVM_output_layer_Model_for_client_' + str(i) + '.pt').classifier.C]

    C= numpy.mean(C)

    # Start training
    for t in range(num_rounds):
        print("Round: ", t);

        with Pool() as pool:
            tasks = [(client_id, gpu, batching, batch_size, cuda, save_path_encoder, C,num_labeled_examples_available, perform_baseline_exp, num_clusters, t, num_clusters_for_overlap) for client_id in range(num_clients)]
            pool.starmap(run_client, tasks)


        local_models=[]

        for i in range(num_clients):
            local_models += [torch.load(save_path + 'SVM_output_layer_Model_for_client_' + str(i) + '.pt')]

        # Evolutionary clustering

        feat_mat = numpy.array([list(numpy.concatenate((local_models[i].classifier.coef_.flatten(), local_models[i].classifier.intercept_))) for i in range(num_clients)])

        for i in range(num_clients):
            feat_mat[i, :] = feat_mat[i, :] / numpy.linalg.norm(feat_mat[i, :])

        W_new = numpy.matmul(feat_mat, feat_mat.T);

        dist_mat = numpy.sum((numpy.expand_dims(feat_mat, axis=1) - numpy.expand_dims(feat_mat, axis=0)) ** 2, axis=-1)


        print("C_t Initial: ")
        centers_init, indices = kmeans_plusplus(feat_mat, n_clusters=3, random_state=0)
        C_t_init = [];
        for k in range(num_clients):
            d_from_c = []
            f = feat_mat[k, :];
            for cent_c in range(3):
                cent_c = centers_init[cent_c, :]
                d_from_c += [numpy.sum((f - cent_c) ** 2)]
            idx = numpy.argmin(d_from_c)

            C_t_init += [idx];

        print(C_t)


        for iter_t in range(5):
            E_hat_W_t, V_hat_W_t = compute_E_hat_V_hat(C_t, W_new)

            if t == 0:
                Phi_hat_t = numpy.asarray(W_new)
                with open(save_path+'Mem_Agg_Batchwise_EXPRTD_Training_Cluster_save_starting_v3', 'wb') as fp:
                    pickle.dump(C_t, fp)
            else:
                W_new = numpy.asarray(W_new)
                alpha_t = sum(sum(V_hat_W_t)) / sum(sum((Phi_hat_t_1 - E_hat_W_t) ** 2 + V_hat_W_t))
                Phi_hat_t = alpha_t * Phi_hat_t_1 + (1. - alpha_t) * W_new



            C_t = AgglomerativeClustering(n_clusters=3, affinity="precomputed", linkage='complete').fit_predict(1.-Phi_hat_t)
            print("Iter t", iter_t, "   Corresponding C_t", C_t)
            if t>0:
                print("Iter t", iter_t, "   Corresponding C_t", alpha_t)


        C_t_inst = AgglomerativeClustering(n_clusters=3, affinity="precomputed", linkage='complete').fit_predict(1. - W_new)
        print("t", t, "   Corresponding C_t_inst", C_t_inst)

        Phi_hat_t_1 = Phi_hat_t;
        print("Cluster update:", t, C_t)
        with open(save_path+'Mem_Agg_Batchwise_EXPRTD_Training_Cluster_save_for_round_v3'+str(t)+'_bs_'+str(batch_size), 'wb') as fp:
            pickle.dump(C_t, fp)

        with open(save_path+'Mem_Agg_Batchwise_EXPRTD_Training_Cluster_inst_save_for_round_v3'+str(t)+'_bs_'+str(batch_size), 'wb') as fp:
            pickle.dump(C_t_inst, fp)

        # Aggregate per cluster
        clusters = set(C_t)
        clusters_list_evo_inst = {j: [] for j in range(len(clusters))}
        for idx in range(len(C_t)):
            clusters_list_evo_inst[C_t[idx]].append(idx)

        cluster_aggs_dict_evo_inst={}
        cluster_aggs_dict_intercept_evo_inst={}


        for i in range(len(clusters_list_evo_inst)):
            weights_coef_list = [copy.deepcopy(local_models[j].classifier.coef_) for j in clusters_list_evo_inst[i]]
            weights_intercept_list = [copy.deepcopy(local_models[j].classifier.intercept_) for j in
                                      clusters_list_evo_inst[i]]
            cluster_aggs_dict_evo_inst[i] = classifier_agg(weights_coef_list)
            cluster_aggs_dict_intercept_evo_inst[i] = classifier_agg(weights_intercept_list)


        with open(save_path + 'No_mem_evo_SVM_Cluster_list' + str(t) + '_bs_' + str(
                batch_size) + '.pkl', 'wb') as fp:
            pickle.dump(clusters_list_evo_inst, fp)

        with open(save_path + 'No_mem_evo_SVM_Cluster_Agg_weights' + str(t) + '_bs_' + str(
                batch_size) + '.pkl', 'wb') as fp:
            pickle.dump(cluster_aggs_dict_evo_inst, fp)

        with open(save_path + 'No_mem_evo_SVM_Cluster_Agg_intercept' + str(t) + '_bs_' + str(
                batch_size) + '.pkl', 'wb') as fp:
            pickle.dump(cluster_aggs_dict_intercept_evo_inst, fp)


        cluster_aggs_dict_snapshot_inst={}
        cluster_aggs_dict_intercept_snapshot_inst={}
        clusters = set(C_t_inst)
        clusters_list_snapshot_inst = {j: [] for j in range(len(clusters))}
        for idx in range(len(C_t_inst)):
            clusters_list_snapshot_inst[C_t_inst[idx]].append(idx)

        for i in range(len(clusters_list_snapshot_inst)):
            weights_coef_list = [copy.deepcopy(local_models[j].classifier.coef_) for j in clusters_list_snapshot_inst[i]]
            weights_intercept_list = [copy.deepcopy(local_models[j].classifier.intercept_) for j in
                                      clusters_list_snapshot_inst[i]]
            cluster_aggs_dict_snapshot_inst[i] = classifier_agg(weights_coef_list)
            cluster_aggs_dict_intercept_snapshot_inst[i] = classifier_agg(weights_intercept_list)

        with open(save_path + 'No_mem_snapshot_SVM_Cluster_list' + str(t) + '_bs_' + str(
                batch_size) + '.pkl', 'wb') as fp:
            pickle.dump(clusters_list_snapshot_inst, fp)

        with open(save_path + 'No_mem_snapshot_SVM_Cluster_Agg_weights' + str(t) + '_bs_' + str(
                batch_size) + '.pkl', 'wb') as fp:
            pickle.dump(cluster_aggs_dict_snapshot_inst, fp)

        with open(save_path + 'No_mem_snapshot_SVM_Cluster_Agg_intercept' + str(t) + '_bs_' + str(
                batch_size) + '.pkl', 'wb') as fp:
            pickle.dump(cluster_aggs_dict_intercept_snapshot_inst, fp)



        if t>=min_round_before_agg:
            clusters_list_main_inst, cluster_aggs_main_dict_inst, cluster_aggs_main_dict_intercept_inst=Perform_agg(clusters_list_main_inst, cluster_aggs_main_dict_inst, cluster_aggs_main_dict_intercept_inst, t, min_round_before_agg, C_t_inst, local_models)



            if t==min_round_before_agg:
                clusters_list_main, cluster_aggs_main_dict, cluster_aggs_main_dict_intercept=Perform_agg_v2(clusters_list_main, cluster_aggs_main_dict, cluster_aggs_main_dict_intercept, C_t, None, None, local_models);

            else:
                clusters_list_main, cluster_aggs_main_dict, cluster_aggs_main_dict_intercept=Perform_agg_v2(clusters_list_main, cluster_aggs_main_dict, cluster_aggs_main_dict_intercept, C_t, count_1/(count_1+1), 1/(count_1+1), local_models);
            count_1+=1

            with open(save_path + 'Approach_1_SVM_Cluster_list' + str(t) + '_bs_' + str(
                    batch_size)+'.pkl', 'wb') as fp:
                pickle.dump(clusters_list_main, fp)

            with open(save_path + 'Approach_1_SVM_Aggregated_Weights' + str(t) + '_bs_' + str(
                    batch_size)+'.pkl', 'wb') as fp:
                pickle.dump(cluster_aggs_main_dict, fp)

            with open(save_path + 'Approach_1_SVM_Aggregated_Intercepts' + str(t) + '_bs_' + str(
                    batch_size)+'.pkl', 'wb') as fp:
                pickle.dump(cluster_aggs_main_dict_intercept, fp)



            if t == min_round_before_agg:
                clusters_list_main_2, cluster_aggs_main_dict_2, cluster_aggs_main_dict_intercept_2 = Perform_agg_v2(
                    clusters_list_main_2, cluster_aggs_main_dict_2, cluster_aggs_main_dict_intercept_2, C_t, None, None,
                    local_models);
                count_2 += 1


            elif rand_score(C_t, C_t_inst)==1:
                clusters_list_main_2, cluster_aggs_main_dict_2, cluster_aggs_main_dict_intercept_2 = Perform_agg_v2(
                    clusters_list_main_2, cluster_aggs_main_dict_2, cluster_aggs_main_dict_intercept_2, C_t,
                    count_2 / (count_2 + 1), 1/(count_2 + 1), local_models);
                count_2 += 1

            with open(save_path + 'Approach_2_SVM_Cluster_list' + str(t) + '_bs_' + str(
                    batch_size) + '.pkl', 'wb') as fp:
                pickle.dump(clusters_list_main_2, fp)

            with open(save_path + 'Approach_2_SVM_Aggregated_Weights' + str(t) + '_bs_' + str(
                    batch_size) + '.pkl', 'wb') as fp:
                pickle.dump(cluster_aggs_main_dict_2, fp)

            with open(save_path + 'Approach_2_SVM_Aggregated_Intercepts' + str(t) + '_bs_' + str(
                    batch_size) + '.pkl', 'wb') as fp:
                pickle.dump(cluster_aggs_main_dict_intercept_2, fp)


            if t == min_round_before_agg:
                clusters_list_main_3, cluster_aggs_main_dict_3, cluster_aggs_main_dict_intercept_3 = Perform_agg_v2(
                    clusters_list_main_3, cluster_aggs_main_dict_3, cluster_aggs_main_dict_intercept_3, C_t, None, None,
                    local_models);

            else:
                clusters_list_main_3, cluster_aggs_main_dict_3, cluster_aggs_main_dict_intercept_3 = Perform_agg_v2(
                    clusters_list_main_3, cluster_aggs_main_dict_3, cluster_aggs_main_dict_intercept_3, C_t,
                    alpha_t, 1-alpha_t, local_models);


            with open(save_path + 'Approach_3_SVM_Cluster_list' + str(t) + '_bs_' + str(
                    batch_size) + '.pkl', 'wb') as fp:
                pickle.dump(clusters_list_main_3, fp)

            with open(save_path + 'Approach_3_SVM_Aggregated_Weights' + str(t) + '_bs_' + str(
                    batch_size) + '.pkl', 'wb') as fp:
                pickle.dump(cluster_aggs_main_dict_3, fp)

            with open(save_path + 'Approach_3_SVM_Aggregated_Intercepts' + str(t) + '_bs_' + str(
                    batch_size) + '.pkl', 'wb') as fp:
                pickle.dump(cluster_aggs_main_dict_intercept_3, fp)

        if perform_baseline_exp:
            if t==0:
                flag = True
                while flag:
                    C_t_IFCA = list(numpy.random.randint(low=0, high=num_clusters, size=(num_clients,)))
                    print('Sampled C_t', C_t_IFCA)
                    iter_flag = True
                    for k in range(num_clusters):
                        iter_flag = (iter_flag) and (k in C_t_IFCA)

                    if iter_flag:
                        flag = False
            else:
                C_t_IFCA=[];
                for client_id in range(num_clients):
                    with open(save_path + 'IFCA_cluster_choice_' + str(client_id) + 'and_Round_' + str(t) + '_bs_' + str(
                                    batch_size) + '.pkl', 'rb') as fp:
                        opt_cluster=pickle.load(fp)
                    C_t_IFCA+=[opt_cluster];

            with open(save_path + 'C_t_ifca for round ' + str(t) + '__', 'wb') as fp:
                print('IFCA-Predicted Clustering:', C_t_IFCA)
                pickle.dump(C_t_IFCA, fp)

            clusters_list_main_IFCA, cluster_aggs_main_dict_IFCA, cluster_aggs_main_dict_intercept_IFCA = Perform_agg_IFCA(
                clusters_list_main_IFCA, cluster_aggs_main_dict_IFCA, cluster_aggs_main_dict_intercept_IFCA, C_t_IFCA, t / (t + 1),
                                                                                                   1 / (t + 1),local_models, num_clusters)

            with open(save_path + 'IFCA_SVM_Cluster_list' + str(t) + '_bs_' + str(
                    batch_size) + '.pkl', 'wb') as fp:
                pickle.dump(clusters_list_main_IFCA, fp)

            with open(save_path + 'IFCA_SVM_Aggregated_Weights' + str(t) + '_bs_' + str(
                    batch_size) + '.pkl', 'wb') as fp:
                pickle.dump(cluster_aggs_main_dict_IFCA, fp)

            with open(save_path + 'IFCA_SVM_Aggregated_Intercepts' + str(t) + '_bs_' + str(
                    batch_size) + '.pkl', 'wb') as fp:
                pickle.dump(cluster_aggs_main_dict_intercept_IFCA, fp)

            C_t_FLSC = []
            if t == 0:
                for i in range(num_clients):
                    C_t_FLSC += [list(numpy.random.choice(num_clusters, num_clusters_for_overlap, replace=False))]

            else:
                for client_id in range(num_clients):
                    with open(save_path + 'FLSC_cluster_choice_' + str(client_id) + 'and_Round_' + str(t) + '_bs_' + str(
                                    batch_size) + '.pkl', 'rb') as fp:
                        C_t_vec=pickle.load(fp)

                    C_t_FLSC+=[C_t_vec];

            with open(save_path + 'C_t_flsc for round ' + str(t) + '__', 'wb') as fp:
                print('FLSC-Predicted Clustering:', C_t_FLSC)
                pickle.dump(C_t_FLSC, fp)

            clusters_list_main_FLSC, cluster_aggs_main_dict_FLSC, cluster_aggs_main_dict_intercept_FLSC = Perform_agg_FLSC(
                clusters_list_main_FLSC, cluster_aggs_main_dict_FLSC, cluster_aggs_main_dict_intercept_FLSC, C_t_FLSC, t / (t + 1),
                                                                                                   1 / (t + 1),local_models, num_clusters, num_clients);


            with open(save_path + 'FLSC_SVM_Cluster_list' + str(t) + '_bs_' + str(
                    batch_size) + '.pkl', 'wb') as fp:
                pickle.dump(clusters_list_main_FLSC, fp)

            with open(save_path + 'FLSC_SVM_Aggregated_Weights' + str(t) + '_bs_' + str(
                    batch_size) + '.pkl', 'wb') as fp:
                pickle.dump(cluster_aggs_main_dict_FLSC, fp)

            with open(save_path + 'FLSC_SVM_Aggregated_Intercepts' + str(t) + '_bs_' + str(
                    batch_size) + '.pkl', 'wb') as fp:
                pickle.dump(cluster_aggs_main_dict_intercept_FLSC, fp)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--save_path', type=str, default='Save_models/')
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--num_rounds', type=int, default=200)
    parser.add_argument('--num_clients', type=int, default=100)
    parser.add_argument('--min_round_before_agg', type=int, default=10)
    parser.add_argument('--save_path_encoder', type=str, default='Encoder_weights_saved')
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--cuda', type=str2bool, default=True)
    parser.add_argument('--batching', type=str2bool, default=True)
    parser.add_argument('--perform_baseline_exp', type=str2bool, default=True)
    parser.add_argument('--num_clusters', type=int, default=3)
    parser.add_argument('--num_clusters_for_overlap', type=int, default=2)
    parser.add_argument('--num_labeled_examples_available', type=int, default=1944)

    args = parser.parse_args()
    main(args)