"""
Main script to run the experiment with specified parameters.
This script handles argument parsing, dataset loading, model initialization,
and running the experiment.
"""

import sys
import argparse
import logging

from module.utils import Hparams
from module.experiment import Experiment
from module.metrics import Metrics
from module.datasets import DatasetLoader

sys.path.append("/usr/pkg/gurobi/lib/python3.10_utf32/")


def main(args):
    """ Main function to run the experiment.
    Args:
        args (argparse): argparse arguments
    """
    hparams = Hparams(args)
    metrics = Metrics(args.metrics_config, hparams.rng)

    ## Dataset
    #TODO Hard coded assume sensitive feature are at column 1
#     fairness_mode = 1 if args.metrics_fairness else None
    fairness_mode = None
    dataset = DatasetLoader(args.dataset, hparams.rs, fairness_mode=fairness_mode)

    ## Encoder
    if args.encoder == "threshold_guess":
        hparams.thres_guess_args_init(args)

    # region Models
    if args.model_parser == "RSET":
        hparams.rset_params_args_init(args)
    elif args.model_parser == "CART":
        hparams.cart_tree_params_args_init(args)
    elif args.model_parser == "GROOT":
        hparams.groot_params_args_init(args, dataset.X.shape[1])
    elif args.model_parser == "ROCT-N":
        hparams.roctn_params_args_init(args)
    elif args.model_parser == "ROCT-V":
        hparams.roctv_params_args_init(args, dataset.X.shape[1])
    elif args.model_parser == "DPLDT":
        hparams.dpldt_params_args_init(args)
    elif args.model_parser == "FPRDT":
        hparams.fprdt_params_args_init(args)
    elif args.model_parser == "PRIVA":
        hparams.priva_params_args_init(args, dataset.X.shape[1])
    elif args.model_parser == "BDPT":
        hparams.bdpt_params_args_init(args, dataset.X.shape[1])
    else:
        assert False, f"Invalid model type: {args.model_type}"
    # endregion

    experiment = Experiment(dataset, hparams, metrics)
    experiment.cross_validate(fold=args.fold)


def confirm_arguments(args):
    """Function to confirm the arguments before running the experiment.

    Args:
        args (argparse): argparse arguments
    """
    print("\nArguments:")

    model_arg = "".join(args.model_parser.split("-")).lower()

    # Global arguments (excluding model-specific ones)
    print("Global Arguments:")
    global_args = {
        key: value
        for key, value in vars(args).items()
        if not key.startswith(model_arg + "_")
        and key not in ["model_parser", "command", "main_command"]
    }
    for key, value in global_args.items():
        print(f"  {key}: {value}")

    # Model-specific arguments
    if (hasattr(args, "model_parser") and args.model_parser):
        # Check if a subparser was used
        print(f"\nModel-Specific Arguments ({args.model_parser}):")
        model_args = {
            key: value
            for key, value in vars(args).items()
            if key.startswith(model_arg + "_")}
        for key, value in model_args.items():
            print(f"  {key}: {value}")

    confirm = input("\nDo you want to proceed with these settings? (yes/no): ")
    if confirm.lower() not in ["yes", "y"]:
        print("Exiting...")
        sys.exit()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train model with specified parameters.")
    parser.add_argument("--skip_confirm", action="store_true", help="Skip confirmation for args")

    #region Input/Output Section
    io_group = parser.add_argument_group("Input/Output", "Arguments related to file paths")
    io_group.add_argument("--output_dir", type=str, default="out/test",
            help="Directory for saved models and results")
    io_group.add_argument("--model_dir", type=str, default="model",
            help="Directory for saved models")
    io_group.add_argument("--param_dir", type=str, default="param", help="Directory for parameters")
    io_group.add_argument("--result_dir", type=str,
            default="result", help="Directory for saved results")
    #endregion

    #region Experiment Setup Section
    exp_group = parser.add_argument_group("Experiment", "Arguments related to experiment setup")
    exp_group.add_argument("--dataset", type=str, default="fico", help="Dataset name")
    exp_group.add_argument("--random_state", type=int, default=42, help="Random state")
    exp_group.add_argument("--tune", action="store_true", help="Tune the model hyperparameters")
    exp_group.add_argument("--selection", action="store_true", help="Use selection set")
    exp_group.add_argument("--k_folds", type=int, default=5, help="Number of folds for cross-val")
    exp_group.add_argument("--fold", type=int, help="Fold index for parallization. \
            If this is not provided, then all five folds will run.")
    exp_group.add_argument("--retrain", action="store_true",
            help="Retrain the model even if it exists in the model directory")
    #endregion

    #region Model Parameters Section
    model_parser = parser.add_subparsers(dest="model_parser", title="Model Parameters",
            description="Arguments related to the model", required=True)

    #region RSET (TreeFARMS)
    rset_group = model_parser.add_parser("RSET", help="TreeFARMS Parameters")
    rset_group.add_argument("--rset_lamb", type=float, default=0.01,
            help="Regularization parameter (TreeFARMS)")
    rset_group.add_argument("--rset_depth_budget", type=int, default=5,
            help="Depth budget of the tree (Max depth + 1)")
    rset_group.add_argument("--rset_eps", type=float, default=0.05,
            help="Epsilon parameter (Rashomon Set Adder Bound)")
    #endregion

    #region CART (Sklearn Decision Tree)
    cart_group = model_parser.add_parser("CART", help="CART Parameters")
    cart_group.add_argument("--cart_max_depth", type=int, default=4, help="Max depth")
    cart_group.add_argument("--cart_min_samples_split", type=int, default=10,
            help="Minimum samples on each split")
    cart_group.add_argument("--cart_min_samples_leaf", type=int, default=5,
            help="Minimum samples on each leaf")
    #endregion

    #region ROBUST TREES
    #region GROOT
    groot_group = model_parser.add_parser("GROOT", help="GROOT Parameters")
    groot_group.add_argument("--groot_epsilon", type=float, default=0.2,
            help="Expected perturbation of each features. (Attack capabilities)")
    groot_group.add_argument("--groot_max_depth", type=int, default=4, help="Max depth of the tree")
    groot_group.add_argument("--groot_min_samples_split", type=int, default=10,
            help="Minimum samples on each split")
    groot_group.add_argument("--groot_min_samples_leaf", type=int, default=5,
            help="Minimum samples on each leaf")
    #endregion

    #region ROCT-N (MILP for Stability)
    roctn_group = model_parser.add_parser("ROCT-N", help="RobustOCT-Nathan Parameters")
    roctn_group.add_argument("--roctn_lambda", type=float, default=0.9,
            help="Lambda value of robust OCT (1 is non-robust)")
    roctn_group.add_argument("--roctn_depth", type=float, default=2,
            help="Depth of the robust tree (not max depth).")
    roctn_group.add_argument("--roctn_time_limit", type=int, default=1800,
            help="Time limit for the robust OCT model")
    #endregion

    #region ROCT-V
    roctv_group = model_parser.add_parser("ROCT-V", help="RobustOCT-Vos Parameters")
    roctv_group.add_argument("--roctv_max_depth", type=int, default=2, help="Max depth")
    roctv_group.add_argument("--roctv_epsilon", type=float, default=0.1,
            help="Epsilon value for the robust tree")
    roctv_group.add_argument("--roctv_time_limit", type=int, default=1800,
            help="Time limit for the robust OCT model")
    #endregion

    #region FPRDT (Fast Provably Robust Decision Tree)
    fprdt_group = model_parser.add_parser("FPRDT", help="FPRDT Parameters")
    fprdt_group.add_argument("--fprdt_max_depth", type=int, default=4, help="Max depth")
    fprdt_group.add_argument("--fprdt_min_samples_split", type=int, default=10,
            help="Minimum samples on each split")
    fprdt_group.add_argument( "--fprdt_min_samples_leaf", type=int, default=5,
            help="Minimum samples on each leaf")
    fprdt_group.add_argument("--fprdt_epsilon", type=float, default=0.1, help="Epsilon value")
    #endregion

    #endregion

    #region PRIVATE TREES

    #region DPLDT (DiffPrivLib Decision Tree)
    dpldt_group = model_parser.add_parser("DPLDT", help="DPLDT Parameters")
    dpldt_group.add_argument("--dpldt_max_depth", type=int, default=4, help="Max depth")
    dpldt_group.add_argument("--dpldt_epsilon", type=float, default=0.1,
            help="Epsilon value for the DP tree")
    #endregion

    #region PRIVA (PrivaTree)
    priva_group = model_parser.add_parser("PRIVA", help="PrivaTree Parameters")
    priva_group.add_argument("--priva_max_depth", type=int, default=4, help="Max depth")
    priva_group.add_argument("--priva_epsilon", type=float, default=0.1,
            help="Epsilon value for the DP tree")
    priva_group.add_argument("--priva_min_samples_split", type=int, default=10,
            help="Minimum samples on each split")
    priva_group.add_argument( "--priva_min_samples_leaf", type=int, default=5,
            help="Minimum samples on each leaf")
    #endregion

    #region BDPT
    bdpt_group = model_parser.add_parser("BDPT", help="BDPT Parameters")
    bdpt_group.add_argument("--bdpt_max_depth", type=int, default=4, help="Max depth")
    bdpt_group.add_argument("--bdpt_min_samples_split", type=int, default=10,
            help="Minimum samples on each split")
    bdpt_group.add_argument("--bdpt_epsilon", type=float, default=0.1,
            help="Epsilon value for the DP tree")
    #endregion

    #endregion

    #endregion

    #region Encoder Parameters Section
    encoder_group = parser.add_argument_group("Threshold Guess Parameters",
            "Arguments related to the threshold guessing configuration")
    encoder_group.add_argument("--encoder", type=str, default=None,
            help="Encoder for binarizing dataset")
    encoder_group.add_argument("--enc_max_depth", type=int, default=2,
            help="Max depth for the threshold guessing encoder model")
    encoder_group.add_argument("--enc_n_estimators", type=int, default=30,
            help="Number of estimators for the threshold guessing encoder model")
    encoder_group.add_argument("--enc_lr", type=float, default=0.1,
            help="Learning rate for the threshold guessing encoder model")
    #endregion

    #region Metrics Pipeline Section
    metrics_group = parser.add_argument_group("Metrics Pipeline",
            "Arguments related to the metrics")
    metrics_group.add_argument("--reset_results", action="store_true",
            help="Reset the results directory before running the experiment")
    metrics_group.add_argument("--metrics_config", type=str,
            default="config/metrics.yaml", help="Metrics configuration file")
    #endregion

    # Logging Parameters Section
    logging_group = parser.add_argument_group("Logging Parameters", "Logging Arguments")
    logging_group.add_argument("--log_level", type=str, default="INFO", help="Logging level")
    logging_group.add_argument("--log_file", type=str, default="result.log", help="Log file path")

    parse_args = parser.parse_args()

    # Confirm arguments before running
    if not parse_args.skip_confirm:
        confirm_arguments(parse_args)

    logging.basicConfig(
        level=logging.getLevelName(parse_args.log_level),
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        handlers=[
            logging.FileHandler(parse_args.log_file),
            logging.StreamHandler(),
        ],
    )

    main(parse_args)
