"""
Statistical metrics for experimental analysis.

Correlation measures and hypothesis testing utilities.
"""

from __future__ import annotations

from typing import Tuple, List, Dict, Optional
import numpy as np
from scipy import stats


def pearson_correlation(
    x: np.ndarray,
    y: np.ndarray
) -> Tuple[float, float]:
    """
    Compute Pearson correlation coefficient.

    Args:
        x: First array
        y: Second array

    Returns:
        (correlation, p_value)
    """
    x = np.asarray(x).flatten()
    y = np.asarray(y).flatten()

    # Remove NaN and inf values
    valid = np.isfinite(x) & np.isfinite(y)
    x = x[valid]
    y = y[valid]

    if len(x) < 3:
        return float('nan'), 1.0

    corr, pval = stats.pearsonr(x, y)
    return float(corr), float(pval)


def spearman_correlation(
    x: np.ndarray,
    y: np.ndarray
) -> Tuple[float, float]:
    """
    Compute Spearman rank correlation coefficient.

    Args:
        x: First array
        y: Second array

    Returns:
        (correlation, p_value)
    """
    x = np.asarray(x).flatten()
    y = np.asarray(y).flatten()

    # Remove NaN and inf values
    valid = np.isfinite(x) & np.isfinite(y)
    x = x[valid]
    y = y[valid]

    if len(x) < 3:
        return float('nan'), 1.0

    corr, pval = stats.spearmanr(x, y)
    return float(corr), float(pval)


def kendall_correlation(
    x: np.ndarray,
    y: np.ndarray
) -> Tuple[float, float]:
    """
    Compute Kendall's tau correlation coefficient.

    Args:
        x: First array
        y: Second array

    Returns:
        (correlation, p_value)
    """
    x = np.asarray(x).flatten()
    y = np.asarray(y).flatten()

    # Remove NaN and inf values
    valid = np.isfinite(x) & np.isfinite(y)
    x = x[valid]
    y = y[valid]

    if len(x) < 3:
        return float('nan'), 1.0

    corr, pval = stats.kendalltau(x, y)
    return float(corr), float(pval)


def compute_correlations(
    x: np.ndarray,
    y: np.ndarray
) -> Dict[str, Tuple[float, float]]:
    """
    Compute all correlation measures.

    Args:
        x: First array
        y: Second array

    Returns:
        Dict mapping correlation type to (value, p_value)
    """
    return {
        'pearson': pearson_correlation(x, y),
        'spearman': spearman_correlation(x, y),
        'kendall': kendall_correlation(x, y),
    }


def confidence_interval(
    data: np.ndarray,
    confidence: float = 0.95
) -> Tuple[float, float, float]:
    """
    Compute confidence interval for mean.

    Args:
        data: Array of values
        confidence: Confidence level (default 0.95)

    Returns:
        (mean, lower_bound, upper_bound)
    """
    data = np.asarray(data)
    data = data[np.isfinite(data)]

    if len(data) == 0:
        return float('nan'), float('nan'), float('nan')

    mean = np.mean(data)
    se = stats.sem(data)
    n = len(data)

    if n < 2:
        return float(mean), float('nan'), float('nan')

    h = se * stats.t.ppf((1 + confidence) / 2, n - 1)

    return float(mean), float(mean - h), float(mean + h)


def bootstrap_confidence_interval(
    data: np.ndarray,
    statistic: callable = np.mean,
    confidence: float = 0.95,
    n_bootstrap: int = 1000,
    random_state: Optional[int] = None
) -> Tuple[float, float, float]:
    """
    Compute bootstrap confidence interval.

    Args:
        data: Array of values
        statistic: Function to compute statistic
        confidence: Confidence level
        n_bootstrap: Number of bootstrap samples
        random_state: Random seed

    Returns:
        (statistic_value, lower_bound, upper_bound)
    """
    rng = np.random.default_rng(random_state)
    data = np.asarray(data)
    data = data[np.isfinite(data)]

    if len(data) == 0:
        return float('nan'), float('nan'), float('nan')

    # Original statistic
    original = statistic(data)

    # Bootstrap samples
    bootstrap_stats = []
    for _ in range(n_bootstrap):
        sample = rng.choice(data, size=len(data), replace=True)
        bootstrap_stats.append(statistic(sample))

    bootstrap_stats = np.array(bootstrap_stats)

    # Percentile interval
    alpha = (1 - confidence) / 2
    lower = np.percentile(bootstrap_stats, 100 * alpha)
    upper = np.percentile(bootstrap_stats, 100 * (1 - alpha))

    return float(original), float(lower), float(upper)


def linear_regression(
    x: np.ndarray,
    y: np.ndarray
) -> Dict[str, float]:
    """
    Perform simple linear regression y = a + b*x.

    Args:
        x: Independent variable
        y: Dependent variable

    Returns:
        Dict with slope, intercept, r_squared, p_value, std_error
    """
    x = np.asarray(x).flatten()
    y = np.asarray(y).flatten()

    # Remove NaN values
    valid = np.isfinite(x) & np.isfinite(y)
    x = x[valid]
    y = y[valid]

    if len(x) < 3:
        return {
            'slope': float('nan'),
            'intercept': float('nan'),
            'r_squared': float('nan'),
            'p_value': 1.0,
            'std_error': float('nan'),
        }

    slope, intercept, r_value, p_value, std_err = stats.linregress(x, y)

    return {
        'slope': float(slope),
        'intercept': float(intercept),
        'r_squared': float(r_value ** 2),
        'r_value': float(r_value),
        'p_value': float(p_value),
        'std_error': float(std_err),
    }


def log_linear_regression(
    x: np.ndarray,
    y: np.ndarray
) -> Dict[str, float]:
    """
    Perform log-linear regression log(y) = a + b*log(x).

    Useful for power-law relationships.

    Args:
        x: Independent variable
        y: Dependent variable

    Returns:
        Dict with exponent, coefficient, r_squared, etc.
    """
    x = np.asarray(x).flatten()
    y = np.asarray(y).flatten()

    # Filter positive values for log
    valid = (x > 0) & (y > 0) & np.isfinite(x) & np.isfinite(y)
    x = x[valid]
    y = y[valid]

    if len(x) < 3:
        return {
            'exponent': float('nan'),
            'coefficient': float('nan'),
            'r_squared': float('nan'),
            'p_value': 1.0,
        }

    log_x = np.log(x)
    log_y = np.log(y)

    result = linear_regression(log_x, log_y)

    return {
        'exponent': result['slope'],
        'coefficient': np.exp(result['intercept']),
        'r_squared': result['r_squared'],
        'p_value': result['p_value'],
    }


def ttest_independent(
    group1: np.ndarray,
    group2: np.ndarray
) -> Tuple[float, float]:
    """
    Perform independent samples t-test.

    Args:
        group1: First group
        group2: Second group

    Returns:
        (t_statistic, p_value)
    """
    group1 = np.asarray(group1)
    group2 = np.asarray(group2)

    group1 = group1[np.isfinite(group1)]
    group2 = group2[np.isfinite(group2)]

    if len(group1) < 2 or len(group2) < 2:
        return float('nan'), 1.0

    t_stat, p_val = stats.ttest_ind(group1, group2)
    return float(t_stat), float(p_val)


def mann_whitney_u(
    group1: np.ndarray,
    group2: np.ndarray
) -> Tuple[float, float]:
    """
    Perform Mann-Whitney U test (non-parametric).

    Args:
        group1: First group
        group2: Second group

    Returns:
        (u_statistic, p_value)
    """
    group1 = np.asarray(group1)
    group2 = np.asarray(group2)

    group1 = group1[np.isfinite(group1)]
    group2 = group2[np.isfinite(group2)]

    if len(group1) < 2 or len(group2) < 2:
        return float('nan'), 1.0

    u_stat, p_val = stats.mannwhitneyu(group1, group2, alternative='two-sided')
    return float(u_stat), float(p_val)


def effect_size_cohens_d(
    group1: np.ndarray,
    group2: np.ndarray
) -> float:
    """
    Compute Cohen's d effect size.

    Args:
        group1: First group
        group2: Second group

    Returns:
        Cohen's d value
    """
    group1 = np.asarray(group1)
    group2 = np.asarray(group2)

    group1 = group1[np.isfinite(group1)]
    group2 = group2[np.isfinite(group2)]

    if len(group1) < 2 or len(group2) < 2:
        return float('nan')

    mean1, mean2 = np.mean(group1), np.mean(group2)
    var1, var2 = np.var(group1, ddof=1), np.var(group2, ddof=1)
    n1, n2 = len(group1), len(group2)

    # Pooled standard deviation
    pooled_std = np.sqrt(((n1 - 1) * var1 + (n2 - 1) * var2) / (n1 + n2 - 2))

    if pooled_std == 0:
        return float('nan')

    return float((mean1 - mean2) / pooled_std)


def summary_statistics(data: np.ndarray) -> Dict[str, float]:
    """
    Compute summary statistics for a dataset.

    Args:
        data: Array of values

    Returns:
        Dict with mean, std, median, min, max, etc.
    """
    data = np.asarray(data)
    data = data[np.isfinite(data)]

    if len(data) == 0:
        return {
            'mean': float('nan'),
            'std': float('nan'),
            'median': float('nan'),
            'min': float('nan'),
            'max': float('nan'),
            'n': 0,
        }

    return {
        'mean': float(np.mean(data)),
        'std': float(np.std(data, ddof=1)),
        'median': float(np.median(data)),
        'min': float(np.min(data)),
        'max': float(np.max(data)),
        'q25': float(np.percentile(data, 25)),
        'q75': float(np.percentile(data, 75)),
        'n': len(data),
    }
