import numpy as np
import matplotlib.pyplot as plt

# ---------- data ----------
seed = 1
set_seq_returns   = np.load(f"./plot/returns_2002_2021_set_MLP_seed_{seed}.npy")
s4_returns        = np.load(f"./plot/returns_2002_2021_s4_simple_seed_{seed}.npy")
h3_returns        = np.load(f"./plot/returns_2002_2021_h3_seed_{seed}.npy")
mha_returns       = np.load(f"./plot/returns_2002_2021_mha_seed_{seed}.npy")
hyena_returns     = np.load(f"./plot/returns_2002_2021_hyena_seed_{seed}.npy")
longconv_returns  = np.load(f"./plot/returns_2002_2021_long-conv_seed_{seed}.npy")
market_returns    = np.load("./plot/cumulative_market_returns.npy")

start_year, end_year  = 2002, 2021
returns_save_path     = f"./plot/neurips/returns_2002_2021_all_seed_{seed}.pdf"

# ---------- font & line sizes ----------
TITLE_FONTSIZE  = 25
LABEL_FONTSIZE  = 23
TICK_FONTSIZE   = 18
LEGEND_FONTSIZE = 18
LINEWIDTH       = 2.8      # <- slightly thicker

# Dash patterns (same 'k' color so they survive B/W printing)
styles = [
    ("-",                "Set‑Seq"),                          # solid
    ("-",               "LongConv"),                         # long dashes
    ("-",               "S4"),                               # dash‑dot
    ("-",                "H3"),                               # dotted
    ("-",  "MHA"),                              # dot‑dash‑dash
    ('-',        "Hyena"),                            # spaced dashes
    ('-',        "Market (top 500 by MktCap)")        # tight dashes
]

# ---------- plot ----------
plt.figure(figsize=(12, 8))

for (ls, label), series in zip(styles, [
        set_seq_returns, longconv_returns, s4_returns,
        h3_returns, mha_returns, hyena_returns, market_returns]):
    z = 10 if label == "Set‑Seq" else 1  # Put Set‑Seq on top
    plt.plot(series, linestyle=ls, linewidth=LINEWIDTH, label=label, zorder=z)

plt.legend(fontsize=LEGEND_FONTSIZE)

# x‑axis ticks at year boundaries
num_days      = len(set_seq_returns)
years         = np.arange(start_year, end_year + 2)
days_per_year = num_days / (end_year - start_year + 1)
positions     = (years - start_year) * days_per_year
positions[-1] = num_days - 1           # ensure last tick aligns with final point

plt.xticks(positions.astype(int), years, rotation=45, ha='right', fontsize=TICK_FONTSIZE)
plt.yticks(fontsize=TICK_FONTSIZE)

plt.xlabel("Year", fontsize=LABEL_FONTSIZE)
plt.ylabel("Cumulative Return", fontsize=LABEL_FONTSIZE)
#plt.title(f"Cumulative Returns: Jan. {start_year} – Dec {end_year}", fontsize=TITLE_FONTSIZE)

plt.grid(True, which='both', linestyle='--', linewidth=0.5)
plt.tight_layout()
plt.savefig(returns_save_path)
print(f"Cumulative returns plot saved to {returns_save_path}")











