
import os
BASE_PATH = os.environ.get("BASE_PATH", "")
if BASE_PATH and BASE_PATH.endswith('/'):
    BASE_PATH = BASE_PATH[:-1]
EQUITIES_DATA_PATH = os.environ.get("EQUITIES_DATA_PATH", ".")
import sys
sys.path.append(BASE_PATH)
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import pickle
from src.dataloaders.preprocessing_equities import EquityEnv
from src.dataloaders.base import SequenceDataset
SEED_NR = 0 
np.random.seed(SEED_NR)
#torch.manual_seed(SEED_NR)
# Approximate trading days per month
DAYS_IN_1MO = 21
DAYS_IN_2MO = 42
DAYS_IN_7MO = 147
DAYS_IN_12MO = 252
DAYS_IN_13MO = 273
DAYS_IN_36MO = 756
DAYS_IN_60MO = 1260
import numpy as np
from scipy.stats import rankdata

def apply_transform(data, transform):
    if transform == 'none':
        return data
    elif transform == 'log':
        return np.sign(data) * np.log(1.0 + np.abs(data))
    elif transform == 'delta':
        # add data[0] in start of data
        data = np.insert(data, 0, data[0])
        # compute delta
        data = np.diff(data)
        # normalize to [-1,1]
    elif transform == 'delta_delta':
        data = np.insert(data, 0, data[0])
        data = np.insert(data, 0, data[0])
        data = np.diff(data)
        data = np.diff(data)
    elif transform == 'delta_log':
        data = np.insert(data, 0, data[0])
        data = np.log(0.1+data)
        data = np.diff(data)
    elif transform == 'delta_delta_log':
        data = np.insert(data, 0, data[0])
        data = np.insert(data, 0, data[0])
        data = np.log(0.1+data)
        data = np.diff(data)
        data = np.diff(data)
    else:
        raise ValueError(f"Invalid transform: {transform}")

    return data

def transform_raw_feature(
        monthly_data,                       # was `monthly_data`; can be monthly or daily now
        monthly_train_start,        # if daily_data=False, these are monthly indices; 
        monthly_train_end,          # if daily_data=True, these are daily indices
        daily_to_monthly_idx,
        num_days, 
        num_assets, 
        feature_name="unknown", 
        normalize=True,
        transform='none',
        cross_sectional_rank=False,
        daily_data=False,
        ):
    """
    If daily_data=False (default):
      - We assume 'data' is monthly_data of shape (num_months, num_assets).
      - Optionally do cross-sectional rank or sign-log transform & z-score standardization
        using [monthly_train_start : monthly_train_end] as the training slice.
      - Expand from monthly to daily using daily_to_monthly_idx and return shape (num_days, num_assets).

    If daily_data=True:
      - We assume 'data' is daily_data of shape (num_days, num_assets).
      - Optionally do cross-sectional rank or sign-log transform & z-score standardization
        using [monthly_train_start : monthly_train_end] as the training slice,
        but now these train_start/end indices refer to daily rows.
      - We skip the monthly->daily expansion and just return the transformed daily data (shape stays (num_days, num_assets)).

    cross_sectional_rank=True means we skip the sign-log transform and do cross-sectional rank scaling each row to [-1,1].
    """
    data = monthly_data
    # ---------------------------
    # 1) Decide how to interpret 'data'
    # ---------------------------
    if daily_data:
        # 'data' is already daily of shape (num_days, num_assets)
        # We'll rename for clarity:
        daily_data_matrix = data
    else:
        # 'data' is monthly of shape (num_months, num_assets)
        monthly_data = data

    # ---------------------------
    # 2) Cross-sectional rank vs. sign-log transform
    # ---------------------------
    if cross_sectional_rank:
        # We do NOT do the log transform in this mode.
        # We'll transform each row of the data across assets to a [-1, 1] rank.
        # shape (num_days, num_assets)
        median_data = []
        data_trans = np.zeros_like(data, dtype=np.float32)
        time_steps = data.shape[0]
        for i in range(time_steps):
            if daily_data:
                row = daily_data_matrix[i]
            else:
                row = monthly_data[i]
            ranks = rankdata(row, method='average')  # [1..N]
            # scale to [0,1], then shift to [-1,1]
            ranks_scaled = (ranks - 1) / (num_assets - 1)
            data_trans[i] = 2.0 * ranks_scaled - 1.0
            median_data.append(np.median(row))
        
        median_data = np.array(median_data)
        first_non_zero_median_idx = np.where(median_data != 0)[0][0]
        
        if 1 < first_non_zero_median_idx and first_non_zero_median_idx < len(median_data)-2:
            median_data[first_non_zero_median_idx-1] = median_data[first_non_zero_median_idx]
            median_data[first_non_zero_median_idx-2] = median_data[first_non_zero_median_idx]
            
        median_data = apply_transform(median_data, transform)
        median_data = np.array(median_data)
        #median_data = np.sign(median_data) * np.log(1.0 + alpha * np.abs(median_data))
        train_slice = median_data[monthly_train_start:monthly_train_end]
        
        mean_ = np.mean(train_slice)
        std_ = np.std(train_slice)
        #median_data = (median_data - mean_) / std_

        #expand to shape (num_days, num_assets)
        median_data = np.tile(median_data[:, np.newaxis], (1, num_assets))
        

    else:
        # sign-log or pass-through
        if daily_data:
            data_trans = apply_transform(daily_data_matrix, transform)
            #data_trans = np.sign(daily_data_matrix) * np.log(1.0 + alpha * np.abs(daily_data_matrix))
            #else:
            #    data_trans = daily_data_matrix
        else:
            data_trans = apply_transform(monthly_data, transform)
            #if log_transform:
            #    data_trans = np.sign(monthly_data) * np.log(1.0 + alpha * np.abs(monthly_data))
            #else:
            #    data_trans = monthly_data
        median_data = np.zeros_like(data_trans)

        # ---------------------------
        # 3) Standardize (if requested)
        # ---------------------------
        # "train_slice" always uses the same indexing approach as the underlying data.
        if normalize:
            train_slice = data_trans[monthly_train_start:monthly_train_end, :]

            mean_ = np.mean(train_slice)
            std_ = np.std(train_slice)
            data_norm = (data_trans - mean_) / std_
        else:
            data_norm = data_trans
    if cross_sectional_rank:
        data_norm = data_trans
    # ---------------------------
    # 4) If daily_data=True, skip expansion; else expand monthly->daily
    # ---------------------------
    if daily_data:
        # Already daily => no expansion
        daily_feat = data_norm
        daily_median_feat = median_data
    else:
        # Expand monthly->daily
        daily_feat = np.zeros((num_days, num_assets), dtype=np.float32)
        daily_median_feat = np.zeros((num_days, num_assets), dtype=np.float32)
        for d_idx in range(num_days):
            m_idx = daily_to_monthly_idx.get(d_idx)
            if m_idx is not None:
                daily_feat[d_idx] = data_norm[m_idx]
                daily_median_feat[d_idx] = median_data[m_idx]
            else:
                daily_feat[d_idx] = 0.0  # or handle differently if needed

    # ---------------------------
    # 5) Print debugging info
    # ---------------------------
    # We'll gather some stats from the original training slice
    # (the slice used for standardizing) just for logging.
    # For cross-sectional rank, original stats might matter less, but we'll keep them for reference.
    print("++++++++++++++++++++++++++++++++++++++++++++++++++++++")
    if daily_data:
        # If daily_data, original training slice is daily
        original_data_train = data[monthly_train_start:monthly_train_end, :]
    else:
        # If monthly_data, original training slice is monthly
        original_data_train = data[monthly_train_start:monthly_train_end, :]

    # For logging, handle shape differences
    if cross_sectional_rank:
        transform_str = "cross-sectional rank transform (scaled to [-1,1])"
    else:
        transform_str = ""
    transform_str = transform_str + "_" + transform

    if normalize:
        norm_str = "z-score standardized"
    else:
        norm_str = "not standardized"

    print(f"Feature: {feature_name}")
    print(f"Transform: {transform_str}")
    print(f"Standardization: {norm_str}")
    print(f"daily_data={daily_data}, returning shape={daily_feat.shape}")

    print(f"Original mean (train slice): {original_data_train.mean():.4f}")
    print(f"Original std (train slice):  {original_data_train.std():.4f}")
    print("Original min (train slice):", np.min(original_data_train),
          "max:", np.max(original_data_train))
    print("Original median:", np.median(original_data_train))
    print("Original 90th percentile:", np.percentile(original_data_train, 90))
    print("Original 75th percentile:", np.percentile(original_data_train, 75))
    print("Original 25th percentile:", np.percentile(original_data_train, 25))
    print("Original 10th percentile:", np.percentile(original_data_train, 10))
    print("Original 5th percentile:", np.percentile(original_data_train, 5))

    print("Resulting data stats:")
    print("  mean:", np.mean(daily_feat), "std:", np.std(daily_feat))
    print("  min:", np.min(daily_feat), "max:", np.max(daily_feat))
    print("  median:", np.median(daily_feat))
    print("++++++++++++++++++++++++++++++++++++++++++++++++++++++")
    print("NEW median:", np.median(daily_median_feat), "min:", np.min(daily_median_feat), "max:", np.max(daily_median_feat))
    # ---------------------------
    # 6) Return
    # ---------------------------
    
    return daily_feat, daily_median_feat


def make_return_feature(asset_return, alpha_price, train_end_idx,
                        window_hi, window_lo, feature_name="momentum", use_cross_sectional_rank=False):
    """
    Computes a cumulative-return feature from [t - window_hi : t - window_lo) for each day t,
    applies sign-log transform, and then standardizes using only the training set.

    Parameters
    ----------
    asset_return : np.ndarray
        Daily returns, shape (T, N). May contain NaNs (we handle them as 0).
    alpha_price : float
        Alpha for sign-log transform, e.g. 100.
    train_end_idx : int
        End-of-training cutoff (excluded from the training set).
    window_hi : int
        The 'older' bound in days. For example, 42 for ~2 months.
    window_lo : int
        The 'closer' bound in days. For example, 21 for ~1 month.
    feature_name : str
        Only for printing/logging (e.g. 'r2_1', 'r12_2').

    Returns
    -------
    feat_norm : np.ndarray
        Shape (T, N), the final feature (sign-log + standardized).
    """
    T, N = asset_return.shape
    # Replace NaNs in returns with 0.0
    cleaned_returns = np.nan_to_num(asset_return, nan=0.0)

    # 1) Compute raw cumulative returns in the specified window
    feat = np.zeros((T, N), dtype=np.float32)
    for t in range(window_hi, T):
        start_idx = t - window_hi
        end_idx   = t - window_lo   # slice is [start_idx, end_idx)
        # product of (1 + returns) - 1
        cum_ret = np.prod(1.0 + cleaned_returns[start_idx:end_idx, :], axis=0) - 1.0
        feat[t] = cum_ret

    # 2) Sign-log transform
    if not use_cross_sectional_rank:
        feat_trans = np.sign(feat) * np.log(1.0 + alpha_price * np.abs(feat))
        #feat_trans = feat

        # 3) Standardize using training data only
        #    Skip first `window_hi` days to avoid partial windows
        valid_start_idx = window_hi
        if valid_start_idx > train_end_idx:
            valid_start_idx = train_end_idx  # just a safeguard if train_end_idx < window_hi

        train_slice = slice(valid_start_idx, train_end_idx)
        train_vals = feat_trans[train_slice, :]
        mean_ = np.mean(train_vals)
        std_  = np.std(train_vals)

        feat_norm = (feat_trans - mean_) / std_
        median_return = feat_norm
    else:
        # shape (num_months, num_assets)
        median_return = []
        data_trans = np.zeros_like(feat, dtype=np.float32)
        for i in range(feat.shape[0]):
            row = feat[i]
            ranks = rankdata(row, method='average')  # [1..N]
            ranks_scaled = (ranks - 1) / (feat.shape[1] - 1)
            data_trans[i] = 2.0 * ranks_scaled - 1.0
            median_return.append(np.median(row))
        median_return = np.array(median_return)
        transform = 'delta'
        median_return = apply_transform(median_return, transform)
        median_return = np.tile(median_return[:, np.newaxis], (1, feat.shape[1]))
        feat_norm = data_trans
        mean_ = np.mean(feat_norm)
        std_ = np.std(feat_norm)
    print("++++++++++++++++++++++++++++++++++++++++++++++++++++++")
    print("Original mean:", np.mean(feat), "Original std:", np.std(feat))
    print("50 quantile:", np.percentile(feat, 50))
    print("90 quantile:", np.percentile(feat, 90))
    print("10 quantile:", np.percentile(feat, 10))
    print("5 quantile:", np.percentile(feat, 5))
    print("1 quantile:", np.percentile(feat, 1))
    print("99 quantile:", np.percentile(feat, 99))
    print("95 quantile:", np.percentile(feat, 95))
    print(f"[{feature_name}] shape={feat_norm.shape}, train-mean={mean_:.4f}, "
          f"train-std={std_:.4f}, window=[{window_hi}:{window_lo}]")
    print("++++++++++++++++++++++++++++++++++++++++++++++++++++++")
    assert feat_norm.shape == asset_return.shape
    return feat_norm, median_return

# =============================================================================
# 1. Data Creation, Saving, and Loading Functions
# =============================================================================
def create_equity_dataset(eq_env, config, feature_set,
                          train_frac=0.7, val_frac=0.15, test_frac=0.15,
                          alpha_price=100, alpha_vol=0.5, ):
    """
    Loads raw data from eq_env, filters out stocks with any NaN values along the time 
    dimension, applies log transformations and global normalization (using training stats),
    and splits the data into train/val/test splits.
    """
    # Map variable names to descriptions
    accounting_var_desc = {
        'atq': 'Total Assets (Quarterly)',
        'actq': 'Current Assets - Total (Quarterly)',
        'cheq': 'Cash and Short-Term Investments (Quarterly)',
        'lctq': 'Current Liabilities - Total (Quarterly)',
        'dlcq': 'Debt in Current Liabilities (Quarterly)',
        'txpq': 'Income Taxes Payable (Quarterly)',
        'saleq': 'Sales/Turnover (Net) (Quarterly)',
        'niq': 'Net Income (Loss) (Quarterly)',
        'dpq': 'Depreciation and Amortization (Quarterly)',
        'ibq': 'Income Before Extraordinary Items (Quarterly)',
        'txdbq': 'Income Tax Deferred (Quarterly)',
        'ppegtq': 'Property, Plant and Equipment - Total (Gross) (Quarterly)',
        'invtq': 'Inventories - Total (Quarterly)',
        'dlttq': 'Long-Term Debt - Total (Quarterly)',
        'seqq': 'Stockholders Equity - Parent (Quarterly)',
        'ivaoq': 'Investment and Advances - Other (Quarterly)',
        'mibq': 'Noncontrolling Interest - Redeemable (Quarterly)',
        'pstkq': 'Preferred/Preference Stock (Capital) - Total (Quarterly)',
        'ceqq': 'Common/Ordinary Equity - Total (Quarterly)',
        'cogsq': 'Cost of Goods Sold (Quarterly)',
        'tieq': 'Total Interest Expense (Quarterly)',
        'xsgaq': 'Selling, General and Administrative Expenses (Quarterly)',
        'revtq': 'Revenue - Total (Quarterly)',
        'oiadpq': 'Operating Income After Depreciation (Quarterly)',
        'ltq': 'Liabilities - Total (Quarterly)',
        'txditcq': 'Deferred Taxes and Investment Tax Credit (Quarterly)',
        'wcapchy': 'Change in Working Capital (Quarterly)',
        'capxy': 'Capital Expenditures (Quarterly)',
        'cshoq': 'Common Shares Outstanding (Quarterly)',
        'ajexq': 'Adjustment Factor (Quarterly)',
        'xrdq': 'R&D Expense (Quarterly)'
    }
    # start time 1990-01-01
    train_beg_date=config["dataset_config"]["train_beg_date"],
    train_end_date=config["dataset_config"]["train_end_date"],
    val_beg_date=config["dataset_config"]["val_beg_date"],
    val_end_date=config["dataset_config"]["val_end_date"],
    test_beg_date=config["dataset_config"]["test_beg_date"],
    test_end_date=config["dataset_config"]["test_end_date"],
    n_assets = config["stocks_sample"]
    start_time = train_beg_date
    end_time = test_end_date
    raw = eq_env.load_raw_full()
    permnos = raw["permnos"]
    
    data = raw["daily_crsp_tensor"].astype(np.float32)  # shape: (T, N, 2)
    
    daily_dates = raw["daily_dates"] # Only use first 2000 stocks 
    
    # last observed date: 2023-03
    start_idx = np.where(daily_dates >= start_time)[0][0]
    train_end_idx_original = np.where(daily_dates <= train_end_date)[0][-1]
    val_end_idx_original = np.where(daily_dates <= val_end_date)[0][-1]
    end_idx = np.where(daily_dates <= end_time)[0][-1]
    data = data[start_idx:end_idx+1]
    num_days = end_idx - start_idx + 1
    train_end_idx = train_end_idx_original - start_idx
    val_end_idx = val_end_idx_original - start_idx
    
    # Filter out stocks with any NaNs along time and feature dimensions.
    # Create combined filter for both NaN and volume requirements
    nan_filter = ~np.isnan(data).any(axis=(0, 2))
    print("Valid stocks (no NaNs):", nan_filter.sum())
    # Use Mcap instead of volume
    # Get market cap at the end of training period instead of the beginning
    train_end = config["dataset_config"]["train_end_date"]  # e.g. 20201231
    train_end_year = int(str(train_end)[:4])              # 2020
    train_end_month = int(str(train_end)[4:6])            # month from train_end

    # Calculate the last quarter end based on train_end_date
    # Map months 1-3 to Q1 (0331), 4-6 to Q2 (0630), 7-9 to Q3 (0930), 10-12 to Q4 (1231)
    if train_end_month <= 3:
        last_quarter_end = train_end_year * 10000 + 331      # Q1
    elif train_end_month <= 6:
        last_quarter_end = train_end_year * 10000 + 630      # Q2
    elif train_end_month <= 9:
        last_quarter_end = train_end_year * 10000 + 930      # Q3
    else:
        last_quarter_end = train_end_year * 10000 + 1231     # Q4
    
    candidate_idx = np.where(raw["dates"] <= last_quarter_end)[0][-1]
    var_names = raw['accounting_vars']
    var_dict = {name: idx for idx, name in enumerate(var_names)}
    cshoq      = raw["compustat_tensor"][:,:,var_dict['cshoq']]
    market_prices = data[..., 0]
    mcap_first_qtr_valid_only = market_prices[train_end_idx,nan_filter]*cshoq[candidate_idx,nan_filter]
    # mcap_first_qtr_valid_only = market_prices[0, nan_filter] * cshoq[candidate_idx, nan_filter]
    # valid_stocks is a (N,) boolean array
    # nan_filter is also a (N,) boolean array
    valid_stocks = nan_filter.copy()
    # ---------------------------------------------------
    # 1) Extract the True indices from nan_filter
    # ---------------------------------------------------
    candidate_indices = np.where(nan_filter)[0]
    assert mcap_first_qtr_valid_only.shape[0] == candidate_indices.shape[0]
    # ---------------------------------------------------
    # 2) Filter out any NaNs
    # ---------------------------------------------------
    valid_mcap_mask = ~np.isnan(mcap_first_qtr_valid_only)
    mcap_filtered = mcap_first_qtr_valid_only[valid_mcap_mask]

    # These are the original stock indices (in [0..N)) that remain after NaN removal
    candidate_indices_filtered = candidate_indices[valid_mcap_mask]
    #n_assets = 500
    # ---------------------------------------------------
    # 3) Sort & pick top 550
    # ---------------------------------------------------
    # Double check that this filtering is correct TODO april 9
    n_stocks_after_nan = mcap_filtered.shape[0]
    assert n_stocks_after_nan > n_assets
    # Sort ascending
    sorted_indices = np.argsort(mcap_filtered)  
    # Pick the last 550 (largest)
    top_assets_indices = sorted_indices[-n_assets:]     

    # The actual stock indices for these top 550 
    selected_stock_indices = candidate_indices_filtered[top_assets_indices]

    # ---------------------------------------------------
    # 4) Build final boolean mask in the full stock universe
    # ---------------------------------------------------
    big_mcap_filter = np.zeros_like(valid_stocks, dtype=bool)  # shape (N,)
    big_mcap_filter[selected_stock_indices] = True

    # ---------------------------------------------------
    # 5) Print median market cap of selected stocks
    # ---------------------------------------------------
    # Get the market cap values for the selected stocks
    all_mcaps = mcap_filtered
    selected_mcaps = mcap_filtered[top_assets_indices]
    all_mcaps = all_mcaps/1000
    selected_mcaps = selected_mcaps/1000
    median_mcap = np.median(selected_mcaps)
    min_mcap = np.min(selected_mcaps)
    max_mcap = np.max(selected_mcaps)
    
    print(f"Selected stocks market cap statistics:")
    print(f"  Median (billion USD): {median_mcap:.2f}")
    print(f"  Min (billion USD): {min_mcap:.2f}")
    print(f"  Max (billion USD): {max_mcap:.2f}")
    print(f"  Total number of stocks selected: {len(selected_mcaps)}")
    print(f"  Total number of stocks in all: {len(all_mcaps)}")
    print(f"Median market cap of all stocks (billion USD): {np.median(all_mcaps):.2f}")
    
    # -----------------x----------------------------------
    # 5) Combine with valid_stocks
    # ---------------------------------------------------
    valid_stocks = valid_stocks & big_mcap_filter

    print(f"After NaN+marketcap filter: using {valid_stocks.sum()} stocks (top 550 by Q1 mcap).")
    valid_permnos = permnos[valid_stocks]
    num_assets = sum(valid_stocks)
    # Apply combined filter
    use_cross_sectional_rank = True
    data = data[:, valid_stocks]
    features = []
    T, N, _ = data.shape
    data[...,0] = np.abs(data[...,0])
    price_data = data[..., 0]
    asset_return = data[..., 3]
    momentum_windows = {
        "return_1week": (7, 0),
        "30day": (30,0),
        "100day": (100,0),
        "252day": (252,0),
        "ST_Rev":  (21,0), #(30,0), #(21,  0),    # prior 1 month
        "r2_1":    (42,21), #(42, 21),    # 2->1 months
        "r12_2":   (252, 42),   # 12->2 months
        "r12_7":   (252, 147),  # 12->7 months
        "r36_13":  (756, 273),  # 36->13 months
        "LT_Rev":  (1260, 273)  # 60->13 months
    }

    monthly_dates = raw["dates"]       # array of monthly dates
    daily_dates   = raw["daily_dates"] # array of daily dates

    # Create a mapping from daily index to monthly index
    daily_to_monthly_idx = {}
    target_daily_dates = daily_dates[start_idx:end_idx+1]
    for i, d_date in enumerate(target_daily_dates):
        monthly_idx = None
        # Find the most recent monthly date <= current daily date
        for m_idx, m_date in enumerate(monthly_dates):
            if m_date <= d_date:
                monthly_idx = m_idx
            else:
                break
        daily_to_monthly_idx[i] = monthly_idx


    for feature in feature_set:
        # If it's one of the momentum/reversal features in our dictionary
        if feature in momentum_windows:
            (window_hi, window_lo) = momentum_windows[feature]
            feat, median_feat = make_return_feature(
                asset_return=asset_return,
                alpha_price=alpha_price,
                train_end_idx=train_end_idx,
                window_hi=window_hi,
                window_lo=window_lo,
                feature_name=feature,
                use_cross_sectional_rank=use_cross_sectional_rank
            )
            assert feat.shape == asset_return.shape
            assert median_feat.shape == asset_return.shape
            features.append(feat)
            features.append(median_feat)
        elif feature == "suv":
            # Load SUV data based on validation year
            suv_path = f"{BASE_PATH}/data/equities/suv_{str(config['dataset_config']['val_beg_date'])[:4]}.npy"
            median_path = str(suv_path).replace(".npy", "_median.npy")
            print(f"Loading SUV data from: {suv_path}")
            suv = np.load(suv_path)
            median_suv = np.load(median_path)
            median_suv = np.tile(median_suv, (1, suv.shape[1]))
            # extend to same shape as suv
            # Apply the same filtering and slicing as other data # REMOVED: suv = suv[start_idx:end_idx+1, valid_stocks]
            assert suv.shape == asset_return.shape, f"SUV shape {suv.shape} doesn't match asset_return shape {asset_return.shape}"
            features.append(suv)
            features.append(median_suv)
            

        elif feature == "beta":
            # Load Beta data based on validation year
            beta_path = f"{BASE_PATH}/data/equities/beta_{str(config['dataset_config']['val_beg_date'])[:4]}.npy"
            print(f"Loading Beta data from: {beta_path}")
            median_path = str(beta_path).replace(".npy", "_median.npy")
            median_beta = np.load(median_path)
            
            beta = np.load(beta_path)
            median_beta = np.tile(median_beta, (1, beta.shape[1]))
            
            # Apply the same filtering and slicing as other data # REMOVED: beta = beta[start_idx:end_idx+1, valid_stocks]
            assert beta.shape == asset_return.shape, f"Beta shape {beta.shape} doesn't match asset_return shape {asset_return.shape}"
            features.append(beta)
            features.append(median_beta)
            
        elif feature == "residual_volatility":
            residual_vol_path = f"{BASE_PATH}/data/equities/residual_vol_{str(config['dataset_config']['val_beg_date'])[:4]}.npy"
            print(f"Loading Residual Volatility data from: {residual_vol_path}")
            residual_vol = np.load(residual_vol_path)
            median_path = str(residual_vol_path).replace(".npy", "_median.npy")
            median_residual_vol = np.load(median_path)
            median_residual_vol = np.tile(median_residual_vol, (1, residual_vol.shape[1]))
            assert residual_vol.shape == asset_return.shape, f"Residual Volatility shape {residual_vol.shape} doesn't match asset_return shape {asset_return.shape}"
            assert median_residual_vol.shape == asset_return.shape, f"Median Residual Volatility shape {median_residual_vol.shape} doesn't match asset_return shape {asset_return.shape}"
            
            features.append(residual_vol)
            features.append(median_residual_vol)
        
        elif feature == "spread":
            asklo = data[..., 1]
            bidhi = data[..., 2]
            spread = (bidhi - asklo) / (bidhi + asklo)   # (T, N)
            # compute rolling 30 day mean spread
            spread_30d = np.zeros_like(spread)
            for t in range(T):
                start_idx_spread = max(0, t-30)
                end_idx_spread = t
                spread_window = spread[start_idx_spread:end_idx_spread]
                if spread_window.size > 0:
                    mean_spread = np.mean(spread_window)
                    spread_30d[t] = mean_spread
                else:
                    spread_30d[t] = 0.0
            # rank transform
            spread_30d_ranked, spread_30d_median = transform_raw_feature(
                    monthly_data=spread_30d,
                    monthly_train_start=0,
                    monthly_train_end=train_end_idx,
                    daily_to_monthly_idx=daily_to_monthly_idx,
                    num_days=num_days,
                    num_assets=num_assets,
                    feature_name="spread",
                    normalize=True,
                    transform='delta_log',
                    cross_sectional_rank=use_cross_sectional_rank,
                    daily_data=True
                )
            features.append(spread_30d_ranked)
            features.append(spread_30d_median)
            assert spread_30d_ranked.shape == asset_return.shape
            

        elif feature == "debug":
            feat = np.ones_like(asset_return)
            feat2 = np.ones_like(asset_return)
            features.append(feat)
            features.append(feat2)

        elif feature == "Rel2High":
            window_hi = 252  # 1 year
            rel2high_raw = np.zeros_like(price_data, dtype=np.float32)
            T, N = price_data.shape
            for t_ in range(T):
                start_idx_rel2high = max(0, t_ - window_hi)
                max_p = np.max(price_data[start_idx_rel2high : t_+1], axis=0)
                ratio = np.where(max_p > 0.0, price_data[t_, :] / max_p, 0.0)
                rel2high_raw[t_] = ratio

            rel2high_norm, rel2high_median =transform_raw_feature(
                    monthly_data=rel2high_raw,
                    monthly_train_start=0,
                    monthly_train_end=train_end_idx,
                    daily_to_monthly_idx=daily_to_monthly_idx,
                    num_days=num_days,
                    num_assets=num_assets,
                    feature_name="rel2high",
                    normalize=True,
                    transform='delta_log',
                    cross_sectional_rank=use_cross_sectional_rank,
                    daily_data=True
                )
            assert rel2high_norm.shape == asset_return.shape
            features.append(rel2high_norm)
            features.append(rel2high_median)
            
        if feature == "return":
            asset_return = data[...,3]
            norm_return, norm_median = transform_raw_feature(
                    monthly_data=asset_return,
                    monthly_train_start=0,
                    monthly_train_end=train_end_idx,
                    daily_to_monthly_idx=daily_to_monthly_idx,
                    num_days=num_days,
                    num_assets=num_assets,
                    feature_name="return",
                    normalize=True,
                    transform='delta',
                    cross_sectional_rank=use_cross_sectional_rank,
                    daily_data=True
                )
            features.append(norm_return)
            features.append(norm_median)
        elif feature == "volume":
            volume = data[..., 4]
            
            norm_volume, norm_median_volume = transform_raw_feature(
                    monthly_data=volume,
                    monthly_train_start=0,
                    monthly_train_end=train_end_idx,
                    daily_to_monthly_idx=daily_to_monthly_idx,
                    num_days=num_days,
                    num_assets=num_assets,
                    feature_name="volume",
                    normalize=True,
                    transform='delta_delta_log',
                    cross_sectional_rank=use_cross_sectional_rank,
                    daily_data=True
                )
            features.append(norm_volume)
            features.append(norm_median_volume)
        elif feature == "day_of_week":
            # Convert dates to day of week (0=Monday, 4=Friday)
            import pandas as pd
            filtered_dates = daily_dates[start_idx:end_idx+1]
            # Convert integer dates to datetime objects
            datetime_dates = pd.to_datetime(filtered_dates, format='%Y%m%d')
            # Get day of week (0=Monday, 6=Sunday)
            day_of_week = datetime_dates.dayofweek.values
            
            # Create Monday indicator (1 if Monday, 0 otherwise)
            is_monday = (day_of_week == 0).astype(np.float32)
            # Create Friday indicator (1 if Friday, 0 otherwise)
            is_friday = (day_of_week == 4).astype(np.float32)
            
            # Expand to match stock dimension
            monday_feature = np.tile(is_monday[:, np.newaxis], (1, N))
            friday_feature = np.tile(is_friday[:, np.newaxis], (1, N))
            
            # Add both features
            features.append(monday_feature)
            features.append(friday_feature)
            print("Added day of week features (Monday and Friday indicators)")
        elif feature == "risk_free_rate":
            rf_rate = raw["rf_daily"][:,0]
            # backfill Nans by propagating last valid observation forward
            import pandas as pd
            rf_pandas = pd.Series(rf_rate)
            rf_pandas = rf_pandas.fillna(method="ffill")
            rf_rate = rf_pandas.values
            norm_rf_rate = (rf_rate - np.mean(rf_rate[start_idx:train_end_idx_original])) / np.std(rf_rate[start_idx:train_end_idx_original])
            # extend to have second dimension equal to data[1]
            norm_rf_rate = np.tile(norm_rf_rate[:, np.newaxis], (1, N))
            # only use valid dates
            norm_rf_rate = norm_rf_rate[start_idx:end_idx+1]
            
            features.append(norm_rf_rate)
        elif feature == "trailing_volatility":
            # Calculate trailing volatility for each stock using multiple window sizes
            window_sizes = [7, 50]  # Multiple rolling windows (in days)
            asset_return = data[..., 3]  # Get return data
            
            for window_size in window_sizes:
                # Initialize volatility array (same shape as asset_return)
                volatility = np.zeros_like(asset_return)
                
                # Vectorized implementation of rolling volatility calculation
                # For time points with enough history, use full window
                for t in range(window_size, T):
                    volatility[t] = np.std(asset_return[t-window_size:t], axis=0)
                
                # For earlier time points, use all available data
                for t in range(1, window_size):
                    volatility[t] = np.std(asset_return[:t], axis=0)
                
                # For t=0, use volatility from t=1 (can't compute with no data)
                volatility[0] = volatility[1]

                # Clip values at 99.99 percentile
                percentile_99 = np.percentile(volatility, 99.99)
                volatility = np.clip(volatility, 0, percentile_99)

                # instead use the transform_raw_feature function
                norm_volatility, norm_median_volatility = transform_raw_feature(
                    monthly_data=volatility,
                    monthly_train_start=window_size,
                    monthly_train_end=train_end_idx,
                    daily_to_monthly_idx=daily_to_monthly_idx,
                    num_days=num_days,
                    num_assets=num_assets,
                    feature_name="volatility",
                    normalize=True,
                    transform='delta_log',
                    cross_sectional_rank=use_cross_sectional_rank,
                    daily_data=True
                )

                features.append(norm_volatility)
                features.append(norm_median_volatility)
                print("median volatility:", norm_median_volatility.min(), norm_median_volatility.max())
                print("vol min:", volatility.min(), "vol max:", volatility.max())
                print("norm vol min:", norm_volatility.min(), "normvol max:", norm_volatility.max())
                
        
        elif feature == "yearly_accounting_vars":
            """
            Fama-French 4 Factor Model Definitions:

            2. Size Factor (SMB)
            3. Value Factor (HML)
            4. Profitability Factor (RMW)
            5. Investment Factor (CMA)
            """

            feature_mapping = {
                # Trading Frictions
                "at": {"transform": "delta_delta_log"}, # Total assets
                "size": {"transform": "delta_delta_log"}, # Size
                "lturnover": {"transform": "delta_delta_log"}, # Turnover
                # Intangibles
                "oa": {'transform': 'delta'},
                "ol": {'transform': 'delta_log'},
                "pcm": {'transform': 'delta_log'},

                # Profitability
                "prof": {'transform': 'delta_log'},
                "cto": {'transform': 'delta_log'},
                "fc2y": {'transform': 'delta_log'},
                "op": {'transform': 'delta_log'},
                "pm": {'transform': 'delta_log'},
                "d2a": {"transform": "delta_log"},
                "rna": {"transform": "delta_log"},

                # Investment
                "investment": {"transform": "delta_log"}, 
                "noa": {'transform': 'delta_log'},
                "dpi2a": {'transform': 'delta_log'},

                #Value
                "a2me": {'transform': 'delta_log'},
                "c": {'transform': 'delta_log'},
                "cf": {'transform': 'delta'},
                "cf2p": {'transform': 'delta_log'},
                "lev": {"transform": "delta_log"},
                "q": {'transform': 'delta_log'},
                "e2p": {'transform': 'delta_log'},
                "beme": {"transform": "delta_log"},  
            }
            # Things to add
            # AC
            # ATO
            # D2P   ADD
            # E2P
            # Investment
            # NI
            # RNA    ADD
            # ROA    ADD
            # ROE    ADD
            # S2P
            # SGA2S   ADD
            # Using monthly accounting variables instead of yearly
            var_names = raw["accounting_vars"]
            var_tensor = raw["compustat_tensor"]  # shape: (monthly_dates, all_stocks, accounting_vars)

            # Restrict to valid stocks
            var_tensor = var_tensor[:, valid_stocks, :]  # (monthly_dates, sum(valid_stocks), accounting_vars)
            print("All accounting variables:", var_names)
            var_dict = {name: idx for idx, name in enumerate(var_names)}

            ids = raw["permnos"][valid_stocks]
            
            #needed_vars = ['atq', 'seqq', 'oiadpq', 'cshoq', 'revtq', 'cogsq', 'dlttq', 'dpq', 'dlcq']
            needed_vars = [
                # Already existing from before:
                'atq','seqq','cshoq','revtq','cogsq','dlttq','dpq','dlcq','cheq',
                'niq','wcapchy','capxy','ibq','txdbq','saleq','oiadpq','xsgaq',
                'ppegtq','invtq','ceqq','pstkq','mibq','ivaoq', 'tieq'
                # For Q, we need 'txdbq' (deferred taxes), or if you store it in 'txditcq', etc.
                # If you do not have 'txdbq' but 'txditcq', adjust accordingly.

                # *We skip anything specific to NI (Net Share Issues).
            ]
            for var_name in needed_vars:
                assert var_name in var_dict, f"Required accounting variable '{var_name}' not found."

            # Dictionary to store cleaned variables for factor calculation
            processed_vars = {}

            # Process each accounting variable
            for _, var_name in enumerate(needed_vars): # Is this processing actually correct? No this needs to be fixed
                #var_names
                i = var_dict[var_name]
                var_data = var_tensor[:, :, i].copy()  # shape: (monthly_dates, #stocks)

                # Forward/backward fill NaNs for each stock
                for stock_idx in range(var_data.shape[1]):
                    stock_series = var_data[:, stock_idx]
                    # Forward fill
                    last_valid_value = None
                    for t in range(stock_series.shape[0]):
                        if not np.isnan(stock_series[t]):
                            last_valid_value = stock_series[t]
                        elif last_valid_value is not None:
                            stock_series[t] = last_valid_value
                    # Backward fill for initial NaNs
                    first_valid_idx = np.where(~np.isnan(stock_series))[0]
                    if len(first_valid_idx) > 0:
                        first_valid_value = stock_series[first_valid_idx[0]]
                        for t in range(first_valid_idx[0]):
                            stock_series[t] = first_valid_value

                # Handle any series that remain all NaN
                if np.isnan(var_data).any():
                    global_median = np.nanmedian(var_data)
                    var_data[np.isnan(var_data)] = global_median
                    #print("Count of NaN:", np.sum(np.isnan(var_data)))

                # If this var is needed for Fama-French factors, store the cleaned version
                if var_name in needed_vars:
                    processed_vars[var_name] = var_data

                # (Optional) you can still do a quick local normalization if you like:
                monthly_train_end = 0
                monthly_train_start = 0
                for m_date in monthly_dates:
                    if m_date <= daily_dates[start_idx]:
                        monthly_train_start += 1
                    if m_date <= daily_dates[train_end_idx_original]:
                        monthly_train_end += 1
                    else:
                        break

            #num_days   = end_idx - start_idx + 1
            num_assets = sum(valid_stocks)
            # Common helper variables
            total_assets = processed_vars['atq']
            atq  = processed_vars['atq']
            dlttq = processed_vars['dlttq']
            dlcq  = processed_vars['dlcq']
            dpq   = processed_vars['dpq']
            book_value = processed_vars['seqq']
            cshoq      = processed_vars['cshoq']
            cheq = processed_vars['cheq']  # shape (num_months, num_assets)
            niq     = processed_vars['niq']     # net income
            dpq     = processed_vars['dpq']     # depreciation
            wcapchy = processed_vars['wcapchy'] # change in working capital
            delta_ncw = processed_vars['wcapchy'] 
            capxy   = processed_vars['capxy']   # capital expenditures
            seqq    = processed_vars['seqq']
            ibq   = processed_vars['ibq']   # income before extraord. items
            txdbq = processed_vars['txdbq'] # deferred taxes
            saleq = processed_vars['saleq']
            ppegtq = processed_vars['ppegtq']
            invtq  = processed_vars['invtq']
            xsgaq = processed_vars['xsgaq']
            market_prices = data[..., 0]
            # monthly -> daily alignment
            # build monthly market_cap
            MAX_CAP = 5_000_000
            over_cap_count = 0
            market_cap = np.zeros_like(book_value)
            num_m = book_value.shape[0]
            
            for t_ in range(num_m):
                c_idx = 0
                for d_idx, d_date in enumerate(target_daily_dates):
                    if d_date <= monthly_dates[t_]:
                        c_idx = d_idx
                    else:
                        break
                market_cap[t_, :] = market_prices[c_idx,:]*cshoq[t_,:]
                # Clip the market cap
                if market_cap[t_, :].max() > MAX_CAP:
                    over_cap_count += 1
                market_cap[t_, :] = np.minimum(market_cap[t_, :], MAX_CAP)
            
            daily_vol = data[..., 4]
            
            num_months = monthly_dates.shape[0]

            monthly_vol = np.zeros((num_months, num_assets), dtype=np.float32)
            for day_i in range(num_days):
                m_idx = daily_to_monthly_idx[day_i]
                if m_idx is not None:
                    monthly_vol[m_idx,:] += daily_vol[day_i,:]
            atq_lag = np.zeros_like(atq)
            atq_lag[12:,:] = atq[:-12,:]
            atq_lag[:12,:]  = atq[:12,:]  # or 1? Up to you
            # end
        
            for key in feature_mapping.keys():
                if key == "at":
                    raw_feature = total_assets
                elif key == "beme":
                    #print(f"Over cap count: {over_cap_count}")
                    
                    # 1) Create 'book_to_market' and 'value_missing_ind' arrays
                    market_to_book = np.zeros_like(book_value)
                    value_missing_ind = np.zeros_like(book_value)

                    # 2) For each [t_, j], if either < 1 => ratio = 1, missing_ind = 1
                    for t_ in range(num_m):
                        for j in range(book_value.shape[1]):
                            if book_value[t_, j] < 1.0 or market_cap[t_, j] < 1.0 or book_value[t_, j] > market_cap[t_, j]*10 or market_cap[t_, j] > book_value[t_, j]*150:
                                market_to_book[t_, j] = 10
                                value_missing_ind[t_, j] = 1.0
                            else:
                                market_to_book[t_, j] =  market_cap[t_, j] / book_value[t_, j]
                    market_to_book = np.clip(market_to_book, 0.3, 100)
                    raw_feature = market_to_book
                elif key == "rna":
                    oiadpq = processed_vars['oiadpq']
                    operating_prof = oiadpq / np.maximum(total_assets, 10)
                    raw_feature = operating_prof
                elif key == "investment":
                    # This would make more sense to have a year long lookback
                    asset_growth = np.zeros_like(total_assets)
                    for t_ in range(12, total_assets.shape[0]):
                        prev_ = total_assets[t_-12,:]
                        curr_ = total_assets[t_,:]
                        asset_growth[t_,:] = (curr_ - prev_)/np.maximum((prev_ +curr_)/2, 10)
                    raw_feature = asset_growth
                elif key == "size":
                    raw_feature = market_cap
                elif key == 'lev':
                    lev_raw = (dlttq + dlcq)/np.maximum(dlttq+dlcq+seqq,10)
                    # Clip to 0
                    lev_raw[lev_raw < 0] = 0
                    raw_feature = lev_raw
                elif key == "d2a":
                    d2a_raw = dpq/np.maximum(atq,10)
                    d2a_raw = np.clip(d2a_raw, 0, 0.3)
                    raw_feature = d2a_raw
                elif key == "lturnover":
                    lturnover_raw = monthly_vol/np.maximum(cshoq,10)
                    
                    # clip at 1 and 99 percentile
                    lturnover_raw = np.clip(lturnover_raw, np.percentile(lturnover_raw, 2), np.percentile(lturnover_raw, 98))
                    raw_feature = lturnover_raw
                elif key == "c":
                    c_raw = cheq / np.maximum(atq, 1)
                    # plausible range 0..2
                    c_raw = np.clip(c_raw, 0, 2)
                    raw_feature = c_raw
                elif key == "cf":
                    cf_numer = niq + dpq - wcapchy - capxy
                    cf_denom = np.maximum(seqq, 1)
                    cf_raw   = cf_numer / cf_denom
                    # plausible range -2..2
                    cf_raw = np.clip(cf_raw, -2, 2)
                    raw_feature = cf_raw
                elif key == "cf2p":
                    cf2p_num = ibq + dpq + txdbq
                    cf2p_raw = cf2p_num / np.maximum(market_cap, 1)
                    # plausible range -5..5
                    cf2p_raw = np.clip(cf2p_raw, -5, 5)
                    raw_feature = cf2p_raw
                elif key == "cto":
                    

                    cto_raw = saleq / np.maximum(atq, 1) # atq_lag
                    # plausible range 0..10
                    cto_raw = np.clip(cto_raw, 0, 10)
                    raw_feature = cto_raw
                elif key == "a2me":
                    # A2ME raw ratio
                    a2me_raw = processed_vars['atq'] / np.maximum(market_cap, 1.0)  # avoid dividing by <1
                    # clamp to [0..20]
                    a2me_raw = np.clip(a2me_raw, 0, 20)
                    raw_feature = a2me_raw
                elif key == "dpi2a":
                    ppegt_lag = np.zeros_like(ppegtq)
                    invt_lag  = np.zeros_like(invtq)
                    at_lag    = np.zeros_like(atq)

                    # Shift by 12 months
                    ppegt_lag[12:] = ppegtq[:-12]
                    invt_lag[12:]  = invtq[:-12]
                    at_lag[12:]    = atq[:-12]

                    # For t=0, set to same or a fallback
                    ppegt_lag[:12] = ppegtq[:12]
                    invt_lag[:12]  = invtq[:12]
                    at_lag[:12]    = atq[:12]

                    delta_ppegt = ppegtq - ppegt_lag
                    delta_invt  = invtq  - invt_lag

                    dpi2a_raw = (delta_ppegt + delta_invt) / np.maximum(at_lag, 1.0)
                    # clamp -1..1
                    dpi2a_raw = np.clip(dpi2a_raw, 0, 1)
                    raw_feature = dpi2a_raw
                elif key == "e2p":
                    e2p_raw = processed_vars['ibq'] / np.maximum(market_cap, 1.0)
                    # clamp -2..2
                    e2p_raw = np.clip(e2p_raw, -2, 2)
                    raw_feature = e2p_raw
                elif key == "fc2y":
                    fc_raw = (xsgaq) / np.maximum(saleq, 1.0)
                    
                    # clamp 0..2
                    fc_raw = np.clip(fc_raw, 0, 5)    
                    raw_feature = fc_raw
                elif key == "noa":
                    op_assets = processed_vars['atq'] - processed_vars['cheq'] - processed_vars['ivaoq']
                    op_liab   = (processed_vars['atq'] 
                                - processed_vars['dlcq'] - processed_vars['dlttq']
                                - processed_vars['mibq'] - processed_vars['pstkq'] - processed_vars['ceqq'])

                    noa_raw = (op_assets - op_liab) / np.maximum(atq, 1.0) # atq_lag
                    noa_raw = np.clip(noa_raw, -1, 2)
                    raw_feature = noa_raw
                elif key == "oa":
                    oa_raw = (delta_ncw - dpq) / np.maximum(atq, 1.0)  # atq_lag
                    oa_raw = np.clip(oa_raw, -1, 1)
                    raw_feature = oa_raw
                elif key == "ol":
                    ol_raw = (processed_vars['cogsq'] + processed_vars['xsgaq']) / np.maximum(processed_vars['atq'], 1.0)
                    ol_raw = np.clip(ol_raw, 0, 2)
                    raw_feature = ol_raw
                elif key == "op":
                    op_numer = (processed_vars['revtq'] - processed_vars['cogsq'] 
                    - processed_vars['tieq'] - processed_vars['xsgaq'])
                    op_denom = np.maximum(processed_vars['seqq'], 1.0)
                    op_raw   = op_numer / op_denom
                    op_raw   = np.clip(op_raw, -2, 2)
                    raw_feature = op_raw
                elif key == "pcm":
                    pcm_numer = processed_vars['saleq'] - processed_vars['cogsq']
                    pcm_denom = np.maximum(processed_vars['saleq'], 1.0)
                    pcm_raw   = pcm_numer / pcm_denom
                    pcm_raw   = np.clip(pcm_raw, 0, 1)
                    raw_feature = pcm_raw
                elif key == "pm":
                    pm_raw = processed_vars['oiadpq'] / np.maximum(processed_vars['saleq'], 1.0)
                    pm_raw = np.clip(pm_raw, -1, 1)
                    raw_feature = pm_raw
                elif key == "prof":
                    gross_profit = processed_vars['saleq'] - processed_vars['cogsq']
                    prof_raw     = gross_profit / np.maximum(processed_vars['seqq'], 1.0)
                    prof_raw     = np.clip(prof_raw, -2, 2)
                    raw_feature = prof_raw
                elif key == "q":
                    q_numer = (processed_vars['atq'] 
                            + market_cap  # from daily price * cshoq
                            - processed_vars['cheq']
                            - processed_vars['txdbq'])
                    q_denom = np.maximum(processed_vars['atq'], 1.0)
                    q_raw   = q_numer / q_denom
                    q_raw   = np.clip(q_raw, 0, 10)
                    raw_feature = q_raw
                else:
                    raise ValueError(f"Unknown feature: {key}")
                    
                processed_feature, processed_feature_median = transform_raw_feature(
                    monthly_data=raw_feature,
                    monthly_train_start=monthly_train_start,
                    monthly_train_end=monthly_train_end,
                    daily_to_monthly_idx=daily_to_monthly_idx,
                    num_days=num_days,
                    num_assets=num_assets,
                    feature_name=key,
                    normalize = True,
                    transform = feature_mapping[key]['transform'],
                    cross_sectional_rank=use_cross_sectional_rank
                )
                features.append(processed_feature)
                features.append(processed_feature_median)


    print(f"Filtered data: using {N} stocks (of original) with no NaNs.")
    # Stack features.
    
    try:
        X = np.stack(features, axis=-1)  # shape: (T, N, 42)
    except Exception as e:
        for i, feature in enumerate(features):
            print(f"Feature {i} shape: {feature.shape}")
        breakpoint()
    
    # Sanity checks on X
    if np.isnan(X).any():
        breakpoint()
    assert not np.isnan(X).any(), "X contains NaN values!"
    
    # Print min/max for each feature and global min/max
    print("===== FEATURE MIN/MAX VALUES =====")
    global_min, global_max = np.inf, -np.inf
    for i in range(X.shape[2]):
        feat_min = X[:, :, i].min()
        feat_max = X[:, :, i].max()
        global_min = min(global_min, feat_min)
        global_max = max(global_max, feat_max)
        print(f"Feature {i}: min={feat_min:.4f}, max={feat_max:.4f}")
    print(f"Global: min={global_min:.4f}, max={global_max:.4f}")
    print("==================================")
    
    # Check for large values
    if np.any(X > 100):
        print("WARNING: Values > 100 detected in X!")
        breakpoint()
    
    # Check for very negative values
    if np.any(X < -100):
        print("WARNING: Values < -100 detected in X!")
        breakpoint()
    
    
    risk_free_rate = raw["rf_daily"][start_idx:end_idx+1,0]
    # expand to shape (T, N)
    risk_free_rate = np.tile(risk_free_rate, (N, 1)).T
    Y = asset_return - risk_free_rate
    splits = {
        "train": {"X": X[:train_end_idx], "Y": Y[:train_end_idx]},
        "val":   {"X": X[train_end_idx+1:val_end_idx], "Y": Y[train_end_idx+1:val_end_idx]},
        "test":  {"X": X[val_end_idx:], "Y": Y[val_end_idx:]}
    }
    
    train_market_return = Y[val_end_idx:]  # shape (T, N) T = 2011, N= 3373
    #train_market_return = Y[val_end_idx:]
    n_random_start_times = 1
    sharpe_ratios = []
    large_return_vec = []
    for i in range(n_random_start_times):
        random_start_time = np.random.randint(0, train_market_return.shape[0] - 248)
        train_market_return_new = train_market_return[random_start_time:random_start_time+248,:]
        train_market_return_new = train_market_return_new.mean(axis=1)
        large_return_vec.append(train_market_return_new)
        train_market_return_mean = train_market_return_new.mean()
        train_market_return_std = train_market_return_new.std()
        train_market_sharpe = train_market_return_mean / train_market_return_std
        train_market_yearly_sharpe = train_market_sharpe * np.sqrt(252)
        sharpe_ratios.append(train_market_yearly_sharpe)
    mean_sharpe_ratio = np.mean(sharpe_ratios)
    print(f"Train market yearly sharpe (500 samples): {mean_sharpe_ratio}") # Sharpe = 0.98

    #
    #concatenate all the large return vectors
    large_return_vec = np.concatenate(large_return_vec, axis=0)
    # Compute the market sharpe ratio for the entire train set
    train_market_return_mean_full = large_return_vec.mean()
    train_market_return_std_full = large_return_vec.std()
    train_market_sharpe_full = train_market_return_mean_full / train_market_return_std_full
    train_market_yearly_sharpe_full = train_market_sharpe_full * np.sqrt(252)
    print(f"Train market yearly sharpe (500 samples post combined): {train_market_yearly_sharpe_full}") # 0.60
    
    
    # Compute the market sharpe ratio for the entire train set
    train_market_return_mean_computed = train_market_return.mean(axis=1)
    train_market_return_mean_full = train_market_return_mean_computed.mean()
    train_market_return_std_full = train_market_return_mean_computed.std()
    train_market_sharpe_full = train_market_return_mean_full / train_market_return_std_full
    train_market_yearly_sharpe_full = train_market_sharpe_full * np.sqrt(252)
    print(f"Train market yearly sharpe: {train_market_yearly_sharpe_full}") # Sharpe = 0.25
    print("Yearly mean: ", train_market_return_mean_full*252)
    print("Yearly std: ", train_market_return_std_full*np.sqrt(252))

    dataset_stats = {
        "median_mcap_billions": median_mcap,
        "min_mcap_billions": min_mcap,
        "max_mcap_billions": max_mcap,
        "valid_permnos": valid_permnos,
        "market_caps": selected_mcaps
    }
    return splits, dataset_stats

def save_dataset(dataset_dict, stats, save_path):
    np.savez(save_path, splits=pickle.dumps(dataset_dict), stats=pickle.dumps(stats))
    print(f"Dataset saved to {save_path}")

def load_saved_dataset(save_path):
    with np.load(save_path, allow_pickle=True) as data:
        splits = pickle.loads(data['splits'].item())
        stats = pickle.loads(data['stats'].item())
    print(f"Dataset loaded from {save_path}")
    return splits, stats

# ==================================
# 2. PyTorch Dataset for Equity Data
# ==================================
class EquityData(Dataset):
    def __init__(self, splits, config, split="train", sequence_length=50, stocks_sample=20, steps_per_epoch=1000):
        """
        splits: dict output from create_equity_dataset()
        split: one of "train", "val", "test"
        sequence_length: number of timesteps used for the lookback window (e.g. 50)
        stocks_sample: number of stocks (columns) to sample per __getitem__
        """
        assert split in splits, f"Split {split} not found in dataset splits."
        self.X = splits[split]["X"]  # shape: (T_split, N, 2)
        self.Y = splits[split]["Y"]  # shape: (T_split, N)
        self.split = split
        self.eval_seed = config.get("eval_seed", 0)
        self.median_mcap_billions = config.get("median_mcap_billions", 0)
        self.sequence_length = sequence_length
        self.stocks_sample = stocks_sample
        self.steps_per_epoch = steps_per_epoch
        self.T, self.N, self.n_features = self.X.shape
        self.max_start = self.T - self.sequence_length
        
        if self.max_start <= 0:
            raise ValueError("Sequence length is longer than available timesteps in split.")

    def __len__(self):
        return self.steps_per_epoch

    def __getitem__(self, idx):
        if (self.split in ["val", "test"]):
            
            old_state = np.random.get_state()
            np.random.seed(self.eval_seed + idx)

        start_idx = np.random.randint(0, self.max_start)
        stock_indices = np.arange(self.N) # no randomness
        X_sample = self.X[start_idx:start_idx+self.sequence_length, stock_indices, :]
        Y_sample = self.Y[start_idx:start_idx+self.sequence_length, stock_indices]
        
        X = torch.tensor(X_sample, dtype=torch.float32)
        Y = torch.tensor(Y_sample, dtype=torch.float32)
        #Expand Y
        Y = Y.unsqueeze(2)
        # X has shape (sequence_length, stocks_sample, nr_features)
        # Y has shape (sequence_length, stocks_sample)

        # 7) Permute to match your desired output shape
        #    X => [nr_features, nr_loans, nr_timesteps]
        #    Y => [nr_loans, nr_timesteps, nr_classes]
        X = torch.permute(torch.tensor(X, dtype=torch.float32), (2,1,0))
        Y = torch.permute(torch.tensor(Y, dtype=torch.float32), (1,0, 2))
        if (self.split in ["val", "test"]):
            
            np.random.set_state(old_state)
        return X.float(), Y.float(), torch.tensor(0).float()
    
    def get_linear_regression_data(self, num_samples=1000):
        X_samples = []
        Y_samples = []
        max_start = self.T - self.sequence_length - 1  # extra timestep for target
        for i in range(num_samples):
            start_idx = np.random.randint(0, max_start)
            stock_idx = np.random.randint(0, self.N)
            x_sample = self.X[start_idx:start_idx+self.sequence_length, stock_idx, :].flatten()
            y_sample = self.Y[start_idx+self.sequence_length, stock_idx]
            X_samples.append(x_sample)
            Y_samples.append(y_sample)
        return np.array(X_samples), np.array(Y_samples)

# =============================================================================
# 3. Adapted EquityDataset Class
# =============================================================================
class EquityDataset(SequenceDataset):
    _name_ = "equities_dataset"
    def setup(self):
        # If SequenceDataset expects a list of collate argument names, add one if needed.
        #if not hasattr(self, "_collate_arg_names") or len(self._collate_arg_names) == 0:
        #    self._collate_arg_names = ["valid_indices"]
        
        # Load configuration.
        self._load_config()
        if len(self._collate_arg_names) == 0:
                self._collate_arg_names.append("valid_indices")

        # Try to load saved equity dataset splits.
        if self.config.get("load_data", False) and os.path.exists(self.config["data_path"]):
            print(f"Loading dataset from {self.config['data_path']}")
            splits, global_stats = load_saved_dataset(self.config["data_path"])
            print(f"Dataset loaded from {self.config['data_path']}")
        else:
            # Create an EquityEnv instance using parameters from the config.
            eq_env = EquityEnv(
                daily_price_data_path=self.config["dataset_config"]["daily_price_data_path"],
                train_beg_date=self.config["dataset_config"]["train_beg_date"],
                train_end_date=self.config["dataset_config"]["train_end_date"],
                val_beg_date=self.config["dataset_config"]["val_beg_date"],
                val_end_date=self.config["dataset_config"]["val_end_date"],
                test_beg_date=self.config["dataset_config"]["test_beg_date"],
                test_end_date=self.config["dataset_config"]["test_end_date"],
            )
            # Create dataset splits and compute global statistics.
            splits, global_stats = create_equity_dataset(

                eq_env,
                self.config,
                feature_set = self.config["feature_set"],
                train_frac=self.config.get("train_frac", 0.7),
                val_frac=self.config.get("val_frac", 0.15),
                test_frac=self.config.get("test_frac", 0.15),
                alpha_price=self.config.get("alpha_price", 100),
                alpha_vol=self.config.get("alpha_vol", 0.5)
            )
            # Save the dataset if configured to do so.
            if self.config.get("save_data", False):
                save_dataset(splits, global_stats, self.config["data_path"])
                
        # Store any global stats if desired.
        self.global_stats = global_stats
        self.config["median_mcap_billions"] = global_stats["median_mcap_billions"]
        # Create EquityData objects for train, validation, and test splits.
        self.dataset_train = EquityData(
            splits, 
            self.config,
            split="train", 
            sequence_length=self.config.get("sequence_length", 50),
            stocks_sample=self.config.get("stocks_sample", 20),
            steps_per_epoch=int(self.config["steps_per_epoch"] * self.config["train_frac"])
        )
        self.dataset_val = EquityData(
            splits, 
            self.config,
            split="val", 
            sequence_length=self.config.get("sequence_length", 50),
            stocks_sample=self.config.get("stocks_sample", 20),
            steps_per_epoch=max(1, int(self.config["steps_per_epoch"] * self.config["val_frac"]))
        )
        self.dataset_test = EquityData(
            splits, 
            self.config,
            split="test", 
            sequence_length=self.config.get("sequence_length", 50),
            stocks_sample=self.config.get("stocks_sample", 20),
            steps_per_epoch=max(1, int(self.config["steps_per_epoch"] * self.config["test_frac"]))
        )

    def _load_config(self):
        """
        Loads the configuration for the equity dataset.
        You can override or set the 'dataset_config' attribute before calling setup.
        """
        
        try:
            self.config = {
                "dataset_config": {
                    "daily_price_data_path": self.dataset_config.daily_price_data_path,
                    "train_beg_date": self.dataset_config.train_beg_date,
                    "train_end_date": self.dataset_config.train_end_date,
                    "val_beg_date": self.dataset_config.val_beg_date,
                    "val_end_date": self.dataset_config.val_end_date,
                    "test_beg_date": self.dataset_config.test_beg_date,
                    "test_end_date": self.dataset_config.test_end_date,
                },
                "train_frac": self.train_frac,
                "val_frac": self.val_frac,
                "test_frac": self.test_frac,
                "alpha_price": self.alpha_price,
                "alpha_vol": self.alpha_vol,
                "sequence_length": self.sequence_length,
                "stocks_sample": self.stocks_sample,
                "load_data": self.load_data,     # Set to True to load a pre-saved dataset.
                "save_data": self.save_data, 
                "data_path": self.data_path,     # Path to save/load the dataset.
                "steps_per_epoch": self.steps_per_epoch,    # Set to True to save after creating the dataset.
                "num_states":self.num_states,
                "feature_set": self.feature_set
            }
        except:
            # Fill in with defualt values if not provided.
            breakpoint()
            print("Using default configuration for EquityDataset.")
            self.config = {
            "_name_": "equities_dataset",
                "dataset_config": {
                    "daily_price_data_path": f"{EQUITIES_DATA_PATH}/daily_price_data_n_equities_50.npz",
                    "train_beg_date": 19900101,
                    "train_end_date": 20201231,
                    "val_beg_date": 20210101,
                    "val_end_date": 20211231,
                    "test_beg_date": 20220101,
                    "test_end_date": 20221231,
                },
                "train_frac": 0.7,
                "val_frac": 0.15,
                "test_frac": 0.15,
                "alpha_price": 100,
                "alpha_vol": 0.5,
                "sequence_length": 50,
                "stocks_sample": 20,
                "load_data": False,
                "save_data": True,
                "data_path": "./equity_dataset.npz",
                "steps_per_epoch": 300,
                "num_states": 0,
                "feature_set": ["return", "volume"]
            }
            

    def init(self):
        # Optionally implement any additional initialization.
        pass

# =============================================================================
# 4. Main: Instantiate the EquityDataset using a config dict.
# =============================================================================
if __name__ == "__main__":
    # Define a configuration dictionary. You can override any of the default parameters here.
    config = {
        "_name_": "equities_dataset",
        "dataset_config": {
            "daily_price_data_path": f"{EQUITIES_DATA_PATH}/daily_price_data_n_equities_50.npz",
            "train_beg_date": 19900101,
            "train_end_date": 20201231,
            "val_beg_date": 20210101,
            "val_end_date": 20211231,
            "test_beg_date": 20220101,
            "test_end_date": 20221231,
        },
        "train_frac": 0.7,
        "val_frac": 0.15,
        "test_frac": 0.15,
        "alpha_price": 100,
        "alpha_vol": 0.5,
        "sequence_length": 50,
        "stocks_sample": 20,
        "load_data": False,
        "save_data": True,
        "data_path": "./equity_dataset.npz",
        "steps_per_epoch": 1000
    }

    # Instantiate the EquityDataset and inject our configuration.
    ds = EquityDataset(**config)
    
    # Optionally, create a DataLoader to inspect a training batch.
    train_loader = DataLoader(ds.dataset_train, batch_size=16, shuffle=False)
    for X_batch, Y_batch in train_loader:
         print("Batch X shape:", X_batch.shape)  # Expected: (batch_size, sequence_length, stocks_sample, 2)
         print("Batch Y shape:", Y_batch.shape)  # Expected: (batch_size, sequence_length, stocks_sample)
         break

