import numpy as np
import matplotlib.pyplot as plt
plt.rc('text', usetex=True)
plt.rc('font', family='serif')
import seaborn as sns

np.random.seed(0)
sns.set()

fig, ax = plt.subplots()

def plot_result(d_c, d_e, var, color='black'):
    std = np.sqrt(var)
    x = np.arange(1, d_e + 1)
    y = []
    for i, j in enumerate(x):
        sigmasq = np.ones(j) * std
        rounds = 20
        y.append([])
        for trial in range(rounds):
            base = np.random.randn(d_e, 1)
            U = base + np.random.randn(d_e, j) * 2
            u, s, v = np.linalg.svd(U, full_matrices=False)
            Usq = np.diag(s) @ v
            val = 1/np.linalg.norm(np.linalg.solve(Usq.T, sigmasq))
            y[i].append(val)
    y = np.array(y)

    ax.plot(x, np.abs(y.mean(1)), label=rf'$d_c={d_c}, d_e={d_e}, \sigma_e^2={var}$', color=color, alpha=.6)
    ci = 1.96 * y.std(1)
    ax.fill_between(x, (y.mean(1)-ci), (y.mean(1)+ci), color=color, alpha=.1)
    plt.axhline(y=np.sqrt(d_c), color=color, linestyle=':')
    if var == 1:
        ax.plot(x, np.sqrt(d_e - x), linestyle='--', color=color)

hyperparams = [
    (40, 200, 3, 'blue'),
    (100, 200, 1, 'red'),
    (50, 300, 3, 'orange'),
    (60, 300, 1, 'green'),
]
for (d_c, d_e, var, color) in hyperparams:
    plot_result(d_c, d_e, var, color)

plt.xlabel(rf'Environments observed', fontsize=16)
plt.ylabel(r'$\tilde\mu$', fontsize=16)
plt.title('Projected Environmental Mean Separation', fontsize=18)
plt.ylim(0, 40)
plt.axhline(y=-10, color='black', linestyle='--', label=r'$\sqrt{d_e - E}$')
plt.legend()
plt.tight_layout()
plt.show()