import os
import pickle
from typing import Tuple

import numpy as np
import pandas as pd
from sklearn.datasets import load_diabetes, fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

import config.config as c
from data.dataset import Dataset
from experiments.experiment import Experiment

def split_scale_reg(
    X: np.ndarray,
    y: np.ndarray,
    random_state: int,
    train_size: int,
    test_size: int,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    X_train, X_test, y_train, y_test = train_test_split(
        X,
        y,
        train_size=train_size,
        test_size=test_size,
        random_state=random_state,
    )
    x_scaler = StandardScaler()
    X_train = x_scaler.fit_transform(X_train)
    X_test = x_scaler.transform(X_test)

    y = y.reshape(-1, 1)
    y_scaler = StandardScaler()
    y_train = y_scaler.fit_transform(y_train.reshape(-1, 1)).ravel()
    y_test = y_scaler.transform(y_test.reshape(-1, 1)).ravel()
    return X_train, X_test, y_train, y_test


def prepare_regression_dataset(dataset_name: str,
                               random_state: int,
                               train_size: int,
                               test_size: int):
    dataset_name = dataset_name.lower()
    if dataset_name in {"diabetes"}:
        data = load_diabetes()
        X, y = data.data, data.target.astype(float)
        return split_scale_reg(X, y, random_state, train_size, test_size)

    if dataset_name in {"california", "california_housing"}:
        data = fetch_california_housing()
        X, y = data.data, data.target.astype(float)
        return split_scale_reg(X, y, random_state, train_size, test_size)

    if dataset_name in {"ames", "ames_housing"}:
        try:
            import openml
            df = openml.datasets.get_dataset(42165).get_data(dataset_format="dataframe")[0] 
            y = df["SalePrice"].values.astype(float)
            X = df.drop(columns=["SalePrice"])
            X = X.select_dtypes(include=[np.number]).fillna(X.median())
            return split_scale_reg(X.values, y, random_state, train_size, test_size)
        except Exception as e:
            raise RuntimeError(
                "Ames dataset requires OpenML access or a local CSV. "
                "Provide a preprocessed numeric CSV with 'SalePrice' as target."
            ) from e

    raise ValueError(f"Unknown regression dataset: {dataset_name}")

def run_experiment_for_dataset(X_train, X_test, y_train, y_test,
                               random_state=c.DEFAULT_RANDOM_STATE):
    trainset = Dataset(X_train, y_train, random_state=random_state)
    testset = Dataset(X_test, y_test, random_state=random_state)

    fixed_params = {
        "model": "linreg",                 
        "model_kwargs": {},
        "weight_init": "normal",
        "random_state": random_state,
        "criterion": "MSE",
        "regularization": {"l2": 1.0},     
        "optimizer": "LBFGS",
        "lr": 1.0,
        "optimizer_kwargs": {},
        "trainset": trainset,
        "testset": testset,
    }

    changing_params = {
        "utility": [
            {"utility_name": "mse"},
            {"utility_name": "mae"},
            {"utility_name": "r2"},
        ]
    }

    exp = Experiment(
        fixed_params=fixed_params,
        changing_params=changing_params,
        n_runs=5,
        valuation_task_name="regression",
        valuation_func=[
            {"func_name": "shapley"},
            {"func_name": "beta_shapley", "alpha": 4, "beta": 1},
            {"func_name": "banzhaf"},
        ],
        save_path=None,
    )
    exp.run_experiment()
    return exp.values, exp.marg_contrib_dict

if __name__ == "__main__":
    random_state = c.DEFAULT_RANDOM_STATE
    train_size = 300
    test_size = 100

    datasets = ["diabetes", "california", "ames"]
    all_values = {}
    all_marg_contrib = {}

    for name in datasets:
        print(f"[regression] Computing values for {name}")
        X_train, X_test, y_train, y_test = prepare_regression_dataset(
            name, random_state, train_size, test_size
        )
        values, mc = run_experiment_for_dataset(X_train, X_test, y_train, y_test, random_state)
        all_values[name] = values
        all_marg_contrib[name] = mc

    os.makedirs("results", exist_ok=True)
    with open("results/all_marg_contrib_regression.pkl", "wb") as f:
        pickle.dump(all_marg_contrib, f)
    with open("results/all_values_regression.pkl", "wb") as f:
        pickle.dump(all_values, f)
