"""
Classification task runner for cost-sensitive learning experiments.

This module provides a unified interface for running classification experiments
with different models, weighting strategies, and evaluation metrics.

Supports:
- Multiple model types (TF-IDF, embeddings, tabular)
- Weighting strategies: unweighted, |delta|-weighted, alpha-balanced
- Train/val/test splits with reproducible seeding
- Comprehensive metrics: accuracy, weighted accuracy, expected cost
"""

from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union, Any
import numpy as np
import pandas as pd

from models.base import BaseModel
from core.metrics import classification_metrics
from core.weights import make_absdelta_weights, make_alpha_balanced_weights
from core.seed import set_seed


@dataclass
class ClassifyResult:
    """
    Results from a classification experiment.

    Attributes:
        model_name: Name/identifier of the model
        weighting: Weighting strategy used ('none', 'absdelta', 'alpha_balanced')
        seed: Random seed used
        train_metrics: Metrics computed on training set
        val_metrics: Metrics computed on validation set (if provided)
        test_metrics: Metrics computed on test set
        config: Full configuration dict for reproducibility
    """
    model_name: str
    weighting: str
    seed: int
    train_metrics: Dict[str, float]
    val_metrics: Optional[Dict[str, float]]
    test_metrics: Dict[str, float]
    config: Dict[str, Any] = field(default_factory=dict)

    def to_dict(self) -> Dict[str, Any]:
        """Convert to flat dictionary for easy DataFrame creation."""
        result = {
            'model_name': self.model_name,
            'weighting': self.weighting,
            'seed': self.seed,
        }

        # Add train metrics with prefix
        for k, v in self.train_metrics.items():
            result[f'train_{k}'] = v

        # Add val metrics with prefix (if present)
        if self.val_metrics:
            for k, v in self.val_metrics.items():
                result[f'val_{k}'] = v

        # Add test metrics with prefix
        for k, v in self.test_metrics.items():
            result[f'test_{k}'] = v

        return result


def run_classification(
    model: BaseModel,
    X_train: Union[np.ndarray, List[str], pd.DataFrame],
    y_train: np.ndarray,
    X_test: Union[np.ndarray, List[str], pd.DataFrame],
    y_test: np.ndarray,
    delta_train: np.ndarray,
    delta_test: np.ndarray,
    X_val: Optional[Union[np.ndarray, List[str], pd.DataFrame]] = None,
    y_val: Optional[np.ndarray] = None,
    delta_val: Optional[np.ndarray] = None,
    weighting: str = 'none',
    seed: int = 42,
    model_name: Optional[str] = None,
    # Indices for pre-computed embedding lookup
    train_indices: Optional[np.ndarray] = None,
    val_indices: Optional[np.ndarray] = None,
    test_indices: Optional[np.ndarray] = None,
) -> ClassifyResult:
    """
    Run a single classification experiment.

    Args:
        model: A BaseModel instance (must have task='classification')
        X_train: Training features (texts, embeddings, or DataFrame)
        y_train: Training labels (binary 0/1)
        X_test: Test features
        y_test: Test labels
        delta_train: Signed delta values for training set (for weighting)
        delta_test: Signed delta values for test set (for evaluation)
        X_val: Optional validation features
        y_val: Optional validation labels
        delta_val: Optional signed delta values for validation
        weighting: Weighting strategy - 'none', 'absdelta', or 'alpha_balanced'
        seed: Random seed for reproducibility
        model_name: Optional name for the model (defaults to class name)
        train_indices: Optional indices into pre-computed embeddings for train
        val_indices: Optional indices into pre-computed embeddings for val
        test_indices: Optional indices into pre-computed embeddings for test

    Returns:
        ClassifyResult with metrics on train/val/test sets

    Raises:
        ValueError: If model task is not 'classification' or invalid weighting
    """
    if model.task != 'classification':
        raise ValueError(f"Expected classification model, got task='{model.task}'")

    if weighting not in ('none', 'absdelta', 'alpha_balanced'):
        raise ValueError(f"Invalid weighting: {weighting}. Must be 'none', 'absdelta', or 'alpha_balanced'")

    # Set seed for reproducibility
    set_seed(seed)

    # Compute training weights based on strategy
    if weighting == 'none':
        sample_weight = None
    elif weighting == 'absdelta':
        # Use |delta| as weights
        sample_weight = np.abs(delta_train).astype(np.float32)
        # Add small epsilon to avoid zero weights
        sample_weight = np.clip(sample_weight, 1e-6, None)
    elif weighting == 'alpha_balanced':
        # Create DataFrame for weight computation
        df_temp = pd.DataFrame({'delta_signed': delta_train})
        # normalize=True to stabilize sklearn solvers with extreme weight ranges
        sample_weight = make_alpha_balanced_weights(df_temp, delta_col='delta_signed', normalize=True)

    # Fit model (pass indices for pre-computed embedding lookup if model supports it)
    fit_kwargs = {'sample_weight': sample_weight}
    if train_indices is not None and hasattr(model, 'fit') and 'indices' in model.fit.__code__.co_varnames:
        fit_kwargs['indices'] = train_indices
    model.fit(X_train, y_train, **fit_kwargs)

    # Compute predictions (pass indices for pre-computed embedding lookup if model supports it)
    if train_indices is not None and hasattr(model, 'predict') and 'indices' in model.predict.__code__.co_varnames:
        y_train_pred = model.predict(X_train, indices=train_indices)
    else:
        y_train_pred = model.predict(X_train)

    if test_indices is not None and hasattr(model, 'predict') and 'indices' in model.predict.__code__.co_varnames:
        y_test_pred = model.predict(X_test, indices=test_indices)
    else:
        y_test_pred = model.predict(X_test)

    # Compute metrics
    # For evaluation, use weights consistent with training strategy:
    # - 'none' or 'absdelta': use raw |delta|
    # - 'alpha_balanced': use delta-scaled weights (scaling factor from training set)
    if weighting == 'alpha_balanced':
        # Compute scaling factor from training set, apply to all splits
        df_train = pd.DataFrame({'delta_signed': delta_train})
        eval_weights_train, alpha = make_alpha_balanced_weights(
            df_train, delta_col='delta_signed', normalize=False, return_alpha=True
        )
        # Apply same alpha to test set
        abs_delta_test = np.abs(delta_test).astype(np.float32)
        eval_weights_test = abs_delta_test.copy()
        eval_weights_test[delta_test > 0] *= alpha
    else:
        # Use raw |delta| as weights (no epsilon, no normalization)
        eval_weights_train = np.abs(delta_train).astype(np.float32)
        eval_weights_test = np.abs(delta_test).astype(np.float32)
        alpha = None  # Not used

    train_metrics = classification_metrics(y_train, y_train_pred, weights=eval_weights_train)
    test_metrics = classification_metrics(y_test, y_test_pred, weights=eval_weights_test)

    # Validation metrics (if provided)
    val_metrics = None
    if X_val is not None and y_val is not None and delta_val is not None:
        if val_indices is not None and hasattr(model, 'predict') and 'indices' in model.predict.__code__.co_varnames:
            y_val_pred = model.predict(X_val, indices=val_indices)
        else:
            y_val_pred = model.predict(X_val)
        if weighting == 'alpha_balanced':
            # Apply same alpha from training to validation
            abs_delta_val = np.abs(delta_val).astype(np.float32)
            eval_weights_val = abs_delta_val.copy()
            eval_weights_val[delta_val > 0] *= alpha
        else:
            eval_weights_val = np.abs(delta_val).astype(np.float32)
        val_metrics = classification_metrics(y_val, y_val_pred, weights=eval_weights_val)

    # Build result
    if model_name is None:
        model_name = model.__class__.__name__

    config = {
        'model_params': model.get_params(),
        'weighting': weighting,
        'seed': seed,
        'n_train': len(y_train),
        'n_test': len(y_test),
        'n_val': len(y_val) if y_val is not None else 0,
    }

    return ClassifyResult(
        model_name=model_name,
        weighting=weighting,
        seed=seed,
        train_metrics=train_metrics,
        val_metrics=val_metrics,
        test_metrics=test_metrics,
        config=config,
    )


def run_classification_sweep(
    model_factory,
    X_train: Union[np.ndarray, List[str], pd.DataFrame],
    y_train: np.ndarray,
    X_test: Union[np.ndarray, List[str], pd.DataFrame],
    y_test: np.ndarray,
    delta_train: np.ndarray,
    delta_test: np.ndarray,
    X_val: Optional[Union[np.ndarray, List[str], pd.DataFrame]] = None,
    y_val: Optional[np.ndarray] = None,
    delta_val: Optional[np.ndarray] = None,
    weightings: List[str] = ['none', 'absdelta', 'alpha_balanced'],
    seeds: List[int] = [0, 1, 2],
    model_name: Optional[str] = None,
) -> List[ClassifyResult]:
    """
    Run classification experiments across multiple seeds and weighting strategies.

    Args:
        model_factory: Callable that returns a fresh BaseModel instance
        X_train, y_train, X_test, y_test: Train/test data
        delta_train, delta_test: Signed delta values
        X_val, y_val, delta_val: Optional validation data
        weightings: List of weighting strategies to try
        seeds: List of random seeds to try
        model_name: Optional name for results

    Returns:
        List of ClassifyResult, one per (weighting, seed) combination
    """
    results = []

    for weighting in weightings:
        for seed in seeds:
            # Create fresh model instance
            model = model_factory()

            result = run_classification(
                model=model,
                X_train=X_train,
                y_train=y_train,
                X_test=X_test,
                y_test=y_test,
                delta_train=delta_train,
                delta_test=delta_test,
                X_val=X_val,
                y_val=y_val,
                delta_val=delta_val,
                weighting=weighting,
                seed=seed,
                model_name=model_name,
            )
            results.append(result)

    return results


def results_to_dataframe(results: List[ClassifyResult]) -> pd.DataFrame:
    """
    Convert list of ClassifyResult to a pandas DataFrame.

    Args:
        results: List of ClassifyResult objects

    Returns:
        DataFrame with one row per result, columns for all metrics
    """
    rows = [r.to_dict() for r in results]
    return pd.DataFrame(rows)


def summarize_results(
    results: List[ClassifyResult],
    group_by: List[str] = ['model_name', 'weighting'],
    metrics: List[str] = ['test_accuracy', 'test_weighted_accuracy', 'test_expected_cost'],
) -> pd.DataFrame:
    """
    Summarize results across seeds, computing mean and std for each metric.

    Args:
        results: List of ClassifyResult objects
        group_by: Columns to group by (default: model_name, weighting)
        metrics: Metrics to summarize

    Returns:
        DataFrame with mean and std for each metric, grouped as specified
    """
    df = results_to_dataframe(results)

    # Filter to requested metrics that exist
    available_metrics = [m for m in metrics if m in df.columns]

    # Group and aggregate
    agg_dict = {}
    for m in available_metrics:
        agg_dict[f'{m}_mean'] = (m, 'mean')
        agg_dict[f'{m}_std'] = (m, 'std')
        agg_dict[f'{m}_n'] = (m, 'count')

    summary = df.groupby(group_by).agg(**agg_dict).reset_index()
    return summary
