import pandas as pd
import numpy as np
import os
import sys
import time
import xgboost as xgb
import argparse

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from models.autoxgb import AutoXGB
from models.nn import NN
from explainers.banzhaf_explainers import (
    ExactBanzhaf,
    KernelBanzhaf,
    MCBanzhaf,
    MSRBanzhaf,
    KernelPairedBanzhaf,
    # KernelPairedWORBanzhaf,
)
from imputer import RawImputer, NoiseImputer, TreeImputer, NoiseTreeImputer
from utils import (
    get_tables_path,
    get_data_and_explicand,
    update_json_file,
)
from experiments.plots_banz_est import (
    plot_by_sample_size_median,
    plot_by_noise,
    plot_by_feature,
    plot_all_datasets,
    plot_by_set_function,
    plot_all_time,
)


def raw_pipeline(config, data, explicand, model, target, random_state, result_dir):
    if target is None:
        features = data.columns.tolist()
    else:
        features = data.drop(columns=[target]).columns.tolist()

    if len(features) <= 20:
        imputer = RawImputer(data, explicand, features, model)
        exact_start_time = time.time()
        exact_banzhaf, S, b = ExactBanzhaf(features, imputer)()
        exact_time = time.time() - exact_start_time

        exact_path = os.path.join(result_dir, f'banzhaf_exact_{config["dataset"]}.json')
        update_json_file(exact_path, exact_banzhaf, random_state)

        np.save(
            os.path.join(result_dir, f'S_{config["dataset"]}_{random_state}.npy'), S
        )
        np.save(
            os.path.join(result_dir, f'b_{config["dataset"]}_{random_state}.npy'), b
        )

    else:
        imputer = TreeImputer(explicand, features, model)
        exact_time = 0

    for S_size in config["S_size"]:
        print(
            f"Running Kernel Banzhaf for S_size: {S_size}, random_state: {random_state}"
        )
        kernel_time = time.time()
        kernel_banzhaf = KernelBanzhaf(features, S_size, imputer)()
        kernel_time = time.time() - kernel_time
        # print(kernel_banzhaf)

        kernel_path = os.path.join(
            result_dir, f'banzhaf_kernel_{config["dataset"]}_{S_size}.json'
        )
        update_json_file(kernel_path, kernel_banzhaf, random_state)

        print(
            f"Running Kernel Paired Banzhaf for S_size: {S_size}, random_state: {random_state}"
        )
        kernel_paired_time = time.time()
        kernel_paired_banzhaf, _ = KernelPairedBanzhaf(features, S_size, imputer)()
        kernel_paired_time = time.time() - kernel_paired_time

        kernel_paired_path = os.path.join(
            result_dir, f'banzhaf_paired_sampling_{config["dataset"]}_{S_size}.json'
        )
        update_json_file(kernel_paired_path, kernel_paired_banzhaf, random_state)

        # print(
        #     f"Running Kernel Paired WOR Banzhaf for S_size: {S_size}, random_state: {random_state}"
        # )
        # kernel_paired_wor_time = time.time()
        # kernel_paired_wor_banzhaf = KernelPairedWORBanzhaf(features, S_size, imputer)()
        # kernel_paired_wor_time = time.time() - kernel_paired_wor_time

        # kernel_paired_wor_path = os.path.join(
        #     result_dir, f'banzhaf_swor_{config["dataset"]}_{S_size}.json'
        # )
        # update_json_file(kernel_paired_wor_path, kernel_paired_wor_banzhaf, random_state)

        print(
            f"Running Monte Carlo Banzhaf for S_size: {S_size}, random_state: {random_state}"
        )
        mc_time = time.time()
        mc_banzhaf = MCBanzhaf(features, S_size, imputer)()
        mc_time = time.time() - mc_time
        # print(mc_banzhaf)

        mc_path = os.path.join(
            result_dir, f'banzhaf_mc_{config["dataset"]}_{S_size}.json'
        )
        update_json_file(mc_path, mc_banzhaf, random_state)

        print(f"Running MSR Banzhaf for S_size: {S_size}, random_state: {random_state}")

        msr_time = time.time()
        msr_banzhaf = MSRBanzhaf(features, S_size, imputer)()
        msr_time = time.time() - msr_time
        # print(msr_banzhaf)

        msr_path = os.path.join(
            result_dir, f'banzhaf_msr_{config["dataset"]}_{S_size}.json'
        )
        update_json_file(msr_path, msr_banzhaf, random_state)

        time_est = {
            "exact": exact_time,
            "kernel": kernel_time,
            "kernel_paired": kernel_paired_time,
            # "swor": kernel_paired_wor_time,
            "mc": mc_time,
            "msr": msr_time,
        }
        time_path = os.path.join(result_dir, f'time_{config["dataset"]}_{S_size}.json')

        update_json_file(time_path, time_est, random_state)


def noise_pipeline(config, data, explicand, model, target, random_state, result_dir, n):
    result_dir = os.path.join(result_dir, "noise")
    if not os.path.exists(result_dir):
        os.makedirs(result_dir)

    S_size = 80 * n if n > 10 else 20 * n

    for noise in config["noise"]:
        if target is None:
            features = data.columns.tolist()
        else:
            features = data.drop(columns=[target]).columns.tolist()
        if n < 50:
            imputer = NoiseImputer(data, explicand, features, model, noise)
        else:
            imputer = NoiseTreeImputer(explicand, features, model, noise)

        kernel_banzhaf = KernelBanzhaf(features, S_size, imputer)()

        kernel_path = os.path.join(
            result_dir, f'banzhaf_kernel_{config["dataset"]}_{noise}.json'
        )
        update_json_file(kernel_path, kernel_banzhaf, random_state)

        kernel_paired_banzhaf, _ = KernelPairedBanzhaf(features, S_size, imputer)()

        kernel_paired_path = os.path.join(
            result_dir, f'banzhaf_paired_sampling_{config["dataset"]}_{noise}.json'
        )
        update_json_file(kernel_paired_path, kernel_paired_banzhaf, random_state)

        # kernel_paired_wor_banzhaf = KernelPairedWORBanzhaf(features, S_size, imputer)()

        # kernel_paired_wor_path = os.path.join(
        #     result_dir, f'banzhaf_swor_{config["dataset"]}_{noise}.json'
        # )
        # update_json_file(kernel_paired_wor_path, kernel_paired_wor_banzhaf, random_state)

        mc_banzhaf = MCBanzhaf(features, S_size, imputer)()
        print(mc_banzhaf)

        mc_path = os.path.join(
            result_dir, f'banzhaf_mc_{config["dataset"]}_{noise}.json'
        )
        update_json_file(mc_path, mc_banzhaf, random_state)

        msr_banzhaf = MSRBanzhaf(features, S_size, imputer)()
        print(msr_banzhaf)

        msr_path = os.path.join(
            result_dir, f'banzhaf_msr_{config["dataset"]}_{noise}.json'
        )
        update_json_file(msr_path, msr_banzhaf, random_state)


def main():
    parser = argparse.ArgumentParser(description="Process some datasets.")
    parser.add_argument(
        "--dataset",
        type=str,
        default=None,
        help="Dataset to use: diabetes, adult, bank, german_credit, nhanes, brca, communitiesandcrime, tuandromd",
    )
    parser.add_argument(
        "--model_type", type=str, default=None, help="Model to use: xgb, autoxgb, nn"
    )
    parser.add_argument("--noise", action="store_true", help="Add noise")
    parser.add_argument("--time", action="store_true", help="Plot time")
    parser.add_argument("--feature", action="store_true", help="Plot by feature")
    parser.add_argument("--swor", action="store_true", help="Use SWOR")
    parser.add_argument("--plot_only", action="store_true", help="Plot only")

    args = parser.parse_args()
    dataset = args.dataset
    model_type = args.model_type

    if dataset is None:
        n = 0
    else:
        tables_path, target = get_tables_path(dataset)
        base_table = (
            pd.read_csv(tables_path, dtype=str)
            if model_type != "nn"
            else pd.read_csv(tables_path)
        )
        n = len(base_table.columns) if target is None else len(base_table.columns) - 1
        print(f"Dataset: {dataset}, Number of features: {n}")

    if model_type == None:
        model_type = "autoxgb" if n < 50 else "xgb"

    S_size = [5 * n, 10 * n, 20 * n, 40 * n]
    if dataset not in ["diabetes"]:
        S_size += [80 * n, 160 * n]
    if dataset not in ["diabetes", "tuandromd"]:
        S_size += [320 * n]
    if dataset not in ["diabetes", "brca", "communitiesandcrime"]:
        S_size += [640 * n]
    if dataset not in ["diabetes", "brca", "communitiesandcrime", "tuandromd"]:
        S_size += [1280 * n]
    if dataset in ["bank", "german_credit"]:
        S_size += [2560 * n, 5120 * n]
    print(f"S sizes: {S_size}")

    estimators = ["paired_sampling", "kernel", "mc", "msr"]
    plot_error = ["l2", "objective"] if n < 50 else ["l2"]

    random_states = np.random.RandomState(42).choice(1000, 50, replace=False)
    random_states = [int(i) for i in random_states]
    print(f"Random states: {random_states}")

    noise = [0, 0.0001, 0.0005, 0.001, 0.005, 0.01, 0.05] if args.noise else None

    config = {
        "model_type": model_type,
        "dataset": dataset,
        "estimators": estimators,
        "data_size": 50,
        "S_size": S_size,
        "random_states": random_states,
        "noise": noise,
        "plot_error": plot_error,
        "plot_all_datasets": False,
        "set_function": False,
        "plot_by_feature": False,
    }

    if config["model_type"] == "nn":
        result_dir = f"nn_results_{config['dataset']}"
        if not os.path.exists(result_dir) and dataset is not None:
            os.makedirs(result_dir)
    elif config["model_type"] == "xgb" or config["model_type"] == "autoxgb":
        result_dir = f"results_{config['dataset']}"
        if not os.path.exists(result_dir) and dataset is not None:
            os.makedirs(result_dir)

    if config["plot_all_datasets"]:
        dataset_dict = {
            "l2": (
                [
                    "diabetes",
                    "adult",
                    "bank",
                    "german_credit",
                    "nhanes",
                    "brca",
                    "communitiesandcrime",
                    "tuandromd",
                ]
                if config["model_type"] != "nn"
                else ["diabetes", "adult", "bank", "german_credit"]
            ),
            "objective": ["diabetes", "adult", "bank", "german_credit"],
        }
        if args.swor:
            estimators = ["paired_sampling", "swor"]
            plot_all_datasets(
                dataset_dict["l2"],
                estimators,
                "l2",
                config["noise"],
                config["model_type"],
            )
            return
        if args.feature:
            print("Plotting by feature")
            estimators = ["mc", "msr", "paired_sampling"]
            plot_by_feature(dataset_dict["l2"], estimators)
            return
        if args.time:
            plot_all_time(dataset_dict["l2"], estimators)
            return
        for error in config["plot_error"]:
            plot_all_datasets(
                dataset_dict[error],
                estimators,
                error,
                config["noise"],
                config["model_type"],
            )
        return

    if config["set_function"]:
        print("Plotting by set function")
        for error in config["plot_error"]:
            plot_by_set_function(result_dir, estimators, error, config["dataset"])
        return

    if not args.plot_only:
        if config["model_type"] == "autoxgb":
            model = AutoXGB(target)
            train_data, test_data = model.load_data(base_table)
            model.train_model(train_data)
            accuracy = model.evaluate_model(test_data)
            print(f"Accuracy: {accuracy}")

        elif config["model_type"] == "xgb":
            model = (
                xgb.XGBClassifier()
                if dataset in ["brca", "nhanes", "communitiesandcrime"]
                else xgb.XGBRegressor()
            )
            model.load_model(f"xgb_{config['dataset']}.json")
            model.get_booster().get_dump(
                f"bst_{config['dataset']}.json", with_stats=True
            )

        elif config["model_type"] == "nn":
            model = NN(target)
            model_path = f"nn_{config['dataset']}.pth"
            train_data, val_data, test_data = model.load_data(base_table)
            if os.path.exists(model_path):
                model.load_model(model_path)
            else:
                model.train_model(train_data, val_data, model_path)
                accuracy = model.evaluate_model(test_data)
                print(f"Accuracy: {accuracy}")
                model.save_model(model_path)

        for random_state in config["random_states"]:
            data, explicand = get_data_and_explicand(
                config["data_size"], base_table, random_state=random_state
            )

            if config["noise"] is None:
                raw_pipeline(
                    config, data, explicand, model, target, random_state, result_dir,
                )
            else:
                noise_pipeline(
                    config, data, explicand, model, target, random_state, result_dir, n
                )

    if config["plot_error"] is not None:
        for plot_error in config["plot_error"]:
            if config["noise"] is not None:
                print("Plotting by sample size with noise")
                plot_by_noise(result_dir, estimators, plot_error, config["dataset"])
            else:
                print("Plotting by sample size")
                plot_by_sample_size_median(
                    result_dir, estimators, plot_error, config["dataset"],
                )

    if config["plot_by_feature"]:
        print("Plotting by feature")
        plot_by_feature(result_dir, estimators, config["dataset"])


if __name__ == "__main__":
    main()
