from typing import Any

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_validate
from sklearn.pipeline import Pipeline
from torch import Tensor

LogisticRegression

SCORING = ["accuracy", "precision", "recall", "f1", "roc_auc", "average_precision"]
DEFAULT_CV = 5
DEFAULT_MAX_ITER = 1_000
DEFAULT_RANDOM_SEED = 42


def cross_validate_lr(
    features: Tensor,
    labels: Tensor,
    undersample: bool,
    pca_dim: int | None,
    max_iter: int = DEFAULT_MAX_ITER,
    cv: int = DEFAULT_CV,
    random_seed: int = DEFAULT_RANDOM_SEED,
    use_cuda: bool = False,
) -> dict[str, Any]:
    feats_np = features.cpu().numpy()
    labels_np = labels.cpu().numpy()

    if use_cuda:
        from cuml.decomposition import PCA
        from cuml.linear_model import LogisticRegression
        from lightning_fabric import seed_everything

        seed_everything(random_seed, verbose=False)
        lr_kwargs = dict(
            max_iter=max_iter,
            class_weight="balanced",
        )
        pca_kwargs = dict(
            n_components=pca_dim,
        )
    else:
        from sklearn.decomposition import PCA
        from sklearn.linear_model import LogisticRegression

        lr_kwargs = dict(
            max_iter=max_iter,
            class_weight="balanced",
            random_state=random_seed,
        )
        pca_kwargs = dict(
            n_components=pca_dim,
            random_state=random_seed,
        )

    if pca_dim is not None:
        model = Pipeline(
            [
                (
                    "pca",
                    PCA(**pca_kwargs),
                ),
                (
                    "lr",
                    LogisticRegression(**lr_kwargs),
                ),
            ]
        )
    else:
        model = Pipeline(
            [
                ("lr", LogisticRegression(**lr_kwargs)),
            ]
        )

    cv_res = cross_validate(
        model,
        feats_np,
        labels_np,
        cv=cv,
        scoring=SCORING,
        return_estimator=True,
    )
    estimators = cv_res.pop("estimator")
    num_params = estimators[0].named_steps["lr"].coef_.size

    metadata = {
        "undersample": undersample,
        "pca_dim": pca_dim,
        "cv": cv,
        "random_state": random_seed,
        "num_params": num_params,
    }

    return {
        "metadata": metadata,
        "metrics": cv_res,
    }
