import pandas as pd
import numpy as np
import os
import sys
import xgboost as xgb
import argparse

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from utils import (
    get_tables_path,
    update_json_file,
    compute_condition_number,
    get_data_and_explicand,
)
from explainers.opt_shap_explainers import (
    leverage_shap,
    optimized_kernel_shap,
    official_tree_shap,
    encode_categorical_columns,
    load_input,
    NoisyModel,
)
from explainers.banzhaf_explainers import (
    ExactBanzhaf,
    KernelPairedBanzhaf,
)
from imputer import ShapRawImputer, TreeImputer, ShapNoiseTreeImputer

from experiments.plots_banz_shap import (
    plot_by_sample_size,
    plot_condition_by_sample_size,
    plot_by_noise,
    plot_all_datasets,
)


def shap_raw_pipeline(config, data, explicand, xgb_model, random_state, result_dir):

    print(f"Running exact SHAP")
    exact_shap = official_tree_shap(data, explicand, xgb_model)[0]
    exact_path = f"{result_dir}/shap_exact_{config['dataset']}.json"
    update_json_file(exact_path, list(exact_shap), random_state)

    for S_size in config["S_size"]:

        print(f"Running leverage SHAP for S_size: {S_size}")
        leverage_shap_values, _ = leverage_shap(data, explicand, xgb_model, S_size)
        leverage_path = f"{result_dir}/shap_leverage_{config['dataset']}_{S_size}.json"
        update_json_file(leverage_path, list(leverage_shap_values), random_state)

        # leverage_shap_cond = compute_condition_number(leverage_shap_assa)
        # leverage_cond_path = (
        #     f"{result_dir}/shap_leverage_cond_{config['dataset']}_{S_size}.json"
        # )
        # update_json_file(leverage_cond_path, leverage_shap_cond, random_state)

        print(f"Running optimized kernel SHAP for S_size: {S_size}")
        optimized_shap_values, _ = optimized_kernel_shap(
            data, explicand, xgb_model, S_size
        )
        optimized_path = (
            f"{result_dir}/shap_optimized_{config['dataset']}_{S_size}.json"
        )
        update_json_file(optimized_path, list(optimized_shap_values), random_state)
        # remove values out of memory
        del leverage_shap_values
        del optimized_shap_values

        # optimized_shap_cond = compute_condition_number(optimized_shap_assa)
        # optimized_cond_path = (
        #     f"{result_dir}/shap_optimized_cond_{config['dataset']}_{S_size}.json"
        # )
        # update_json_file(optimized_cond_path, optimized_shap_cond, random_state)


def banzhaf_raw_pipeline(
    config, data, explicand, xgb_model, target, random_state, result_dir
):
    if target is None:
        features = data.columns.tolist()
    else:
        features = data.drop(columns=[target]).columns.tolist()

    if config["dataset"] in ["nhanes", "brca", "communitiesandcrime", "tuandromd"]:
        imputer = TreeImputer(explicand, features, xgb_model)
    else:
        imputer = ShapRawImputer(data, explicand, features, xgb_model)

        print(f"Running Exact Banzhaf")
        exact_banzhaf, _, _ = ExactBanzhaf(features, imputer)()
        exact_path = f"{result_dir}/banzhaf_exact_{config['dataset']}.json"
        update_json_file(exact_path, exact_banzhaf, random_state)

    for S_size in config["S_size"]:
        print(f"Running Kernel Banzhaf for S_size: {S_size}")
        kernel_banzhaf, _ = KernelPairedBanzhaf(features, S_size, imputer)()
        kernel_path = (
            f"{result_dir}/banzhaf_paired_sampling_{config['dataset']}_{S_size}.json"
        )
        update_json_file(kernel_path, kernel_banzhaf, random_state)

        del kernel_banzhaf

        # kernel_banzhaf_cond = compute_condition_number(kernel_banzhaf_assa)
        # kernel_cond_path = f"{result_dir}/banzhaf_paired_sampling_cond_{config['dataset']}_{S_size}.json"
        # update_json_file(kernel_cond_path, kernel_banzhaf_cond, random_state)


def shap_noise_pipeline(
    config, data, explicand, xgb_model, random_state, result_dir, n
):
    S_size = 20 * n if config["dataset"] in ["diabetes"] else 80 * n
    for noise_level in config["noise"]:
        noisy_model = NoisyModel(xgb_model, noise_level)
        print(f"Running leverage SHAP for S_size: {S_size}")
        leverage_shap_values, _ = leverage_shap(data, explicand, noisy_model, S_size)
        leverage_path = (
            f"{result_dir}/shap_leverage_{config['dataset']}_{noise_level}.json"
        )
        update_json_file(leverage_path, list(leverage_shap_values), random_state)

        print(f"Running optimized kernel SHAP for S_size: {S_size}")
        optimized_shap_values, _ = optimized_kernel_shap(
            data, explicand, noisy_model, S_size
        )
        optimized_path = (
            f"{result_dir}/shap_optimized_{config['dataset']}_{noise_level}.json"
        )
        update_json_file(optimized_path, list(optimized_shap_values), random_state)


def banzhaf_noise_pipeline(
    config, data, explicand, xgb_model, target, random_state, result_dir, n
):
    S_size = 20 * n if config["dataset"] in ["diabetes"] else 80 * n
    if target is None:
        features = data.columns.tolist()
    else:
        features = data.drop(columns=[target]).columns.tolist()
    if config["dataset"] not in ["nhanes", "brca", "communitiesandcrime", "tuandromd"]:
        for noise_level in config["noise"]:
            noisy_model = NoisyModel(xgb_model, noise_level)
            imputer = ShapRawImputer(data, explicand, features, noisy_model)
            print(f"Running Kernel Banzhaf for S_size: {S_size}")
            kernel_banzhaf, _ = KernelPairedBanzhaf(features, S_size, imputer)()
            kernel_path = f"{result_dir}/banzhaf_paired_sampling_{config['dataset']}_{noise_level}.json"
            update_json_file(kernel_path, kernel_banzhaf, random_state)
    else:
        for noise_level in config["noise"]:
            print(f"Running Kernel Banzhaf for S_size: {S_size}")
            imputer = ShapNoiseTreeImputer(explicand, features, xgb_model, noise_level)
            kernel_banzhaf, _ = KernelPairedBanzhaf(features, S_size, imputer)()
            kernel_path = f"{result_dir}/banzhaf_paired_sampling_{config['dataset']}_{noise_level}.json"
            update_json_file(kernel_path, kernel_banzhaf, random_state)


def main():
    parser = argparse.ArgumentParser(description="Process some datasets.")
    parser.add_argument(
        "--dataset",
        type=str,
        default=None,
        help="Dataset to use: diabetes, adult, bank, german_credit, nhanes, brca, communitiesandcrime, tuandromd",
    )
    parser.add_argument(
        "--model_type", type=str, default=None, help="Model to use: xgb, autoxgb"
    )
    parser.add_argument("--noise", action="store_true", help="Add noise")
    parser.add_argument("--plot_only", action="store_true", help="Plot results")

    args = parser.parse_args()
    dataset = args.dataset
    model_type = args.model_type

    if dataset is None:
        n = 0
    else:
        tables_path, target = get_tables_path(dataset)
        base_table = pd.read_csv(tables_path)  # Ensure optimal data types
        base_table = encode_categorical_columns(base_table)
        n = len(base_table.columns) if target is None else len(base_table.columns) - 1
        print(f"Dataset: {dataset}, Number of features: {n}")

    if model_type == None:
        model_type = "autoxgb" if n < 63 else "xgb"

    S_size = [5 * n, 10 * n, 20 * n, 40 * n]
    if dataset not in ["diabetes"]:
        S_size += [80 * n, 160 * n]
    if dataset not in ["diabetes", "tuandromd"]:
        S_size += [320 * n]
    if dataset not in ["diabetes", "brca", "communitiesandcrime"]:
        S_size += [640 * n]
    if dataset not in ["diabetes", "brca", "communitiesandcrime", "tuandromd"]:
        S_size += [1280 * n]
    if dataset in ["bank", "german_credit"]:
        S_size += [2560 * n, 5120 * n]
    print(f"S sizes: {S_size}")

    estimators = ["leverage", "optimized"]
    plot_error = ["l2"]

    random_states = np.random.RandomState(42).choice(1000, 50, replace=False)
    random_states = [int(i) for i in random_states]
    print(f"Random states: {random_states}")

    noise = [0, 0.0001, 0.0005, 0.001, 0.005, 0.01, 0.05] if args.noise else None

    config = {
        "model_type": model_type,
        "dataset": dataset,
        "estimators": estimators,
        "data_size": 50,
        "S_size": S_size,
        "random_states": random_states,
        "noise": noise,
        "plot_error": plot_error,
        "plot_all_datasets": True,
    }

    if config["dataset"] is not None:
        result_dir = f"shap_results_{config['dataset']}"
        if config["noise"] is not None:
            result_dir = f"{result_dir}/noise"
        if not os.path.exists(result_dir):
            os.makedirs(result_dir)

    if config["plot_all_datasets"]:
        datasets = [
            "diabetes",
            "adult",
            "bank",
            "german_credit",
            "nhanes",
            "brca",
            "communitiesandcrime",
            "tuandromd",
        ]
        error = config["plot_error"][0] if config["plot_error"] is not None else None
        print("Plotting all datasets")
        plot_all_datasets(
            datasets, estimators, error, config["noise"],
        )
        return

    if not args.plot_only:
        if dataset == "nhanes":
            X = base_table
            y = np.load(f"dataset/y_{dataset}.npy")
        else:
            X = base_table.drop(columns=[target])
            y = base_table[target]
        xgb_model = xgb.XGBRegressor(n_estimators=100, max_depth=4)
        xgb_model.fit(X, y)

        for random_state in config["random_states"]:
            X = base_table.drop(columns=[target]) if target is not None else base_table
            baseline, explicand = load_input(X, random_state)
            if noise is None:
                shap_raw_pipeline(
                    config, baseline, explicand, xgb_model, random_state, result_dir
                )
            else:
                shap_noise_pipeline(
                    config, baseline, explicand, xgb_model, random_state, result_dir, n
                )

        for random_state in config["random_states"]:
            X = base_table.drop(columns=[target]) if target is not None else base_table
            if dataset in ["nhanes", "brca"]:
                xgb_model = xgb.XGBClassifier()
                xgb_model.load_model(f"shap_xgb_{config['dataset']}.json")
                xgb_model.get_booster().get_dump(
                    f"shap_bst_{config['dataset']}.json", with_stats=True
                )
            elif dataset in ["tuandromd"]:
                xgb_model = xgb.XGBRegressor()
                xgb_model.load_model(f"xgb_{config['dataset']}.json")
                xgb_model.get_booster().get_dump(
                    f"bst_{config['dataset']}.json", with_stats=True
                )
            elif dataset in ["communitiesandcrime"]:
                xgb_model = xgb.XGBClassifier()
                xgb_model.load_model(f"xgb_{config['dataset']}.json")
                xgb_model.get_booster().get_dump(
                    f"bst_{config['dataset']}.json", with_stats=True
                )
            data, explicand = get_data_and_explicand(
                config["data_size"], base_table, random_state
            )
            if noise is None:
                banzhaf_raw_pipeline(
                    config, data, explicand, xgb_model, target, random_state, result_dir
                )
            else:
                banzhaf_noise_pipeline(
                    config,
                    data,
                    explicand,
                    xgb_model,
                    target,
                    random_state,
                    result_dir,
                    n,
                )

    if config["plot_error"] is not None:
        if noise is not None:
            result_dir = result_dir.replace("/noise", "")
            print("Plotting by sample size with noise")
            plot_by_noise(result_dir, estimators, config["dataset"])
        else:
            print("Plotting by sample size")
            plot_by_sample_size(
                result_dir, estimators, config["dataset"], n,
            )
    else:
        print("Plotting by condition number")
        plot_condition_by_sample_size(
            result_dir, estimators, config["dataset"], n,
        )


if __name__ == "__main__":
    main()
