import pickle
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import rc
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator


sns.set()
rc('text', usetex=True)


def tb_load(folder_names, tag):
    """  Collect scalar data from a list of files  """
    summaries = []
    for folder_name in folder_names:
        data = EventAccumulator(folder_name).Reload().Scalars(tag)
        steps_and_values = [(event.step, event.value) for event in data]
        summaries.append(steps_and_values)
    return summaries


def plot_learning_curves(list_of_folder_lists, tag, dt=5):
    for folder_list in list_of_folder_lists:
        data = tb_load(folder_list, tag)
        steps = []
        total = None
        for series in data:
            steps, values = zip(*series)
            values = np.array(values)
            smoothed = np.array([np.mean(values[max(t - dt, 0):min(t + dt, len(values))]) for t in range(len(values))])
            if total is None:
                total = np.expand_dims(smoothed, 1)
            else:
                total = np.concatenate((total, np.expand_dims(smoothed, 1)), axis=1)
        mean1 = np.mean(total, 1)
        std1 = np.std(total, 1)
        plt.plot(steps, mean1)
        plt.fill_between(steps, mean1 - std1, mean1 + std1, alpha=0.41)


def make_histograms(list_of_pickle_lists, field, num_bins=50):
    pkl = ''
    for pkl_list in list_of_pickle_lists:
        all_runs = []
        for pkl in pkl_list:
            with open(pkl, 'rb') as p1:
                d1 = pickle.load(p1)
            all_runs += d1[field]
        print(pkl)
        print(np.round(np.mean(all_runs), 1))
        print(np.round(np.std(all_runs), 1))
        print(np.round(np.median(all_runs), 1))
        print(np.round(np.quantile(all_runs, 0.05), 1))
        sns.distplot(all_runs, bins=num_bins, kde_kws={'linewidth': 3, 'bw_adjust': 1.0})
