import numpy as np
import torch
import argparse
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
from xgboost import XGBClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score

@dataclass(frozen=True)
class XGBPropensityConfig:
    # Core model hyperparameters
    max_depth: int = 3
    min_child_weight: float = 1.0
    gamma: float = 0.0
    subsample: float = 0.8
    colsample_bytree: float = 0.8
    learning_rate: float = 0.05
    n_estimators: int = 400
    reg_lambda: float = 1.0
    reg_alpha: float = 0.0

    # Training / evaluation
    test_size: float = 0.2
    threshold: float = 0.5
    seed: int = 42

    # System / performance
    n_jobs: int = -1
    tree_method: str = "hist"  # "hist" is fast and reproducible on CPU

def build_xgb_cfg(args: argparse.Namespace) -> XGBPropensityConfig:
    return XGBPropensityConfig(
        max_depth=int(getattr(args, "xgb_max_depth")),
        min_child_weight=float(getattr(args, "xgb_min_child_weight")),
        gamma=float(getattr(args, "xgb_gamma")),
        subsample=float(getattr(args, "xgb_subsample")),
        colsample_bytree=float(getattr(args, "xgb_colsample_bytree")),
        learning_rate=float(getattr(args, "xgb_learning_rate")),
        n_estimators=int(getattr(args, "xgb_n_estimators")),
        reg_lambda=float(getattr(args, "xgb_reg_lambda")),
        reg_alpha=float(getattr(args, "xgb_reg_alpha", 0.0)),
        test_size=float(getattr(args, "xgb_test_size", 0.2)),
        threshold=float(getattr(args, "xgb_threshold", 0.5)),
        seed=int(getattr(args, "seed", 42)),
        n_jobs=int(getattr(args, "xgb_n_jobs", -1)),
        tree_method=str(getattr(args, "xgb_tree_method", "hist")),
    )

def xgb_propensity_model(
        X: torch.Tensor, 
        A: torch.Tensor, 
        config: Union[XGBPropensityConfig, Any],
        verbose: bool = True,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, XGBClassifier, Dict[str, float]]]:


    X_np = X.detach().cpu().to(torch.float32).numpy()
    A_np = A.detach().cpu().to(torch.int64).numpy()

    if isinstance(config, XGBPropensityConfig):
        cfg = config
    else:
        # args-like object support (backward compatible with your current code)
        cfg = XGBPropensityConfig(
            max_depth=int(getattr(config, "xgb_max_depth")),
            min_child_weight=float(getattr(config, "xgb_min_child_weight")),
            gamma=float(getattr(config, "xgb_gamma")),
            subsample=float(getattr(config, "xgb_subsample")),
            colsample_bytree=float(getattr(config, "xgb_colsample_bytree")),
            learning_rate=float(getattr(config, "xgb_learning_rate")),
            n_estimators=int(getattr(config, "xgb_n_estimators")),
            reg_lambda=float(getattr(config, "xgb_reg_lambda")),
            reg_alpha=float(getattr(config, "xgb_reg_alpha", 0.0)),
            test_size=float(getattr(config, "test_size", 0.2)),
            threshold=float(getattr(config, "threshold", 0.5)),
            seed=int(getattr(config, "seed", 42)),
            n_jobs=int(getattr(config, "n_jobs", -1)),
            tree_method=str(getattr(config, "tree_method", "hist")),
        )

    X_train, X_test, A_train, A_test = train_test_split(
        X_np,
        A_np,
        test_size=cfg.test_size,
        random_state=cfg.seed,
        stratify=A_np,
    )

    model = XGBClassifier(
            max_depth=cfg.max_depth,
            min_child_weight=cfg.min_child_weight,
            gamma=cfg.gamma,
            subsample=cfg.subsample,
            colsample_bytree=cfg.colsample_bytree,
            learning_rate=cfg.learning_rate,
            n_estimators=cfg.n_estimators,
            reg_lambda=cfg.reg_lambda,
            reg_alpha=cfg.reg_alpha,
            objective="binary:logistic",
            eval_metric="logloss",
            tree_method=cfg.tree_method,
            random_state=cfg.seed,
            n_jobs=cfg.n_jobs,
        )
    
    model.fit(X_train, A_train)

    # Evaluate on held-out test split
    p_test = model.predict_proba(X_test)[:, 1]
    A_hat = (p_test >= cfg.threshold).astype(np.int64)

    metrics: Dict[str, float] = {
        "test_accuracy": float(accuracy_score(A_test, A_hat)),
        "test_auc": float(roc_auc_score(A_test, p_test)),
    }

    if verbose:
        print(f"XGB Test Accuracy: {metrics['test_accuracy']:.4f}")
        print(f"XGB Test AUC:      {metrics['test_auc']:.4f}")

    return model 


