import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rc, rcParams
import matplotlib.font_manager as fm

fm.fontManager.ttflist += fm.createFontList(['OpenSans-Light.ttf'])
rcParams['font.size'] = 15
rc('font', family='sans-serif')
rcParams['font.family'] = 'Open Sans'
rcParams['font.weight'] = 'light'
rcParams['mathtext.fontset'] = 'cm'
import matplotlib.colors as mcolors


from Tools.compute_optimal_polynomial import compute_optimal_polynomial


kappa = .1
deg_max = 3
y_limit_on_axis = [-2.5, 6]


def eig_gap(relative_gap, int_for_deg=2):
    if int_for_deg == 2:
        relative_gap1 = relative_gap / 2
        relative_gap2 = relative_gap / 2
    elif int_for_deg == 3:
        relative_gap1 = relative_gap / 2 - (1 - relative_gap ** 2) / 4
        relative_gap2 = relative_gap / 2 + (1 - relative_gap ** 2) / 4
    else:
        raise NotImplementedError

    eig_gap = np.concatenate((np.linspace(kappa, (1 + kappa) / 2 - relative_gap1 * (1 - kappa), 50),
                              np.linspace((1 + kappa) / 2 + relative_gap2 * (1 - kappa), 1, 50)))
    return eig_gap


plt.figure(figsize=(14, 4.5), dpi=80)
idx = 0
balance_dict = {2: "Balanced", 3: "Unbalanced"}

colors = {1: mcolors.to_rgb('#66C2A5'), 2: mcolors.to_rgb('#FC8D62'), 3: mcolors.to_rgb('#984EA3')}

for best_degrees, int_for_deg, relative_gap in [({1, 2, 3}, 2, 0), ({2}, 2, .6), ({3}, 3, .6)]:
    idx += 1
    eig_list = eig_gap(relative_gap=relative_gap, int_for_deg=int_for_deg)
    ax = plt.subplot(1, 3, idx)
    plt.tight_layout(rect=(0, 0.1, 1, .9))
    plt.xlabel(r"$\lambda$", fontdict={'size': 20})
    if idx == 1:
        plt.ylabel(r"$\sigma_K^{\Lambda}(\lambda)$", rotation=0, loc="center", labelpad=20, fontdict={'size': 20})
        # ax.yaxis.set_label_position("left")
    plt.gca().set_ylim(*y_limit_on_axis)
    kept_values = list()
    for deg in range(1, deg_max + 1):
        coefs = compute_optimal_polynomial(eig_list=eig_list, deg=deg)
        eigen_array = np.array(np.linspace(0, 1, 100))
        V = np.array([eigen_array ** k for k in range(deg + 1)]).T
        values = V @ coefs
        plt.plot(np.linspace(0, 1, 100), values, color=colors[deg])
        if deg in best_degrees:
            kept_values.append(values[0])
    cs = [colors[best_degree] for best_degree in best_degrees]
    for kept_value, c in zip(kept_values, cs):
        plt.plot([-.05], [kept_value], '*', color=c)

    plt.fill_between(eig_list[:50], np.ones(50), -np.ones(50), color='gray', alpha=0.30)
    plt.fill_between(eig_list[50:], np.ones(50), -np.ones(50), color='gray', alpha=0.30)

    plt.figlegend([r"$\sigma_{}^\Lambda$".format(deg) for deg in range(1, deg_max + 1)], loc='lower center', frameon=False, prop={'size': 20, 'style': 'normal'}, fontsize=20, ncol=3)
    plt.title("{} intervals with R={}".format(balance_dict[int_for_deg], relative_gap), fontsize=20)


plt.show()
