import os
import argparse
from pydoc import cli
from utils import get_path
from pathlib import Path
import csv


import numpy as np
# import pandas as pd
import matplotlib
from matplotlib import pyplot as plt
from matplotlib import rc

matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
plt.switch_backend('agg')

parser = argparse.ArgumentParser(description='plot script')
parser.add_argument('--plot_algorithms', action='store_true')
parser.add_argument("--algorithms", help='algorithms for comparison', nargs='+',
                    type=str, default=['Regular', 'Joint', 'FedAvg', 'DFedEM', 'Federico', 'FedFomo', 'FedAvg+', 'CFL'])
parser.add_argument("--n_neighbors", help="Number of neighbors of Federico to get model from in each round.",
                    type=int, default=1)
parser.add_argument("--n_components", help='number of components in the mixture of distribution', type=int,
                    default=2)
parser.add_argument('--plot_client_weights', action='store_true')
parser.add_argument('--dataset', type=str, default='mnist')
parser.add_argument("--frac", help="fraction of training dataset used",
                    type=float, default=0.5)
parser.add_argument("--n_clusters", help="number of components/clusters; default is -1",
                    type=int, default=-1)
parser.add_argument("--private_model_type", help="Private model architecture.",
                    type=str, choices=['LeNet5', 'MLP', 'CNN', 'CNN1', 'CNN2','ResNet18'], default="MLP")
parser.add_argument('--result_path', type=str, default='./results')
parser.add_argument("--n_clients", help="Number of clients.",
                    type=int, default=2)
parser.add_argument("--use_private_SGD", help="[int as bool] Use private SGD or not.",
                    type=int, default=0)
parser.add_argument("--optimizer", help="Optimizer.",
                    type=str, default='adam')
parser.add_argument("--lr", help="Learning rate.",
                    type=float, default=0.001)
parser.add_argument("--delta", help="delta parameter for DP SGD.",
                    type=float, default=0.00001)
parser.add_argument("--noise_multiplier", help="Gaussian noise deviation for DP SGD.",
                    type=float, default=1.0)
parser.add_argument("--l2_norm_clip", help="L2 norm maximum for clipping in DP SGD.",
                    type=float, default=1.0)
parser.add_argument("--n_rounds", help="Number of FL rounds.",
                    type=int, default=300)
parser.add_argument("--batch_size", help="Batch size during training.",
                    type=int, default=50)
parser.add_argument("--cw_ratio", help="Ratio of component weight training data in each client",
                    type=float, default=0.2)
parser.add_argument("--cw_momentum", help="Momentum update for client weights",
                    type=float, default=0.9)
parser.add_argument("--greedy_eps", help="Epsilon parameter for the epsilon-greedy sampling in training.\
                                          smaller epsilon results in more greedy-like sampling",
                    type=float, default=1.0)
args = parser.parse_args()

# Plot setup
cmap = plt.get_cmap('jet')  # hsv, jet
markers = [None, 'o', '^', 's', 'd', 'P', 'v', '*', 'X', 'h', 'D', "1", "2", 0, 1, 2, 3, 4]   # may be not enough
font = {'family': 'monospace',
        'size': 12}
rc('font', **font)
markersize = 9


def plot_single_acc(global_acc, args, file_path, title):
    # Plot accuracy
    plt.figure(figsize=(4, 4))  # ratio
    plt.plot(np.arange(len(global_acc)), global_acc)
    plt.grid(True)
    plt.xlabel('Rounds')
    plt.ylabel('Accuracy')
    plt.title(title)
    plt.savefig(file_path, bbox_inches='tight')
    plt.close()


def plot_all_acc(global_acc, args, file_path, labels, title):
    # Plot accuracy
    n_runs, n_rounds = global_acc.shape
    x = np.arange(n_rounds)
    colors = cmap(np.linspace(0, 1.0, n_runs))
    if n_rounds > 10: markevery = n_rounds // 10 
    else: markevery = 1
    plt.figure(figsize=(6, 4))  # ratio
    # print(len(labels), len(global_acc), len(colors), n_runs)
    for i in range(n_runs):
        plt.plot(x, global_acc[i], label=labels[i], color=colors[i],
                 marker=markers[i], markersize=markersize, markevery=markevery)

    plt.grid(True)
    # plt.title(title)
    plt.xlabel('Rounds')
    plt.ylabel('Accuracy')
    plt.legend( loc="lower right", borderaxespad=0)
    # if 'neighbors' in title: plt.ylim(0.1, 0.45)
    plt.savefig(file_path, bbox_inches='tight')
    plt.close()



# load the results

res_dir = os.path.join('res_plots', args.dataset, 'hyper-param-sensitivity', 'n_neighbors')
if not os.path.exists(res_dir): os.makedirs(res_dir)
args.seed=0
args.class_seed=22
n_neighborss=[0,1,3,5,7]

all_global_accuracies = []
for n_neighbors in n_neighborss:
    args.n_neighbors = n_neighbors
    load_file = get_path(args, algorithm='Federico', class_seed=args.class_seed, seed=args.seed)
    load_results = np.load(load_file)
    # customized_acc (n clients, n rounds) test_size_norm (n_clients)
    # global_acc (n_rounds) 
    test_size_norm = load_results['client_test_sizes'] / load_results['client_test_sizes'].sum()
    global_acc = np.dot(load_results['customized_accuracy'].T, test_size_norm)[:149]
    file_name = os.path.join(res_dir,'{}_neighbors.pdf'.format(args.n_neighbors))        
    plot_single_acc(global_acc, args, file_name, title='Accuracy with {} neighbors'.format(n_neighbors))
    all_global_accuracies.append(global_acc) 
all_global_accuracies = np.stack(all_global_accuracies)
file_name = os.path.join(res_dir,'neighbors.pdf')        
plot_all_acc(all_global_accuracies, args, file_name, ["M={}".format(n_neighbors) for n_neighbors in n_neighborss], 
                             'Accuracy with different number of neighbors')

all_global_accuracies = []
args.n_neighbors = 3
cw_momentums = [0.2, 0.4, 0.6, 0.8, 1.0]
res_dir = os.path.join('res_plots', args.dataset, 'hyper-param-sensitivity', 'cw_momentum')
if not os.path.exists(res_dir): os.makedirs(res_dir)
for cw_momentum in cw_momentums:
    args.cw_momentum = cw_momentum
    load_file = get_path(args, algorithm='Federico', class_seed=args.class_seed, seed=args.seed)
    load_results = np.load(load_file)
    # customized_acc (n clients, n rounds) test_size_norm (n_clients)
    # global_acc (n_rounds) 
    test_size_norm = load_results['client_test_sizes'] / load_results['client_test_sizes'].sum()
    global_acc= np.dot(load_results['customized_accuracy'].T, test_size_norm)[:149]
    file_name = os.path.join(res_dir,'{}_cw_momentum.pdf'.format(cw_momentum))        
    plot_single_acc(global_acc, args, file_name, title='Accuracy with momentum {}'.format(cw_momentum))
    all_global_accuracies.append(global_acc) 
all_global_accuracies = np.stack(all_global_accuracies)
file_name = os.path.join(res_dir,'momentums.pdf')        
plot_all_acc(all_global_accuracies, args, file_name, ["momentum={}".format(m) for m in cw_momentums], 
                             'Accuracy with different momentum')
