import argparse
from argparse import RawTextHelpFormatter

from rich.console import Console

from . import benchmark, compare
from .utils import load_config


def get_CLI_parser() -> argparse.ArgumentParser:
    """
    Get the argument parser for the CLI. There are two commands available:

    - benchmark: Benchmark model performance
    - compare: Compare model performance

    Benchmark command arguments:

    - data: Path to the dataset to train and evaluate models on. This should be a CSV, pickle, or parquet file.
    - name: Name of the experiment to save the results to. Will be used to load cached results if they exist. Default: `data` file name without extension.
    - features: Name of the column containing the features. Default: Features.
    - target: Name of the column containing the target. Default: Target.
    - run_nested_CV: Whether to run nested CV with hyperparameter tuning for the best models. Default: False.
    - use_optuna: Whether to use Optuna for hyperparameter optimization. If not set, GridSearchCV from scikit-learn will be used. Default: False.
    - n_trials: Number of trials for Optuna hyperparameter search. Default: 100.
    - timeout: Time limit (in seconds) for hyperparameter search. Default: 3600.
    - fold_col: Name(s) of the column(s) containing the CV fold number(s). If a list is provided, models will be benchmarked in an nxk-fold CV, where n is the number of repeats and k is the number of folds. If a single string is provided, it will be treated as a single fold column. nxk-fold CV does not currently support nested CV and final hyperparameter tuning. Default: Fold.
    - main_metric: Main metric to use for model selection. This will be used to infer the prediction task (classification or regression). Default: R2.
    - sec_metrics: Secondary metrics to use for model selection. Default: MSE MAE.
    - parametric: Whether to use parametric statistical tests for model comparison.
    - impute: Method to use for imputing missing values. If None, no imputation will be performed. Valid choices are 'mean', 'median', 'knn', or a float or int value for constant imputation.
    - remove_constant: If specified, features with variance below this threshold will be removed. If None, no features are removed.
    - remove_correlated: If specified, features with correlation above this threshold will be removed. If None, no features are removed.
    - scaler: Type of scaler to use, if the data is to be scaled first. Valid choices are 'Standard' and 'MinMax'. Default: None.
    - n_jobs: Number of jobs to run in parallel for hyperparameter tuning. Default: 1.

    Compare command arguments:

    - CV_results_path: Path to the directory containing the CV results
    - main_metric: The main metric to use for comparison
    - sec_metrics: Secondary metrics to use for comparison
    - parametric: Whether to use parametric statistical tests for model comparison.

    Returns
    -------
    argparse.ArgumentParser
        Argument parser for the CLI.
    """
    parser = argparse.ArgumentParser(
        description="ASTRA - Automated model selection using statistical testing"
    )
    subparsers = parser.add_subparsers(dest="command", help="Available commands")

    benchmark_parser = subparsers.add_parser(
        "benchmark",
        help="Benchmark model performance",
        formatter_class=RawTextHelpFormatter,
    )
    group = benchmark_parser.add_mutually_exclusive_group(required=True)
    group.add_argument(
        "data",
        nargs="?",
        type=str,
        help="Path to the dataset to train and evaluate models on.",
    )
    group.add_argument(
        "--config",
        type=str,
        help="Path to a YAML config file. Must at least contain a 'data' key.",
    )
    benchmark_parser.add_argument(
        "--name",
        type=str,
        help="Name of the experiment. Results will be saved in a folder with this name\n"
        "in the 'results' directory. Will be used to load cached results if they exist. \n"
        "Default: `data` file name without extension.",
        default=None,
    )
    benchmark_parser.add_argument(
        "--features",
        type=str,
        default="Features",
        help="Name of the column containing the features. Default: Features.",
    )
    benchmark_parser.add_argument(
        "--target",
        type=str,
        default="Target",
        help="Name of the column containing the target. Default: Target.",
    )
    benchmark_parser.add_argument(
        "--run_nested_CV",
        action="store_true",
        default=False,
        help="Whether to run nested CV with hyperparameter tuning for the best\n"
        "models. Default: False.",
    )
    benchmark_parser.add_argument(
        "--use_optuna",
        action="store_true",
        default=False,
        help="Whether to use Optuna for hyperparameter optimization.\n"
        "If not set, GridSearchCV from scikit-learn will be used.\n"
        "Default: False.",
    )
    benchmark_parser.add_argument(
        "--n_trials",
        type=int,
        default=100,
        help="Number of trials for Optuna hyperparameter search. Default: 100.",
    )
    benchmark_parser.add_argument(
        "--timeout",
        type=int,
        default=3600,
        help="Time limit (in seconds) for Optuna hyperparameter search. Default: 3600.",
    )
    benchmark_parser.add_argument(
        "--fold_col",
        type=str,
        nargs="+",
        default="Fold",
        help="Name(s) of the column(s) containing the CV fold number(s).\n"
        "If a list is provided, models will be benchmarked in an nxk-fold CV,\n"
        "where n is the number of repeats and k is the number of folds.\n"
        "If a single string is provided, it will be treated as a single fold column.\n"
        "Default: Fold.",
    )
    benchmark_parser.add_argument(
        "--main_metric",
        type=str,
        default="R2",
        help="Main metric to use for model selection. This will be used to infer the\n"
        "prediction task (classification or regression). Default: R2.",
    )
    benchmark_parser.add_argument(
        "--sec_metrics",
        type=str,
        nargs="+",
        default=["MSE", "MAE"],
        help="Secondary metrics to use for model selection. Default: MSE MAE.",
    )
    benchmark_parser.add_argument(
        "--parametric",
        type=str,
        choices=["True", "False", "auto"],
        default="auto",
        help="Whether to use parametric statistical tests for model comparison.\n"
        "If 'auto' (default), the assumptions of parametric tests will be checked,\n"
        "and parametric tests will be used if the assumptions are met.",
    )
    benchmark_parser.add_argument(
        "--impute",
        type=str,
        default=None,
        help="Method to use for imputing missing values. If None, no imputation will be performed.\n"
        "Valid choices are 'mean', 'median', 'knn', or a float or int value for constant imputation.",
    )
    benchmark_parser.add_argument(
        "--remove_constant",
        type=float,
        default=None,
        help="If specified, features with variance below this threshold will be removed.\n"
        "If None, no features are removed.",
    )
    benchmark_parser.add_argument(
        "--remove_correlated",
        type=float,
        default=None,
        help="If specified, features with correlation above this threshold will be removed.\n"
        "If None, no features are removed.",
    )
    benchmark_parser.add_argument(
        "--scaler",
        type=str,
        default=None,
        help="Type of scaler to use, if the data is to be scaled first.\n"
        "Valid choices are 'Standard' and 'MinMax'. Default: None.",
    )
    benchmark_parser.add_argument(
        "--n_jobs",
        type=int,
        default=1,
        help="Number of jobs to run in parallel for hyperparameter tuning. Default: 1.",
    )

    compare_parser = subparsers.add_parser("compare", help="Compare model performance")
    compare_parser.add_argument(
        "CV_results",
        type=str,
        nargs="+",
        help="Path to a single directory containing CV results, or a list of directories\n"
        "containing CV results. CV results should be pickled dictionaries with metrics as\n"
        "keys and lists of scores as values, for example, `final_CV.pkl` returned by\n"
        "astra benchmark, and ending with `final_CV.pkl`. The model name will be the parent\n"
        "directory if passing a list of paths, or the file name (minus the `final_CV.pkl`)\n"
        "if passing a single directory.",
    )
    compare_parser.add_argument(
        "--main_metric",
        type=str,
        required=True,
        help="The main metric to use for comparison",
    )
    compare_parser.add_argument(
        "--sec_metrics",
        type=str,
        nargs="+",
        help="Secondary metrics to use for comparison",
    )
    compare_parser.add_argument(
        "--parametric",
        type=str,
        choices=["True", "False", "auto"],
        default="auto",
        help="Whether to use parametric statistical tests for model comparison.\n"
        "If 'auto' (default), the assumptions of parametric tests will be checked,\n"
        "and parametric tests will be used if the assumptions are met.",
    )

    return parser


def main() -> int:
    """
    Main function for the CLI. Parses the arguments and runs the appropriate command.

    Returns
    -------
    int
        Exit code. 0 if successful, 1 if an error occurred.
    """
    console = Console()
    console.print(
        "\n[bold blue]:wave: Welcome to ASTRA - Automated model selection using statistical testing[/bold blue]",
        justify="center",
    )
    astra_string = """
    ----------------------------------------------
    _
      __ _  ___ | |_  _ __   __ _
     / _` |/ __|| __|| '__| / _` |
    | (_| |\__ \| |_ | |   | (_| |
     \__,_||___/ \__||_|    \__,_|

    ----------------------------------------------
        """
    console.print(
        astra_string,
        style="bold blue",
        justify="center",
    )
    console.print(
        "[bold cyan]:thinking_face: For help, run: astra --help[/bold cyan]",
        justify="center",
    )
    console.print(
        "[bold cyan]:test_tube: To benchmark models, run: astra benchmark --help[/bold cyan]",
        justify="center",
    )
    console.print(
        "[bold cyan]:trophy: To compare models, run: astra compare --help[/bold cyan]",
        justify="center",
    )

    parser = get_CLI_parser()
    args = parser.parse_args()

    if args.command == "benchmark":
        if isinstance(args.fold_col, list) and len(args.fold_col) == 1:
            args.fold_col = args.fold_col[0]

        if args.config:
            config = load_config(args.config)

            # Override CLI arguments with config values
            for key, value in config.items():
                setattr(args, key, value)

            if args.data is None:
                raise ValueError("The config file must include a 'data' field.")

            # Custom model settings
            if "models" in config:
                args.models = {}
                for model in config["models"]:
                    model_name = model["name"]
                    model_params = model.get("params", None)
                    hparam_grid = model.get("hparam_grid", None)
                    args.models[model_name] = {
                        "params": model_params,
                        "hparam_grid": hparam_grid,
                    }

        args.main_metric = args.main_metric.lower()
        args.sec_metrics = [metric.lower() for metric in args.sec_metrics]

        if args.parametric == "True":
            args.parametric = True
        if args.parametric == "False":
            args.parametric = False

        benchmark.run(
            data=args.data,
            name=args.name,
            features=args.features,
            target=args.target,
            run_nested_CV=args.run_nested_CV,
            use_optuna=args.use_optuna,
            n_trials=args.n_trials,
            timeout=args.timeout,
            fold_col=args.fold_col,
            main_metric=args.main_metric,
            sec_metrics=args.sec_metrics,
            parametric=args.parametric,
            impute=args.impute,
            remove_constant=args.remove_constant,
            remove_correlated=args.remove_correlated,
            scaler=args.scaler,
            custom_models=args.models if hasattr(args, "models") else None,
            n_jobs=args.n_jobs,
        )

    elif args.command == "compare":
        args.main_metric = args.main_metric.lower()
        args.sec_metrics = [metric.lower() for metric in args.sec_metrics]

        if args.parametric == "True":
            args.parametric = True
        if args.parametric == "False":
            args.parametric = False

        compare.run(
            CV_results=args.CV_results,
            main_metric=args.main_metric,
            sec_metrics=args.sec_metrics,
            parametric=args.parametric,
        )

    else:
        parser.print_help()
        return 1

    return 0


if __name__ == "__main__":
    exit(main())
