from joblib import Parallel, delayed
import re
from pathlib import Path

import numpy as np
import pandas as pd
from pyriemann.tangentspace import TangentSpace
from sklearn.dummy import DummyClassifier
from sklearn.linear_model import LogisticRegressionCV
from sklearn.svm import SVC
from sklearn.metrics import f1_score, precision_score, recall_score, roc_auc_score
from sklearn.pipeline import make_pipeline

from EvaGeM.distribution_based import alpha_precision, beta_recall
from prdc import compute_prdc



def compute_quality_metrics(X_real, y_real, X_fake, y_fake):
    X_real_flat = X_real.reshape(len(X_real), -1)
    X_fake_flat = X_fake.reshape(len(X_fake), -1)

    if len(X_real_flat) != len(X_fake_flat):
        min_len = min(len(X_real_flat), len(X_fake_flat))
        X_real_flat = X_real_flat[:min_len]
        X_fake_flat = X_fake_flat[:min_len]
        y_real = y_real[:min_len]
        y_fake = y_fake[:min_len]

    prdc_metrics = compute_prdc(
        real_features=X_real_flat, fake_features=X_fake_flat, nearest_k=10
    )
    quality_metrics = {
        "Precision": prdc_metrics["precision"],
        "Recall": prdc_metrics["recall"],
        "Density": prdc_metrics["density"],
        "Coverage": prdc_metrics["coverage"],
        r"$\alpha$-precision": alpha_precision(
            X_real_flat, X_fake_flat, plot_curve=False
        ),
        r"$\beta$-recall": beta_recall(X_real_flat, X_fake_flat, plot_curve=False),
    }

    return quality_metrics


def compute_classification_metric(X_real, y_real, X_fake, y_fake, clf):
    if clf == "SVC":
        clf = SVC(
            kernel="rbf",
            C=1,
            probability=True,
            class_weight="balanced",
            gamma="scale",
            random_state=42,
            max_iter=5000,
        )
    elif clf == "LR":
        clf = LogisticRegressionCV(
            cv=5,
            penalty="l2",
            solver="liblinear",
            class_weight="balanced",
            random_state=42,
            max_iter=5000,
        )
    elif clf == "dummy":
        clf = DummyClassifier()

    clf = make_pipeline(TangentSpace(metric="riemann"), clf)
    clf.fit(X_real, y_real)
    y_score_pred = clf.predict_proba(X_fake)
    y_pred = clf.predict(X_fake)

    return {
        "ROC-AUC": roc_auc_score(y_fake, y_score_pred[:, 1]),
        "Precision": precision_score(y_fake, y_pred),
        "Recall": recall_score(y_fake, y_pred),
        "F1": f1_score(y_fake, y_pred),
    }

def evaluate_metrics(
    cov_train,
    y_train,
    cov_val,
    y_val,
    generated_train,
    y_generated_train,
    generated_val,
    y_generated_val,
    training_time=None,
    sampling_time=None,
):
    """
    Simplified evaluator.

    Assumptions:
      - All inputs are SPD covariance matrices (NORMALIZE == False).
      - No SPD/correlation checks, no projection, no file I/O.
      - `generated_train` / `generated_val` can be either:
          * shape (N, n, n)  -> already final samples
          * shape (T, N, n, n) -> trajectory; uses the last step ([-1])

    Returns:
      dict with:
        - quality metrics (Train vs Train/Val/Gen, Val vs Gen)
        - classification metrics (baseline, Gen->Val, Train->Gen)
        - optionally training/sampling time if provided
    """
    # Allow either final samples or (time, batch, n, n) trajectories
    gen_train_final = generated_train[-1] if generated_train.ndim == 4 else generated_train
    gen_val_final = generated_val[-1] if generated_val.ndim == 4 else generated_val

    out = {
        "quality": {
            "Train vs Train": compute_quality_metrics(cov_train, y_train, cov_train, y_train),
            "Train vs Val": compute_quality_metrics(cov_train, y_train, cov_val, y_val),
            "Train vs Gen": compute_quality_metrics(cov_train, y_train, gen_train_final, y_generated_train),
            "Val vs Gen": compute_quality_metrics(cov_val, y_val, gen_val_final, y_generated_val),
        },
        "classification": {
            # Real->Real baseline
            "Train vs Val": compute_classification_metric(cov_train, y_train, cov_val, y_val, "LR"),
            # CAS-style: train on generated, test on real val
            "Gen vs Val": compute_classification_metric(gen_train_final, y_generated_train, cov_val, y_val, "LR"),
            # Reverse direction: train on real, test on generated
            "Train vs Gen": compute_classification_metric(cov_train, y_train, gen_train_final, y_generated_train, "LR"),
        },
    }

    if training_time is not None:
        out["training_time_s"] = float(training_time)
    if sampling_time is not None:
        out["sampling_time_s"] = float(sampling_time)

    return out