# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at

#   http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.



import math
import torch
import numpy
import argparse
from scipy.io import arff
# import weka.core.jvm
# import weka.core.converters
import re
import copy
from collections import Counter
from collections import defaultdict
from sklearn.cluster import KMeans, AgglomerativeClustering
from sklearn import metrics
from scipy.spatial.distance import cdist
from numpy import dot
from numpy.linalg import norm
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt


import pickle

from sklearn.cluster import kmeans_plusplus

from sklearn.metrics import rand_score
from Common_functions import *
from multiprocessing import Pool
import subprocess

def run_client(t, client_id, min_round_before_agg, batch_size=64, save_path='Save_models/'):
    # Command to run your client.py script with specified client ID and GPU
    cmd = f"python3 Evaluate_Performance_Local.py --t {t} --client_id {client_id} --batch_size {batch_size} --min_round_before_agg {min_round_before_agg} --save_path {save_path}"
    # Run the command
    subprocess.run(cmd, shell=True)

def main(args):
    num_rounds=args.num_rounds
    num_clients=args.num_clients
    min_round_before_agg=args.min_round_before_agg
    batch_size=args.batch_size
    save_path=args.save_path

    with Pool() as pool:
        for client_id in range(num_clients):
            tasks = [(t,client_id, min_round_before_agg, batch_size, save_path) for t in range(num_rounds)]
            pool.starmap(run_client, tasks)

    acc_app_1_vec = [];
    acc_app_3_vec = [];
    acc_evo_inst_vec = [];
    acc_snapshot_vec = [];
    acc_ifca_vec = [];
    acc_flsc_vec = [];

    for t in range(num_rounds):
        acc_app_1 = [];
        acc_app_3 = [];
        acc_evo_inst = [];
        acc_snapshot = [];
        acc_ifca = [];
        acc_flsc = [];

        for client_id in range(num_clients):
            try:
                with open(save_path + 'Accs_for_client_' + str(client_id) + 'and_Round_' + str(t) + '_bs_' + str(
                        batch_size) + '.pkl', 'rb') as fp:
                    acc_to_save = pickle.load(fp);
            except:
                with Pool() as pool:
                    tasks = [(t, client_id, min_round_before_agg, batch_size, save_path)]
                    pool.starmap(run_client, tasks)

                with open(save_path + 'Accs_for_client_' + str(client_id) + 'and_Round_' + str(t) + '_bs_' + str(
                        batch_size) + '.pkl', 'rb') as fp:
                    acc_to_save = pickle.load(fp);

            try:
                os.remove(save_path + 'Accs_for_client_' + str(client_id) + 'and_Round_' + str(t) + '_bs_' + str(
                        batch_size) + '.pkl');
            except:
                print("Unable to delete the file");

            if t >= min_round_before_agg:
                acc_app_1 += [acc_to_save['App_1_SVM']]
                acc_app_3 += [acc_to_save['App_3_SVM']]

            acc_evo_inst += [acc_to_save['No_mem_evo_SVM_acc']]
            acc_snapshot += [acc_to_save['No_mem_snapshot_SVM_acc']]
            acc_ifca += [acc_to_save['IFCA_SVM_acc']]
            acc_flsc += [acc_to_save['FLSC_SVM_acc']]

        if t >= min_round_before_agg:
            acc_app_1_vec += [acc_app_1]
            acc_app_3_vec += [acc_app_3]
        acc_evo_inst_vec += [acc_evo_inst];
        acc_snapshot_vec += [acc_snapshot];
        acc_ifca_vec += [acc_ifca ];
        acc_flsc_vec += [acc_flsc ];

    Experiment_accs={};
    Experiment_accs['acc_app_1_vec']=acc_app_1_vec;
    Experiment_accs['acc_app_3_vec']=acc_app_3_vec;
    Experiment_accs['acc_evo_inst_vec'] =acc_evo_inst_vec;
    Experiment_accs['acc_snapshot_vec'] =acc_snapshot_vec;
    Experiment_accs['acc_ifca_vec'] =acc_ifca_vec;
    Experiment_accs['acc_flsc_vec'] =acc_flsc_vec;
    with open(save_path + 'Accuracy_save_for _bs_' + str(
            batch_size)+'_num_rounds_'+str(num_rounds)+'_min_rounds_before_agg_'+str(min_round_before_agg)+'.pkl', 'wb') as fp:
        pickle.dump(Experiment_accs,fp);


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--save_path', type=str, default='Save_models/')
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--num_rounds', type=int, default=200)
    parser.add_argument('--num_clients', type=int, default=100)
    parser.add_argument('--min_round_before_agg', type=int, default=10)

    args = parser.parse_args()
    main(args)