import matplotlib.pyplot as plt
import numpy as np
from os import listdir
from os.path import isfile, join
import itertools
from utils import utils
from collections import defaultdict
import pandas as pd


def run(folder, extra_results, sep, extension, criteria="-l",
        plot=True, save_plots=False, save_analysis=False, combine_metrics=False):
    if folder.endswith("/"):
        folder = folder[:-1]

    datasets, method, metrics, plot_means, plot_stds = summarize(folder, criteria, save_analysis)

    if save_plots or plot:
        if combine_metrics:
            plot_combined(plot_means, plot_stds, datasets, metrics,
                          plot, save_plots, extension, folder, sep, criteria, extra_results)
        else:
            plot_sep(plot_means, plot_stds, datasets, metrics, method,
                     plot, save_plots, extension, folder, sep, criteria, extra_results)


def summarize(folder, criteria='-l', save_analysis=False):
    files = [f for f in listdir(folder) if isfile(join(folder, f)) and not f.startswith('.') and f.endswith(".csv")
             and not f.startswith('analysis-') and not f.startswith('parameters-')]

    if len(files) > 1:
        files = sorted(files, key=lambda x: int(x.split(".")[0].split(criteria)[-1]))

    method = files[0].split(criteria)[0]

    datasets, n_folds, headers, metrics = utils.read_header(files, folder)

    line = ""

    plot_means = []
    plot_stds = []

    for i in range(0, len(headers)):
        local_max_values = headers[i]["max_value"]
        local_num_nodes = headers[i]["num_nodes"]

        datasets_max_values = defaultdict(list)
        datasets_num_nodes = defaultdict(list)

        for j in range(0, len(local_max_values), n_folds * len(metrics)):
            dataset = int(j / (n_folds * len(metrics)))

            local_data = np.array(local_max_values)[j:j + n_folds * len(metrics)]
            local_nodes = np.array(local_num_nodes)[j:j + n_folds * len(metrics)]

            for k in range(0, len(local_data), len(metrics)):
                for l in range(len(metrics)):
                    datasets_max_values[datasets[dataset] + "." + metrics[l]].append(local_data[k+l])
                    datasets_num_nodes[datasets[dataset] + "." + metrics[l]].append(local_nodes[k+l])

        datasets_max_values = pd.DataFrame(datasets_max_values)
        datasets_num_nodes = pd.DataFrame(datasets_num_nodes)

        means_max_values = datasets_max_values.mean().fillna(0).values
        std_max_values = datasets_max_values.std().fillna(0).values
        plot_means.append(means_max_values)
        plot_stds.append(std_max_values)

        line += files[i] + "\t" + "\t".join(datasets_max_values.keys()) + "\n"
        line += "means_max_values\t" + "\t".join(map(str, means_max_values)) + "\n"
        line += "std_max_values\t" + "\t".join(map(str, std_max_values)) + "\n"
        line += "means_num_nodes\t" + "\t".join(map(str, datasets_num_nodes.mean().values)) + "\n"

    if save_analysis:
        output_file = open(join(folder, "analysis-" + folder.split("/")[-1] + ".csv"), "w+")
        output_file.write(line)

    return datasets, method, metrics, np.array(plot_means), np.array(plot_stds)


def plot_sep(means, stds, datasets, metrics, method,
             plot, save_plots, extensions, folder, sep, criteria, extra_results=None):

    (extra_results, linestyles, markers,
     plot_means, plot_stds, sep_labels, sep_values) = init_plot_vars(extra_results, means, sep, stds)

    for i in range(len(datasets)):
        for j in range(len(metrics)):
            title = datasets[i] + "." + metrics[j]
            init_plt(title, sep_labels, sep_values)

            for extra in extra_results:
                datasets_extra, method_extra, metrics_extra, means_extra, stds_extra = summarize(extra, criteria)

                ind_dataset = (datasets_extra == (datasets[i])).nonzero()[0][0] * len(metrics)
                ind_metric = (metrics_extra == (metrics[j])).nonzero()[0][0]

                plot_means_extra = np.transpose(means_extra)[ind_dataset+ind_metric]
                plot_stds_extra = np.transpose(stds_extra)[ind_dataset+ind_metric]

                plt.errorbar(sep_values, plot_means_extra, plot_stds_extra, label=method_extra,
                             linestyle=next(linestyles), marker=next(markers),
                             clip_on=False, markeredgewidth=2, capsize=5)

            ind_values = i * len(metrics)
            plt.errorbar(sep_values, plot_means[ind_values+j], plot_stds[ind_values+j], label=method,
                         clip_on=False, markeredgewidth=2, capsize=5, linestyle='-', marker='o')

            apply_output_options(title, extensions, folder, plot, save_plots)


def apply_output_options(title, extensions, folder, plot, save_plots):
    plt.legend(loc='best', fontsize=11.7)

    if save_plots:
        for extension in extensions:
            plot_path = join(folder, "{0}.{1}".format(title, extension))
            plt.savefig(plot_path, bbox_inches='tight', pad_inches=0)
    if plot:
        plt.show()
    else:
        plt.close()


def init_plt(title, sep_labels, sep_values):
    fig, ax = plt.subplots()
    ax.yaxis.grid()
    ax.set_ylim([0, 1])
    ax.set_xticklabels(sep_labels)

    plt.rc('font', family='serif')
    plt.title(title, fontsize=18)
    plt.yticks(np.linspace(0, 1, num=11))
    plt.xticks(sep_values)


def init_plot_vars(extra_results, means, sep, stds):
    if extra_results is None:
        extra_results = []

    sep_values = np.linspace(1, 100, num=len(sep))
    sep_labels = sep

    plot_means = np.transpose(means)
    plot_stds = np.transpose(stds)

    markers = itertools.cycle(('x', 'D', '*'))
    linestyles = itertools.cycle(('-.', '--', ':'))

    return extra_results, linestyles, markers, plot_means, plot_stds, sep_labels, sep_values


def plot_combined(means, stds, datasets, metrics, plot, save_plots,
                  extensions, folder, sep, criteria, extra_results=None):
    (extra_results, linestyles, markers,
     plot_means, plot_stds, sep_labels, sep_values) = init_plot_vars(extra_results, means, sep, stds)

    for i in range(len(datasets)):
        title = datasets[i]
        init_plt(title, sep_labels, sep_values)

        # for k in range(len(metrics)):
        for j in range(len(metrics)):
            plot_means_extra = np.array([])
            plot_stds_extra = np.array([])
            for extra in extra_results:
                datasets_extra, method_extra, metrics_extra, means_extra, stds_extra = summarize(extra, criteria)

                ind_dataset = (datasets_extra == (datasets[i])).nonzero()[0][0]
                ind_metric = (metrics_extra == (metrics[j])).nonzero()[0][0]

                means_extra = np.transpose(means_extra)[ind_dataset+ind_metric]
                stds_extra = np.transpose(stds_extra)[ind_dataset+ind_metric]

                plot_means_extra = np.append(plot_means_extra, means_extra, axis=0)
                plot_stds_extra = np.append(plot_stds_extra, stds_extra, axis=0)

                if len(means_extra) > 1:
                    plt.errorbar(sep_values, means_extra, stds_extra, label=method_extra,
                                 linestyle=next(linestyles), marker=next(markers),
                                 clip_on=False, markeredgewidth=2, capsize=5)

            if not(len(plot_means_extra) > 1):
                full_means = np.append(plot_means[i * len(metrics) + j], plot_means_extra, axis=0)
                full_stds = np.append(plot_stds[i * len(metrics) + j], plot_stds_extra, axis=0)
                plt.errorbar(sep_values, full_means, full_stds, label=metrics[j],
                             clip_on=False, markeredgewidth=2, capsize=5, linestyle='-', marker='o')
            else:
                plt.errorbar(sep_values,
                             plot_means[i * len(metrics) + j],
                             plot_stds[i * len(metrics) + j], label=metrics[j],
                             clip_on=False, markeredgewidth=2, capsize=5, linestyle='-', marker='o')

            if j < len(metrics) - 1:
                continue
            else:
                apply_output_options(title, extensions, folder, plot, save_plots)


if __name__ == '__main__':
    extras = ["../results/uci-3T3F-50p-b32-nmax-new",
              "../results/uci-3T3F-50p-b32-nmax",
              "../results/uci-3T3F-50p-b32-push2-nn",
              "../results/uci-3T3F-50p-b32-push2-nn-8nodes",
              "../results/uci-3T3F-50p-b32-gamma"]
    extras2 = ["../results/uci-3T3F-50p-b32-nmax-short",
               "../results/uci-3T3F-50p-b64-nmax-short"]

    auto_enc = ["../results/autoencoder50_only_som_pur",
                "../results/autoencoder50_only_som_ce"]

    teste = ["../results/semi2_autoencoder30_mnist"]
    run("../results/semi4_autoencoder10_b256_mnist", extra_results=None, sep=["100" ,"1000", "4000", "60000"], # "64", "128", "256", "512"],
        extension=["eps"], criteria="t_",
        plot=True, save_plots=False, save_analysis=True, combine_metrics=False)
