import matplotlib.pyplot as plt
import json
from experiment_stats import EXPERIMENT_STATS_DIR
from caltech101_experiments import ALL_IMAGE_FILEPATHS
import os 
import numpy as np 

PLOTS_DIR = "Plots/"

def parse_ys_labels_top2_sparse(L): 
    svd_list = [] 
    greedy_list = [] 
    sparse1_list = [] 
    sparse2_list = [] 

    assert(len(L) >= 1 and len(L[0]) == 4) 

    for entry in L: 
        svd_list.append(entry[0]) 
        greedy_list.append(entry[1]) 
        sparse1_list.append(entry[2]) 
        sparse2_list.append(entry[3])
    
    return [(svd_list, 'SVD'), (greedy_list, 'Greedy $k$-CSS$_{1, 2}$'), (sparse1_list, 'Regular $k$-CSS$_{1, 2}$ Setting1'), (sparse2_list, 'Regular $k$-CSS$_{1, 2}$ Setting2')]

def parse_ys_labels_all_settings(L): 
    assert(len(L) >= 1 and len(L[0]) >= 2) 
    svd_list = [] 
    greedy_list = [] 
    sparse_list_of_lists = [[] for i in range(len(L[0])-2)]

    for entry in L: 
        svd_list.append(entry[0]) 
        greedy_list.append(entry[1])

        for i in range(2, len(entry), 1):
            sparse_index = i - 2
            sparse_list_of_lists[sparse_index].append(entry[i])

    # collect results in ys_labels
    ys_labels = [(svd_list, 'SVD'), (greedy_list, 'Greedy $k$-CSS$_{1, 2}$')] 
    sparse_label_expr = 'Regular $k$-CSS$_{1, 2}$ Setting%d'
    for i in range(len(sparse_list_of_lists)): 
        sparse_embedding_num = i + 1 
        ys_labels.append((sparse_list_of_lists[i], sparse_label_expr%sparse_embedding_num))

    return ys_labels

def parse_ys_labels_greedy_sparse_settings(L):
    assert(len(L) >= 1 and len(L[0]) >= 1) 
    greedy_list = [] 
    sparse_list_of_lists = [[] for i in range(len(L[0])-1)]

    for entry in L: 
        greedy_list.append(entry[0])

        for i in range(1, len(entry), 1):
            sparse_index = i - 1
            sparse_list_of_lists[sparse_index].append(entry[i])

    # collect results in ys_labels
    ys_labels = [(greedy_list, 'Greedy $k$-CSS$_{1, 2}$')] 
    sparse_label_expr = 'Regular $k$-CSS$_{1, 2}$ Setting%d'
    for i in range(len(sparse_list_of_lists)): 
        sparse_embedding_num = i + 1 
        ys_labels.append((sparse_list_of_lists[i], sparse_label_expr%sparse_embedding_num))

    return ys_labels


def parse_ys_yerrs_labels_top2_sparse(means, stds):
    ys_labels = parse_ys_labels_top2_sparse(means) 
    yerrs_labels = parse_ys_labels_top2_sparse(stds)
    assert(len(ys_labels) == len(yerrs_labels))
    return [(ys_labels[i][0], yerrs_labels[i][0], ys_labels[i][1]) for i in range(len(ys_labels))]

def parse_ys_yerrs_labels_all_settings(means, stds): 
    ys_labels = parse_ys_labels_all_settings(means)
    yerrs_labels = parse_ys_labels_all_settings(stds)
    assert(len(ys_labels) == len(yerrs_labels))
    return [(ys_labels[i][0], yerrs_labels[i][0], ys_labels[i][1]) for i in range(len(ys_labels))]

def make_line_plot(x, ys_labels, xlabel, ylabel, title, save_path):
    for y, label in ys_labels: 
        plt.plot(x, y, label=label)
    plt.xlabel(xlabel) 
    plt.ylabel(ylabel)
    plt.legend(loc="upper left", bbox_to_anchor=(0,0), bbox_transform=plt.gcf().transFigure)
    # plt.subplots_adjust(left=0.0, bottom=0.1, right=0.3)
    plt.title(title)
    plt.savefig(save_path, bbox_inches="tight")
    plt.close()

def make_errorbar_plot(x, ys_yerrs_labels, xlabel, ylabel, title, save_path):
    for y, yerr, label in ys_yerrs_labels:
        plt.errorbar(x=x, y=y, yerr=yerr, label=label)
    plt.xlabel(xlabel) 
    plt.ylabel(ylabel)
    plt.legend(loc="upper right")
    plt.title(title)
    plt.savefig(save_path)
    plt.close()

def make_bar_plot(x, ys_labels, xlabel, ylabel, title, save_path): 
    plt.ticklabel_format(axis="y", style="sci", scilimits=(0,0))

    x = np.array(x)
    offset = 0.0 
    width = 0.75
    for y, label in ys_labels: 
        plt.bar(x + offset, y, width=width) 
        offset += width
    plt.xlabel(xlabel) 
    plt.ylabel(ylabel)
    plt.title(title)
    plt.savefig(save_path)
    plt.close()

def make_bar_plot_with_yerr(x, ys_yerrs_labels, xlabel, ylabel, title, save_path): 
    plt.ticklabel_format(axis="y", style="sci", scilimits=(0,0))

    x = np.array(x)
    offset = 0.0 
    width = 0.75
    for y, yerr, label in ys_yerrs_labels: 
        plt.bar(x + offset, y, yerr=yerr, width=width) 
        offset += width

    plt.xlabel(xlabel) 
    plt.ylabel(ylabel)
    plt.title(title)
    plt.savefig(save_path)
    plt.close()

if __name__ == "__main__":
    save_dir = PLOTS_DIR 
    if not os.path.isdir(save_dir): 
        os.mkdir(save_dir)

    dataset_names = ["bcsstk13s", "isolet_transpose", "forest_cover"]
    titles = ["Bcsstk13", "Isolet", "Forest Cover"]

    ###########################################################################
    # make plots for: bcsstk13s, isolet_transpose, forest_cover (top2 sparse embedding settings)
    ###########################################################################
    for dataset_name, title in zip(dataset_names, titles): 
        stats = json.load(open(EXPERIMENT_STATS_DIR + "%s_experiment_stats.json" % dataset_name)) 
        test_ranks = stats['test_ranks']
        error_mins = stats['error_mins']
        error_means = stats['error_means']
        error_stds = stats['error_stds']

        # min error line plot 
        make_line_plot(x=test_ranks, ys_labels=parse_ys_labels_top2_sparse(error_mins), xlabel="Rank", 
            ylabel="Entrywise l1 Error Min", title=title, 
            save_path=save_dir + "%s_min_error_lineplot.png" % dataset_name) 

        # mean, std errorbar plot 
        make_errorbar_plot(x=test_ranks, ys_yerrs_labels=parse_ys_yerrs_labels_top2_sparse(error_means, error_stds), 
            xlabel="Rank", ylabel="Entrywise l1 Error Mean & Std", title=title, 
            save_path=save_dir + "%s_mean_std_error_errorbarplot.png" % dataset_name) 

    ###########################################################################
    # make plots for: caltech101 individual images (top2 sparse embedding settings)
    ###########################################################################
    for image_num, _ in ALL_IMAGE_FILEPATHS: 
        title = "Caltech101 Image %d" % image_num
        stats = json.load(open(EXPERIMENT_STATS_DIR + "caltech101_image%d_experiment_stats.json" % image_num)) 
        test_ranks = stats['test_ranks']
        error_mins = stats['error_mins']
        error_means = stats['error_means']
        error_stds = stats['error_stds']

        # min error line plot 
        make_line_plot(x=test_ranks, ys_labels=parse_ys_labels_top2_sparse(error_mins), xlabel="Rank", 
            ylabel="Entrywise l1 Error Min", title=title, 
            save_path=save_dir + "caltech101_image%d_min_error_lineplot.png" % image_num) 

        # mean, std errorbar plot 
        make_errorbar_plot(x=test_ranks, ys_yerrs_labels=parse_ys_yerrs_labels_top2_sparse(error_means, error_stds), 
            xlabel="Rank", ylabel="Entrywise l1 Error Mean & Std", title=title, 
            save_path=save_dir + "caltech101_image%d_mean_std_error_errorbarplot.png" % image_num) 

    ###########################################################################
    # make plots for caltech101 image set (top2 sparse embedding settings)
    ###########################################################################
    image_set_stats = json.load(open(EXPERIMENT_STATS_DIR + "caltech101_image_set_experiment_stats.json"))
    test_ranks = image_set_stats['test_ranks']
    error_mins = image_set_stats['image_set_mean_error_mins_percent'] 
    error_means = image_set_stats['image_set_mean_error_means_percent'] 
    error_stds = image_set_stats['image_set_mean_error_stds_percent'] 

    title = "Caltech101" 
    # min error line plot 
    make_line_plot(x=test_ranks, ys_labels=parse_ys_labels_top2_sparse(error_mins), xlabel="Rank", 
        ylabel="Entrywise l1 Error Min Percent", title=title, 
        save_path=save_dir + "caltech101_image_set_min_error_percent_lineplot.png")

    # mean, std errorbar plot 
    make_errorbar_plot(x=test_ranks, ys_yerrs_labels=parse_ys_yerrs_labels_top2_sparse(error_means, error_stds), 
        xlabel="Rank", ylabel="Entrywise l1 Error Mean & Std Percent", title=title, 
        save_path=save_dir + "caltech101_image_set_mean_std_error_percent_errorbarplot.png")

    ###########################################################################
    # make plots for: bcsstk13s, isolet_transpose, forest_cover (all settings)
    ###########################################################################
    for dataset_name, title in zip(dataset_names, titles): 
        stats = json.load(open(EXPERIMENT_STATS_DIR + "%s_experiment_stats_all_settings.json" % dataset_name)) 
        test_ranks = stats['test_ranks']
        error_mins = stats['error_mins']
        error_means = stats['error_means'] 
        error_stds = stats['error_stds'] 
        work_means = stats['work_means']
        span_means = stats['span_means']

        # min error bar plot 
        make_bar_plot(x=test_ranks, ys_labels=parse_ys_labels_all_settings(error_mins), xlabel="Rank", 
            ylabel="Entrywise l1 Error Min", title=title, 
            save_path=save_dir + "%s_min_error_all_settings_barplot.png" % dataset_name) 

        # mean, std bar plot 
        make_bar_plot_with_yerr(x=test_ranks, ys_yerrs_labels=parse_ys_yerrs_labels_all_settings(error_means, error_stds), 
            xlabel="Rank", ylabel="Entrywise l1 Error Mean & Std", title=title, 
            save_path=save_dir + "%s_mean_std_error_all_settings_barplot.png" % dataset_name) 

        # work bar plot
        make_bar_plot(x=test_ranks, ys_labels=parse_ys_labels_greedy_sparse_settings(work_means), xlabel="Rank", 
            ylabel="Average Work", title=title, 
            save_path=save_dir + "%s_average_work_all_settings_barplot.png" % dataset_name) 

        # span bar plot
        make_bar_plot(x=test_ranks, ys_labels=parse_ys_labels_greedy_sparse_settings(span_means), xlabel="Rank", 
            ylabel="Average Span", title=title, 
            save_path=save_dir + "%s_average_span_all_settings_barplot.png" % dataset_name) 


    ###########################################################################
    # make plots for caltech101 image set (all settings)
    ###########################################################################
    image_set_stats = json.load(open(EXPERIMENT_STATS_DIR + "caltech101_image_set_experiment_stats_all_settings.json"))
    test_ranks = image_set_stats['test_ranks']
    error_mins = image_set_stats['image_set_mean_error_mins_percent'] 
    error_means = image_set_stats['image_set_mean_error_means_percent'] 
    error_stds = image_set_stats['image_set_mean_error_stds_percent']
    work_means = image_set_stats['image_set_mean_work_means']
    span_means = image_set_stats['image_set_mean_span_means']

    title = "Caltech101" 
    # min error bar plot
    make_bar_plot(x=test_ranks, ys_labels=parse_ys_labels_all_settings(error_mins), xlabel="Rank", 
        ylabel="Entrywise l1 Error Min Percent", title=title, 
        save_path=save_dir + "caltech101_image_set_min_error_percent_all_settings_barplot.png")

    # mean, std bar plot 
    make_bar_plot_with_yerr(x=test_ranks, ys_yerrs_labels=parse_ys_yerrs_labels_all_settings(error_means, error_stds), 
        xlabel="Rank", ylabel="Entrywise l1 Error Mean & Std Percent", title=title, 
        save_path=save_dir + "caltech101_image_set_mean_std_error_percent_all_settings_barplot.png")

    # work bar plot 
    make_bar_plot(x=test_ranks, ys_labels=parse_ys_labels_greedy_sparse_settings(work_means), xlabel="Rank", 
        ylabel="Average Work", title=title, 
        save_path=save_dir + "caltech101_image_set_average_work_all_settings_barplot.png")

    # span bar plot 
    make_bar_plot(x=test_ranks, ys_labels=parse_ys_labels_greedy_sparse_settings(span_means), xlabel="Rank", 
        ylabel="Average Span", title=title, 
        save_path=save_dir + "caltech101_image_set_average_span_all_settings_barplot.png")



        
