import os
import pickle
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 22})

from matplotlib.pyplot import figure


def plot_and_save(experiment_name, output_dir, results, transition_eigengap, cov_eigengap, eigenvalues,num_states):

    if not os.path.exists("./" + output_dir):
        os.makedirs("./" + output_dir)

    figure(figsize=(15, 12), dpi=100)

    # Plot experiement results
    plt.figure(1)
    plt.grid(True)

    colour_list = ['blue', 'green', 'red', 'magenta', 'purple']
    linestyle_list = [':', '--', '-', '-.']
    itr = 0
    for key in results:
        filename1 = output_dir + "/" + experiment_name + "_" + key + "_errors.csv"
        filename2 = output_dir + "/" + experiment_name + "_" + key + "_indices.csv"
        with open(filename1, 'wb') as fp:
            pickle.dump(results[key][0], fp)
        with open(filename2, 'wb') as fp:
            pickle.dump(results[key][1], fp)
        print(key, ", Number of indices : ", len(results[key][1]), ". Number of error values : ", len(results[key][0]))
        print(key, " initial error : ", results[key][0][0])
        print(key, " final error : ",  results[key][0][-1])

        index = results[key][2]
        sin2error = results[key][0]
        stddev = results[key][1]
        if key == "vanilla":
            index = results[key][2][::50]
            sin2error = results[key][0][::50]
            stddev = results[key][1][::50]
            l = "Oja's algorithm"
        elif key == "data_drop":
            index = results[key][2][::10]
            sin2error = results[key][0][::10]
            stddev = results[key][1][::10]
            l = "Downsampled"
        elif key == "offline":
            l = "Offline"
        elif key == "iid":
            l = "IID"
        else:
            l = key
        print(l)
        print("Standard Deviation", stddev)
        y_lower = sin2error-stddev
        for i in range(len(y_lower)):
            if y_lower[i] < 0:
                y_lower[i] = sin2error[i]
        y_upper = sin2error+stddev
        errors = [y_lower, y_upper]
        plt.plot(results[key][2], results[key][0], label=l, linestyle=linestyle_list[itr], color=colour_list[itr], linewidth=3)
        # plt.errorbar(index, sin2error, errors, label=l, linestyle=linestyle_list[itr], color=colour_list[itr], linewidth=3)
        itr+=1

    plt.xticks(fontweight='bold')
    plt.yticks(fontweight='bold')
    plt.rcParams["text.usetex"] = True
    plt.rcParams["text.latex.preamble"] = r"\usepackage{amsmath}"
    plt.rcParams["text.latex.preamble"] = r"\usepackage{bm}"
    # plt.title('sin² error vs Number of Samples',fontweight='bold',fontsize=30)
    # plt.text(1, 1, 'Transition Eigengap : ' + str(round(transition_eigengap, 3)))
    # plt.text(0.95, 0.95, 'Cov Eigengap : ' + str(round(cov_eigengap, 3)))
    # plt.text(0.90, 0.90, 'Num states : ' + str(num_states))
    # plt.xlabel("Sample Size",fontsize=30)
    # plt.ylabel('sin² error',fontsize=30)
    legend = plt.legend(fontsize=30)
    plt.setp(legend.get_texts())
    filename = output_dir + "/" + experiment_name + ".png"
    plt.savefig(filename)

    # # Plot eigenvalues of Covariance Matrix
    # plt.figure(2)
    # plt.title("Eigenvalues of Covariance Matrix")
    # plt.xlabel("Index")
    # plt.ylabel("Eigenvalue")
    # plt.plot(eigenvalues)
    # filename = output_dir + "/" + experiment_name + "_eigenvalues" + ".png"
    # plt.savefig(filename)
