"""Oracle evaluation utilities for comparing surrogate predictions to ground truth.

This module provides functions for:
- Computing oracle (RDKit-based) properties for generated/optimized molecules
- Comparing surrogate model predictions against oracle values
- Generating comparison plots and reports
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, List, Sequence

import numpy as np
import matplotlib.pyplot as plt

from moltenflow.data.properties import compute_properties_batch
from moltenflow.eval.prediction_metrics import compute_regression_metrics, RegressionMetrics
from moltenflow.utils.logging import get_logger

logger = get_logger(__name__)


@dataclass
class OracleEvalResult:
    """Results from oracle property evaluation.

    Attributes:
        smiles: List of SMILES strings evaluated
        oracle_properties: Oracle-computed property values (N, n_props)
        valid_mask: Boolean mask indicating valid molecules
        property_names: Names of the properties
        n_total: Total number of molecules
        n_valid: Number of valid molecules
        valid_frac: Fraction of valid molecules
    """

    smiles: List[str]
    oracle_properties: np.ndarray
    valid_mask: np.ndarray
    property_names: List[str]
    n_total: int
    n_valid: int
    valid_frac: float

    def to_dict(self) -> Dict:
        """Convert to dictionary for JSON serialization."""
        return {
            "n_total": self.n_total,
            "n_valid": self.n_valid,
            "valid_frac": self.valid_frac,
            "property_names": self.property_names,
        }


@dataclass
class SurrogateOracleComparison:
    """Comparison metrics between surrogate predictions and oracle values.

    Attributes:
        regression_metrics: Per-property regression metrics (R2, MAE, RMSE, MSE)
        n_compared: Number of samples compared (valid molecules only)
        property_names: Names of the properties
    """

    regression_metrics: RegressionMetrics
    n_compared: int
    property_names: List[str]

    def to_dict(self) -> Dict:
        """Convert to dictionary for JSON serialization."""
        return {
            "n_compared": self.n_compared,
            "property_names": self.property_names,
            "regression_metrics": self.regression_metrics.to_dict(),
        }


def compute_oracle_properties(
    smiles_list: Sequence[str],
    property_names: Sequence[str],
) -> OracleEvalResult:
    """Compute oracle (RDKit) properties for a list of SMILES.

    Args:
        smiles_list: List of SMILES strings to evaluate
        property_names: List of property names to compute (e.g., ["qed", "sas", "plogp"])

    Returns:
        OracleEvalResult with computed properties and validity information
    """
    property_names = list(property_names)
    smiles_list = list(smiles_list)
    n_total = len(smiles_list)

    if n_total == 0:
        return OracleEvalResult(
            smiles=[],
            oracle_properties=np.array([]).reshape(0, len(property_names)),
            valid_mask=np.array([], dtype=bool),
            property_names=property_names,
            n_total=0,
            n_valid=0,
            valid_frac=0.0,
        )

    # Compute properties using the oracle
    props, valid_mask = compute_properties_batch(
        smiles_list, property_names, return_valid_mask=True
    )

    n_valid = int(valid_mask.sum())
    valid_frac = n_valid / n_total if n_total > 0 else 0.0

    logger.info(f"Oracle evaluation: {n_valid}/{n_total} valid molecules ({valid_frac:.1%})")

    return OracleEvalResult(
        smiles=smiles_list,
        oracle_properties=props,
        valid_mask=valid_mask,
        property_names=property_names,
        n_total=n_total,
        n_valid=n_valid,
        valid_frac=valid_frac,
    )


def compare_surrogate_to_oracle(
    surrogate_predictions: np.ndarray,
    oracle_properties: np.ndarray,
    property_names: Sequence[str],
    valid_mask: np.ndarray | None = None,
) -> SurrogateOracleComparison:
    """Compare surrogate model predictions against oracle values.

    Args:
        surrogate_predictions: Surrogate-predicted properties (N, n_props)
        oracle_properties: Oracle-computed properties (N, n_props)
        property_names: Names of the properties
        valid_mask: Optional mask for valid molecules. If None, uses non-NaN oracle values.

    Returns:
        SurrogateOracleComparison with regression metrics
    """
    surrogate_predictions = np.asarray(surrogate_predictions, dtype=np.float64)
    oracle_properties = np.asarray(oracle_properties, dtype=np.float64)
    property_names = list(property_names)

    # Determine valid samples (non-NaN in oracle)
    if valid_mask is None:
        valid_mask = ~np.any(np.isnan(oracle_properties), axis=1)
    else:
        valid_mask = np.asarray(valid_mask, dtype=bool)
        # Also exclude NaN oracle values
        valid_mask = valid_mask & ~np.any(np.isnan(oracle_properties), axis=1)

    n_compared = int(valid_mask.sum())

    if n_compared == 0:
        logger.warning("No valid samples for surrogate vs oracle comparison")
        # Return empty metrics
        n_props = len(property_names)
        empty_metrics = RegressionMetrics(
            property_names=property_names,
            n_samples=0,
            mse=[float("nan")] * n_props,
            rmse=[float("nan")] * n_props,
            mae=[float("nan")] * n_props,
            r2=[float("nan")] * n_props,
            mse_overall=float("nan"),
            rmse_overall=float("nan"),
            mae_overall=float("nan"),
            r2_overall=float("nan"),
        )
        return SurrogateOracleComparison(
            regression_metrics=empty_metrics,
            n_compared=0,
            property_names=property_names,
        )

    # Filter to valid samples
    pred_valid = surrogate_predictions[valid_mask]
    oracle_valid = oracle_properties[valid_mask]

    # Compute regression metrics
    metrics = compute_regression_metrics(
        y_true=oracle_valid,
        y_pred=pred_valid,
        property_names=property_names,
    )

    logger.info(f"Surrogate vs Oracle comparison ({n_compared} samples):")
    for i, prop in enumerate(property_names):
        logger.info(
            f"  {prop}: R2={metrics.r2[i]:.3f}, MAE={metrics.mae[i]:.4f}, RMSE={metrics.rmse[i]:.4f}"
        )

    return SurrogateOracleComparison(
        regression_metrics=metrics,
        n_compared=n_compared,
        property_names=property_names,
    )


def plot_surrogate_vs_oracle(
    surrogate_predictions: np.ndarray,
    oracle_properties: np.ndarray,
    property_names: Sequence[str],
    save_path: str,
    valid_mask: np.ndarray | None = None,
    figsize: tuple[int, int] | None = None,
    title: str = "Surrogate vs Oracle Properties",
) -> None:
    """Generate scatter plots comparing surrogate predictions to oracle values.

    Creates a grid of scatter plots (one per property) with:
    - Points colored by density
    - Diagonal reference line (y=x)
    - R^2 annotation

    Args:
        surrogate_predictions: Surrogate-predicted properties (N, n_props)
        oracle_properties: Oracle-computed properties (N, n_props)
        property_names: Names of the properties
        save_path: Path to save the figure
        valid_mask: Optional mask for valid molecules
        figsize: Optional figure size. Default: auto-calculated based on n_props
        title: Figure title
    """
    surrogate_predictions = np.asarray(surrogate_predictions, dtype=np.float64)
    oracle_properties = np.asarray(oracle_properties, dtype=np.float64)
    property_names = list(property_names)
    n_props = len(property_names)

    # Determine valid samples
    if valid_mask is None:
        valid_mask = ~np.any(np.isnan(oracle_properties), axis=1)
    else:
        valid_mask = np.asarray(valid_mask, dtype=bool)
        valid_mask = valid_mask & ~np.any(np.isnan(oracle_properties), axis=1)

    pred_valid = surrogate_predictions[valid_mask]
    oracle_valid = oracle_properties[valid_mask]

    if len(pred_valid) == 0:
        logger.warning("No valid samples for plotting surrogate vs oracle")
        return

    # Calculate figure size
    if figsize is None:
        n_cols = min(n_props, 3)
        n_rows = (n_props + n_cols - 1) // n_cols
        figsize = (5 * n_cols, 5 * n_rows)

    n_cols = min(n_props, 3)
    n_rows = (n_props + n_cols - 1) // n_cols

    fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize, squeeze=False)
    axes = axes.flatten()

    for i, prop_name in enumerate(property_names):
        ax = axes[i]

        pred_i = pred_valid[:, i]
        oracle_i = oracle_valid[:, i]

        # Filter NaN for this property
        valid_i = ~(np.isnan(pred_i) | np.isnan(oracle_i))
        pred_i = pred_i[valid_i]
        oracle_i = oracle_i[valid_i]

        if len(pred_i) == 0:
            ax.text(0.5, 0.5, "No valid data", ha="center", va="center", transform=ax.transAxes)
            ax.set_title(prop_name)
            continue

        # Scatter plot
        ax.scatter(oracle_i, pred_i, alpha=0.5, s=10, edgecolors="none")

        # Diagonal line
        all_vals = np.concatenate([oracle_i, pred_i])
        lims = [all_vals.min(), all_vals.max()]
        margin = (lims[1] - lims[0]) * 0.05
        lims = [lims[0] - margin, lims[1] + margin]
        ax.plot(lims, lims, "k--", alpha=0.5, linewidth=1, label="y=x")
        ax.set_xlim(lims)
        ax.set_ylim(lims)

        # Compute R^2
        ss_res = np.sum((pred_i - oracle_i) ** 2)
        ss_tot = np.sum((oracle_i - oracle_i.mean()) ** 2)
        r2 = 1 - (ss_res / ss_tot) if ss_tot > 1e-10 else 0.0

        # Compute MAE
        mae = np.mean(np.abs(pred_i - oracle_i))

        ax.set_xlabel(f"Oracle {prop_name}")
        ax.set_ylabel(f"Surrogate {prop_name}")
        ax.set_title(f"{prop_name}\nR$^2$={r2:.3f}, MAE={mae:.3f}")
        ax.set_aspect("equal", adjustable="box")

    # Hide unused axes
    for i in range(n_props, len(axes)):
        axes[i].set_visible(False)

    fig.suptitle(title, fontsize=14, y=1.02)
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches="tight")
    plt.close(fig)

    logger.info(f"Saved surrogate vs oracle plot to {save_path}")


def evaluate_optimization_with_oracle(
    input_smiles: Sequence[str],
    output_smiles: Sequence[str],
    surrogate_predictions_before: np.ndarray,
    surrogate_predictions_after: np.ndarray,
    property_names: Sequence[str],
    output_dir: str | None = None,
) -> Dict:
    """Comprehensive oracle evaluation for optimization results.

    Computes:
    - Oracle properties for input and output molecules
    - Surrogate vs oracle comparison for both before and after
    - Property deltas (oracle-based improvement)

    Args:
        input_smiles: Input SMILES strings
        output_smiles: Optimized output SMILES strings
        surrogate_predictions_before: Surrogate predictions for inputs (N, n_props)
        surrogate_predictions_after: Surrogate predictions for outputs (N, n_props)
        property_names: Names of the properties
        output_dir: Optional directory to save plots

    Returns:
        Dictionary with oracle evaluation results
    """
    import os

    property_names = list(property_names)
    results = {}

    # Compute oracle properties for inputs
    logger.info("Computing oracle properties for input molecules...")
    input_oracle = compute_oracle_properties(input_smiles, property_names)
    results["input_oracle"] = input_oracle.to_dict()
    results["input_oracle_properties"] = input_oracle.oracle_properties.tolist()

    # Compute oracle properties for outputs
    logger.info("Computing oracle properties for output molecules...")
    output_oracle = compute_oracle_properties(output_smiles, property_names)
    results["output_oracle"] = output_oracle.to_dict()
    results["output_oracle_properties"] = output_oracle.oracle_properties.tolist()

    # Compare surrogate to oracle for inputs (before optimization)
    logger.info("Comparing surrogate predictions (before) to oracle...")
    comparison_before = compare_surrogate_to_oracle(
        surrogate_predictions_before,
        input_oracle.oracle_properties,
        property_names,
        valid_mask=input_oracle.valid_mask,
    )
    results["surrogate_oracle_comparison_before"] = comparison_before.to_dict()

    # Compare surrogate to oracle for outputs (after optimization)
    logger.info("Comparing surrogate predictions (after) to oracle...")
    comparison_after = compare_surrogate_to_oracle(
        surrogate_predictions_after,
        output_oracle.oracle_properties,
        property_names,
        valid_mask=output_oracle.valid_mask,
    )
    results["surrogate_oracle_comparison_after"] = comparison_after.to_dict()

    # Compute oracle-based deltas for valid output molecules
    valid_both = input_oracle.valid_mask & output_oracle.valid_mask
    n_valid_both = int(valid_both.sum())

    if n_valid_both > 0:
        input_props = input_oracle.oracle_properties[valid_both]
        output_props = output_oracle.oracle_properties[valid_both]
        deltas = output_props - input_props

        results["oracle_deltas"] = {
            "n_samples": n_valid_both,
            "mean_delta": deltas.mean(axis=0).tolist(),
            "std_delta": deltas.std(axis=0).tolist(),
            "property_names": property_names,
        }

        logger.info(f"Oracle-based property changes ({n_valid_both} valid pairs):")
        for i, prop in enumerate(property_names):
            mean_delta = deltas[:, i].mean()
            std_delta = deltas[:, i].std()
            logger.info(f"  {prop}: {mean_delta:+.4f} +/- {std_delta:.4f}")
    else:
        results["oracle_deltas"] = {
            "n_samples": 0,
            "mean_delta": [float("nan")] * len(property_names),
            "std_delta": [float("nan")] * len(property_names),
            "property_names": property_names,
        }

    # Generate plots if output directory provided
    if output_dir is not None:
        os.makedirs(output_dir, exist_ok=True)

        # Plot for before optimization
        plot_surrogate_vs_oracle(
            surrogate_predictions_before,
            input_oracle.oracle_properties,
            property_names,
            os.path.join(output_dir, "surrogate_vs_oracle_before.png"),
            valid_mask=input_oracle.valid_mask,
            title="Surrogate vs Oracle (Input Molecules)",
        )

        # Plot for after optimization
        plot_surrogate_vs_oracle(
            surrogate_predictions_after,
            output_oracle.oracle_properties,
            property_names,
            os.path.join(output_dir, "surrogate_vs_oracle_after.png"),
            valid_mask=output_oracle.valid_mask,
            title="Surrogate vs Oracle (Optimized Molecules)",
        )

    return results
