# real_best_hparams.py
import pandas as pd
from typing import Dict, Any, Optional

def get_best_hparams(dataset, penalty, series, metric="AUROC", weight_decay=1e-5):
    metric = metric.upper()
    if metric not in ("AUROC", "AUPRC"):
        raise ValueError(f"metric must be 'AUROC' or 'AUPRC', got {metric}")

    df = pd.read_csv(f'./server_results/real_data_{penalty}.csv')
    row = df[(df["Dataset"] == dataset) &
             (df["series"] == series)]

    if row.empty:
        raise ValueError(
            f"No tuned hyperparameters for (dataset={dataset}, series={series})"
        )

    row = row.iloc[0]

    if metric == "AUROC":
        return {
            "lag": int(row["lag_AUROC"]),
            "lr": float(row["lr_AUROC"]),
            "hidden_dim": int(row["hidden_dim_AUROC"]),
            "dropout": float(row["dropout_AUROC"]),
            "layers": int(row["layers_AUROC"]),
            "ind_lambda": float(row["ind_lambda_AUROC"]),
            "int_lambda": float(row["int_lambda_AUROC"]),
            "weight_decay": float(weight_decay),
        }
    else:  # AUPRC
        return {
            "lag": int(row["lag_AUPRC"]),
            "lr": float(row["lr_AUPRC"]),
            "hidden_dim": int(row["hidden_dim_AUPRC"]),
            "dropout": float(row["dropout_AUPRC"]),
            "layers": int(row["layers_AUPRC"]),
            "ind_lambda": float(row["ind_lambda_AUPRC"]),
            "int_lambda": float(row["int_lambda_AUPRC"]),
            "weight_decay": float(weight_decay),
        }

# utils/configs.py
def build_training_config(
    dataset: str,
    penalty_type: str,
    data_dim: int,
    use_best: bool = False,  
    best_hparams: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
    # -----------------------------
    # 1) penalty -> importance_type
    # -----------------------------
    penalty_to_importance = {
        "Fast_Shap": "Shapley",
        "Shapley": "Shapley",
        "Jacob_F": "Jacobian",
        "Jacob_L1": "Jacobian",
        "Layer_Weight": "Layer_Weight",
    }
    if penalty_type not in penalty_to_importance:
        raise ValueError(f"Unknown penalty_type: {penalty_type}")
    importance_type = penalty_to_importance[penalty_type]

    # -----------------------------
    # 2) ignore_diagonal (metrics)
    # -----------------------------
    ignore_diagonal = dataset in ("DREAM3", "DREAM4")

    # -----------------------------
    # 3) batch size rule
    # -----------------------------
    batch_size = 512 if dataset == "CausalTime" else -1

    # -----------------------------
    # 4) hidden_dim grid by dataset
    # -----------------------------
    if dataset in ("DREAM3", "DREAM4"):
        hidden_dim_grid = [
            int(0.5 * data_dim),
            data_dim,
            2 * data_dim,
            3 * data_dim,
            4 * data_dim,
        ]
    else:
        hidden_dim_grid = [
            data_dim,
            2 * data_dim,
            3 * data_dim,
            4 * data_dim,
            5 * data_dim
        ]

    # -----------------------------
    # 5) int_lambda grid by penalty
    # -----------------------------
    if penalty_type == "Fast_Shap":
        int_lambda_grid = [1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1]
    else:
        int_lambda_grid = [0.0]

    # -----------------------------
    # 6) Build param_grid
    #    - If use_best=True and best_hparams provided,
    #      we create a "1-point grid" from that.
    # -----------------------------
    if use_best and best_hparams is not None:
        param_grid = {
            "lr": [best_hparams["lr"]],
            "hidden_dim": [best_hparams["hidden_dim"]],
            "layers": [best_hparams["layers"]],
            "dropout": [best_hparams["dropout"]],
            "ind_lambda": [best_hparams["ind_lambda"]],
            "int_lambda": [best_hparams["int_lambda"]],
            "weight_decay": [best_hparams.get("weight_decay", 1e-5)],
        }
    else:
        param_grid = {
            "lr": [5e-4, 1e-4, 1e-3, 5e-3],
            "hidden_dim": hidden_dim_grid,
            "layers": [0, 1, 2, 3, 4, 5],
            "dropout": [0, 0.1, 0.2],
            "ind_lambda": [1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1],
            "int_lambda": int_lambda_grid,
            "weight_decay": [1e-5],
        }

    return {
        "importance_type": importance_type,
        "ignore_diagonal": ignore_diagonal,
        "batch_size": batch_size,
        "param_grid": param_grid,
    }
