import math
import torch
import numpy
import argparse
from scipy.io import arff
# import weka.core.jvm
# import weka.core.converters
import re
import copy
from collections import Counter
from collections import defaultdict
from sklearn.cluster import KMeans, AgglomerativeClustering
from sklearn import metrics
from scipy.spatial.distance import cdist
from numpy import dot
from numpy.linalg import norm
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt


import pickle

from sklearn.cluster import kmeans_plusplus

from sklearn.metrics import rand_score
from Common_functions import *
from multiprocessing import Pool
import subprocess


def main(args):
    client_id = args.client_id
    t = args.t
    batch_size = args.batch_size
    min_round_before_agg = args.min_round_before_agg
    save_path = args.save_path
    local_model=torch.load(save_path + 'SVM_output_layer_Model_for_client_' + str(client_id) + '.pt')
    local_model.encoder.eval()

    acc_to_save={};


    print("t: ", t, "client_id: ", client_id)

    with open('test_x', 'rb') as fp:
        test_=pickle.load(fp)
    with open('test_y', 'rb') as fp:
        test_labels_=pickle.load(fp)

    test=test_[client_id]
    test_labels=test_labels_[client_id]

    test_label_dict = defaultdict(list)
    for k in range(10):
        test_label_dict[k] = []
        for j in range(len(test_labels)):
            if test_labels[j] == k:
                test_label_dict[k].append(j)

    feats = local_model.encode(test);
    enc_to_del = local_model.encoder
    local_model.encoder = None
    del enc_to_del


    if t>=min_round_before_agg:
        with open(save_path + 'Approach_1_SVM_Cluster_list' + str(t) + '_bs_' + str(
                batch_size) + '.pkl', 'rb') as fp:
            clusters_list_main=pickle.load(fp)

        with open(save_path + 'Approach_1_SVM_Aggregated_Weights' + str(t) + '_bs_' + str(
                batch_size) + '.pkl', 'rb') as fp:
            cluster_aggs_main_dict=pickle.load(fp)

        with open(save_path + 'Approach_1_SVM_Aggregated_Intercepts' + str(t) + '_bs_' + str(
                batch_size) + '.pkl', 'rb') as fp:
            cluster_aggs_main_dict_intercept=pickle.load(fp)



        for i in range(len(clusters_list_main)):
            if client_id in clusters_list_main[i]:
                local_model.classifier.coef_ = cluster_aggs_main_dict[i]
                local_model.classifier.intercept_ = cluster_aggs_main_dict_intercept[i]

                break;


        acc = classifier_score_modded_feats(local_model, feats, test_labels)
        acc_to_save['App_1_SVM']=acc;

        # with open(save_path + 'Approach_1_SVM_acc_for_client_'+str(client_id)+'and_Round_' + str(t) + '_bs_' + str(
        #         batch_size) + '.pkl', 'wb') as fp:
        #     pickle.dump(acc, fp)



        with open(save_path + 'Approach_3_SVM_Cluster_list' + str(t) + '_bs_' + str(
                batch_size) + '.pkl', 'rb') as fp:
            clusters_list_main = pickle.load(fp)

        with open(save_path + 'Approach_3_SVM_Aggregated_Weights' + str(t) + '_bs_' + str(
                batch_size) + '.pkl', 'rb') as fp:
            cluster_aggs_main_dict = pickle.load(fp)

        with open(save_path + 'Approach_3_SVM_Aggregated_Intercepts' + str(t) + '_bs_' + str(
                batch_size) + '.pkl', 'rb') as fp:
            cluster_aggs_main_dict_intercept = pickle.load(fp)

        for i in range(len(clusters_list_main)):
            if client_id in clusters_list_main[i]:
                local_model.classifier.coef_ = cluster_aggs_main_dict[i]
                local_model.classifier.intercept_ = cluster_aggs_main_dict_intercept[i]

                break;

        acc = classifier_score_modded_feats(local_model, feats, test_labels)
        acc_to_save['App_3_SVM'] = acc;

        # with open(save_path + 'Approach_3_SVM_acc_for_client_' + str(client_id) + 'and_Round_' + str(t) + '_bs_' + str(
        #         batch_size) + '.pkl', 'wb') as fp:
        #     pickle.dump(acc, fp)


    with open(save_path + 'IFCA_SVM_Cluster_list' + str(t) + '_bs_' + str(
            batch_size) + '.pkl', 'rb') as fp:
        clusters_list_main = pickle.load(fp)

    with open(save_path + 'IFCA_SVM_Aggregated_Weights' + str(t) + '_bs_' + str(
            batch_size) + '.pkl', 'rb') as fp:
        cluster_aggs_main_dict = pickle.load(fp)

    with open(save_path + 'IFCA_SVM_Aggregated_Intercepts' + str(t) + '_bs_' + str(
            batch_size) + '.pkl', 'rb') as fp:
        cluster_aggs_main_dict_intercept = pickle.load(fp)

    for i in range(len(clusters_list_main)):
        if client_id in clusters_list_main[i]:
            local_model.classifier.coef_ = cluster_aggs_main_dict[i]
            local_model.classifier.intercept_ = cluster_aggs_main_dict_intercept[i]

            break;

    acc = classifier_score_modded_feats(local_model, feats, test_labels)
    # with open(save_path + 'IFCA_SVM_acc_for_client_' + str(client_id) + 'and_Round_' + str(t) + '_bs_' + str(
    #         batch_size) + '.pkl', 'wb') as fp:
    #     pickle.dump(acc, fp)

    acc_to_save['IFCA_SVM_acc']=acc




    with open(save_path + 'C_t_flsc for round ' + str(t) + '__', 'rb') as fp:
        C_t_FLSC=pickle.load(fp)

    with open(save_path + 'FLSC_SVM_Aggregated_Weights' + str(t) + '_bs_' + str(
            batch_size) + '.pkl', 'rb') as fp:
        cluster_aggs_main_dict = pickle.load(fp)

    with open(save_path + 'FLSC_SVM_Aggregated_Intercepts' + str(t) + '_bs_' + str(
            batch_size) + '.pkl', 'rb') as fp:
        cluster_aggs_main_dict_intercept = pickle.load(fp)

    C_t_idx_vec = C_t_FLSC[client_id]
    num_clusters_for_overlap=len(C_t_idx_vec)
    cluster_i_class_agg = 0. * cluster_aggs_main_dict[C_t_idx_vec[0]];
    cluster_i_intercept_agg = 0. * cluster_aggs_main_dict_intercept[C_t_idx_vec[0]];
    for j in C_t_idx_vec:
        cluster_i_class_agg += cluster_aggs_main_dict[j]
        cluster_i_intercept_agg += cluster_aggs_main_dict_intercept[j]

    local_model.classifier.coef_ = cluster_i_class_agg / num_clusters_for_overlap
    local_model.classifier.intercept_ = cluster_i_intercept_agg / num_clusters_for_overlap

    acc = classifier_score_modded_feats(local_model, feats, test_labels)
    # with open(save_path + 'FLSC_SVM_acc_for_client_' + str(client_id) + 'and_Round_' + str(t) + '_bs_' + str(
    #         batch_size) + '.pkl', 'wb') as fp:
    #     pickle.dump(acc, fp)

    acc_to_save['FLSC_SVM_acc']=acc;




    with open(save_path + 'No_mem_evo_SVM_Cluster_list' + str(t) + '_bs_' + str(
                    batch_size) + '.pkl', 'rb') as fp:
        clusters_list_main = pickle.load(fp)

    with open(save_path + 'No_mem_evo_SVM_Cluster_Agg_weights' + str(t) + '_bs_' + str(
                    batch_size) + '.pkl', 'rb') as fp:
        cluster_aggs_main_dict = pickle.load(fp)

    with open(save_path + 'No_mem_evo_SVM_Cluster_Agg_intercept' + str(t) + '_bs_' + str(
                    batch_size) + '.pkl', 'rb') as fp:
        cluster_aggs_main_dict_intercept = pickle.load(fp)

    for i in range(len(clusters_list_main)):
        if client_id in clusters_list_main[i]:
            local_model.classifier.coef_ = cluster_aggs_main_dict[i]
            local_model.classifier.intercept_ = cluster_aggs_main_dict_intercept[i]

            break;

    acc = classifier_score_modded_feats(local_model, feats, test_labels)
    # with open(save_path + 'No_mem_evo_SVM_acc_for_client_' + str(client_id) + 'and_Round_' + str(t) + '_bs_' + str(
    #         batch_size) + '.pkl', 'wb') as fp:
    #     pickle.dump(acc, fp)

    acc_to_save['No_mem_evo_SVM_acc']=acc;



    with open(save_path + 'No_mem_snapshot_SVM_Cluster_list' + str(t) + '_bs_' + str(
                    batch_size) + '.pkl', 'rb') as fp:
        clusters_list_main = pickle.load(fp)

    with open(save_path + 'No_mem_snapshot_SVM_Cluster_Agg_weights' + str(t) + '_bs_' + str(
                    batch_size) + '.pkl', 'rb') as fp:
        cluster_aggs_main_dict = pickle.load(fp)

    with open(save_path + 'No_mem_snapshot_SVM_Cluster_Agg_intercept' + str(t) + '_bs_' + str(
                    batch_size) + '.pkl', 'rb') as fp:
        cluster_aggs_main_dict_intercept = pickle.load(fp)

    for i in range(len(clusters_list_main)):
        if client_id in clusters_list_main[i]:
            local_model.classifier.coef_ = cluster_aggs_main_dict[i]
            local_model.classifier.intercept_ = cluster_aggs_main_dict_intercept[i]

            break;

    acc = classifier_score_modded_feats(local_model, feats, test_labels)
    # with open(save_path + 'No_mem_snapshot_SVM_acc_for_client_' + str(client_id) + 'and_Round_' + str(t) + '_bs_' + str(
    #         batch_size) + '.pkl', 'wb') as fp:
    #     pickle.dump(acc, fp)

    acc_to_save['No_mem_snapshot_SVM_acc']=acc;

    with open(save_path + 'Accs_for_client_' + str(client_id) + 'and_Round_' + str(t) + '_bs_' + str(batch_size) + '.pkl', 'wb') as fp:
        pickle.dump(acc_to_save, fp)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--save_path', type=str, default='Save_models/')
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--client_id', type=int)
    parser.add_argument('--t', type=int)
    parser.add_argument('--min_round_before_agg', type=int, default=30)



    args = parser.parse_args()
    main(args)