import numpy as np
import os
import pandas as pd
import argparse

def int_yyyymmdd_to_datetime(date_int: int) -> pd.Timestamp:
    """
    Convert an integer YYYYMMDD (or YYYYMMDD-like integer) to a pandas Timestamp.
    """
    year = date_int // 10000
    month = (date_int % 10000) // 100
    day = date_int % 100
    return pd.Timestamp(year=year, month=month, day=day)

def truncate_and_save_data(n_equities: int, file_path: str, lookback_horizon: int = 16) -> None:
    """
    Load the full data from the given NPZ file, pick the first n_equities that have no missing data
    in the last lookback_horizon years (default 16 years), and then save the truncated data to a new NPZ file.
    
    The following fields are truncated along the equity dimension (axis=1 or both axes as applicable):
        - compustat_tensor: (dates, permnos, accounting_vars)
        - crsp_tensor: (dates, permnos, monthly_price_vars)
        - daily_crsp_tensor: (daily_dates, permnos, daily_price_vars)
        - compustat_yr_tensor_filled: (dates, permnos, yr_accounting_vars)
        - compustat_yr_tensor: (dates, permnos, yr_accounting_vars)
        - returns: (dates, permnos)
        - permnos: (permnos,)
        - pmno_merge_mat: (permnos, permnos)
    
    All other fields are saved unchanged.
    
    The new file is saved in the same directory as the original file with the same name,
    except that "_n_equities_{n_equities}" is appended before the file extension.
    
    Parameters:
        n_equities (int): Number of equities (permnos) to keep.
        file_path (str): Full path to the original NPZ file.
        lookback_horizon (int): Lookback horizon (in years) to check for complete data.
    """
    # Load the original data (allow_pickle=True for list objects)
    data = np.load(file_path, allow_pickle=True)
    
    # Convert the "dates" (monthly dates) to datetime objects.
    # (Assuming the "dates" field is a list/array of integers in YYYYMMDD format.)
    dates_dt = np.array([int_yyyymmdd_to_datetime(d) for d in data["dates"]])
    
    # Determine the threshold date: last date minus lookback_horizon years.
    start_date = pd.Timestamp(year=2011, month=1, day=1)
    end_date = pd.Timestamp(year=2022, month=12, day=31)
    
    # Build a boolean mask for dates between 2007 and 2020
    valid_time_mask = (dates_dt >= start_date) & (dates_dt <= end_date)
    
    # Use the "returns" field (shape: (dates, permnos)) to check for missing data.
    returns_recent = data["returns"][valid_time_mask, :]  # only consider the last lookback_horizon years
    # For each equity (column), check if any value is NaN over that period.
    valid_equity_mask = ~np.isnan(returns_recent).any(axis=0)
    # Make all equities valid
    valid_equity_mask = np.ones(data["returns"].shape[1], dtype=bool)
    
    valid_equity_indices = np.where(valid_equity_mask)[0]
    
    total_equities = data["permnos"].shape[0]
    removed = total_equities - len(valid_equity_indices)
    print(f"Out of {total_equities} equities, {removed} equities have missing data in the last {lookback_horizon} years; {len(valid_equity_indices)} equities remain.")
    
    # Pick the first n_equities from those with complete data.
    if len(valid_equity_indices) < n_equities:
        print(f"Warning: Only {len(valid_equity_indices)} equities are available with complete data. Using all of them.")
        chosen_indices = valid_equity_indices
    else:
        chosen_indices = valid_equity_indices[:n_equities]
    breakpoint()
    # Build a new dictionary with truncated fields using the chosen equity indices.
    truncated_data = {
        "compustat_tensor": data["compustat_tensor"][:, chosen_indices, :],
        "accounting_vars": data["accounting_vars"],
        "crsp_tensor": data["crsp_tensor"][:, chosen_indices, :],
        "monthly_price_vars": data["monthly_price_vars"],
        "dates": data["dates"],
        "ff_3f_daily": data["ff_3f_daily"],
        "rf_daily": data["rf_daily"],
        "daily_crsp_tensor": data["daily_crsp_tensor"][:, chosen_indices, :],
        "daily_price_vars": data["daily_price_vars"],
        "daily_dates": data["daily_dates"],
        "compustat_yr_tensor_filled": data["compustat_yr_tensor_filled"][:, chosen_indices, :],
        "compustat_yr_tensor": data["compustat_yr_tensor"][:, chosen_indices, :],
        "yr_accounting_vars": data["yr_accounting_vars"],
        "ff_monthly_data": data["ff_monthly_data"],
        "rf_monthly_rate": data["rf_monthly_rate"],
        "returns": data["returns"][:, chosen_indices],
        "permnos": data["permnos"][chosen_indices],
        #"pmno_merge_mat": data["pmno_merge_mat"][np.ix_(chosen_indices, chosen_indices)]
    }
    
    # Build the new file name by appending _n_equities_{n_equities} before the extension.
    base, ext = os.path.splitext(file_path)
    new_file_path = f"{base}_n_equities_{n_equities}{ext}"
    
    # Save the truncated data using compression.
    np.savez_compressed(new_file_path, **truncated_data)
    print(f"Truncated data saved to {new_file_path}")

# Example usage:
if __name__ == "__main__":
    # Set up command line argument parsing
    parser = argparse.ArgumentParser(description='Process and truncate equities data.')
    parser.add_argument('--n_equities', type=int, default=4000,
                        help='Number of equities to keep after filtering for nans') # 37000
    file_path = os.path.join(os.environ.get("EQUITIES_DATA_PATH"), "daily_price_data.npz")
    # Parse arguments
    args = parser.parse_args()
    
    # Run the function with the provided arguments
    truncate_and_save_data(n_equities=args.n_equities, file_path=file_path)