from DataManipulation import load_from_file
import matplotlib.pyplot as plt
import matplotlib


def plot_values(ax, p, values, labels, title, y_label, labels_to_ignore=[]):
    ax.grid()
    ax.set_xlabel("p")
    ax.set_ylabel(y_label)
    ax.set_title(title)
    # linetypes = ['-', '--', '-.', ':', (0, (1, 10))]
    # linetypes = ['--', '-.', '-.', '-.', '-.', '-.', '-.']
    linetypes = ['-'] * len(values)
    cmap = matplotlib.cm.get_cmap('CMRmap')
    colors = [cmap(0.1), cmap(0.2), cmap(0.3), cmap(0.4), cmap(0.5), cmap(0.6), cmap(0.7)]
    for i in range(len(values)):
        if labels[i] not in labels_to_ignore:
            ax.plot(p[i], values[i], label=labels[i], linestyle=linetypes[i], color=colors[i])
    # plt.legend()
    # plt.show()
    # plt.close(fig)


def fibonacci_array(n, zero_start=False):
    if n == 1:
        ret = [1]
    elif n == 2:
        ret = [1, 2]
    else:
        ret = [1, 2]
        while (len(ret)) < n:
            ret.append(ret[-1] + ret[-2])
    if zero_start:
        ret.insert(0, 0)
        ret.pop()
    return ret


def plot():
    for iters in fibonacci_array(9, True):
        leduc_cd_responses_tests_plot(f"results/br_sbr/all_leduc_holdem_lp_cfr_iter={iters}")


def leduc_cd_responses_tests_plot(file_name, labels_to_ignore=[]):
    data = load_from_file(file_name)
    plt.rcParams.update({'font.size': 14, 'font.family': 'Times New Roman'})
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 3))
    plt.gcf().subplots_adjust(bottom=0.17, left=0.08, right=0.99, top=0.9)
    plot_values(ax1, data["p"], data["gain"], data["labels"], "", "Gain", labels_to_ignore)
    plot_values(ax2, data["p"], data["expl"], data["labels"], "", "Exploitability", labels_to_ignore)
    plt.legend(bbox_to_anchor=(-1.2, 1.02, 2.2, .102), loc='lower left',
               ncol=len(data["p"]), mode="expand", borderaxespad=0., handlelength=1, handletextpad=0.3, borderpad=0.1)
    plt.show()


if __name__ == '__main__':
    plot()
    # leduc_cd_responses_tests_plot(f"results/br_sbr/all_leduc_holdem_lp_cfr_iter={3}", ["max-margin", "unsafe", "resolving"])
