"""Regression prediction metrics for surrogate model evaluation.

This module provides metrics for evaluating property prediction performance
on held-out test sets, including R-squared, MSE, MAE, and RMSE.
"""

from __future__ import annotations

from dataclasses import dataclass, asdict
from typing import Dict, List, Sequence

import numpy as np


@dataclass
class RegressionMetrics:
    """Regression metrics for property prediction evaluation.

    Attributes:
        property_names: Names of the properties being predicted
        n_samples: Number of samples evaluated
        mse: Mean squared error per property
        rmse: Root mean squared error per property
        mae: Mean absolute error per property
        r2: R-squared (coefficient of determination) per property
        mse_overall: Overall MSE averaged across properties
        rmse_overall: Overall RMSE averaged across properties
        mae_overall: Overall MAE averaged across properties
        r2_overall: Overall R^2 averaged across properties
    """

    property_names: List[str]
    n_samples: int
    mse: List[float]
    rmse: List[float]
    mae: List[float]
    r2: List[float]
    mse_overall: float
    rmse_overall: float
    mae_overall: float
    r2_overall: float

    def to_dict(self) -> Dict:
        """Convert to dictionary for JSON serialization."""
        return asdict(self)

    def summary(self) -> str:
        """Return a formatted summary string."""
        lines = [
            f"Regression Metrics (n={self.n_samples})",
            "-" * 40,
        ]
        for i, name in enumerate(self.property_names):
            lines.append(
                f"  {name}: R2={self.r2[i]:.4f}, "
                f"MSE={self.mse[i]:.4f}, RMSE={self.rmse[i]:.4f}, MAE={self.mae[i]:.4f}"
            )
        lines.append("-" * 40)
        lines.append(
            f"  Overall: R2={self.r2_overall:.4f}, "
            f"MSE={self.mse_overall:.4f}, RMSE={self.rmse_overall:.4f}, MAE={self.mae_overall:.4f}"
        )
        return "\n".join(lines)


def compute_r2(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    """Compute R-squared (coefficient of determination).

    Args:
        y_true: Ground truth values
        y_pred: Predicted values

    Returns:
        R-squared value. Returns 0.0 if total variance is zero.
    """
    ss_res = np.sum((y_true - y_pred) ** 2)
    ss_tot = np.sum((y_true - np.mean(y_true)) ** 2)
    if ss_tot < 1e-12:
        return 0.0
    return float(1.0 - (ss_res / ss_tot))


def compute_regression_metrics(
    y_true: np.ndarray,
    y_pred: np.ndarray,
    property_names: Sequence[str] | None = None,
) -> RegressionMetrics:
    """Compute comprehensive regression metrics for property prediction.

    Args:
        y_true: Ground truth values of shape (N, n_props) or (N,)
        y_pred: Predicted values of shape (N, n_props) or (N,)
        property_names: Names of properties. If None, uses "prop_0", "prop_1", etc.

    Returns:
        RegressionMetrics dataclass with per-property and overall metrics

    Raises:
        ValueError: If shapes don't match or inputs are empty
    """
    y_true = np.asarray(y_true, dtype=np.float64)
    y_pred = np.asarray(y_pred, dtype=np.float64)

    # Handle 1D case
    if y_true.ndim == 1:
        y_true = y_true[:, np.newaxis]
    if y_pred.ndim == 1:
        y_pred = y_pred[:, np.newaxis]

    if y_true.shape != y_pred.shape:
        raise ValueError(f"Shape mismatch: y_true={y_true.shape}, y_pred={y_pred.shape}")

    n_samples, n_props = y_true.shape

    if n_samples == 0:
        raise ValueError("Cannot compute metrics on empty arrays")

    # Generate property names if not provided
    if property_names is None:
        property_names = [f"prop_{i}" for i in range(n_props)]
    else:
        property_names = list(property_names)
        if len(property_names) != n_props:
            raise ValueError(
                f"property_names length ({len(property_names)}) != n_props ({n_props})"
            )

    # Compute per-property metrics
    mse_list: List[float] = []
    rmse_list: List[float] = []
    mae_list: List[float] = []
    r2_list: List[float] = []

    for i in range(n_props):
        y_t = y_true[:, i]
        y_p = y_pred[:, i]

        # Filter NaN values
        valid_mask = ~(np.isnan(y_t) | np.isnan(y_p))
        y_t_valid = y_t[valid_mask]
        y_p_valid = y_p[valid_mask]

        if len(y_t_valid) == 0:
            mse_list.append(float("nan"))
            rmse_list.append(float("nan"))
            mae_list.append(float("nan"))
            r2_list.append(float("nan"))
            continue

        errors = y_t_valid - y_p_valid
        mse = float(np.mean(errors**2))
        rmse = float(np.sqrt(mse))
        mae = float(np.mean(np.abs(errors)))
        r2 = compute_r2(y_t_valid, y_p_valid)

        mse_list.append(mse)
        rmse_list.append(rmse)
        mae_list.append(mae)
        r2_list.append(r2)

    # Compute overall metrics (mean across properties, ignoring NaN)
    valid_mse = [v for v in mse_list if not np.isnan(v)]
    valid_rmse = [v for v in rmse_list if not np.isnan(v)]
    valid_mae = [v for v in mae_list if not np.isnan(v)]
    valid_r2 = [v for v in r2_list if not np.isnan(v)]

    mse_overall = float(np.mean(valid_mse)) if valid_mse else float("nan")
    rmse_overall = float(np.mean(valid_rmse)) if valid_rmse else float("nan")
    mae_overall = float(np.mean(valid_mae)) if valid_mae else float("nan")
    r2_overall = float(np.mean(valid_r2)) if valid_r2 else float("nan")

    return RegressionMetrics(
        property_names=property_names,
        n_samples=n_samples,
        mse=mse_list,
        rmse=rmse_list,
        mae=mae_list,
        r2=r2_list,
        mse_overall=mse_overall,
        rmse_overall=rmse_overall,
        mae_overall=mae_overall,
        r2_overall=r2_overall,
    )


def compute_regression_metrics_per_split(
    y_true_dict: Dict[str, np.ndarray],
    y_pred_dict: Dict[str, np.ndarray],
    property_names: Sequence[str] | None = None,
) -> Dict[str, RegressionMetrics]:
    """Compute regression metrics for multiple data splits.

    Args:
        y_true_dict: Dict mapping split name -> ground truth array
        y_pred_dict: Dict mapping split name -> prediction array
        property_names: Names of properties

    Returns:
        Dict mapping split name -> RegressionMetrics
    """
    results = {}
    for split_name in y_true_dict:
        if split_name not in y_pred_dict:
            continue
        results[split_name] = compute_regression_metrics(
            y_true_dict[split_name],
            y_pred_dict[split_name],
            property_names,
        )
    return results
