import os
import argparse
from pydoc import cli
from utils import get_path
from pathlib import Path


import numpy as np
import matplotlib
from matplotlib import pyplot as plt
from matplotlib import rc

matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
plt.switch_backend('agg')

parser = argparse.ArgumentParser(description='plot script')
parser.add_argument("--algorithm", help='algorithms for comparison', 
                    type=str, default='Federico')
parser.add_argument("--n_neighbors", help="Number of neighbors of Federico to get model from in each round.",
                    type=int, default=1)
parser.add_argument("--n_components", help='number of components in the mixture of distribution', type=int,
                    default=2)
parser.add_argument('--n_runs', type=int, default=2)  # number of trials
parser.add_argument("--class_seed", help='"Random seed for data partition',
                    type=int, default=[])
parser.add_argument("--seed", help='"Random seeds for model', type=int, default=[])
parser.add_argument('--dataset', type=str, default='mnist')
parser.add_argument("--frac", help="fraction of training dataset used",
                    type=float, default=0.5)
parser.add_argument("--n_clusters", help="number of components/clusters; default is -1",
                    type=int, default=-1)
parser.add_argument("--major_percent", help="Percentage of majority class for client data partition.",
                    type=float, default=0.8)
parser.add_argument("--private_model_type", help="Private model architecture.",
                    type=str, choices=['LeNet5', 'MLP', 'CNN', 'CNN1', 'CNN2','ResNet18'], default="MLP")
parser.add_argument('--result_path', type=str, default='./results')
parser.add_argument("--n_clients", help="Number of clients.",
                    type=int, default=2)
parser.add_argument("--use_private_SGD", help="[int as bool] Use private SGD or not.",
                    type=int, default=0)
parser.add_argument("--optimizer", help="Optimizer.",
                    type=str, default='adam')
parser.add_argument("--lr", help="Learning rate.",
                    type=float, default=0.001)
parser.add_argument("--delta", help="delta parameter for DP SGD.",
                    type=float, default=0.00001)
parser.add_argument("--noise_multiplier", help="Gaussian noise deviation for DP SGD.",
                    type=float, default=1.0)
parser.add_argument("--l2_norm_clip", help="L2 norm maximum for clipping in DP SGD.",
                    type=float, default=1.0)
parser.add_argument("--n_rounds", help="Number of FL rounds.",
                    type=int, default=300)
parser.add_argument("--batch_size", help="Batch size during training.",
                    type=int, default=50)
parser.add_argument("--cw_ratio", help="Ratio of component weight training data in each client",
                    type=float, default=0.2)
parser.add_argument("--cw_momentum", help="Momentum update for client weights",
                    type=float, default=0.9)
parser.add_argument("--greedy_eps", help="Epsilon parameter for the epsilon-greedy sampling in training.\
                                          smaller epsilon results in more greedy-like sampling",
                    type=float, default=1.0)
args = parser.parse_args()

# Plot setup
cmap = plt.get_cmap('jet')  # hsv, jet
markers = [None, 'o', '^', 's', 'd', 'P', 'v', '*', 'X', 'h', 'D', "1", "2", 0, 1, 2, 3, 4]   # may be not enough
font = {'family': 'monospace',
        'size': 12}
rc('font', **font)
markersize = 9



def get_colors_list(client_major_classes):
    if type(client_major_classes[0]) == np.ndarray or type(client_major_classes[0]) == list:
        client_major_classes = [tuple(class_combination) for class_combination in client_major_classes]
    colors_list = None
    unique_client_major_classes = list(set(client_major_classes))
    unique_client_major_classes = {classes:i for i, classes in enumerate(unique_client_major_classes)}
    cluster_clients = {i:[] for i in range(len(unique_client_major_classes))}
    client_cluster_list = []
    for c_id, client_major_class in enumerate(client_major_classes):
        cluster_clients[unique_client_major_classes[client_major_class]].append(c_id)
        client_cluster_list.append(unique_client_major_classes[client_major_class])
    color_idx = 0
    color_assignments = {}
    # for class_combination in list(client_major_classes):
    #     if class_combination not in color_assignments:
    #         color_assignments[class_combination] = color_idx
    #         color_idx += 1
    colors = cmap(np.linspace(0, 1.0, len(unique_client_major_classes)))
    colors_list = [colors[c] for c in client_cluster_list]
    return colors_list, client_cluster_list



def plot_rounds_client_weights(client_id, client_weights, file_path, client_major_classes, momentum):
    plt.figure(figsize=(4, 4))  # ratio

    # (# of seeds, # rounds, #clients/components)
    n_rounds, n_components = client_weights.shape
    x = np.arange(n_rounds)
    # colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
    colors_list, client_cluster_list = get_colors_list(client_major_classes)
    if n_rounds > 10: markevery = n_rounds // 10 
    else: markevery = 1

    if args.algorithm == 'DFedEM': colors_list = cmap(np.linspace(0, 1.0, n_components))

    for i in range(n_components):
        label = 'component {}'.format(i) if args.algorithm == 'DFedEM' else 'client {}'.format(i) 
        plt.plot(x, client_weights[:, i], label=label,
                color=colors_list[i], marker=markers[i], markersize=markersize, markevery=markevery)

    plt.grid(True)
    plt.xlabel('Rounds')
    plt.ylabel('Client Weights')
    plt.title(r'$\beta = {}$'.format(momentum))
    if momentum==1.0: plt.legend(loc="center right")
    plt.savefig(file_path, bbox_inches='tight')
    plt.close()

res_dir = os.path.join('res_plots', args.dataset,'cw_momentum study', 'n_clusters_{}'.format(args.n_clusters),\
                       "component_weights_plots", 'class_seed_{}_seed_{}'.format(args.class_seed, args.seed))
if not os.path.exists(res_dir): os.makedirs(res_dir)

clien_weights=[]
cw_momentums = [0.2, 0.6, 1.0]
for cw_momentum in cw_momentums:
    args.cw_momentum = cw_momentum
    load_file = get_path(args, algorithm=args.algorithm, seed=args.seed, class_seed=args.class_seed)
    load_results = np.load(load_file, allow_pickle=True)
    client_major_classes = load_results['client_major_classes']
    # print(client_major_classes)
    # shape: (n_clients, # of seeds, # rounds, # components)
    client_weights_all = load_results['component_weights']
    client_id=5
    clien_weights.append(client_weights_all[client_id])
    file_title = 'momentum_{}.pdf'.format(cw_momentum)
    file_path = os.path.join(res_dir, file_title)
    plot_rounds_client_weights(client_id, client_weights_all[client_id][:149], file_path, client_major_classes, cw_momentum)