import numpy as np
import matplotlib.pyplot as plt
plt.rc('text', usetex=True)
plt.rc('font', family='serif')
import seaborn as sns
sns.set()

include = [0, 1, 2, 3, 4]
results = np.concatenate([np.expand_dims(np.load(f'resultsnew{i}_seed5.npy'), 0) for i in include])
xs = np.arange(results.shape[-1]-1)+2

fig, ax = plt.subplots()
labels = ['Test (no shift)', 'Test (shift)']
markers = ['o', '^', 's', '*', 'P']
for res_ind, result in enumerate(results):
    plt.gca().set_prop_cycle(None)
    for i in range(len(labels)):
        ind = 2 * i + 3
        ax.plot(xs, result[ind][1:], label=(f'IRM {labels[i]}' if res_ind == 0 else '__nolegend__'),
                marker=markers[res_ind], alpha=.75)
        if res_ind == 0:
            ax.plot(xs, np.repeat(result[ind, 0], len(xs)), label=(f'ERM {labels[i]}'),
                    linestyle='--', color=f'C{len(labels)*2-i}', linewidth=1.5)

ax.plot(xs, results[0, 7][1:], label='Optimal Invariant Classifier', linestyle='--', linewidth=1.5)
plt.axvline(x=6, color='black', label=r'$E=d_e$', alpha=.75, linestyle=':')
plt.ylim(-.01, 1.01)
plt.xlabel('Environments Seen (E)', fontsize=16)
plt.ylabel('Accuracy', fontsize=16)
plt.legend(loc='center left', fontsize=12)
plt.show()