import numpy as np
import os
import argparse
import subprocess
import sys

# Define global experiment settings
METHODS = [
    "cg_boost",
    "erlp_boost",
    "lp_boost",
    "md_boost",
    "mlp_boost",
    "neg_margins",
    # 'adaboost',
    # 'xgboost',
    # 'lightgbm'
]
DATASETS = [
    "adult",
    "banana",
    "breast_cancer",
    "compas_propublica",
    "diabetes",
    "employment_CA2018",
    "employment_TX2018",
    "german_credit",
    "heart",
    "image",
    "mushroom_secondary",
    "public_coverage_CA2018",
    "public_coverage_TX2018",
    "ringnorm",
    "solar_flare",
    "splice",
    "thyroid",
    "titanic",
    "twonorm",
    "waveform",
]
SEEDS = [1, 2, 3, 4, 5]


# Define per-method hyperparameter ranges
def get_hyperparam_ranges(N):
    return {
        "cg_boost": (1e-4, pow(10, -1 / 3)),
        "erlp_boost": (1.0, 0.06 * N),  # 0.1*0.6 (train/test split)
        "lp_boost": (1e-4, pow(10, -1 / 3)),
        "md_boost": (1.0, 120.0),
        "qrlp_boost": (1.0, 0.06 * N),  # 0.1*0.6 (train/test split)
        "neg_margins": (1e-4, pow(10, -1 / 3)),
        "adaboost": (1e-3, 1.0),
        "xgboost": (1e-3, 1.0),
        "lightgbm": (1e-3, 1.0),
    }


def get_dataset_size(dataset):
    dataset_sizes = {
        "adult": 48842,
        "banana": 5300,
        "breast_cancer": 263,
        "compas_propublica": 7206,
        "diabetes": 768,
        "employment_CA2018": 75660,
        "employment_TX2018": 52089,
        "german_credit": 1000,
        "heart": 270,
        "image": 2086,
        "mushroom_secondary": 61069,
        "public_coverage_CA2018": 34638,
        "public_coverage_TX2018": 24732,
        "ringnorm": 7400,
        "solar_flare": 144,
        "splice": 2991,
        "thyroid": 215,
        "titanic": 887,
        "twonorm": 7400,
        "waveform": 5000,
    }
    return dataset_sizes.get(dataset, 1000)


NUM_HYPERPARAMS = 10  # Number of hyperparameter values per method

# Generate hyperparameter values for each method
HYPERPARAM_VALUES = {}

for dataset in DATASETS:
    N = get_dataset_size(dataset)  # Function to determine dataset size
    hyperparam_ranges = get_hyperparam_ranges(N)

    for method, (min_val, max_val) in hyperparam_ranges.items():
        if method not in HYPERPARAM_VALUES:
            HYPERPARAM_VALUES[method] = {}

        HYPERPARAM_VALUES[method][dataset] = np.linspace(
            min_val, max_val, NUM_HYPERPARAMS
        ).tolist()


def map_job_id(job_id):
    """Map SLURM job_id to a unique (method, dataset, seed, hyperparam) combination."""

    total_experiments = sum(
        len(HYPERPARAM_VALUES[method][dataset]) * len(SEEDS)
        for method in METHODS
        for dataset in DATASETS
    )

    if job_id >= total_experiments:
        raise ValueError(
            f"Job ID {job_id} out of range (0-{total_experiments - 1})."
        )

    job_counter = job_id

    # Find the correct method and dataset
    for method in METHODS:
        for dataset in DATASETS:
            if method not in HYPERPARAM_VALUES:
                continue  # Skip if method is not present in dictionary

            num_hyper = len(HYPERPARAM_VALUES[method][dataset])
            num_experiments = num_hyper * len(SEEDS)

            if job_counter < num_experiments:
                break
            job_counter -= num_experiments
        else:
            continue
        break

    num_hyper = (
        len(HYPERPARAM_VALUES[method][dataset])
        if method in HYPERPARAM_VALUES
        else 1
    )

    # Correctly cycle through datasets and seeds
    seed_idx = job_counter // num_hyper
    hyperparam_idx = job_counter % num_hyper

    return (
        method,
        dataset,
        SEEDS[seed_idx],
        HYPERPARAM_VALUES[method][dataset][hyperparam_idx],
    )


def str2bool(text):
    text = text.lower()  # Convert input to lowercase for flexibility
    if text in ("true", "t", "1", "yes"):
        return True
    elif text in ("false", "f", "0", "no"):
        return False
    else:
        raise argparse.ArgumentTypeError(
            "Boolean value expected (True/False, 1/0, Yes/No)."
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Run ML experiments with SLURM array."
    )
    parser.add_argument(
        "--expe_id",
        default=0,
        type=int,
        required=False,
        help="Experiment ID (used for SLURM array jobs)",
    )
    parser.add_argument(
        "--tree_type",
        default="CART",
        type=str,
        help="The type of tree to use",
        choices=["CART", "blossom"],
    )
    parser.add_argument(
        "--max_depth",
        type=int,
        default=1,
        help="Maximum depth of the decision trees",
    )
    parser.add_argument(
        "--crb",
        default="False",
        type=str2bool,
        help="To use confidence rated boosted trees",
    )
    parser.add_argument(
        "--run_benchmark",
        default="False",
        type=str2bool,
        help="To use benchmark",
    )
    parser.add_argument(
        "--run_single_shot",
        default="True",
        type=str2bool,
        help="To use benchmark",
    )
    args = parser.parse_args()

    # Retrieve experiment parameters from job ID
    solver, dataset, seed, hyperparam = map_job_id(args.expe_id)

    # Run experiment
    print(
        f"Running experiment {args.expe_id}: solver={solver}, dataset={dataset}, seed={seed}, hyperparam={hyperparam}, tree type {args.tree_type}"
    )

    # Call `main.py` with extracted parameters
    if args.run_benchmark == True:
        python_path = sys.executable
        command = [
            python_path,
            "main.py",
            "--dataset",
            dataset,
            "--benchmark",  # difference
            solver,
            "--seed",
            str(seed),
            "--hyperparam",
            str(hyperparam),
            "--tree_type",
            str(args.tree_type),
            "--max_depth",
            str(args.max_depth),
            "--crb",
            str(args.crb),
            "--run_benchmark",
            "True",
        ]
    elif args.run_single_shot == True:
        python_path = sys.executable
        command = [
            python_path,
            "main.py",
            "--dataset",
            dataset,
            "--solver",
            solver,
            "--seed",
            str(seed),
            "--hyperparam",
            str(hyperparam),
            "--tree_type",
            str(args.tree_type),
            "--max_depth",
            str(args.max_depth),
            "--crb",
            str(args.crb),
            "--run_benchmark",
            "False",
            "--run_single_shot",
            "True",
        ]
    else:
        python_path = sys.executable
        command = [
            python_path,
            "main.py",
            "--dataset",
            dataset,
            "--solver",
            solver,
            "--seed",
            str(seed),
            "--hyperparam",
            str(hyperparam),
            "--tree_type",
            str(args.tree_type),
            "--max_depth",
            str(args.max_depth),
            "--crb",
            str(args.crb),
        ]

    # Run the command safely
    try:
        subprocess.run(command, check=True, env=os.environ)
    except subprocess.CalledProcessError as e:
        print(f"Error: main.py failed with exit code {e.returncode}")
