"""
Trend Functions for Predictable Trend Verification

This module contains mathematical models for fitting trends to sequences of objective scores.
Used by PredictableTrendVerifier to determine if an objective follows a predictable pattern.
"""

import numpy as np
from typing import Tuple, Dict, Any, Callable, Optional
from scipy.optimize import curve_fit
from scipy.stats import pearsonr


# Minimum positive value for strictly increasing trend constraints
MIN_POSITIVE = 1e-1


def linear_func(t: np.ndarray, a: float, b: float) -> np.ndarray:
    """
    Linear trend function: f(t) = a*t + b
    
    Args:
        t: Time/iteration indices (starting from 1)
        a: Slope
        b: Intercept
        
    Returns:
        Array of predicted values
    """
    return a * t + b

def flat_func(t: np.ndarray, c: float) -> np.ndarray:
    """
    Constant/flat trend function: f(t) = c
    
    Args:
        t: Time/iteration indices
        c: Constant value
        
    Returns:
        Array of constant values
    """
    return np.full_like(t, c, dtype=float)

def exponential_sat_func(t: np.ndarray, S: float, gain: float, k: float) -> np.ndarray:
    """
    Exponential Saturation (EXP3 / Converging).
    Models rapid initial learning that quickly hits a ceiling.
    
    Formula: f(t) = S + gain * (1 - exp(-k * (t-1)))
    
    Params:
    - S: Starting score (at t=1)
    - gain: Total possible improvement (Asymptote - Start). gain >= 0.
    - k: Rate of convergence. k > 0.
    """
    return S + gain * (1 - np.exp(-k * (t - 1)))

def power_law_func(t: np.ndarray, a: float, b: float, c: float) -> np.ndarray:
    """
    Power Law with Asymptote (POW3).
    Models "heavy tail" learning that continues improving for a long time.
    Common in LLM scaling laws.
    
    Formula: f(t) = c - a * t^(-b)
    
    Params:
    - c: The asymptote (limit as t->inf).
    - a: Scale factor (distance from asymptote).
    - b: Decay rate of the error.
    
    Constraints to ensure upward trend:
    - a >= 0 (So we subtract a positive value that shrinks)
    - b > 0  (So the subtracted term shrinks over time)
    """
    return c - a * np.power(t, -b)

def logarithmic_func(t: np.ndarray, a: float, b: float) -> np.ndarray:
    """
    Logarithmic trend function: f(t) = a * log(t) + b
    
    Args:
        t: Time/iteration indices (must be > 0)
        a: Scale parameter
        b: Offset parameter
        
    Returns:
        Array of predicted values
    """
    return a * np.log(t) + b

# def converging_func(t: np.ndarray, L: float, k: float, S: float) -> np.ndarray:
#     """
#     Converging trend function (can converge upward or downward):
#     f(t) = L + (S - L) * exp(-k * (t-1))
    
#     This unifies converging up and down into a single function:
#     - If S < L: converges upward (concave) to asymptote L
#     - If S > L: converges downward (convex) to asymptote L
    
#     Args:
#         t: Time/iteration indices (starting from 1)
#         L: Asymptotic limit
#         k: Rate of convergence (positive)
#         S: Starting value at t=1
        
#     Returns:
#         Array of predicted values
#     """
#     return L + (S - L) * np.exp(-k * (t - 1))


TREND_FUNCTIONS = {
    'linear': {
        'func': linear_func,
        'params': ['a', 'b'],
        'bounds': ([MIN_POSITIVE, -np.inf], [np.inf, np.inf]),  # a > 0 for strictly increasing
        'description': 'Linear trend (strictly increasing)',
        'n_params': 2
    },
    'flat': {
        'func': flat_func,
        'params': ['c'],
        'bounds': ([-np.inf], [np.inf]),
        'description': 'Constant/flat trend',
        'n_params': 1
    },
    'exponential_sat': {
        'func': exponential_sat_func,
        'params': ['S', 'gain', 'k'],
        'bounds': ([-np.inf, MIN_POSITIVE, MIN_POSITIVE], [np.inf, np.inf, 10]),  # gain > 0, k > 0 for strictly increasing
        'description': 'Exponential saturation (rapid initial gain, converges to ceiling)',
        'n_params': 3
    },
    'power_law': {
        'func': power_law_func,
        'params': ['a', 'b', 'c'],
        'bounds': ([MIN_POSITIVE, 3e-1, -np.inf], [np.inf, 10, np.inf]),  # a > 0, b > 0 for strictly increasing
        'description': 'Power law with asymptote (heavy tail learning)',
        'n_params': 3
    },
    'logarithmic': {
        'func': logarithmic_func,
        'params': ['a', 'b'],
        'bounds': ([MIN_POSITIVE, -np.inf], [np.inf, np.inf]),  # a > 0 for strictly increasing
        'description': 'Logarithmic growth (strictly increasing)',
        'n_params': 2
    },
}


def fit_trend(
    t_values: np.ndarray,
    y_values: np.ndarray,
    trend_type: str,
    maxfev: int = 5000
) -> Tuple[np.ndarray, Dict[str, float], float]:
    """
    Fit a trend function to data using scipy.optimize.curve_fit.
    
    Args:
        t_values: Time/iteration indices
        y_values: Observed values at each time point
        trend_type: Type of trend to fit (key in TREND_FUNCTIONS)
        maxfev: Maximum function evaluations for optimization
        
    Returns:
        Tuple of:
        - Optimal parameters
        - Dictionary mapping parameter names to values
        - Mean squared error of the fit
    """
    if trend_type not in TREND_FUNCTIONS:
        raise ValueError(f"Unknown trend type: {trend_type}")
    
    trend_info = TREND_FUNCTIONS[trend_type]
    func = trend_info['func']
    param_names = trend_info['params']
    bounds = trend_info['bounds']
    
    try:
        # Get initial guess based on data
        initial_guess = _get_initial_guess(t_values, y_values, trend_type)
        
        # Fit the curve
        popt, _ = curve_fit(
            func,
            t_values,
            y_values,
            p0=initial_guess,
            bounds=bounds,
            maxfev=maxfev
        )
        
        # Calculate predictions and MSE
        predictions = func(t_values, *popt)
        mse = np.mean((y_values - predictions) ** 2)
        
        # Create parameter dictionary
        param_dict = dict(zip(param_names, popt))
        
        return popt, param_dict, mse
        
    except Exception as e:
        print(f"Error fitting {trend_type} trend: {e}")
        # # Return default parameters on failure (original - produces bad fits)
        # default_params = np.ones(trend_info['n_params'])
        # param_dict = dict(zip(param_names, default_params))
        # return default_params, param_dict, float('inf')

        # Return data-aware default parameters on failure
        # These create a nearly-flat line at the data mean (satisfying strictly increasing constraint)
        y_mean = np.mean(y_values)

        if trend_type == 'linear':
            # f(t) = a*t + b, with minimal slope; at t=mean(t), f = y_mean
            t_mean = np.mean(t_values)
            default_params = np.array([MIN_POSITIVE, y_mean - MIN_POSITIVE * t_mean])
        elif trend_type == 'exponential_sat':
            # f(t) = S + gain * (1 - exp(-k*(t-1))), minimal gain
            default_params = np.array([y_mean, MIN_POSITIVE, 1.0])
        elif trend_type == 'power_law':
            # f(t) = c - a * t^(-b), minimal a
            default_params = np.array([MIN_POSITIVE, 1.0, y_mean])
        elif trend_type == 'logarithmic':
            # f(t) = a * log(t) + b, minimal slope
            default_params = np.array([MIN_POSITIVE, y_mean])
        else:
            # Fallback to ones (shouldn't reach here for known types)
            default_params = np.ones(trend_info['n_params'])

        param_dict = dict(zip(param_names, default_params))

        # Calculate MSE with these default params
        predictions = func(t_values, *default_params)
        mse = np.mean((y_values - predictions) ** 2)

        return default_params, param_dict, mse


def _get_initial_guess(t_values: np.ndarray, y_values: np.ndarray, trend_type: str) -> np.ndarray:
    """
    Generate intelligent initial guesses for curve fitting based on data.

    Args:
        t_values: Time indices
        y_values: Observed values
        trend_type: Type of trend

    Returns:
        Array of initial parameter guesses
    """
    y_mean = np.mean(y_values)
    y_min = np.min(y_values)
    y_max = np.max(y_values)
    y_range = y_max - y_min

    if trend_type == 'linear':
        # Estimate slope and intercept
        if len(t_values) > 1:
            slope = (y_values[-1] - y_values[0]) / (t_values[-1] - t_values[0])
        else:
            slope = 0

        # Clamp slope to be at least MIN_POSITIVE (since bounds require a > 0)
        # If data trends down, we start with minimal positive slope
        slope = max(MIN_POSITIVE, slope)

        intercept = y_values[0] - slope * t_values[0]
        return np.array([slope, intercept])
    
    elif trend_type == 'flat':
        return np.array([y_mean])
    
    elif trend_type == 'exponential_sat':
        # Estimate starting value, gain, and rate
        S = y_values[0]   # Starting value
        gain = max(MIN_POSITIVE, y_values[-1] - y_values[0])  # Total improvement (must be > 0)
        k = max(MIN_POSITIVE, 1.0 / len(t_values))  # Rough rate estimate (must be > 0)
        return np.array([S, gain, k])
    
    elif trend_type == 'quadratic':
        # Simple quadratic fit initialization
        return np.array([0.01, 0.1, y_values[0]])
    
    elif trend_type == 'logarithmic':
        # Estimate log scale and offset
        if len(t_values) > 1 and t_values[0] > 0:
            a = y_range / np.log(t_values[-1] / t_values[0])
        else:
            a = 1.0
        # Clamp a to be at least MIN_POSITIVE (bounds require a > 0)
        a = max(MIN_POSITIVE, a)
        b = y_mean
        return np.array([a, b])
    
    elif trend_type == 'sigmoid':
        # Estimate sigmoid parameters
        L = y_range
        k = 4.0 / len(t_values)
        t0 = np.median(t_values)
        S = y_min
        return np.array([L, k, t0, S])
    
    elif trend_type == 'power_law':
        # f(t) = c - a * t^(-b), where c is asymptote, a = c - y[0]
        c = y_values[-1]  # Asymptote estimate
        a = max(MIN_POSITIVE, c - y_values[0])  # Scale factor (must be > 0)
        b = max(MIN_POSITIVE, 0.5)  # Decay rate (must be > 0)
        return np.array([a, b, c])
    
    else:
        # Default initialization
        n_params = TREND_FUNCTIONS[trend_type]['n_params']
        return np.ones(n_params)


def evaluate_trend_fit(
    t_values: np.ndarray,
    y_values: np.ndarray,
    func: Callable,
    params: np.ndarray
) -> Dict[str, float]:
    """
    Evaluate the quality of a trend fit.
    
    Args:
        t_values: Time indices
        y_values: Observed values
        func: Trend function
        params: Fitted parameters
        
    Returns:
        Dictionary of evaluation metrics
    """
    predictions = func(t_values, *params)
    residuals = y_values - predictions
    
    # Calculate various metrics
    mse = np.mean(residuals ** 2)
    rmse = np.sqrt(mse)
    mae = np.mean(np.abs(residuals))
    
    # R-squared
    ss_res = np.sum(residuals ** 2)
    ss_tot = np.sum((y_values - np.mean(y_values)) ** 2)
    r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0
    
    # Pearson correlation
    if len(y_values) > 1:
        correlation, p_value = pearsonr(y_values, predictions)
    else:
        correlation, p_value = 0, 1
    
    return {
        'mse': mse,
        'rmse': rmse,
        'mae': mae,
        'r_squared': r_squared,
        'correlation': correlation,
        'p_value': p_value,
        'max_error': np.max(np.abs(residuals)),
        'mean_residual': np.mean(residuals),
        'std_residual': np.std(residuals)
    }


def find_best_trend(
    t_values: np.ndarray,
    y_values: np.ndarray,
    trend_types: Optional[list] = None,
    metric: str = 'mse'
) -> Tuple[str, np.ndarray, Dict[str, float], Dict[str, Any]]:
    """
    Find the best-fitting trend from a set of candidates.
    
    Args:
        t_values: Time indices
        y_values: Observed values
        trend_types: List of trend types to try (default: all available)
        metric: Metric to use for selection ('mse', 'rmse', 'mae', 'r_squared')
        
    Returns:
        Tuple of:
        - Best trend type name
        - Optimal parameters for best trend
        - Parameter dictionary for best trend
        - Dictionary of all fit results
    """
    if trend_types is None:
        trend_types = list(TREND_FUNCTIONS.keys())
    
    best_trend = None
    best_params = None
    best_param_dict = None
    best_score = float('inf') if metric != 'r_squared' else -float('inf')
    all_results = {}
    
    for trend_type in trend_types:
        try:
            params, param_dict, mse = fit_trend(t_values, y_values, trend_type)
            
            # Evaluate fit
            trend_func = TREND_FUNCTIONS[trend_type]['func']
            metrics = evaluate_trend_fit(t_values, y_values, trend_func, params)
            
            # Store results
            all_results[trend_type] = {
                'params': params,
                'param_dict': param_dict,
                'metrics': metrics
            }
            
            # Check if this is the best fit
            score = metrics[metric]
            if metric == 'r_squared':
                is_better = score > best_score
            else:
                is_better = score < best_score
                
            if is_better:
                best_trend = trend_type
                best_params = params
                best_param_dict = param_dict
                best_score = score
                
        except Exception as e:
            print(f"Failed to fit {trend_type}: {e}")
            continue
    
    return best_trend, best_params, best_param_dict, all_results