import sys
import os
import numpy as np
from scipy.stats import rankdata

# Import your real EquityEnv
BASE_PATH = os.environ.get("BASE_PATH", "")
sys.path.append(BASE_PATH)
from pathlib import Path
import numpy as np
# ---------------------------------------------------------------------------
# Configuration that never changes
# ---------------------------------------------------------------------------
DATA_DIR      = Path(f"{BASE_PATH}/data/equities")        # existing save folder
# Create the equities directory if it doesn't exist
DATA_DIR.mkdir(parents=True, exist_ok=True)
EQUITIES_DATA_PATH = os.environ.get("EQUITIES_DATA_PATH", ".")
DAILY_PATH    = os.path.join(EQUITIES_DATA_PATH, "daily_price_data_n_equities_4000.npz")
KEEP_N_STOCKS = 500                                                    # as before
from src.dataloaders.preprocessing_equities import EquityEnv

def compute_suv(daily_return, daily_volume, target_dates, save_path="suv_data.npy"):
    """
    Computes 'Standard Unexplained Volume' (SUV) for each (stock, day) in [train_beg_date : test_end_date]
    EXCLUDING any stocks that contain NaNs for daily volume or returns.
    
    The volume ~ returns regression is fit on each stock's data from the *previous* month:
        volume_(t in M-1) = b0 + b_pos * pos_return + b_neg * neg_return
    We compute that residual's std for M-1.  Then we apply those coefficients to
    the *current* month M.  The out-of-sample residual for day t in M is:
        actual_vol - predicted_vol, 
    which we then standardize by the M-1 residual's std.

    So for days in March, the model is fit using February data, and the March residual
    is divided by the February residual's stdev. The first month in the range has no prior month,
    so we skip it or set SUV=0 for that month.

    Saves the final (T, valid_N) SUV array as a float32 .npy file.

    Parameters
    ----------
    eq_env : EquityEnv
        A real EquityEnv object to load daily data from.
    config : dict
        Config with "dataset_config": { "train_beg_date": ..., "test_end_date": ... }.
    save_path : str
        Path to save the (T, valid_N) float32 array of SUV.

    Returns
    -------
    suv_array : np.ndarray, shape (T, valid_N)
        The final standard unexplained volume for each day in [start:end], 
        for the valid subset of stocks with no NaNs.
    """
    

    T, N = daily_return.shape

    assert not np.isnan(daily_return).any(), "daily_return still contains NaNs"
    assert not np.isnan(daily_volume).any(), "daily_volume still contains NaNs"
    
    valid_N = daily_return.shape[1]

    # 4) Convert each daily date to YYYYMM
    def to_yyyymm(date_int):
        year  = date_int // 10000
        month = (date_int % 10000) // 100
        return year*100 + month

    yyyymm = np.array([to_yyyymm(d) for d in target_dates], dtype=np.int32)
    unique_months = np.unique(yyyymm)
    unique_months_sorted = np.sort(unique_months)

    # We'll skip the *first* month in the sorted list, because it has no "previous" month to fit on.
    # We'll define a dictionary for storing the regression parameters from each month -> (beta, std)
    # Then we apply them to the *next* month.

    # 5) Prepare an output SUV array (T, valid_N)
    suv_array = np.zeros((T, valid_N), dtype=np.float32)

    # 6) For each stock, we iterate over months in chronological order, 
    #    fit on month i => store (coefs, residual_std). Then apply them to month i+1.
    for stock_idx in range(valid_N):
        print(f"Processing stock {stock_idx} of {valid_N}")
        # Use log volumes
        vol_series = np.log(daily_volume[:, stock_idx]+1)
        ret_series = daily_return[:, stock_idx]

        # We'll keep a dict: regression_params[month_val] = (beta, resid_std)
        regression_params = {}

        # Fit for each month i. Then apply to month i+1.
        for i in range(len(unique_months_sorted)):
            curr_m = unique_months_sorted[i]
            # gather that month's mask
            mask_curr = (yyyymm == curr_m)
            if not np.any(mask_curr):
                continue

            # If i == 0, there's no previous month to apply. 
            # We can't do out-of-sample for the first month. We skip or store zeros.
            if i == 0:
                # We'll just skip storing any regression for month 0 
                # => can't compute out-of-sample for this month anyway.
                continue

            # Let's get the "previous" month
            prev_m = unique_months_sorted[i-1]
            mask_prev = (yyyymm == prev_m)
            if not np.any(mask_prev):
                # If there's no data for the prev month, skip.
                continue

            # Fit using the *previous* month's data
            vol_prev = vol_series[mask_prev]
            ret_prev = ret_series[mask_prev]
            if len(vol_prev) < 5:
                # Not enough data to do a stable regression
                continue

            pos_ret = np.where(ret_prev > 0, ret_prev, 0.0)
            neg_ret = np.where(ret_prev < 0, -ret_prev, 0.0)

            X_prev = np.column_stack([
                np.ones_like(vol_prev),
                pos_ret,
                neg_ret
            ])
            beta, residuals, rank, svals = np.linalg.lstsq(X_prev, vol_prev, rcond=None)
            y_hat = X_prev @ beta
            resid_prev = vol_prev - y_hat
            
            std_resid_prev = resid_prev.std()
            if std_resid_prev < 1e-9:
                # can't standardize
                continue

            # Now apply to the *current* month data => out-of-sample predicted
            vol_curr = vol_series[mask_curr]
            ret_curr = ret_series[mask_curr]
            pos_curr = np.where(ret_curr > 0, ret_curr, 0.0)
            neg_curr = np.where(ret_curr < 0, -ret_curr, 0.0)

            X_curr = np.column_stack([
                np.ones_like(vol_curr),
                pos_curr,
                neg_curr
            ])

            vol_hat_curr = X_curr @ beta
            resid_curr = vol_curr - vol_hat_curr
            
            # standardize using the prev month's residual std
            suv_curr = resid_curr / std_resid_prev

            # store in suv_array
            suv_array[mask_curr, stock_idx] = suv_curr

    # --- NEW: Calculate and save cross-sectional median before normalization ---
    median_suv = np.median(suv_array, axis=1, keepdims=True) # Shape (T, 1)
    median_save_path = str(save_path).replace(".npy", "_median.npy")
    np.save(median_save_path, median_suv.astype(np.float32))
    print(f"Saved median SUV (shape {median_suv.shape}) to {median_save_path}")
    # --- END NEW ---

    # 7) Save (Original rank-normalized version)
    # Clip the array at the 1 and 99th percentile
    percentile_1 = np.percentile(suv_array, 0.1)
    percentile_99 = np.percentile(suv_array, 99.9)
    suv_array = np.clip(suv_array, percentile_1, percentile_99)

    # Cross-sectional rank normalization
    for t in range(T):
        row = suv_array[t, :]
        # Handle rows with all same values or NaNs if necessary, though NaNs should be gone
        if np.all(row == row[0]) or np.isnan(row).any():
             # Assign default rank (e.g., 0) or handle as appropriate
             normalized_ranks = np.zeros_like(row)
        else:
            ranks = rankdata(row, method='ordinal')  # Get ranks (1 to N)
            # Scale ranks to [-1, 1]
            # Check if N > 1 to avoid division by zero
            if valid_N > 1:
                normalized_ranks = 2 * (ranks - 1) / (valid_N - 1) - 1
            else:
                normalized_ranks = np.zeros_like(row) # Or handle single stock case as needed
        suv_array[t, :] = normalized_ranks
    
    np.save(save_path, suv_array.astype(np.float32))
    print(f"Saved rank-normalized SUV (shape {suv_array.shape}) to {save_path}")


    return suv_array # Return the rank-normalized version

def compute_beta(excess_return, market_return_proxy, target_dates, save_path="beta_data.npy"):
    """
    Computes beta values for each (stock, day) in the dataset.
    Beta is calculated on a monthly basis, where for each month:
    - We use the 252 trading days preceding that month to estimate the covariance matrix
    - Beta = cov(stock, market) / var(market)
    
    Parameters
    ----------
    excess_return : np.ndarray, shape (T, N)
        Excess returns for each stock (returns minus risk-free rate)
    market_return_proxy : np.ndarray, shape (T,)
        Market return proxy (e.g., from Fama-French factors)
    target_dates : np.ndarray, shape (T,)
        Dates corresponding to each timestep in YYYYMMDD format
    save_path : str
        Path to save the (T, N) float32 array of beta values
        
    Returns
    -------
    beta_array : np.ndarray, shape (T, N)
        Beta values for each day and stock
    """
    T, N = excess_return.shape
    
    # Initialize output array for beta values
    beta_array = np.zeros((T, N), dtype=np.float32)
    
    # Convert each daily date to YYYYMM
    def to_yyyymm(date_int):
        year = date_int // 10000
        month = (date_int % 10000) // 100
        return year*100 + month
    
    yyyymm = np.array([to_yyyymm(d) for d in target_dates], dtype=np.int32)
    unique_months = np.unique(yyyymm)
    unique_months_sorted = np.sort(unique_months)
    
    # Create a time index array for easier indexing
    time_indices = np.arange(T)
    
    # Process each month
    for month in unique_months_sorted:
        print(f"Processing beta for month {month}")
        
        # Get indices for current month
        month_mask = (yyyymm == month)
        month_indices = time_indices[month_mask]
        
        if len(month_indices) == 0:
            continue
        
        # Find the first day of the current month
        first_day_idx = month_indices[0]
        
        # If we don't have 30 previous days, default beta to 1
        if first_day_idx < 60:
            beta_array[month_mask, :] = 1.0
            continue
        
        # Get the preceding days for beta calculation (up to 252)
        lookback_start = max(0, first_day_idx - 252)
        lookback_end = first_day_idx
        
        # Slice the data for lookback period
        stock_returns_history = excess_return[lookback_start:lookback_end, :]
        market_return_history = market_return_proxy[lookback_start:lookback_end]
        
        # Calculate beta for each stock using the lookback period
        for stock_idx in range(N):
            stock_returns = stock_returns_history[:, stock_idx]
            
            # Calculate beta = cov(stock, market) / var(market)
            cov_stock_market = np.cov(stock_returns, market_return_history)[0, 1]
            var_market = np.var(market_return_history)
            
            assert var_market > 1e-8  # Avoid division by near-zero
            beta = cov_stock_market / var_market
                
            # Assign beta to all days in the current month for this stock
            beta_array[month_mask, stock_idx] = beta

    # --- NEW: Calculate and save cross-sectional median before normalization ---
    median_beta = np.median(beta_array, axis=1, keepdims=True) # Shape (1, N)
    
    median_save_path = str(save_path).replace(".npy", "_median.npy")
    np.save(median_save_path, median_beta.astype(np.float32))
    print(f"Saved median Beta (shape {median_beta.shape}) to {median_save_path}")
    # --- END NEW ---

    # Cross-sectional rank normalization
    for t in range(T):
        row = beta_array[t, :]
        # Handle rows with all same values or NaNs if necessary
        if np.all(row == row[0]) or np.isnan(row).any():
             # Assign default rank (e.g., 0) or handle as appropriate
             normalized_ranks = np.zeros_like(row)
        else:
            ranks = rankdata(row, method='ordinal')  # Get ranks (1 to N)
            # Scale ranks to [-1, 1]
            # Check if N > 1 to avoid division by zero
            if N > 1:
                normalized_ranks = 2 * (ranks - 1) / (N - 1) - 1
            else:
                normalized_ranks = np.zeros_like(row) # Or handle single stock case as needed

        beta_array[t, :] = normalized_ranks

    # Save the rank-normalized beta data
    
    np.save(save_path, beta_array.astype(np.float32))
    print(f"Saved rank-normalized Beta (shape {beta_array.shape}) to {save_path}")


    return beta_array # Return the rank-normalized version


def compute_residual_volatility(excess_return,
                                ff_3f_daily,
                                target_dates,
                                save_path="residual_vol_data.npy"):
    """
    Idiosyncratic (residual) volatility à la Ang, Hodrick, Xing & Zhang (2006).

    For each stock and month M:
        1.  Use *only* the two immediately‑preceding months (M‑2, M‑1)
            to fit  r_t = b0 + b_mkt*(MKT-RF)_t + b_smb*SMB_t + b_hml*HML_t  .
        2.  Compute the std‑dev of those in‑sample residuals.
        3.  Assign that single std value to every day in month M.
    This guarantees that the feature at day t never sees data from t+1 or later.

    The resulting (T,N) panel is clipped (0.1 / 99.9 pct) and cross‑sectionally
    rank‑normalised to [‑1,1] on every day, mirroring the SUV / Beta treatment.
    A vector of daily cross‑sectional medians (pre‑normalisation) is also saved.
    """
    import numpy as np
    from scipy.stats import rankdata

    T, N = excess_return.shape
    assert ff_3f_daily.shape[0] == T, "Factors and returns length mismatch"
    factors = ff_3f_daily[:, :3]         # use MKT‑RF, SMB, HML

    # --- date handling ----------------------------------------------------
    def to_yyyymm(date_int):
        return (date_int // 10000) * 100 + ((date_int % 10000) // 100)

    yyyymm = np.array([to_yyyymm(d) for d in target_dates], dtype=np.int32)
    months = np.sort(np.unique(yyyymm))
    time_idx = np.arange(T)

    resid_vol = np.zeros((T, N), dtype=np.float32)

    # ---------------------------------------------------------------------
    for stock in range(N):
        if stock % 100 == 0:
            print(f"Processing stock {stock+1}/{N}")

        r = excess_return[:, stock]

        for i in range(2, len(months)):          # start once we HAVE two look‑back months
            m_cur   = months[i]
            m_prev1 = months[i-1]
            m_prev2 = months[i-2]

            lb_mask = (yyyymm == m_prev1) | (yyyymm == m_prev2)
            assert lb_mask.sum() > 20               # <≈ one trading month of data

            X_lb = np.column_stack([np.ones(lb_mask.sum()), factors[lb_mask]])
            y_lb = r[lb_mask]

            beta, *_ = np.linalg.lstsq(X_lb, y_lb, rcond=None)
            resid   = y_lb - X_lb @ beta
            sigma   = resid.std()
            assert not np.isnan(sigma)
            #if sigma < 1e-8:
            #    continue
            
            # assign to current month (out‑of‑sample)
            resid_vol[yyyymm == m_cur, stock] = sigma

    # -------- save pre‑normalisation median -------------------------------
    
    median_daily = 10*np.median(resid_vol, axis=1, keepdims=True).astype(np.float32)
    
    median_path  = str(save_path).replace(".npy", "_median.npy")
    np.save(median_path, median_daily)
    print(f"Saved median residual‑volatility to {median_path}")

    for t in range(T):
        row = resid_vol[t]
        assert not np.isnan(row).any()
        if np.all(row == row[0]):
            resid_vol[t] = 0.0
        else:
            ranks = rankdata(row, method="ordinal")
            resid_vol[t] = 2 * (ranks - 1) / (N - 1) - 1
    
    # -------- save & return ----------------------------------------------
    np.save(save_path, resid_vol.astype(np.float32))
    print(f"Saved rank‑normalised residual volatility to {save_path}")

    return resid_vol



def main():
    start_year = 1998
    end_year = 2021
    # Validation years 1998 … 2021  (inclusive)
    for val_year in range(start_year, end_year+1):
        print(f"\n=== Running split with validation year {val_year} ===")

        # ------------------------------------------------------------------
        # 1. Derive train/val/test dates for this split
        #    train: [val_year-8, val_year-1]
        #    val  : [val_year,   val_year]
        #    test : same as val (we only care about val anyway)
        # ------------------------------------------------------------------
        def ymd(y, m, d): return y*10000 + m*100 + d

        train_beg_date = ymd(val_year - 8, 1, 1)
        train_end_date = ymd(val_year - 1, 12, 31)
        val_beg_date   = ymd(val_year, 1, 1)
        val_end_date   = ymd(val_year, 12, 31)
        test_beg_date  = ymd(val_year+1, 1, 1)
        test_end_date  = ymd(val_year+1, 12, 31)

        # EquityEnv for this split
        eq_env = EquityEnv(
            daily_price_data_path = DAILY_PATH,
            train_beg_date = train_beg_date,
            train_end_date = train_end_date,
            val_beg_date   = val_beg_date,
            val_end_date   = val_end_date,
            test_beg_date  = test_beg_date,
            test_end_date  = test_end_date,
        )

        # ------------------------------------------------------------------
        # 2. Load raw data and slice to [train_beg_date : val_end_date]
        # ------------------------------------------------------------------
        raw         = eq_env.load_raw_full()
        full_dates  = raw["daily_dates"]

        start_idx   = np.where(full_dates >= train_beg_date)[0][0]
        end_idx     = np.where(full_dates <= test_end_date)[0][-1]

        data        = raw["daily_crsp_tensor"][start_idx:end_idx+1].astype(np.float32)
        target_dates = full_dates[start_idx:end_idx+1]
        
        # --------------------------------------------------
        train_end_idx_orig = np.where(full_dates <= train_end_date)[0][-1]
        train_end_idx      = train_end_idx_orig - start_idx     # 0-based in `data`

        # ------------------------------------------------------------------
        # 3. Filter stocks (identical logic to your current script)
        # ------------------------------------------------------------------
        # ------------------------------------------------------------------
        # NEW:  identical NaN + market-cap filter
        #       (no more volume screen)
        # ------------------------------------------------------------------
        n_assets   = KEEP_N_STOCKS          # e.g. 500 -- must match your dataset code

        nan_filter = ~np.isnan(data).any(axis=(0, 2))
        print("Valid stocks (no NaNs):", nan_filter.sum())

        # ---- last-quarter-end that falls *inside* the training window ----
        train_end_year   = int(str(train_end_date)[:4])
        train_end_month  = int(str(train_end_date)[4:6])

        if   train_end_month <= 3:  last_qtr_end = train_end_year*10000 +  331
        elif train_end_month <= 6:  last_qtr_end = train_end_year*10000 +  630
        elif train_end_month <= 9:  last_qtr_end = train_end_year*10000 +  930
        else:                       last_qtr_end = train_end_year*10000 + 1231

        candidate_idx = np.where(raw["dates"] <= last_qtr_end)[0][-1]

        # -------- market-cap at that quarter end --------------------------
        var_names  = raw["accounting_vars"]
        var_dict   = {name: i for i, name in enumerate(var_names)}
        cshoq      = raw["compustat_tensor"][:, :, var_dict["cshoq"]]     # shares-outstanding
        mkt_price  = data[..., 0]                                         # px @ each day
        mcap_last  = mkt_price[train_end_idx, nan_filter] * cshoq[candidate_idx, nan_filter]

        # -------- top-N market-cap selection -----------------------------
        valid_stocks     = nan_filter.copy()
        cand_idx         = np.where(nan_filter)[0]               # unwrap True positions
        ok_mask          = ~np.isnan(mcap_last)
        mcap_filtered    = mcap_last[ok_mask]
        cand_idx_filt    = cand_idx[ok_mask]

        assert mcap_filtered.size > n_assets                     # sanity

        top_idx          = np.argsort(mcap_filtered)[-n_assets:] # largest N
        selected_indices = cand_idx_filt[top_idx]

        big_mcap_filter  = np.zeros_like(valid_stocks, dtype=bool)
        big_mcap_filter[selected_indices] = True

        valid_stocks &= big_mcap_filter
        print(f"After NaN + mcap: keeping {valid_stocks.sum()} stocks (top {n_assets}).")
        ids = raw["permnos"][valid_stocks]
        #print("ids:", ids)
        
        # final slice
        data        = data[:, valid_stocks]
        daily_ret   = data[..., 3]
        daily_vol   = data[..., 4]

        # ------------------------------------------------------------------
        # 4. Compute SUV  &  Beta  and save with split-specific filenames
        # ------------------------------------------------------------------
        suv_file   = DATA_DIR / f"suv_{val_year}.npy"
        beta_file  = DATA_DIR / f"beta_{val_year}.npy"
        residual_vol_file = DATA_DIR / f"residual_vol_{val_year}.npy"

        suv_array  = compute_suv(daily_ret,
                                 daily_vol,
                                 target_dates,
                                 save_path=suv_file)

        rf_rate    = raw["rf_daily"][start_idx:end_idx+1]
        excess_ret = daily_ret - rf_rate
        mkt_proxy  = raw["ff_3f_daily"][start_idx:end_idx+1, 0]

        beta_array = compute_beta(excess_ret,
                                  mkt_proxy,
                                  target_dates,
                                  save_path=beta_file)

        print(f"✔  Done {val_year}: suv → {suv_file.name}, beta → {beta_file.name}")

        ff_3f_daily = raw["ff_3f_daily"][start_idx:end_idx+1]
        resid_vol_array = compute_residual_volatility(excess_ret,
                                    ff_3f_daily,
                                    target_dates,
                                    save_path=residual_vol_file)


if __name__ == "__main__":
    main()
