import os
import pickle
from typing import Tuple

import numpy as np
from sklearn.datasets import load_digits, load_iris, load_wine
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

import config.config as c
from data.dataset import Dataset
from experiments.experiment import Experiment

def split_scale(
    X: np.ndarray,
    y: np.ndarray,
    random_state: int,
    train_size: int,
    test_size: int,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    X_train, X_test, y_train, y_test = train_test_split(
        X,
        y,
        train_size=train_size,
        test_size=test_size,
        stratify=y,
        random_state=random_state,
    )
    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_test = scaler.transform(X_test)
    return X_train, X_test, y_train, y_test


def prepare_multiclass_dataset(dataset_name: str,
                               random_state: int,
                               train_size: int,
                               test_size: int):
    dataset_name = dataset_name.lower()
    if dataset_name in {"digits"}:
        data = load_digits()
        X, y = data.data, data.target
    elif dataset_name in {"iris"}:
        data = load_iris()
        X, y = data.data, data.target
    elif dataset_name in {"wine"}:
        data = load_wine()
        X, y = data.data, data.target
    else:
        raise ValueError(f"Unknown multiclass dataset: {dataset_name}")

    return split_scale(X, y, random_state, train_size, test_size)

def run_experiment_for_dataset(X_train, X_test, y_train, y_test,
                               random_state=c.DEFAULT_RANDOM_STATE):
    K = int(np.unique(y_train).size)
    D = int(X_train.shape[1])
    trainset = Dataset(X_train, y_train, random_state=random_state)
    testset = Dataset(X_test, y_test, random_state=random_state)
    fixed_params = {
        "model": "mlp_class",           
        "model_kwargs": {                     
            "hidden_dim": 64,
            "num_class": K,          
        },
        "weight_init": "normal",
        "random_state": random_state,
        "criterion": "CrossEntropy",                   
        "regularization": {"l2": 1.0},
        "optimizer": "LBFGS",
        "lr": 1.0,
        "optimizer_kwargs": {},
        "trainset": trainset,
        "testset": testset,
    }

    changing_params = {
        "utility": [
            {"utility_name": "accuracy"},
            {"utility_name": "f1"},
            {"utility_name": "recall"},
        ]
    }

    exp = Experiment(
        fixed_params=fixed_params,
        changing_params=changing_params,
        n_runs=5,                               
        valuation_task_name="multiclass",
        valuation_func=[
            {"func_name": "shapley"},
            {"func_name": "beta_shapley", "alpha": 4, "beta": 1},
            {"func_name": "banzhaf"},
        ],
        save_path=None,
    )
    exp.run_experiment()
    return exp.values, exp.marg_contrib_dict

if __name__ == "__main__":
    random_state = c.DEFAULT_RANDOM_STATE
    train_size = 100
    test_size = 50

    datasets = ["digits", "wine", "iris"]
    all_values = {}
    all_marg_contrib = {}

    for name in datasets:
        print(f"[multiclass] Computing values for {name}")
        X_train, X_test, y_train, y_test = prepare_multiclass_dataset(
            name, random_state, train_size, test_size
        )
        values, mc = run_experiment_for_dataset(X_train, X_test, y_train, y_test, random_state)
        all_values[name] = values
        all_marg_contrib[name] = mc

    os.makedirs("results", exist_ok=True)
    with open("results/all_marg_contrib_multiclass.pkl", "wb") as f:
        pickle.dump(all_marg_contrib, f)
    with open("results/all_values_multiclass.pkl", "wb") as f:
        pickle.dump(all_values, f)
