import re
import numpy as np
import inspect
from inspect import getsource

LIB_FLAG = "normal"  # simple #

# Here is based formulation of factors
def mean_return(returns, window=5):
    """Calculate mean return over recent window"""
    return np.mean(returns[-window:])


# def std_return(returns, window=5):
#     """Calculate standard deviation of returns over recent window"""
#     return np.std(returns[-window:])


def momentum(prices, window=10):
    """Calculate momentum: percentage change between current price and window periods ago"""
    return (prices[-1] / prices[-window]) - 1


def reward_to_drawdown(prices, window=20):
    """
    Calculate ratio of total return to max drawdown.
    Higher score = better reward given risk.
    """
    recent_prices = prices[-window:]
    total_return = (recent_prices[-1] / recent_prices[0]) - 1

    peak = np.maximum.accumulate(recent_prices)
    drawdowns = (recent_prices - peak) / peak
    max_dd = abs(np.min(drawdowns)) + 1e-6

    return total_return / max_dd


def sharpe_ratio(returns, window=20):
    """Calculate Sharpe ratio (assuming risk-free rate = 0)"""
    mean_ret = np.mean(returns[-window:])
    std_ret = np.std(returns[-window:])
    return mean_ret / std_ret if std_ret != 0 else 0


def volatility_score(prices, window=20):
    """
    Stability score based on volatility (standard deviation of log returns).
    Higher = more stable.
    """
    log_returns = np.diff(np.log(prices[-window:]))
    vol = np.std(log_returns)

    # Convert to score: higher score = more stable (lower volatility)
    score = 1 / (vol + 1e-6)
    return score


def price_position(prices, window=20):
    """Calculate current price's relative position within window range (0 to 1)"""
    recent_prices = prices[-window:]
    min_price = np.min(recent_prices)
    max_price = np.max(recent_prices)
    return (
        (recent_prices[-1] - min_price) / (max_price - min_price)
        if (max_price - min_price) != 0
        else 0.5
    )


def log_return(prices, window=1):
    """Calculate cumulative log return over window"""
    return np.log(prices[-1] / prices[-window - 1])


# Additional factors that apply TA
def rsi(prices, window=14):
    """Calculate Relative Strength Index"""
    deltas = np.diff(prices[-window - 1 :])
    gains = np.where(deltas > 0, deltas, 0)
    losses = np.where(deltas < 0, -deltas, 0)
    avg_gain = np.mean(gains)
    avg_loss = np.mean(losses)
    rs = avg_gain / avg_loss if avg_loss != 0 else np.inf
    return 100 - (100 / (1 + rs))


def moving_average(prices, window=20):
    """Calculate simple moving average"""
    return np.mean(prices[-window:])


def bollinger_band_score(prices, window=20, std_dev=2):
    """
    Stability score from Bollinger Band width.
    Score closer to 1 means more stable (narrow band).
    """
    ma = moving_average(prices, window)
    std = np.std(prices[-window:])
    width = (2 * std) / (ma + 1e-6)  # Prevent div0

    return 1 / (1 + width)


def price_ema_ratio(prices, window=10):
    """Calculate ratio of current price to EMA"""
    weights = np.exp(np.linspace(-1.0, 0.0, window))
    weights /= weights.sum()
    ema = np.convolve(prices[-window:], weights, mode="valid")[0]
    return prices[-1] / ema


def return_skewness_score(prices, window=10):
    """
    Return a skewness-based stability score: the closer to zero skewness, the higher the score.
    Output range: (0, 1], with 1 = perfect symmetry
    """
    log_returns = np.diff(np.log(prices[-window:]))
    mean = np.mean(log_returns)
    std = np.std(log_returns)
    if std == 0:
        return 1  # Perfectly flat, considered stable

    skew = np.mean(((log_returns - mean) / std) ** 3)

    # Stability score: closer to 1 means more symmetric return distribution
    score = np.exp(-abs(skew))
    return score


# Addition factor from alpha101, 191 (close price only)
def price_zscore(prices, window=14):
    """Z-score of current price within window"""
    mean = np.mean(prices[-window:])
    std = np.std(prices[-window:])
    return (prices[-1] - mean) / std if std != 0 else 0


def pct_rank_position(prices, window=14):
    """Percentile rank of current price within window"""
    rank = np.argsort(np.argsort(prices[-window:]))[-1]
    return rank / (window - 1)


def price_diff_mean(prices, window=7):
    """Difference between current price and window mean"""
    return prices[-1] - np.mean(prices[-window:])


def mean_reversion_strength(prices, window=7):
    """Deviation from rolling mean normalized by std (like signal strength)"""
    mu = np.mean(prices[-window:])
    sigma = np.std(prices[-window:])
    return (prices[-1] - mu) / sigma if sigma != 0 else 0


def local_trend_slope(prices, window=5):
    """Linear regression slope of prices over window (trend direction)"""
    x = np.arange(window)
    y = prices[-window:]
    A = np.vstack([x, np.ones(window)]).T
    slope, _ = np.linalg.lstsq(A, y, rcond=None)[0]
    return slope


def trend_acceleration(prices, window=5):
    """Second difference of prices to detect acceleration"""
    return prices[-1] - 2 * prices[-2] + prices[-3]


def rolling_max_decay(prices, window=10):
    """Distance between current price and rolling max (mean-reversion signal)"""
    max_price = np.max(prices[-window:])
    return (prices[-1] - max_price) / max_price if max_price != 0 else 0


def recent_jump(prices, window=5):
    """Max single-day return in recent window"""
    returns = np.diff(prices[-(window + 1) :]) / prices[-(window + 1) : -1]
    return np.max(returns)


def return_skewness(prices, window=10):
    """Skewness of log returns"""
    log_returns = np.diff(np.log(prices[-window:]))
    mean = np.mean(log_returns)
    std = np.std(log_returns)
    if std == 0:
        return 0
    skew = np.mean(((log_returns - mean) / std) ** 3)
    return skew


if LIB_FLAG == "simple":
    init_expression_lib = {
        "mean_return_5": lambda prices, returns: mean_return(returns, 5),
        "mean_return_14": lambda prices, returns: mean_return(returns, 14),
        "momentum_5": lambda prices, returns: momentum(prices, 5),
        "momentum_14": lambda prices, returns: momentum(prices, 14),
        "reward_to_drawdown_5": lambda prices, returns: reward_to_drawdown(prices, 5),
        "reward_to_drawdown_14": lambda prices, returns: reward_to_drawdown(prices, 14),
        "volatility_5": lambda prices, returns: volatility_score(prices, 5),
        "volatility_14": lambda prices, returns: volatility_score(prices, 14),
        "sharpe_ratio_5": lambda prices, returns: sharpe_ratio(returns, 5),
        "sharpe_ratio_14": lambda prices, returns: sharpe_ratio(returns, 14),
        "log_return_1": lambda prices, returns: log_return(prices, 1),
    }

elif LIB_FLAG == "normal":
    # Expression library (dictionary format)
    init_expression_lib = {
        "mean_return_5": lambda prices, returns: mean_return(returns, 5),
        "mean_return_14": lambda prices, returns: mean_return(returns, 14),
        "momentum_5": lambda prices, returns: momentum(prices, 5),
        "momentum_14": lambda prices, returns: momentum(prices, 14),
        "sharpe_ratio_5": lambda prices, returns: sharpe_ratio(returns, 5),
        "sharpe_ratio_14": lambda prices, returns: sharpe_ratio(returns, 14),
        "ma_5": lambda prices, returns: moving_average(prices, 5),
        "ma_14": lambda prices, returns: moving_average(prices, 14),
        "ema_ratio_5": lambda prices, returns: price_ema_ratio(prices, 5),
        "ema_ratio_14": lambda prices, returns: price_ema_ratio(prices, 14),
        "reward_to_drawdown_5": lambda prices, returns: reward_to_drawdown(prices, 5),
        "reward_to_drawdown_14": lambda prices, returns: reward_to_drawdown(prices, 14),
        "reward_to_drawdown_21": lambda prices, returns: reward_to_drawdown(prices, 21),
        "volatility_5": lambda prices, returns: volatility_score(prices, 5),
        "volatility_14": lambda prices, returns: volatility_score(prices, 14),
        "price_position_5": lambda prices, returns: price_position(prices, 5),
        "price_position_14": lambda prices, returns: price_position(prices, 14),
        "bb_width_14": lambda prices, returns: bollinger_band_score(prices, 14),
        "bb_width_21": lambda prices, returns: bollinger_band_score(prices, 21),
        "rsi_21": lambda prices, returns: rsi(prices, 21),
        "log_return_1": lambda prices, returns: log_return(prices, 1),
        "return_skewness_score_21": lambda prices, returns: return_skewness_score(
            prices, 21
        ),
    }

elif LIB_FLAG == "advanced":
    init_expression_lib = {
        "mean_return_5": lambda prices, returns: mean_return(returns, 5),
        "momentum_5": lambda prices, returns: momentum(prices, 5),
        "reward_to_drawdown_5": lambda prices, returns: reward_to_drawdown(prices, 5),
        "sharpe_ratio_5": lambda prices, returns: sharpe_ratio(returns, 5),
        "volatility_5": lambda prices, returns: volatility_score(prices, 5),
        "price_position_5": lambda prices, returns: price_position(prices, 5),
        "ma_5": lambda prices, returns: moving_average(prices, 5),
        "bb_width_5": lambda prices, returns: bollinger_band_score(prices, 5),
        "ema_ratio_5": lambda prices, returns: price_ema_ratio(prices, 5),
        "mean_return_14": lambda prices, returns: mean_return(returns, 14),
        "momentum_14": lambda prices, returns: momentum(prices, 14),
        "reward_to_drawdown_14": lambda prices, returns: reward_to_drawdown(prices, 14),
        "sharpe_ratio_14": lambda prices, returns: sharpe_ratio(returns, 14),
        "volatility_14": lambda prices, returns: volatility_score(prices, 14),
        "price_position_14": lambda prices, returns: price_position(prices, 14),
        "ma_14": lambda prices, returns: moving_average(prices, 14),
        "bb_width_14": lambda prices, returns: bollinger_band_score(prices, 14),
        "ema_ratio_14": lambda prices, returns: price_ema_ratio(prices, 14),
        "rsi_14": lambda prices, returns: rsi(prices, 14),
        "log_return_1": lambda prices, returns: log_return(prices, 1),
        # + advanced
        "price_zscore_5": lambda prices, returns: price_zscore(prices, 5),
        "price_zscore_14": lambda prices, returns: price_zscore(prices, 14),
        "pct_rank_pos_5": lambda prices, returns: pct_rank_position(prices, 5),
        "pct_rank_pos_14": lambda prices, returns: pct_rank_position(prices, 14),
        "price_diff_mean_5": lambda prices, returns: price_diff_mean(prices, 5),
        "price_diff_mean_14": lambda prices, returns: price_diff_mean(prices, 14),
        "mean_reversion_strength_5": lambda prices, returns: mean_reversion_strength(
            prices, 5
        ),
        "mean_reversion_strength_14": lambda prices, returns: mean_reversion_strength(
            prices, 14
        ),
        "local_trend_slope_5": lambda prices, returns: local_trend_slope(prices, 5),
        "local_trend_slope_14": lambda prices, returns: local_trend_slope(prices, 14),
        "trend_acceleration": lambda prices, returns: trend_acceleration(prices, 5),
        "rolling_max_decay_5": lambda prices, returns: rolling_max_decay(prices, 5),
        "rolling_max_decay_14": lambda prices, returns: rolling_max_decay(prices, 14),
        "recent_jump_5": lambda prices, returns: recent_jump(prices, 5),
        "recent_jump_14": lambda prices, returns: recent_jump(prices, 14),
        "return_skewness_5": lambda prices, returns: return_skewness(prices, 5),
        "return_skewness_14": lambda prices, returns: return_skewness(prices, 14),
    }
else:
    raise NotImplementedError()

# Assuming these functions are defined in your code (not imported)
def get_implementation_source(func_name, init_expression_lib):

    if func_name not in init_expression_lib:
        print(f"Key '{func_name}' not found in expression library")
        return None

    lambda_str = inspect.getsource(init_expression_lib[func_name])
    # Method 1: Simple extraction (when you know the pattern)
    match = re.search(r":\s*(\w+)\(", lambda_str)
    if match:
        func_name = match.group(1)
        # print(func_name)  # Output: std_return

    # Method 2: More robust extraction (handles more cases)
    match = re.search(r":\s*([a-zA-Z_][a-zA-Z0-9_]*)\(", lambda_str)
    if match:
        func_name = match.group(1)
        # print(func_name)  # Output: std_return

    # Method 3: Full decomposition (gets function and args)
    match = re.search(r":\s*(\w+)\((.*?)\)", lambda_str)
    if match:
        func_name = match.group(1)
        func_args = match.group(2)
        # print(f"Function: {func_name}")      # Output: Function: std_return
        # print(f"Arguments: {func_args}")     # Output: Arguments: returns, 5

    try:
        # Get the actual function object from globals()
        func = globals().get(func_name)
        if func:
            return getsource(func)
        return f"Function {func_name} not found in global scope"
    except (TypeError, OSError):
        return f"Source not available for {func_name}"


if __name__ == "__main__":
    std_return_source = get_implementation_source("std_return_5", init_expression_lib)
    print(std_return_source)
