from typing import Dict, Any, Tuple
import numpy as np
from sklearn.model_selection import KFold, StratifiedKFold, cross_validate
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score, accuracy_score, f1_score, roc_auc_score, average_precision_score
from sklearn.isotonic import IsotonicRegression
from sklearn.pipeline import Pipeline
from scipy.stats import spearmanr
from sklearn.utils.class_weight import compute_sample_weight


def _fit_with_sample_weight(model, X, y, sample_weight=None):
    if sample_weight is None:
        model.fit(X, y)
        return
    try:
        if isinstance(model, Pipeline) and hasattr(model, "named_steps") and "est" in model.named_steps:
            model.fit(X, y, **{"est__sample_weight": sample_weight})
        else:
            model.fit(X, y, sample_weight=sample_weight)
    except TypeError:
        model.fit(X, y)


def _get_continuous_scores(model, X):
    if hasattr(model, "predict_proba"):
        proba = model.predict_proba(X)
        if proba is not None and len(proba.shape) == 2 and proba.shape[1] >= 2:
            return proba[:, 1]
    if hasattr(model, "decision_function"):
        return model.decision_function(X)
    # fallback to predict for regressors
    return model.predict(X)


def evaluate_regression(model, X, y, cv_splits=5, random_state=42, sample_weights=None, return_oof=False, oof_output_path=None) -> Dict[str, float]:
    kf = KFold(n_splits=cv_splits, shuffle=True, random_state=random_state)
    scores = {
        "rmse": [],
        "mae": [],
        "r2": [],
    }
    oof_preds_cal = []
    oof_preds_raw = []
    oof_y = []
    oof_fold = []
    for fold_idx, (train_idx, test_idx) in enumerate(kf.split(X)):
        X_train = [X[i] for i in train_idx]
        y_train = np.asarray([y[i] for i in train_idx])
        X_test = [X[i] for i in test_idx]
        y_test = np.asarray([y[i] for i in test_idx])
        sw_train = None if sample_weights is None else np.asarray([sample_weights[i] for i in train_idx])

        _fit_with_sample_weight(model, X_train, y_train, sample_weight=sw_train)

        preds_test = np.asarray(model.predict(X_test)).reshape(-1)

        # Isotonic calibration (train-fold) to better align predictions to ratings
        try:
            preds_train = np.asarray(model.predict(X_train)).reshape(-1)
            iso = IsotonicRegression(out_of_bounds="clip")
            try:
                iso.fit(preds_train, y_train, sample_weight=sw_train)
            except TypeError:
                iso.fit(preds_train, y_train)
            preds_test_cal = iso.transform(preds_test)
        except Exception:
            preds_test_cal = preds_test

        scores["rmse"].append(float(np.sqrt(mean_squared_error(y_test, preds_test_cal))))
        scores["mae"].append(float(mean_absolute_error(y_test, preds_test_cal)))
        scores["r2"].append(float(r2_score(y_test, preds_test_cal)))
        oof_preds_cal.extend(list(preds_test_cal))
        oof_preds_raw.extend(list(preds_test))
        oof_y.extend(list(y_test))
        oof_fold.extend([fold_idx] * len(y_test))
    results = {k: float(np.mean(v)) for k, v in scores.items()}
    if len(oof_preds_cal) > 1 and np.std(oof_preds_cal) > 0 and np.std(oof_y) > 0:
        arr_preds = np.asarray(oof_preds_cal)
        arr_y = np.asarray(oof_y)
        pearson = float(np.corrcoef(arr_preds, arr_y)[0, 1])
        results["corr_rating_pearson"] = pearson
        try:
            spear, _ = spearmanr(arr_preds, arr_y)
            results["corr_rating_spearman"] = float(spear)
        except Exception:
            pass
    if return_oof and oof_output_path:
        import csv
        with open(oof_output_path, "w", newline="", encoding="utf-8") as f:
            w = csv.writer(f)
            w.writerow(["fold", "y_true", "y_pred_raw", "y_pred_calibrated"])
            for fold, yt, pr, pc in zip(oof_fold, oof_y, oof_preds_raw, oof_preds_cal):
                w.writerow([fold, yt, pr, pc])
    return results


def evaluate_classification(model, X, y, cv_splits=5, random_state=42, sample_weights=None, return_oof=False, oof_output_path=None) -> Dict[str, float]:
    y = np.asarray(y)
    y_bin = (y > 3.0).astype(int)
    skf = StratifiedKFold(n_splits=cv_splits, shuffle=True, random_state=random_state)
    scores = {
        "accuracy": [],
        "f1": [],
        "roc_auc": [],
        "pr_auc": [],
    }
    oof_scores_cal = []
    oof_scores_raw = []
    oof_y = []
    oof_fold = []
    zeros = np.zeros_like(y_bin)
    for fold_idx, (train_idx, test_idx) in enumerate(skf.split(zeros, y_bin)):
        X_train = [X[i] for i in train_idx]
        y_train_bin = y_bin[train_idx]
        y_train_original = y[train_idx]
        X_test = [X[i] for i in test_idx]
        y_test_bin = y_bin[test_idx]
        y_test_original = y[test_idx]
        sw_train = None if sample_weights is None else np.asarray([sample_weights[i] for i in train_idx])

        # Combine balanced class weights with provided sample weights (e.g., rating_count-based)
        # Avoid double-balancing if the estimator already uses class_weight internally
        est = model
        if isinstance(model, Pipeline) and hasattr(model, "named_steps") and "est" in model.named_steps:
            est = model.named_steps["est"]
        has_internal_cw = hasattr(est, "class_weight") and getattr(est, "class_weight", None) not in (None, {})
        try:
            cw = None if has_internal_cw else compute_sample_weight(class_weight="balanced", y=y_train_bin)
            if sw_train is not None and cw is not None:
                sw_combined = cw * sw_train
            elif sw_train is not None:
                sw_combined = sw_train
            else:
                sw_combined = cw
        except Exception:
            sw_combined = sw_train

        _fit_with_sample_weight(model, X_train, y_train_bin, sample_weight=sw_combined)

        proba_test = None
        scores_continuous_test_raw = None
        if hasattr(model, "predict_proba"):
            proba_test = model.predict_proba(X_test)[:, 1]
            scores_continuous_test_raw = proba_test
        elif hasattr(model, "decision_function"):
            scores_continuous_test_raw = model.decision_function(X_test)

        preds_test = model.predict(X_test)
        scores["accuracy"].append(float(accuracy_score(y_test_bin, preds_test)))
        scores["f1"].append(float(f1_score(y_test_bin, preds_test)))
        if proba_test is not None:
            scores["roc_auc"].append(float(roc_auc_score(y_test_bin, proba_test)))
            scores["pr_auc"].append(float(average_precision_score(y_test_bin, proba_test)))

        if scores_continuous_test_raw is not None:
            try:
                scores_continuous_train = _get_continuous_scores(model, X_train)
                scores_continuous_train = np.asarray(scores_continuous_train).reshape(-1)
                iso = IsotonicRegression(out_of_bounds="clip")
                try:
                    iso.fit(scores_continuous_train, y_train_original, sample_weight=sw_train)
                except TypeError:
                    iso.fit(scores_continuous_train, y_train_original)
                scores_continuous_test_cal = iso.transform(np.asarray(scores_continuous_test_raw).reshape(-1))
            except Exception:
                scores_continuous_test_cal = np.asarray(scores_continuous_test_raw).reshape(-1)
            oof_scores_raw.extend(list(np.asarray(scores_continuous_test_raw).reshape(-1)))
            oof_scores_cal.extend(list(np.asarray(scores_continuous_test_cal).reshape(-1)))
            oof_y.extend(list(np.asarray(y_test_original).reshape(-1)))
            oof_fold.extend([fold_idx] * len(y_test_original))
    results = {k: float(np.mean(v)) for k, v in scores.items() if len(v) > 0}
    if len(oof_scores_cal) > 1 and np.std(oof_scores_cal) > 0 and np.std(oof_y) > 0:
        arr_scores = np.asarray(oof_scores_cal)
        arr_y = np.asarray(oof_y)
        pearson = float(np.corrcoef(arr_scores, arr_y)[0, 1])
        results["corr_rating_pearson"] = pearson
        try:
            spear, _ = spearmanr(arr_scores, arr_y)
            results["corr_rating_spearman"] = float(spear)
        except Exception:
            pass
    if return_oof and oof_output_path:
        import csv
        with open(oof_output_path, "w", newline="", encoding="utf-8") as f:
            w = csv.writer(f)
            w.writerow(["fold", "y_true", "score_raw", "score_calibrated"])
            for fold, yt, sr, sc in zip(oof_fold, oof_y, oof_scores_raw, oof_scores_cal):
                w.writerow([fold, yt, sr, sc])
    return results
