#!/usr/bin/env python
# make_results_table.py
#
# Aggregate per-seed summary statistics and emit a LaTeX table.

import os
import numpy as np
from pathlib import Path

# ------------------------------------------------------------------
# configuration ----------------------------------------------------
# ------------------------------------------------------------------
L           = 5                     # number of seeds
START_YEAR  = 2002
END_YEAR    = 2021
OUT_DIR     = Path("./plot/neurips")
OUT_DIR.mkdir(parents=True, exist_ok=True)
OUT_FILE    = OUT_DIR / f"results_table_{L}_seeds_{START_YEAR}_{END_YEAR}.tex"

# Map filename-stem  ->  label to print in the table
MODELS = [
    ("long-conv",     "Long-Conv"),
    ("s4_simple",     "S4"),
    ("h3",            "H3"),
    ("mha",           "Transformer"),
    ("hyena",         "Hyena"),
    ("market_portfolio", "Market Portfolio"),  # single file, no seeds
    ("set_MLP",       "Set-Seq (Ours)"),
]

#MODELS = [
#    ("set_MLP",       "Set-Seq (Ours)"),
#]

# keys we actually care about
KEYS = {
    "sharpe_ratio_annualized" : "sharpe",
    "mean_return_annualized"  : "ret",
    "std_return_annualized"   : "std",
    "beta_model"              : "beta",
    "turnover_model"          : "tov",
    "short_fraction_model"    : "short",
}

# ------------------------------------------------------------------
# helper -----------------------------------------------------------
# ------------------------------------------------------------------
def load_single(stat_path: Path):
    """Load one *.npy file and return its dict, raising FileNotFoundError if absent."""
    return np.load(stat_path, allow_pickle=True).item()

def aggregate_model(stem: str) -> dict:
    """
    Return {key: (mean, std)} for all KEYS.
    For Sharpe we keep the std; for the others std = None.
    """
    stats_per_seed = []

    # Handle the Market Portfolio (no seeds) gracefully
    n_seeds = 1 if stem == "market_portfolio" else L

    for seed in range(n_seeds):
        fn = f"./plot/summary_stats_{stem}_{START_YEAR}_{END_YEAR}_seed_{seed+1}.npy"
        stats_per_seed.append(load_single(Path(fn)))

    out = {}
    for full_key, short_key in KEYS.items():
        vals = np.array([d[full_key] for d in stats_per_seed], dtype=float)
        
        if full_key == "sharpe_ratio_annualized":
            out[short_key] = (vals.mean(), vals.std(ddof=0))
        else:
            out[short_key] = (vals.mean(), None)          # std not required
    return out

# ------------------------------------------------------------------
# main --------------------------------------------------------------
# ------------------------------------------------------------------
rows = []
for stem, label in MODELS:
    try:
        m = aggregate_model(stem)
    except FileNotFoundError as e:
        print(f"[warn] missing stats for '{stem}': {e}")
        continue
    rows.append(
        f"{label:<15} & "
        f"${m['sharpe'][0]:.2f}\\,\\pm\\,{m['sharpe'][1]:.2f}$ & "
        f"{m['ret'][0]*100:.1f} & "
        f"{m['std'][0]*100:.2f} & "
        f"{m['beta'][0]:.3f} & "
        f"{m['tov'][0]:.2f} & "
        f"{m['short'][0]:.2f} \\\\"
    )

table_body = "\n".join(rows)

latex = rf"""
\begin{{table}}[ht]
  \centering
  \caption{{Summary statistics for the equities task out of sample (Jan.~{START_YEAR}--Dec.~{END_YEAR}).  
  Each model is trained five times on different random seeds; all values are the mean over those runs, and the Sharpe ratio is reported as $\text{{mean}}\pm\text{{std}}$.  
  The Sharpe Ratio, Mean Return, and Std-Dev Return are annualized.  
  Beta is relative to the market, while Daily Turnover and Short Fraction are daily averages.}}
  \label{{tab:summary_stats}}
  \resizebox{{\textwidth}}{{!}}{{%
  \begin{{tabular}}{{l S[table-format=1.2] S[table-format=2.1] S[table-format=2.2] S[table-format=1.2] S[table-format=1.2] S[table-format=1.2]}}
    \toprule
    Model         & {{Sharpe}} & {{Return \%}} & {{Std Dev \%}} & {{Beta}} & {{Daily}} &  {{Short}} \\
                  & {{Ratio}}  &               & {{Return}}     &          & {{Turnover}} &  {{Fraction}} \\
    \midrule
{table_body}
    \bottomrule
  \end{{tabular}}
  }}%
\end{{table}}
"""

OUT_FILE.write_text(latex.strip() + "\n")
print(f"LaTeX table written to: {OUT_FILE.resolve()}")






