import math
import torch
import numpy
import argparse
from scipy.io import arff
# import weka.core.jvm
# import weka.core.converters
import re
import copy
from collections import Counter
from collections import defaultdict
from sklearn.cluster import KMeans, AgglomerativeClustering
from sklearn import metrics
from scipy.spatial.distance import cdist
from numpy import dot
from numpy.linalg import norm
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt


import pickle

from sklearn.cluster import kmeans_plusplus

from sklearn.metrics import rand_score
from Common_functions import *
from multiprocessing import Pool
import subprocess


def main(args):
    num_rounds =args.num_rounds
    min_round_before_agg = args.min_round_before_agg;
    batch_size = args.batch_size;
    save_path =args.save_path
    acc_app_1_vec=[];
    acc_app_3_vec=[];
    acc_evo_inst_vec=[];
    acc_snapshot_vec=[];
    acc_ifca_vec=[];
    acc_flsc_vec=[];
    with open(save_path + 'Accuracy_save_for _bs_' + str(
            batch_size) + '_num_rounds_' + str(num_rounds) + '_min_rounds_before_agg_' + str(min_round_before_agg) + '.pkl',
              'rb') as fp:
        Experiment_accs=pickle.load(fp);
    acc_app_1_vec=Experiment_accs['acc_app_1_vec'];
    acc_app_3_vec=Experiment_accs['acc_app_3_vec'];
    acc_evo_inst_vec=Experiment_accs['acc_evo_inst_vec'];
    acc_snapshot_vec=Experiment_accs['acc_snapshot_vec'] ;
    acc_ifca_vec=Experiment_accs['acc_ifca_vec'];
    acc_flsc_vec=Experiment_accs['acc_flsc_vec'];


    plt.plot(numpy.arange(min_round_before_agg, num_rounds),numpy.mean(acc_app_1_vec,-1))
    plt.plot(numpy.arange(min_round_before_agg, num_rounds),numpy.mean(acc_app_3_vec,-1))
    plt.plot(numpy.mean(acc_ifca_vec,-1))
    plt.plot(numpy.mean(acc_flsc_vec,-1))
    plt.plot(numpy.mean(acc_evo_inst_vec,-1))
    plt.plot(numpy.mean(acc_snapshot_vec,-1))

    plt.xlabel('Rounds')
    plt.ylabel('Accuracy')
    plt.legend(["Approach 1","Approach 3", "IFCA","FLSC", "Evolutionary Inst", "Snapshot Inst"])
    plt.savefig("out.png")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--save_path', type=str, default='Save_models/')
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--num_rounds', type=int, default=200)
    parser.add_argument('--min_round_before_agg', type=int, default=10)


    args = parser.parse_args()
    main(args)