import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

import seaborn as sns
sns.set()

import pickle
import os, glob


N_RESULTS_EXPECTED = 300
HOMOGENEITIES = [0.1, 0.35, 0.6, 0.85] 
LOCAL_COMPUTATION_BUDGETS = [256, 512, 1024]
OPTIMIZERS = ["SGD", "SARAH", "L-SGD", "SCAFFOLD", "BVR-L-SGD"]
MARKERS    = ["o",   "x",     "v",     "s",        "*" ]
CRITERIONS = ["train_loss", "train_acc", "test_loss", "test_acc"]


def does_match_budget(name, budget):
    if "K=1_" in name:
        return "b="+str(budget) in name
    else:
        assert "b=16" in name
        return "K="+str(budget//16) in name
        

def get_optimizer_name(name):
    if "sgd" in name:
        if "K=1_" in name:
            return "SGD"
        else:
            return "L-SGD"
    elif "sarah" in name:
        if "K=1_" in name:
            return "SARAH"
        else:
            return "BVR-L-SGD"
    elif "scaffold" in name:
        return "SCAFFOLD"
    else:
        assert False


def get_homogeneity_name(value):
    return "homogeneity={}".format(value)


def extract_best_results(dir_path, seed):
    dirs = glob.glob(os.path.join(dir_path, "*"))
    tuning_results_lst = []
    for dir in dirs:
        with open(os.path.join(dir, "seed="+str(seed)+".pickle"), "rb") as f:
            tuning_results_lst.append(pickle.load(f))

    best_train_loss_lst = []
    for tuning_results in tuning_results_lst:
        if int(np.sum(~np.isnan(tuning_results["train_loss"]))) != N_RESULTS_EXPECTED:
            best_train_loss_lst.append(np.inf)
        else:
            best_train_loss_lst.append(np.min(tuning_results["train_loss"]))
    best_tuning_id = np.argmin(best_train_loss_lst)
    best_tuned_results = tuning_results_lst[best_tuning_id]
    ans = {'train_loss': best_tuned_results['train_loss'],
           'test_loss': best_tuned_results['test_loss'],
           'train_acc': best_tuned_results['train_acc'],
           'test_acc': best_tuned_results['test_acc']}
    return ans, dirs[best_tuning_id]


def mean_and_std(results):
    ans = {}
    for eval_name in results[0].keys():
       mean = np.mean([result[eval_name] for result in results], axis=0)
       std  = np.std( [result[eval_name] for result in results], axis=0)
       ans[eval_name] = (mean, std)
    return ans


def main():
    working_dir = "results"

    exp_name = "20210413"

    model_name = "fc"
    dataset_name = "cifar10"

    seeds = list(range(5))

    optimizer_names = ["lsgd", "scaffold", "lsarah"]

    results_dict = {(homogeneity, budget): {}
                    for homogeneity in HOMOGENEITIES for budget in LOCAL_COMPUTATION_BUDGETS}
    for homogeneity in HOMOGENEITIES:
        dir_path = os.path.join("results", exp_name, model_name, dataset_name, get_homogeneity_name(homogeneity))
        optimizer_dirs = glob.glob(os.path.join(dir_path, "*"))
        for budget in LOCAL_COMPUTATION_BUDGETS:
            for optimizer_dir in optimizer_dirs:
                if not does_match_budget(optimizer_dir, budget):
                    continue
                tmp = []
                for seed in seeds:
                    results, _ = extract_best_results(optimizer_dir, seed)
                    tmp.append(results)
                results_dict[(homogeneity, budget)][get_optimizer_name(os.path.basename(optimizer_dir))] = mean_and_std(tmp)

    save_dir = os.path.join("figs", exp_name, model_name, dataset_name)
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)
    plot_best_results_by_homogeneities(results_dict=results_dict, save_base_dir=save_dir)
    plot_best_results_by_budgets(results_dict=results_dict, save_base_dir=save_dir)
    plot_results_by_n_communications(results_dict=results_dict, save_base_dir=save_dir)



def plot_best_results_by_homogeneities(results_dict, save_base_dir):
    for budget in LOCAL_COMPUTATION_BUDGETS:
        for eval_name in CRITERIONS:
            save_dir = os.path.join(save_base_dir, "best_results_by_homogeneities", "budget="+str(budget))
            if not os.path.isdir(save_dir):
                os.makedirs(save_dir)
            for marker, optimizer_name in zip(MARKERS, OPTIMIZERS):
                mean = []
                std  = []
                for homogeneity in HOMOGENEITIES:
                    results = results_dict[(homogeneity, budget)][optimizer_name][eval_name][0]
                    if "loss" in eval_name:
                        best_iter = np.argmin(results)
                    else:
                        best_iter = np.argmax(results)                                            
                    mean.append(results[best_iter])
                    std.append( results_dict[(homogeneity, budget)][optimizer_name][eval_name][1][best_iter])
                markes, caps, bars = plt.errorbar(HOMOGENEITIES, mean, yerr=std, linewidth=4, label=optimizer_name, marker=marker, markersize=12)
                [bar.set_alpha(0.5) for bar in bars]
                [cap.set_alpha(0.5) for cap in caps]
            plt.xlabel('q', fontsize=24)
            plt.ylabel(eval_name.replace("_", " "), fontsize=24)
            plt.xticks(fontsize=18)
            plt.yticks(fontsize=18)
                                
            plt.legend(fontsize=16)
            plt.title("$B = {}$".format(budget), fontsize=24)
            plt.tight_layout()
            save_path = os.path.join(save_dir, eval_name+".png")
            plt.savefig(save_path)
            plt.close()


def plot_best_results_by_budgets(results_dict, save_base_dir):
    for homogeneity in HOMOGENEITIES:
        for eval_name in CRITERIONS:
            save_dir = os.path.join(save_base_dir, "best_results_by_budgets", "q="+str(homogeneity))
            if not os.path.isdir(save_dir):
                os.makedirs(save_dir)
            for marker, optimizer_name in zip(MARKERS, OPTIMIZERS):
                mean = []
                std  = []
                for budget in LOCAL_COMPUTATION_BUDGETS:
                    results = results_dict[(homogeneity, budget)][optimizer_name][eval_name][0]
                    if "loss" in eval_name:
                        best_iter = np.argmin(results)
                    else:
                        best_iter = np.argmax(results)
                    mean.append(results[best_iter])
                    std.append( results_dict[(homogeneity, budget)][optimizer_name][eval_name][1][best_iter])
                markes, caps, bars = plt.errorbar(LOCAL_COMPUTATION_BUDGETS, mean, yerr=std, linewidth=4, label=optimizer_name, marker=marker, markersize=12)
                [bar.set_alpha(0.5) for bar in bars]
                [cap.set_alpha(0.5) for cap in caps]    
            plt.xlabel('$B$', fontsize=24)
            plt.ylabel(eval_name.replace("_", " "), fontsize=24)
            plt.xticks(fontsize=18)
            plt.yticks(fontsize=18)
                                
            plt.legend(fontsize=18)
            plt.title("$q = {}$".format(homogeneity), fontsize=24)
            plt.tight_layout()
            save_path = os.path.join(save_dir, eval_name+".png")
            plt.savefig(save_path)
            plt.close()


def plot_results_by_n_communications(results_dict, save_base_dir):
    for eval_name in CRITERIONS:
        for homogeneity in HOMOGENEITIES:
            for budget in LOCAL_COMPUTATION_BUDGETS:
                save_dir = os.path.join(save_base_dir, "results_by_n_communications", "q="+str(homogeneity), "budget="+str(budget))
                if not os.path.isdir(save_dir):
                    os.makedirs(save_dir)
            
                for marker, optimizer_name in zip(MARKERS, OPTIMIZERS):
                    mean = results_dict[(homogeneity, budget)][optimizer_name][eval_name][0]
                    std  = results_dict[(homogeneity, budget)][optimizer_name][eval_name][1]
                    markes, caps, bars = plt.errorbar(range(len(mean[::30])), mean[::30], yerr=std[::30], linewidth=4, marker=marker, markersize=12, label=optimizer_name)
                    [bar.set_alpha(0.5) for bar in bars]
                    [cap.set_alpha(0.5) for cap in caps]
                    
                plt.xlabel('# Communication Rounds/10', fontsize=24)
                plt.ylabel(eval_name.replace("_", " "), fontsize=24)
                plt.xticks(range(len(mean[::30])), np.array(range(len(mean[::30])))*30+30, fontsize=18)
                plt.yticks(fontsize=18)
                plt.legend(fontsize=18)
                plt.title("$q = {}, B = {}$".format(homogeneity, budget), fontsize=24)
                plt.tight_layout()
                save_path = os.path.join(save_dir, eval_name+".png")
                plt.savefig(save_path)
                plt.close()

        
if __name__ == "__main__":
    main()
