import math
import random
from typing import Any, Dict, Optional

import numpy as np
import pandas as pd
from sklearn.base import BaseEstimator
from sklearn.compose import ColumnTransformer
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.metrics import (
    precision_recall_curve,
    auc,
    classification_report,
    f1_score,
    roc_auc_score,
    root_mean_squared_error,
)
from sklearn.model_selection import ParameterGrid, ShuffleSplit, StratifiedShuffleSplit
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.preprocessing import label_binarize
from sklearn.utils.multiclass import type_of_target
from tabpfn import TabPFNClassifier, TabPFNRegressor
from tabpfn.config import ModelInterfaceConfig
from xgboost import XGBClassifier, XGBRegressor


class LinearModel(BaseEstimator):

    def __init__(self,
                 task: str = "auto",
                 random_state: int = None,
                 **kwargs):
        self.task = task
        self.clf_kwargs = kwargs
        self.reg_kwargs = kwargs
        self.random_state = random_state

        # will be filled in fit()
        self.pipeline_: Pipeline | None = None
        self.is_classification_: bool | None = None
        self.num_cols_: list[str] = []
        self.cat_cols_: list[str] = []

    # ──────────────────────────────────────────────────────────
    # core helpers
    # ──────────────────────────────────────────────────────────
    def _infer_task(self, y):
        if self.task != "auto":
            return self.task == "classification"

        target_type = type_of_target(y)
        return target_type in {"binary", "multiclass"}

    def _build_pipeline(self):
        """Create ColumnTransformer & downstream estimator."""
        preproc = ColumnTransformer(
            transformers=[
                ("num", StandardScaler(with_mean=False), self.num_cols_),
                ("cat",
                 OneHotEncoder(handle_unknown="ignore", dtype=np.float32),
                 self.cat_cols_)
            ],
            remainder="drop"
        )

        if self.is_classification_:
            estimator = LogisticRegression(
                random_state=self.random_state,
                max_iter=1000,
                solver="liblinear",
                n_jobs=-1,
                **self.clf_kwargs
            )
        else:
            estimator = LinearRegression(**self.reg_kwargs)

        return Pipeline([("prep", preproc), ("est", estimator)])

    # ──────────────────────────────────────────────────────────
    # sklearn API
    # ──────────────────────────────────────────────────────────
    def fit(self, X: pd.DataFrame, y):
        """Fit preprocessing + model."""
        if not isinstance(X, pd.DataFrame):
            X = pd.DataFrame(X)

        self.is_classification_ = self._infer_task(y)

        self.num_cols_ = X.select_dtypes(include=["number"]).columns.tolist()
        self.cat_cols_ = X.select_dtypes(
            include=["category", "object"]).columns.tolist()

        self.pipeline_ = self._build_pipeline()
        self.pipeline_.fit(X, y)
        return self

    def predict(self, X: pd.DataFrame):
        """Predict target."""
        if self.pipeline_ is None:
            raise RuntimeError("fit() must be called first.")
        return self.pipeline_.predict(X)

    def predict_proba(self, X: pd.DataFrame):
        """Predict class probabilities (classification only)."""
        if not self.is_classification_:
            raise AttributeError("predict_proba is only available "
                                 "for classification tasks.")
        return self.pipeline_.predict_proba(X)

    # ──────────────────────────────────────────────────────────
    # convenience attributes
    # ──────────────────────────────────────────────────────────
    @property
    def coef_(self):
        """Access underlying linear coefficients after fit()."""
        if self.pipeline_ is None:
            raise RuntimeError("fit() must be called first.")
        return self.pipeline_["est"].coef_

    @property
    def classes_(self):
        """Return class labels (classification)."""
        if self.is_classification_:
            return self.pipeline_["est"].classes_
        raise AttributeError("classes_ attribute exists only "
                             "for classification tasks.")


# -----------------------------------------------------------------------------
# 1.  Metric helpers
# -----------------------------------------------------------------------------

def roc_auc_macro(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    """Macro‑average ROC‑AUC (handles binary / multi‑class)."""
    # If binary probs with shape (N, 2) are given, take positive class prob
    if y_true.max() == 1 and y_pred.ndim == 2 and y_pred.shape[1] == 2:
        y_pred = y_pred[:, 1]
        labels = None
    if y_pred.ndim > 1:
        labels = [i for i in range(y_pred.shape[1])]
    return roc_auc_score(y_true, y_pred, average="macro", multi_class="ovr", labels=labels)


def f1_macro(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    return f1_score(y_true, y_pred, average="macro")


def f1_weighted(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    return f1_score(y_true, y_pred, average="weighted")


def auprc_macro(y_true, y_score, classes=None):
    if classes is None:
        classes = np.unique(y_true)
    y_true_bin = label_binarize(y_true, classes=classes)
    if y_true_bin.shape[1] == 1:  # binary edge‑case
        y_true_bin = np.concatenate([1 - y_true_bin, y_true_bin], axis=1)
    auprcs = []
    for idx in range(y_true_bin.shape[1]):
        precision, recall, _ = precision_recall_curve(y_true_bin[:, idx], y_score[:, idx])
        auprcs.append(auc(recall, precision))
    return float(np.mean(auprcs))


# -----------------------------------------------------------------------------
# 2.  Model & hyper‑parameter grids
# -----------------------------------------------------------------------------
TABPFN_CONFIG = ModelInterfaceConfig.from_user_input(
    inference_config=dict(FINGERPRINT_FEATURE=False)
)

_MODELS: Dict[str, list[dict[str, Any]]] = {
    # --- Classification ------------------------------------------------------
    "categorical": [
        {
            "class": XGBClassifier,
            "kwargs": {
                "n_estimators": [10, 50, 100],
                "min_child_weight": [1, 10],
                "max_depth": [5, 10, 20],
                "gamma": [0.0, 1.0],
                "objective": ["binary:logistic"],
                "nthread": [-1],
                "tree_method": ["hist"],
                "device": ['cpu'],
                "enable_categorical": [True],
            },
        },
        {
            "class": TabPFNClassifier,
            "kwargs": {
                "n_estimators": [4, 8, 16],
                "softmax_temperature": [0.8, 0.9, 1.0],
                "balance_probabilities": [True],
            },
        },
    ],
    # --- Regression ----------------------------------------------------------
    "numerical": [
        {
            "class": XGBRegressor,
            "kwargs": {
                "n_estimators": [10, 50, 100],
                "min_child_weight": [1, 10],
                "max_depth": [5, 10, 20],
                "gamma": [0.0, 1.0],
                "objective": ["reg:squarederror"],
                "nthread": [-1],
                "tree_method": ["hist"],
                "device": ['cpu'],
                "enable_categorical": [True],
            },
        },
        {
            "class": TabPFNRegressor,
            "kwargs": {
                "n_estimators": [4, 8, 16],
                "softmax_temperature": [0.8, 0.9, 1.0],
            },
        },
    ],
}

# -----------------------------------------------------------------------------
# 3.  Metric configs (what to optimise during search)
# -----------------------------------------------------------------------------
_METRIC = dict(
    categorical=dict(function=roc_auc_macro, method="predict_proba", direction="upper"),
    numerical=dict(function=root_mean_squared_error, method="predict", direction="lower"),
)

_METRIC_IMB = dict(
    categorical=dict(function=roc_auc_macro, method="predict_proba", direction="upper"),
    numerical=dict(function=root_mean_squared_error, method="predict", direction="lower"),
)


# -----------------------------------------------------------------------------
# 4.  GridSearch core class
# -----------------------------------------------------------------------------


class GridSearch:
    """Lightweight grid‑search wrapper supporting XGB & TabPFN."""

    def __init__(self, task_type: str, metric=_METRIC):
        if task_type not in ("categorical", "numerical"):
            raise ValueError("task_type must be 'categorical' or 'numerical'")
        self.task_type = task_type
        self.model_specs = _MODELS[task_type]
        self.metric_cfg = metric[task_type]
        self.metric_sign = +1 if self.metric_cfg["direction"] == "upper" else -1
        self.best_model: Optional[Any] = None
        self.best_kwargs: Optional[dict[str, Any]] = None

    # ────────────────────────────────────────────────────────────────────── #
    # internal helpers
    # ────────────────────────────────────────────────────────────────────── #

    def _scorer(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
        return self.metric_cfg["function"](y_true, y_pred)

    def _split(self, x: np.ndarray, y: np.ndarray):
        if self.task_type == "categorical":
            splitter = StratifiedShuffleSplit(n_splits=1, test_size=1 / 9)
        else:
            splitter = ShuffleSplit(n_splits=1, test_size=1 / 9)
        train_idx, val_idx = next(splitter.split(x, y))
        return train_idx, val_idx

    # ────────────────────────────────────────────────────────────────────── #
    # fit/score API
    # ────────────────────────────────────────────────────────────────────── #

    def fit(self, x: pd.DataFrame, y: np.ndarray, params: Optional[dict[str, Any]] = None):
        self._fit(x, y, params=params)
        return self

    def _fit(self, x: pd.DataFrame, y: np.ndarray, params: Optional[dict[str, Any]] = None, **fit_kwargs):
        random_state = random.randint(0, 2 ** 32 - 1)
        if isinstance(x, pd.DataFrame):
            nan_ind = np.nonzero(~np.isnan(x.values).all(axis=1))
        else:
            nan_ind = np.nonzero(~np.isnan(x).all(axis=1))
        x = x.iloc[nan_ind]
        y = y[nan_ind]

        if params is not None:
            # Explicit params ⇒ use first model type only (convention)
            cls = self.model_specs[0]["class"]
            self.best_model = cls(**params)
            self.best_model.fit(x, y, random_state=random_state, **fit_kwargs)
            self.best_kwargs = params
            return

        # 1. create a single train/val split
        train_idx, val_idx = self._split(x, y)
        if isinstance(x, pd.DataFrame):
            x_train, y_train = x.iloc[train_idx], y[train_idx]
            x_val, y_val = x.iloc[val_idx], y[val_idx]
            skip_pfn = len(x) > 10000
        else:
            x_train, y_train = x[train_idx], y[train_idx]
            x_val, y_val = x[val_idx], y[val_idx]
            skip_pfn = len(x) > 10000

        # 2. grid search across all model specs
        best_score = -math.inf
        best_cls: Optional[type] = None
        best_kwargs = {}

        for spec in self.model_specs:
            cls = spec["class"]
            param_grid = spec.get("kwargs", {})
            # create at least one empty‑dict if grid is empty
            for kwargs in ParameterGrid(param_grid) if param_grid else [{}]:
                base_kwargs = {}
                if cls in [TabPFNClassifier, TabPFNRegressor]:
                    if skip_pfn:
                        continue
                    base_kwargs = dict(inference_config=TABPFN_CONFIG)
                model = cls(random_state=random_state, **base_kwargs, **kwargs)
                model.fit(x_train, y_train, **fit_kwargs)
                preds = getattr(model, self.metric_cfg["method"])(x_val)
                score = self.metric_sign * self._scorer(y_val, preds)
                if score > best_score:
                    best_score = score
                    best_cls = cls
                    best_kwargs = kwargs.copy()

                del model  # free memory

        # 3. refit on full data
        if best_cls is None:
            raise RuntimeError("GridSearch failed to find a best model – check configuration.")
        self.best_model = best_cls(**best_kwargs)
        self.best_model.fit(x, y, **fit_kwargs)
        self.best_kwargs = best_kwargs
        self.best_kwargs['Model'] = best_cls.__name__

    def predict(self, x):
        return self.best_model.predict(x)

    def predict_proba(self, x):
        return self.best_model.predict_proba(x)

    def score(self, x: np.ndarray, y: np.ndarray) -> float:
        if self.best_model is None:
            raise RuntimeError("Call fit() before score().")
        preds = getattr(self.best_model, self.metric_cfg["method"])(x)
        return self._scorer(y, preds)

    def classification_report(self, x: np.ndarray, y: np.ndarray):
        if self.best_model is None:
            raise RuntimeError("Call fit() before classification_report().")
        preds = self.best_model.predict(x)
        return classification_report(y, preds, digits=4)
