import os
import argparse
from pydoc import cli
from utils import get_path
from pathlib import Path
import csv
import pandas as pd


import numpy as np
# import pandas as pd
import matplotlib
from matplotlib import pyplot as plt
from matplotlib import rc

matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
plt.switch_backend('agg')

parser = argparse.ArgumentParser(description='plot script')
parser.add_argument('--plot_algorithms', action='store_true')
parser.add_argument("--algorithms", help='algorithms for comparison', nargs='+',
                    type=str, default=['Regular', 'Joint', 'FedAvg', 'DFedEM', 'Federico', 'FedFomo', 'FedAvg+', 'CFL'])
parser.add_argument("--n_neighbors", help="Number of neighbors of Federico to get model from in each round.",
                    type=int, default=1)
parser.add_argument("--n_components", help='number of components in the mixture of distribution', type=int,
                    default=2)
parser.add_argument('--plot_client_weights', action='store_true')
parser.add_argument('--n_runs', type=int, default=2)  # number of trials
parser.add_argument("--class_seeds", help='"Random seeds for data partition',
                    nargs='+', type=int, default=[])
parser.add_argument("--seeds", help='"Random seeds for model',
                    nargs='+', type=int, default=[])
parser.add_argument('--dataset', type=str, default='mnist')
parser.add_argument("--frac", help="fraction of training dataset used",
                    type=float, default=0.5)
parser.add_argument("--n_clusters", help="number of components/clusters; default is -1",
                    type=int, default=-1)
parser.add_argument("--private_model_type", help="Private model architecture.",
                    type=str, choices=['LeNet5', 'MLP', 'CNN', 'CNN1', 'CNN2','ResNet18'], default="MLP")
parser.add_argument('--result_path', type=str, default='./results')
parser.add_argument("--n_clients", help="Number of clients.",
                    type=int, default=2)
parser.add_argument("--use_private_SGD", help="[int as bool] Use private SGD or not.",
                    type=int, default=0)
parser.add_argument("--optimizer", help="Optimizer.",
                    type=str, default='adam')
parser.add_argument("--lr", help="Learning rate.",
                    type=float, default=0.001)
parser.add_argument("--delta", help="delta parameter for DP SGD.",
                    type=float, default=0.00001)
parser.add_argument("--noise_multiplier", help="Gaussian noise deviation for DP SGD.",
                    type=float, default=1.0)
parser.add_argument("--l2_norm_clip", help="L2 norm maximum for clipping in DP SGD.",
                    type=float, default=1.0)
parser.add_argument("--n_rounds", help="Number of FL rounds.",
                    type=int, default=300)
parser.add_argument("--batch_size", help="Batch size during training.",
                    type=int, default=50)
parser.add_argument("--cw_ratio", help="Ratio of component weight training data in each client",
                    type=float, default=0.2)
parser.add_argument("--cw_momentum", help="Momentum update for client weights",
                    type=float, default=0.9)
parser.add_argument("--greedy_eps", help="Epsilon parameter for the epsilon-greedy sampling in training.\
                                          smaller epsilon results in more greedy-like sampling",
                    type=float, default=1.0)
args = parser.parse_args()

# Plot setup
cmap = plt.get_cmap('jet')  # hsv, jet
markers = [None, 'o', '^', 's', 'd', 'P', 'v', '*', 'X', 'h', 'D', "1", "2", 0, 1, 2, 3, 4]   # may be not enough
font = {'family': 'monospace',
        'size': 12}
rc('font', **font)
markersize = 9



data = []

def write_to_csv(file_name, data, columns):
    f = open(file_name, 'w+', newline='')
    with f:
        csvwriter = csv.writer(f)
        csvwriter.writerow(columns)
        for d in data: csvwriter.writerow(d)
    f.close()


def plot_algorithms_accurcies(mean_accuracies, std_accuracies, args, file_path):
    # Plot accuracy
    n_algo, n_rounds = mean_accuracies.shape
    x = np.arange(n_rounds)
    colors = cmap(np.linspace(0, 1.0, n_algo))
    if n_rounds > 10: markevery = n_rounds // 10 
    else: markevery = 1
    plt.figure(figsize=(6, 4))  # ratio
    for i, algorithm in enumerate(args.algorithms):
        linestyle = '--' if algorithm == 'Regular' or algorithm == 'Joint' else '-'

        plt.plot(x, mean_accuracies[i], label=algorithm, color=colors[i],
                 linestyle=linestyle,
                 marker=markers[i], markersize=markersize, markevery=markevery)
        plt.fill_between(x,
                         mean_accuracies[i] -
                         std_accuracies[i],
                         mean_accuracies[i] +
                         std_accuracies[i],
                         facecolor=colors[i], alpha=0.2)

    plt.grid(True)
    plt.xlabel('Rounds')
    plt.ylabel('Accuracy')
    plt.legend(bbox_to_anchor=(1.04,0), loc="lower left", borderaxespad=0)
    plt.savefig(file_path, bbox_inches='tight')
    plt.close()


all_seeds_global_accuracies = []

# load the results
for class_seed in args.class_seeds:
    all_global_accuracies = []
    for algorithm in args.algorithms:
        row_data = [algorithm, class_seed, class_seed, args.n_clusters]
        load_file = get_path(args, algorithm=algorithm,
                            class_seed=class_seed, seed=class_seed)
        load_results = np.load(load_file)
        # customized_acc (n clients, n rounds) test_size_norm (n_clients)
        # global_acc (n_rounds) 
        test_size_norm = load_results['client_test_sizes'] / load_results['client_test_sizes'].sum()
        global_acc_single_run = np.dot(load_results['customized_accuracy'].T, test_size_norm)[:149]
        row_data += [global_acc_single_run[-1], np.max(global_acc_single_run)]
        data.append(row_data)
        all_global_accuracies.append(global_acc_single_run)
    all_global_accuracies = np.stack(all_global_accuracies)
    all_seeds_global_accuracies.append(all_global_accuracies)
res_dir = os.path.join('res_plots', args.dataset, 'n_clusters_{}'.format(args.n_clusters))
if not os.path.exists(res_dir): os.makedirs(res_dir)
file_name=os.path.join(res_dir,'global_acc_{}_neighbors_cw_momentum_{}_eps_{}.csv'.format(args.n_neighbors, args.cw_momentum, args.greedy_eps))   
write_to_csv(file_name, data,  columns=['algorithm', 'class_seed', 'seed', 'n_clusters', 'final global_acc', 'max global_acc']) 
file_name = os.path.join(res_dir,'global_acc_{}_neighbors_cw_momentum_{}_eps_{}.pdf'.format(args.n_neighbors, args.cw_momentum, args.greedy_eps))        
all_seeds_global_accuracies = np.stack(all_seeds_global_accuracies, axis=(1))
mean_global_accuracies = np.mean(all_seeds_global_accuracies, axis=(1))
std_global_accuracies = np.std(all_seeds_global_accuracies, axis=(1))
plot_algorithms_accurcies(mean_global_accuracies, std_global_accuracies, args, file_name)

algorithms = ['FedAvg', 'FedAvg+', 'Regular', 'CFL',  'DFedEM', 'FedFomo', 'Federico']
res = pd.read_csv(file_name, header=0)
aggregated_res = res.groupby('algorithm').agg({'final global_acc':['mean','std'], 'max global_acc':['mean','std']}).round(4) * 100
aggregated_res = aggregated_res.reindex(algorithms)
file_name=os.path.join(res_dir,'aggregated_global_acc_{}_neighbors_cw_momentum_{}_eps_{}.csv'.format(args.n_neighbors, args.cw_momentum, args.greedy_eps))   