import sys
import os
sys.path.append(os.environ.get("BASE_PATH", ""))
#rom scripts.notebooks.true_loss_level.get_transition_probabilities import load_model_corelogic
from hydra import compose, initialize
import numpy as np
import pickle
import pandas as pd
import torch
import json
import os
import argparse

BASE_PATH = os.environ.get("BASE_PATH", "")
if BASE_PATH and BASE_PATH.endswith('/'):
    BASE_PATH = BASE_PATH[:-1]

def get_config():

    rand_train2_top4 = {
         "experiment": "equities/attention_factors_equities",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-04-12/10-43-38/last.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    }

    return rand_train2_top4


def hydraload_corelogic(config,path):
    from train import SequenceLightningModule
    import src.utils as utils

    #from train import preemption_setup
    config = utils.train.process_config(config)
    config.train.pretrained_model_path = path
    utils.train.print_config(config, resolve=True)

    #config = preemption_setup(config)
    #config.dataset.data_path = dataset
    
    model = SequenceLightningModule(config)
    model = SequenceLightningModule.load_from_checkpoint(
            config.train.pretrained_model_path,
            config=config,
            strict=config.train.pretrained_model_strict_load,
        )
    return model

def load_model(experiment,checkpoint_path, nr_factors, **kwargs):
    try:
        initialize(version_base=None, config_path="./../../../configs/") # dont use relative path
    except:
        print("Already initialized")
    cfg = compose(config_name="config.yaml",
                overrides=["experiment="+experiment])
    try:
        cfg.model.layer.n_factors = nr_factors
    except:
        print("No n_factors in config")
    #cfg.dataset.load_data = True # Why true
    #cfg.dataset.dataset_config.database_size = 1000
    model = hydraload_corelogic(cfg, checkpoint_path)
    return model


def check_causality(model, val_X, val_Y, t=20):
    """
    Check whether the model's predictions at time t depend on inputs from future timesteps.
    Uses Sharpe ratio as the objective function for more financial relevance.
    
    Args:
        model: The model to evaluate
        val_X: Input tensor of shape [T, N] or similar
        val_Y: Target tensor
        t: Time step to check for causality
    """
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Ensure we have proper tensors with gradients enabled
    x = torch.tensor(val_X, dtype=torch.float32, requires_grad=True).to(device)
    y = torch.tensor(val_Y, dtype=torch.float32).to(device)
    
    # Get model predictions
    y_pred = model((x, {}))[0]
    y_pred = y_pred.squeeze(1).to(device)
    y = y.squeeze(1).to(device)
    # Calculate portfolio returns using predictions as weights
    # We'll use predictions at time t to allocate portfolio for returns at t+1
    weights = y_pred[:, t, :]  # Get predictions at time step t
    next_returns = y[:, t+1, :]  # Actual returns at t+1
    
    # Calculate portfolio returns (dot product of weights and next period returns)
    portfolio_returns = (weights * next_returns).sum(dim=1)  # Sum across assets
    
    # Calculate Sharpe ratio (mean return / std deviation)
    mean_return = portfolio_returns.mean()
    std_return = portfolio_returns.std() + 1e-8  # Add small constant to avoid division by zero
    
    # Use negative Sharpe ratio as loss (since we maximize Sharpe but minimize loss)
    loss = -mean_return / std_return
    
    # Compute gradients
    model.zero_grad()
    loss.backward()
    
    # Check gradients at future timesteps
    if hasattr(x, 'grad') and x.grad is not None:
        future_timestep_grads = x.grad[:, :,:, t+1:].abs().sum().item()
        print(f"Future timestep gradients after time step {t}: {future_timestep_grads}")
        
        if future_timestep_grads > 1e-5:
            print(f"Model is NOT causal! Detected gradient influence from future time steps after {t}")
            return False
        else:
            print(f"Model is causal up to time step {t}")
            return True
    else:
        print("No gradients were computed. Check model configuration.")
        return None

import matplotlib.pyplot as plt
def plot_cumulative_returns(start_year, end_year, config_path, experiment_yaml, 
                           returns_save_path,
                           data_path, 
                           table_save_path,
                            year_by_year_path,
                            nr_factors, 
                            model_name="attention_factors",  
                            model_str ="Attention Factor Model (K=30)",
                            path_identifier = "factors_30",
                            check_model_causality=False
                            ):
    # Dictionary to store per-year arrays for computing stats
    year_stats = {}

    all_portfolio_returns = []
    all_market_returns = []
    # Initialize lists to store daily metrics across all years (for overall summary)
    all_daily_turnovers = []
    all_daily_leverages = []
    all_daily_short_fractions = []
    last_weights = None  # For turnover across year boundaries
    
    causality_results = {}  # Track causality check results by year

    for i in range(start_year, end_year+1):
        date = str(i)
        print("CURRENT YEAR", date)
        config_path_date = config_path + date +"_"+ path_identifier + "/last.ckpt"
        model = load_model(experiment_yaml, config_path_date, nr_factors)
        dataset_path = data_path + date + ".npz"

        with np.load(dataset_path, allow_pickle=True) as data:
            splits = pickle.loads(data['splits'].item())
            stats = pickle.loads(data['stats'].item())
        print(f"Dataset loaded from {dataset_path}")

        val = splits["val"]
        val_X = val["X"][:246,:]  # shape (T, N)
        val_Y = val["Y"][:246,:]  # shape (T, N)

        model.eval()
        model._state = None
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        model = model.to(device)

        val_X = torch.tensor(val_X, dtype=torch.float32).to(device)
        val_Y = torch.tensor(val_Y, dtype=torch.float32).to(device)
        
        

        # Your original shape logic for X, Y, etc.
        X = torch.permute(torch.tensor(val_X, dtype=torch.float32), (2,1,0))  
        X = X.unsqueeze(0)
        val_Y = val_Y.unsqueeze(-1)
        val_Y = val_Y.unsqueeze(0)  # (1, 246, 500, 1)
        val_Y = torch.permute(val_Y, (2,0,1,3))  # (500, 1, 246, 1)

        # Perform causality check if flag is set
        if check_model_causality:
            print(f"\nPerforming causality check for year {date}...")
            # Reshape tensors for causality check if needed
            check_X = X.clone()  # Clone to avoid modifying original
            check_Y = val_Y.clone()  # Match expected shape
            causality_results[i] = check_causality(model, check_X, check_Y, t=20)
            print(f"Causality check for year {date}: {'Passed' if causality_results[i] else 'Failed'}\n")

        with torch.no_grad():
            pred_Y, out_dict, out_dict_2 = model((X, {}))
            if model_name == "attention_factors":
                explained_variance = out_dict_2['state'][0]['explained_variance']
                e = out_dict_2['state'][0]['e']  #500, 246`
                factor_portfolio = out_dict_2['state'][0]['factor_portfolio']
                factor_portfolio = torch.permute(factor_portfolio, (0, 2, 1))
                factor_portfolio = torch.permute(factor_portfolio, (1,0, 2))
                factor_portfolio = torch.unsqueeze(factor_portfolio, dim=-1)
                breakpoint()
                pred_Y = pred_Y + factor_portfolio
            else:
                explained_variance = torch.tensor(0.0)
                e = torch.zeros(pred_Y.shape[0], pred_Y.shape[2])
            
        assert pred_Y.shape == val_Y.shape

        #per_stock_explained_variance = 1 - torch.var(e,dim=-1)/torch.var(val_Y[:,0,:,0],dim=-1)
        #explained_variance_1 = per_stock_explained_variance.mean()
        
        # -------------------------
        # Calculate daily metrics
        # -------------------------
        W = pred_Y[:, 0, :-1, 0]  # shape (N, T-1)
        daily_leverage = torch.sum(torch.abs(W), dim=0)  # (T-1,)
        daily_abs_short = torch.sum(torch.abs(W * (W < 0)), dim=0)
        daily_short_fraction = torch.zeros_like(daily_leverage)
        nz_mask = (daily_leverage != 0)
        daily_short_fraction[nz_mask] = daily_abs_short[nz_mask] / daily_leverage[nz_mask]


        if last_weights is None:
            # assume we start from a flat book in the first year
            last_weights = torch.zeros_like(W[:, 0])

        turnover_first_day = torch.sum(torch.abs(W[:, 0] - last_weights))      # scalar
        daily_turnover_intra = torch.sum(torch.abs(W[:, 1:] - W[:, :-1]), dim=0)  # (T-2,)

        daily_turnover = torch.cat(
            [turnover_first_day.unsqueeze(0), daily_turnover_intra], dim=0)    # (T-1,)

        # keep last day's weights for the next year
        last_weights = W[:, -1]

        # store metrics
        all_daily_turnovers.extend(daily_turnover.cpu().numpy())
        all_daily_leverages.extend(daily_leverage.cpu().numpy())          # (T-1,)
        all_daily_short_fractions.extend(daily_short_fraction.cpu().numpy())

        ret = (W * val_Y[:, 0, 1:, 0]).sum(dim=0)  # (T-1,)
        portfolio_returns = ret.cpu().numpy()
        market_returns = torch.mean(val_Y[:, 0, 1:, 0], axis=0).cpu().numpy()

        all_portfolio_returns.extend(list(portfolio_returns))
        all_market_returns.extend(list(market_returns))

        # Save daily arrays for per-year stats
        year_stats[i] = {
            "daily_returns": portfolio_returns,
            "daily_market_returns": market_returns,
            "daily_turnover": daily_turnover_intra.cpu().numpy(),
            "daily_leverage": daily_leverage.cpu().numpy(),
            "daily_short_fraction": daily_short_fraction.cpu().numpy(),
            "explained_variance": explained_variance.cpu().numpy(),
        }

    # -------------------------
    # Aggregate over all years
    # -------------------------
    all_portfolio_returns = np.array(all_portfolio_returns)
    all_market_returns = np.array(all_market_returns)
    # Convert lists used for metrics to numpy arrays before calculations
    all_daily_turnovers = np.array(all_daily_turnovers)
    all_daily_leverages = np.array(all_daily_leverages)
    all_daily_short_fractions = np.array(all_daily_short_fractions)
    mean_returns = np.mean(all_portfolio_returns)
    std_returns = np.std(all_portfolio_returns)
    sharpe_ratio_daily = mean_returns / (std_returns + 1e-12)
    sharpe_ratio_annualized = sharpe_ratio_daily * np.sqrt(252)
    mean_return_annualized = mean_returns * 252
    std_return_annualized = std_returns * np.sqrt(252)
    
    transaction_costs = 0.0005*all_daily_turnovers + 0.0001*all_daily_short_fractions
    all_portfolio_returns_net = all_portfolio_returns - transaction_costs
    mean_returns_net = np.mean(all_portfolio_returns_net)
    std_returns_net = np.std(all_portfolio_returns_net)
    sharpe_ratio_daily_net = mean_returns_net / (std_returns_net + 1e-12)
    sharpe_ratio_annualized_net = sharpe_ratio_daily_net * np.sqrt(252)
    mean_return_annualized_net = mean_returns_net * 252
    std_return_annualized_net = std_returns_net * np.sqrt(252)
    print(f"Sharpe ratio (yearly): {sharpe_ratio_annualized}")
    print(f"Mean return (yearly): {mean_return_annualized}")
    print(f"Std return (yearly): {std_return_annualized}")
    print(f"Sharpe ratio (yearly) net: {sharpe_ratio_annualized_net}")
    print(f"Mean return (yearly) net: {mean_return_annualized_net}")
    print(f"Std return (yearly) net: {std_return_annualized_net}")
    mean_market_returns = np.mean(all_market_returns)
    std_market_returns = np.std(all_market_returns)
    sharpe_ratio_market_daily = mean_market_returns / (std_market_returns + 1e-12)
    sharpe_ratio_market_annualized = sharpe_ratio_market_daily * np.sqrt(252)
    mean_market_return_annualized = mean_market_returns * 252
    std_market_return_annualized = np.std(all_market_returns) * np.sqrt(252)
    print(f"Sharpe ratio (yearly) market: {sharpe_ratio_market_annualized}")
    print(f"Mean market return (yearly): {mean_market_return_annualized}")
    print(f"Std market return (yearly): {std_market_return_annualized}")

    beta = np.cov(all_portfolio_returns, all_market_returns)[0,1] / (np.var(all_market_returns) + 1e-12)
    print(f"Beta: {beta}")

    mean_turnover = np.mean(all_daily_turnovers)
    mean_leverage = np.mean(all_daily_leverages) 
    mean_short_fraction = np.mean(all_daily_short_fractions)

    print(f"Mean Daily Turnover: {mean_turnover:.4f}")
    print(f"Mean Leverage: {mean_leverage:.4f}")
    print(f"Mean Short Fraction: {mean_short_fraction:.4f}")

    all_portfolio_returns_cum = [1 + x for x in all_portfolio_returns]
    cumulative_returns = np.cumprod(all_portfolio_returns_cum)

    all_portfolio_returns_net_cum = [1 + x for x in all_portfolio_returns_net]
    cumulative_returns_net = np.cumprod(all_portfolio_returns_net_cum)

    all_market_returns_cum = [1 + x for x in all_market_returns]
    cumulative_market_returns = np.cumprod(all_market_returns_cum)
    print("cumulative returns", cumulative_returns[-1])
    print("cumulative market returns", cumulative_market_returns[-1])

    # -------------------------
    # 1) Original Summary Table
    # -------------------------
    sr_model = f"{sharpe_ratio_annualized:.2f}"
    sr_model_net = f"{sharpe_ratio_annualized_net:.2f}"
    mr_model = f"{mean_return_annualized*100:.2f}"
    std_model = f"{std_return_annualized*100:.2f}"
    mr_model_net = f"{mean_return_annualized_net*100:.2f}"
    std_model_net = f"{std_return_annualized_net*100:.2f}"
    cr_model = f"{cumulative_returns[-1]:.2f}"
    beta_model = f"{beta:.2f}"
    turnover_model = f"{mean_turnover:.2f}"
    leverage_model = f"{mean_leverage:.2f}"
    short_model = f"{mean_short_fraction:.2f}"

    sr_market = f"{sharpe_ratio_market_annualized:.2f}"
    mr_market = f"{mean_market_return_annualized*100:.2f}"
    std_market = f"{std_market_return_annualized*100:.2f}"
    cr_market = f"{cumulative_market_returns[-1]:.2f}"

    beta_market = "1.00"
    turnover_market = "0.00"
    leverage_market = "1.00"
    short_market = "0.00"

    name = model_str

    os.makedirs(os.path.dirname(table_save_path), exist_ok=True)
    latex_table = f"""\\documentclass{{article}}
\\usepackage{{booktabs}}
\\usepackage{{graphicx}}
\\usepackage{{siunitx}}

\\begin{{document}}

\\begin{{table}}[ht]
  \\centering
  \\caption{{Summary Statistics (Jan. {start_year}- Dec. {end_year}). The Sharpe Ratio, Mean Return, and Std Dev Return are annualized. The Net Sharpe Ratio is the annualized Sharpe Ratio after accounting for transaction and shorting costs. Beta is relative to the market. Turnover, Leverage, and Short Fraction are daily averages.}}
  \\label{{tab:summary_stats}}
  \\resizebox{{\\textwidth}}{{!}}{{ 
  \\begin{{tabular}}{{l S[table-format=1.2] S[table-format=1.2] S[table-format=1.2] S[table-format=3.2] S[table-format=1.2] S[table-format=1.2] S[table-format=1.2] S[table-format=1.3] S[table-format=1.2] S[table-format=1.3]}}
    \\toprule
    Model         & {{Sharpe}} & {{Return \\%}} & {{Std Dev \\%}} & {{Net}} & {{Net Ret \\%}} & {{Net Std \\%}} & {{Beta}} & {{Daily}} & {{Leverage}} & {{Short}} \\\\
                  & {{Ratio}}  & {{}}       & {{Return}}  & {{Sharpe}} & {{}} & {{}} & {{}} & {{Turnover}} & {{}} & {{Fraction}} \\\\
    \\midrule
    {name}  & {sr_model} & {mr_model} & {std_model} & {sr_model_net} & {mr_model_net} & {std_model_net} & {beta_model} & {turnover_model} & {leverage_model} & {short_model} \\\\
    Market Portfolio         & {sr_market} & {mr_market} & {std_market} & {sr_market} & {mr_market} & {std_market} & {beta_market} & {turnover_market} & {leverage_market} & {short_market} \\\\
    \\bottomrule
  \\end{{tabular}}
  }}
\\end{{table}}

\\end{{document}}
"""
    with open(table_save_path, 'w') as f:
        f.write(latex_table)
    print(f"Summary statistics table saved to {table_save_path}")

    dict_results = {
        "sharpe_ratio_annualized": sharpe_ratio_annualized,
        "mean_return_annualized": mean_return_annualized,
        "std_return_annualized": std_return_annualized,
        "sharpe_ratio_annualized_net": sharpe_ratio_annualized_net,
        "mean_return_annualized_net": mean_return_annualized_net,
        "turnover_model": mean_turnover,
        "leverage_model": mean_leverage,
        "short_fraction_model": mean_short_fraction,
        "beta_model": beta,
    }
    # save the dict with table_save_path replace tex with npy
    np.save(table_save_path.replace(".tex", ".npy"), dict_results)

    # -------------------------
    # 2) Year-by-year table in a single environment (scriptsize)
    # -------------------------
    def compute_year_stats(daily_returns, daily_market_returns, daily_turnover, 
                           daily_leverage, daily_short_fraction, explained_variance):
        assert len(daily_returns) >=2
        mean_r = np.mean(daily_returns)
        std_r = np.std(daily_returns)
        sharpe_daily = mean_r / (std_r + 1e-12)
        sharpe_annual = sharpe_daily * np.sqrt(252)
        mean_return_annual = mean_r * 252
        std_return_annual = std_r * np.sqrt(252)

        if len(daily_returns) == len(daily_market_returns) and len(daily_returns) > 1:
            cov_pm = np.cov(daily_returns, daily_market_returns)
            beta_yr = cov_pm[0, 1] / (cov_pm[1, 1] + 1e-12)
        else:
            beta_yr = 0.0

        avg_turnover = np.mean(daily_turnover) if len(daily_turnover) else 0.0
        avg_leverage = np.mean(daily_leverage) if len(daily_leverage) else 0.0
        avg_short_fraction = np.mean(daily_short_fraction) if len(daily_short_fraction) else 0.0

        return {
            "sharpe": sharpe_annual,
            "return": mean_return_annual,
            "std": std_return_annual,
            "turnover": avg_turnover,
            "leverage": avg_leverage,
            "short_fraction": avg_short_fraction,
            "beta": beta_yr,
            "explained_variance": explained_variance,
        }
    
    # -------------------------
    # 3) Plot the cumulative returns (unchanged)
    # -------------------------
    plt.figure(figsize=(12, 8))
    plt.plot(cumulative_returns, label=name)
    # Save the cumulative returns to a file
    # save as np array
    np_returns_path = returns_save_path.replace(".pdf", ".npy")
    np.save(np_returns_path, cumulative_returns)


    market_returns_path = f"./plot/neurips/cumulative_market_returns_{start_year}_{end_year}.npy"
    np.save(market_returns_path, cumulative_market_returns)
    #plt.plot(cumulative_returns_net, label=f"{name} (net)")
    #plt.plot(cumulative_market_returns, label="Market Portfolio")
    plt.legend(fontsize='large')

    # Create year labels for x-axis
    num_days = len(cumulative_returns)
    years = list(range(start_year, end_year + 2))
    days_per_year = num_days / (end_year - start_year + 1)
    positions = [int(i * days_per_year) for i in range(len(years))]
    positions[-1] = num_days - 1

    plt.xticks(positions, years, rotation=45, ha='right', fontsize='large')
    plt.xlabel("Year", fontsize='x-large')
    plt.ylabel("Cumulative Return (Log Scale)", fontsize='x-large')
    plt.yscale('log')
    plt.title(f"Cumulative Returns: Jan. {start_year} to Dec. {end_year}", fontsize='xx-large')
    plt.grid(True, which='both', linestyle='--', linewidth=0.5)
    plt.tight_layout()
    plt.savefig(returns_save_path)
    plt.close()
    print(f"Cumulative returns plot saved to {returns_save_path}")



    # -------------------------
    # 3) Plot the cumulative returns (unchanged)
    # -------------------------
    plt.figure(figsize=(12, 8))
    plt.plot(cumulative_returns, label=name)
    plt.plot(cumulative_returns_net, label=f"{name} (net)")
    # Save the cumulative returns to a file
    # save as np array
    np_net_returns_path = returns_save_path.replace(".pdf", "_net.npy")
    net_returns_save_path = returns_save_path.replace(".pdf", "_net.pdf")
    np.save(np_net_returns_path, cumulative_returns_net)

    plt.legend(fontsize='large')

    # Create year labels for x-axis
    num_days = len(cumulative_returns)
    years = list(range(start_year, end_year + 2))
    days_per_year = num_days / (end_year - start_year + 1)
    positions = [int(i * days_per_year) for i in range(len(years))]
    positions[-1] = num_days - 1

    plt.xticks(positions, years, rotation=45, ha='right', fontsize='large')
    plt.xlabel("Year", fontsize='x-large')
    plt.ylabel("Cumulative Return", fontsize='x-large')
    plt.title(f"Cumulative Returns: Jan. {start_year} to Dec. {end_year}", fontsize='xx-large')
    plt.grid(True, which='both', linestyle='--', linewidth=0.5)
    plt.tight_layout()
    plt.savefig(net_returns_save_path)
    print(f"Cumulative net returns plot saved to {net_returns_save_path}")

    # Build dictionary of results by year
    results_by_year = {}
    for y in range(start_year, end_year+1):
        data_y = year_stats[y]
        stats_y = compute_year_stats(
            data_y["daily_returns"],
            data_y["daily_market_returns"],
            data_y["daily_turnover"],
            data_y["daily_leverage"],
            data_y["daily_short_fraction"],
            data_y["explained_variance"]
        )
        results_by_year[y] = stats_y

    # Choose how you want to split: first 12 years, last 13 years, etc.
    # Example: 1998-2009 (12 years), then 2010-2021 (12 years).
    # If you truly want 13 in the second block, adjust accordingly.
    # Update with start and end year
    
    first_half_years = list(range(start_year,start_year+ (end_year-start_year)//2))   # 98..09 => 12
    second_half_years = list(range(start_year+ (end_year-start_year)//2, end_year+1))  # 10..21 => 12

    # We will produce the LaTeX in a single "table" environment
    # with \scriptsize and one caption+label, but two tabular blocks
    # separated by a thicker rule (\specialrule).

    def build_latex_block(year_list):
        """
        Returns a tuple: (header_line, data_rows) for the given year range.
        Example header_line: "& 98 & 99 & ... & 09"
        Data rows: a set of row strings (Sharpe row, Return row, etc.)
        """
        # Build a header of the form:  "& 98 & 99 & ... & 09"
        header_str = " & " + " & ".join(f"{str(y)[-2:]:0>2}" for y in year_list) + " \\\\"

        # For each metric, we gather "label & val1 & val2 & ..."
        row_sharpe = ["Sharpe"]
        row_return = ["Return \\%"]
        row_std = ["Std \\%"]
        row_turnover = ["Turnover"]
        row_leverage = ["Leverage"]
        row_short = ["Frac. Short"]
        row_beta = ["Beta"]
        row_explained_variance = ["Expl. Var."]
        for y in year_list:
            st = results_by_year.get(y, {})
            row_sharpe.append(f"{st['sharpe']:.2f}")
            row_return.append(f"{st['return']*100:.2f}")
            row_std.append(f"{st['std']*100:.2f}")
            row_turnover.append(f"{st['turnover']:.2f}")
            row_leverage.append(f"{st['leverage']:.2f}")
            row_short.append(f"{st['short_fraction']:.2f}")
            row_beta.append(f"{st['beta']:.2f}")
            row_explained_variance.append(f"{st['explained_variance']:.2f}")
        def row2latex(label_vals):
            return " & ".join(label_vals) + " \\\\"

        data_rows = [
            row2latex(row_sharpe),
            row2latex(row_return),
            row2latex(row_std),
            row2latex(row_turnover),
            row2latex(row_leverage),
            row2latex(row_short),
            row2latex(row_beta),
            row2latex(row_explained_variance),
        ]
        return header_str, data_rows

    hdr_first, rows_first = build_latex_block(first_half_years)
    hdr_second, rows_second = build_latex_block(second_half_years)

    # Single caption explaining everything
    explanation = (
        "Year-by-year performance metrics for the model tested on each given year, "
        "trained on the prior 8 years. Sharpe, Return, and Std are annualized from daily data. "
        "Turnover is the average daily turnover, Leverage is the average absolute weights, "
        "Fraction Short is the fraction of the portfolio short on average, and Beta is computed "
        "against the average of all assets taken as the market portfolio. The explained variance is the average of the per-stock explained variance, "
        "computed as 1 - the ratio of the variance of the residuals to the variance of the original data. The model is the {name}, with the 500 largest assets in the CRSP dataset by market cap in the year before the test year."
    )

    # Build a single table environment in LaTeX
    year_by_year_tex = r"""\documentclass{article}
\usepackage{booktabs}
\usepackage{graphicx}
\usepackage{siunitx}
\usepackage{multirow}
\usepackage{array}
\usepackage{booktabs}

\begin{document}
\begin{table}[ht]
\scriptsize
\centering
\caption{""" + explanation + r"""}
\label{tab:year_by_year_performance}

% -- First block (first_half_years) --
\begin{tabular}{l*{""" + str(len(first_half_years)) + r"""}{S}}
\toprule
""" + hdr_first + r"""
\midrule
""" + "\n".join(rows_first) + r"""
%\end{tabular}

% -- Thicker line to separate parts --
%\vspace{-1pt}
%\specialrule{1.2pt}{0pt}{0pt}
%\vspace{-1pt}

% -- Second block (second_half_years) --
%\begin{tabular}{l*{""" + str(len(second_half_years)) + r"""}{S}}
\toprule
""" + hdr_second + r"""
\midrule
""" + "\n".join(rows_second) + r"""
\bottomrule
\end{tabular}

\end{table}
\end{document}
"""

    
    with open(year_by_year_path, "w") as f:
        f.write(year_by_year_tex)

    print(f"Year-by-year performance table saved to {year_by_year_path}")

    





def main():
    # --- Argument Parsing ---
    parser = argparse.ArgumentParser(description="Evaluate equity prediction models.")
    parser.add_argument("--use_seed", action="store_true", 
                        help="Whether to use seeds for evaluation")
    parser.add_argument("--seed_numbers", type=str, default="0",
                        help="Comma-separated list of seed numbers to evaluate (used only if use_seed=True)")
    parser.add_argument("--model_name", type=str, default="set_MLP",
                        help="Name of the model architecture to evaluate.")
    parser.add_argument("--nr_factors", type=int, default=30,
                        help="Number of factors (used only for attention_factors model).")
    parser.add_argument("--config_base_path", type=str, default=f"{BASE_PATH}/outputs/outputs/",
                        help="Base directory containing model output folders (usually dated).")
    parser.add_argument("--date", type=str, default="2025-05-10",
                        help="Date subfolder within config_base_path containing model checkpoints.")
    parser.add_argument("--start_year", type=int, default=2002, #1998
                        help="First year of the evaluation period.")
    parser.add_argument("--end_year", type=int, default=2021,  # 2016
                        help="Last year of the evaluation period.")
    parser.add_argument("--plot_dir", type=str, default="./plot/",
                        help="Directory to save output plots and tables.")
    parser.add_argument("--data_path_base", type=str, default=f"{BASE_PATH}/data/equities/equity_dataset_",
                        help="Base path for input data files (year will be appended).")
    parser.add_argument("--experiment_yaml", type=str, default=None,
                        help="Specific experiment YAML file path (overrides defaults based on model_name). Required if model_name is not predefined.")
    parser.add_argument("--check_causality", action="store_true",
                        help="Run causality check on the model.")

    args = parser.parse_args()

    model_name = args.model_name
    nr_factors = args.nr_factors

    # --- Define model-specific configurations and default experiment_yaml ---
    experiment_yaml_default = None # Initialize default

    if model_name == "attention_factors":
        experiment_yaml_default = "equities/attention_factors_equities"
        returns_name = f"returns_{args.start_year}_{args.end_year}_factors_{nr_factors}.pdf"
        table_name = f"summary_stats_factors_{nr_factors}_{args.start_year}_{args.end_year}.tex"
        year_by_year_name = f"year_by_year_performance_factors_{nr_factors}_{args.start_year}_{args.end_year}.tex"
        name = f"Attention Factor Model (K={nr_factors})"
        path_identifier = f"factors_{nr_factors}"
    elif model_name == "attention_factors_no_factor_portfolio":
        experiment_yaml_default = "equities/attention_factors_equities_no_factor_portfolio"
        returns_name = f"returns_{args.start_year}_{args.end_year}_factors_{nr_factors}_no_factor_portfolio.pdf"
        table_name = f"summary_stats_factors_{nr_factors}_no_factor_portfolio_{args.start_year}_{args.end_year}.tex"
        year_by_year_name = f"year_by_year_performance_factors_{nr_factors}_no_factor_portfolio_{args.start_year}_{args.end_year}.tex"
        name = f"Attention Factor Model (K={nr_factors}) (No Factor Portfolio)"
        path_identifier = f"factors_{nr_factors}_no_factors_portfolio"
    elif model_name == "set_MLP":
        experiment_yaml_default = "equities/equities_set_seq"
        returns_name = f"returns_{args.start_year}_{args.end_year}_{model_name}.pdf"
        table_name = f"summary_stats_{model_name}_{args.start_year}_{args.end_year}.tex"
        year_by_year_name = f"year_by_year_performance_{model_name}.tex"
        name = "Set-Seq Model"
        path_identifier = model_name
    elif model_name == "set_MHA":
        experiment_yaml_default = "equities/mha_equities_small_model"
        returns_name = f"returns_{args.start_year}_{args.end_year}_{model_name}.pdf"
        table_name = f"summary_stats_{model_name}_{args.start_year}_{args.end_year}.tex"
        year_by_year_name = f"year_by_year_performance_{model_name}.tex"
        name = "MHA-Seq Model"
        path_identifier = model_name
    elif model_name == "set_MLP_no_set":
        experiment_yaml_default = "equities/set_equities_small_model_no_set"
        returns_name = f"returns_{args.start_year}_{args.end_year}_{model_name}.pdf"
        table_name = f"summary_stats_{model_name}_{args.start_year}_{args.end_year}.tex"
        year_by_year_name = f"year_by_year_performance_{model_name}.tex"
        name = "Seq Model (No Set)"
        path_identifier = model_name
    elif model_name == "set_gated_selection":
        experiment_yaml_default = "equities/set_equities_small_model_gated_selection"
        returns_name = f"returns_{args.start_year}_{args.end_year}_{model_name}.pdf"
        table_name = f"summary_stats_{model_name}_{args.start_year}_{args.end_year}.tex"
        year_by_year_name = f"year_by_year_performance_{model_name}.tex"
        name = "Gated Selection Model"
        path_identifier = model_name
    else:
        # Handle unrecognized model names
        print(f"Warning: Unrecognized model_name '{model_name}'. Using generic naming conventions.")
        experiment_yaml_default = None # No default YAML for unknown models
        returns_name = f"returns_{args.start_year}_{args.end_year}_{model_name}.pdf"
        table_name = f"summary_stats_{model_name}_{args.start_year}_{args.end_year}.tex"
        year_by_year_name = f"year_by_year_performance_{model_name}.tex"
        name = model_name  # Generic name
        path_identifier = model_name # Use model_name as path identifier

    # --- Determine final experiment_yaml ---
    # Use command-line override if provided, otherwise use the default determined above
    experiment_yaml = args.experiment_yaml if args.experiment_yaml is not None else experiment_yaml_default

    # Check if experiment_yaml is set, especially for unknown models
    if experiment_yaml is None:
        raise ValueError(
            f"Experiment YAML path must be provided via --experiment_yaml for the model '{model_name}' "
            "as it's not a predefined model with a default YAML."
        )

    # --- Process seeds if needed ---
    if args.use_seed:
        # Parse the seeds string into a list of integers
        seeds = [int(s.strip()) for s in args.seed_numbers.split(',')]
    else:
        # If not using seeds, just run once with None as the seed
        seeds = [None]

    # --- Construct base paths using args ---
    config_path_base = os.path.join(args.config_base_path, args.date) + "/"
    data_path = args.data_path_base

    os.makedirs(args.plot_dir, exist_ok=True)

    # Loop over all seeds (or just once if not using seeds)
    for seed in seeds:
        # Add seed suffix to paths if using seeds
        seed_suffix = f"_seed_{seed}" if seed is not None else ""
        
        # Construct paths with seed suffix if needed
        config_path = config_path_base  # Base path doesn't change
        
        returns_save_path = os.path.join(args.plot_dir, returns_name.replace('.pdf', f'{seed_suffix}.pdf'))
        table_save_path = os.path.join(args.plot_dir, table_name.replace('.tex', f'{seed_suffix}.tex'))
        year_by_year_path = os.path.join(args.plot_dir, year_by_year_name.replace('.tex', f'{seed_suffix}.tex'))
        
        # Update path_identifier if using seeds
        current_path_identifier = f"{path_identifier}{seed_suffix}" if seed is not None else path_identifier

        print(f"Evaluating with {'seed '+str(seed) if seed is not None else 'default configuration'}")
        
        # --- Call evaluation function ---
        plot_cumulative_returns(
            args.start_year, args.end_year,
            config_path,
            experiment_yaml, # Pass the final determined YAML path
            returns_save_path,
            data_path,
            table_save_path,
            year_by_year_path,
            nr_factors,
            model_name,
            model_str=name,
            path_identifier=current_path_identifier,
            check_model_causality=args.check_causality
        )



if __name__ == "__main__":
    main()