#!/usr/bin/env python3
"""
Generate a JSON configuration file for cluster execution.

This script creates a configuration file that lists all dataset x model combinations
that should be run. Each combination will be executed as a separate SLURM job.

Usage:
    python cluster_scripts/generate_cluster_config.py --datasets dataset1 dataset2 --models model1 model2 --output config.json
"""

import argparse
import json
import sys
from itertools import product
from pathlib import Path


def generate_config(
    datasets,
    models,
    optimization_config,
    cv_config,
    resources_config,
    hyperparameters,
    output_path,
):
    """
    Generate cluster configuration file.

    Args:
        datasets: List of dataset configurations (dicts with type, id, etc.)
        models: List of model names to run
        optimization_config: Optimization settings (n_trials, method, etc.)
        cv_config: Cross-validation settings
        resources_config: Resource settings (n_jobs, random_seed)
        hyperparameters: Optional dict of custom hyperparameter ranges per model.
                         If None, will use defaults from experiment_runner.models
        output_path: Path to save the JSON config file

    Returns:
        Path to the generated config file
    """
    # Generate all combinations
    combinations = []
    for dataset, model in product(datasets, models):
        combination = {
            "dataset": dataset,
            "model": model,
            "optimization": optimization_config,
            "cv": cv_config,
            "resources": resources_config,
        }
        combinations.append(combination)

    # Sort combinations by n×p (ascending), then by model name
    # This ensures consistent ordering with smallest datasets first
    def sort_key(combo):
        dataset = combo["dataset"]
        model_name = combo["model"]
        # Get n×p if available, otherwise use a large number to sort to end
        # Handle case where dataset might not be a dict (shouldn't happen, but defensive)
        if isinstance(dataset, dict):
            n_times_p = dataset.get("n_times_p", float("inf"))
        else:
            n_times_p = float("inf")
        return (n_times_p, model_name.lower())

    combinations.sort(key=sort_key)

    # Ensure we have hyperparameters (use defaults if not provided)
    if hyperparameters is None:
        # Import default models to get default hyperparameters
        import sys

        sys.path.insert(0, str(Path(__file__).parent.parent))
        from experiment_runner.models import default_models

        # Build hyperparameters dict from defaults for all requested models
        final_hyperparameters = {}
        for model_name in models:
            if model_name in default_models:
                estimator, param_distributions = default_models[model_name]
                # Convert tuple format to list format for JSON serialization
                converted_params = {}
                for param_name, param_spec in param_distributions.items():
                    if isinstance(param_spec, tuple):
                        # Convert tuple to list: ("randint", 1, 5) -> ["randint", 1, 5]
                        converted_params[param_name] = list(param_spec)
                    elif isinstance(param_spec, list):
                        # Already a list (categorical or fixed value)
                        converted_params[param_name] = param_spec
                    else:
                        # Single value - wrap in list
                        converted_params[param_name] = [param_spec]
                final_hyperparameters[model_name] = converted_params
                print(f"  Using default hyperparameters for {model_name}")
            else:
                print(
                    f"⚠️  Warning: Model {model_name} not found in default_models, skipping hyperparameters"
                )
    else:
        # Filter hyperparameters to only include specified models
        final_hyperparameters = {
            model_name: hyperparameters[model_name]
            for model_name in models
            if model_name in hyperparameters
        }

        # Warn about models in hyperparameters that aren't being used
        unused_models = set(hyperparameters.keys()) - set(models)
        if unused_models:
            print(
                f"⚠️  Warning: Hyperparameters provided for models not in the run list: {sorted(unused_models)}"
            )

        # Warn about models in run list without hyperparameters
        missing_models = set(models) - set(hyperparameters.keys())
        if missing_models:
            print(
                f"⚠️  Warning: No hyperparameters provided for models: {sorted(missing_models)}. They will use defaults at runtime."
            )

    # Create full config structure
    config = {
        "metadata": {
            "total_runs": len(combinations),
            "n_datasets": len(datasets),
            "n_models": len(models),
            "description": "Cluster execution configuration for hyperparameter tuning",
        },
        "global_config": {
            "optimization": optimization_config,
            "cv": cv_config,
            "resources": resources_config,
        },
        "runs": combinations,
        "hyperparameters": final_hyperparameters,
    }

    # Write to file
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "w") as f:
        json.dump(config, f, indent=2)

    print(f"✅ Generated configuration file: {output_path}")
    print(f"   Total runs: {len(combinations)}")
    print(f"   Datasets: {len(datasets)}")
    print(f"   Models: {len(models)}")

    return output_path


def parse_suite_spec(suite_spec):
    """
    Parse suite specification with optional indexing.

    Args:
        suite_spec: String like "269", "353[1-10]", or "353[23-]" (23 to end)

    Returns:
        Tuple of (suite_id, start_idx, end_idx) where indices are 1-based, None if all
    """
    import re

    # Match pattern like "353[1-10]" or "353[23-]" (end index optional)
    match = re.match(r"^(\d+)\[(\d+)-(\d+)?\]$", suite_spec)
    if match:
        suite_id = int(match.group(1))
        start_idx = int(match.group(2))
        end_idx_str = match.group(3)
        end_idx = int(end_idx_str) if end_idx_str else None
        return suite_id, start_idx, end_idx

    # Just a suite ID
    try:
        suite_id = int(suite_spec)
        return suite_id, None, None
    except ValueError:
        raise ValueError(
            f"Invalid suite specification: {suite_spec}. Use format '269', '353[1-10]', or '353[23-]'"
        )


def fetch_suite_tasks(suite_spec):
    """
    Fetch tasks from an OpenML benchmark suite, sorted by n*p (ascending).

    Args:
        suite_spec: Suite specification (e.g., "269", "353[1-10]", or "353[23-]")

    Returns:
        List of dataset configurations for each task in the suite
    """
    try:
        import openml
    except ImportError:
        print("❌ Error: openml package is required to fetch benchmark suites")
        print("   Install with: pip install openml")
        sys.exit(1)

    try:
        # Parse suite specification
        suite_id, start_idx, end_idx = parse_suite_spec(suite_spec)

        print(f"Fetching tasks from OpenML suite {suite_id}...")
        suite = openml.study.get_suite(suite_id)

        # Fetch all tasks with their metadata
        task_metadata = []
        for task_id in suite.tasks:
            try:
                # Get task and dataset information
                task = openml.tasks.get_task(task_id)
                dataset = task.get_dataset()

                # Get n (instances) and p (features)
                n = dataset.qualities.get("NumberOfInstances", 0)
                p = dataset.qualities.get("NumberOfFeatures", 0)

                # Calculate n*p for sorting
                n_times_p = n * p if (n and p) else 0

                task_metadata.append(
                    {
                        "task_id": int(task_id),
                        "dataset_name": dataset.name,
                        "n": n,
                        "p": p,
                        "n_times_p": n_times_p,
                    }
                )

                print(
                    f"  Fetched Task {task_id}: {dataset.name} (n={n}, p={p}, n*p={n_times_p})"
                )

            except Exception as e:
                print(f"  ⚠️  Failed to fetch task {task_id}: {e}")
                continue

        # Sort by n*p (ascending)
        task_metadata.sort(key=lambda x: x["n_times_p"])

        print(f"\n✅ Fetched {len(task_metadata)} tasks from suite {suite_id}")
        print("   Sorted by n*p (ascending)")

        # Apply indexing if specified (1-based indexing)
        if start_idx is not None:
            if start_idx < 1:
                print(f"⚠️  Warning: start index {start_idx} < 1, using 1")
                start_idx = 1
            if start_idx > len(task_metadata):
                print(
                    f"⚠️  Warning: start index {start_idx} > {len(task_metadata)}, no datasets selected"
                )
                return []

            # Handle end index
            if end_idx is not None:
                if end_idx > len(task_metadata):
                    print(
                        f"⚠️  Warning: end index {end_idx} > {len(task_metadata)}, using {len(task_metadata)}"
                    )
                    end_idx = len(task_metadata)
                # Convert to 0-based indexing for slicing
                task_metadata = task_metadata[start_idx - 1 : end_idx]
                print(
                    f"   Selected datasets [{start_idx}-{end_idx}] (total: {len(task_metadata)})"
                )
            else:
                # No end index specified - slice to the end
                task_metadata = task_metadata[start_idx - 1 :]
                print(
                    f"   Selected datasets [{start_idx}-end] (total: {len(task_metadata)})"
                )

        # Create dataset configs with n×p information
        datasets = []
        for i, task_info in enumerate(task_metadata, 1):
            dataset_config = {
                "type": "openml_task",
                "task_id": task_info["task_id"],
                "name": f"task_{task_info['task_id']}_{task_info['dataset_name']}",
                "n_times_p": task_info["n_times_p"],  # Store for sorting
            }
            datasets.append(dataset_config)
            print(
                f"  {i}. Task {task_info['task_id']}: {task_info['dataset_name']} (n*p={task_info['n_times_p']})"
            )

        return datasets

    except Exception as e:
        print(f"❌ Failed to fetch suite {suite_spec}: {e}")
        import traceback

        traceback.print_exc()
        return []


def main():
    parser = argparse.ArgumentParser(
        description="Generate cluster configuration file for hyperparameter tuning"
    )
    parser.add_argument(
        "--datasets",
        nargs="+",
        help="Dataset configurations as JSON strings, file paths, or OpenML dataset IDs",
    )
    parser.add_argument(
        "--suite",
        type=str,
        help="OpenML benchmark suite ID with optional indexing (e.g., '269', '353[1-10]', or '353[23-]' for datasets 23 to end)",
    )
    parser.add_argument(
        "--models",
        nargs="+",
        required=True,
        help="Model names to run (e.g., MPFRegressor XGBRegressor)",
    )
    parser.add_argument(
        "--output",
        type=str,
        default="cluster_config.json",
        help="Output path for configuration file",
    )
    parser.add_argument(
        "--n-trials",
        type=int,
        default=200,
        help="Number of optimization trials per run",
    )
    parser.add_argument(
        "--optimization-method",
        type=str,
        default="optuna",
        choices=["optuna", "random", "grid"],
        help="Optimization method",
    )
    parser.add_argument(
        "--cv-strategy",
        type=str,
        default="simple",
        choices=["simple", "nested"],
        help="Cross-validation strategy",
    )
    parser.add_argument(
        "--cv-outer-folds",
        type=int,
        default=5,
        help="Number of outer CV folds (for nested CV)",
    )
    parser.add_argument(
        "--cv-inner-folds",
        type=int,
        default=10,
        help="Number of inner CV folds",
    )
    parser.add_argument(
        "--cv-train-split",
        type=float,
        default=0.8,
        help="Train split percentage (for simple CV)",
    )
    parser.add_argument(
        "--cv-simple-folds",
        type=int,
        default=10,
        help="Number of CV folds for simple strategy",
    )
    parser.add_argument(
        "--random-seed",
        type=int,
        default=42,
        help="Random seed for reproducibility",
    )
    parser.add_argument(
        "--n-jobs",
        type=int,
        default=-1,
        help="Number of parallel jobs for cross-validation",
    )
    parser.add_argument(
        "--hyperparams",
        type=str,
        default=None,
        help="Path to JSON file with custom hyperparameter ranges, or JSON string",
    )

    args = parser.parse_args()

    # Validate that at least one of datasets or suite is provided
    if not args.datasets and not args.suite:
        parser.error("At least one of --datasets or --suite must be provided")

    # Parse dataset configurations
    datasets = []
    if args.datasets:
        for dataset_arg in args.datasets:
            # Try to parse as JSON string first
            try:
                dataset = json.loads(dataset_arg)
                # Only accept JSON if it's a dict (dataset config), not a plain number
                if isinstance(dataset, dict):
                    datasets.append(dataset)
                else:
                    # JSON parsed but it's not a dict (e.g., just a number), continue to next parsing step
                    raise json.JSONDecodeError("Not a dict", dataset_arg, 0)
            except json.JSONDecodeError:
                # If not JSON, try as file path
                dataset_path = Path(dataset_arg)
                if dataset_path.exists():
                    with open(dataset_path, "r") as f:
                        dataset = json.load(f)
                        datasets.append(dataset)
                else:
                    # Assume it's a dataset ID for OpenML
                    # Try to fetch n×p if possible
                    try:
                        import openml

                        dataset_obj = openml.datasets.get_dataset(int(dataset_arg))
                        n = dataset_obj.qualities.get("NumberOfInstances", 0)
                        p = dataset_obj.qualities.get("NumberOfFeatures", 0)
                        n_times_p = n * p if (n and p) else float("inf")
                        datasets.append(
                            {
                                "type": "openml",
                                "dataset_id": int(dataset_arg),
                                "name": f"openml_dataset_{dataset_arg}",
                                "n_times_p": n_times_p,
                            }
                        )
                        print(
                            f"  Fetched dataset {dataset_arg}: n={n}, p={p}, n*p={n_times_p}"
                        )
                    except Exception:
                        # If fetching fails, just use default (will sort to end)
                        datasets.append(
                            {
                                "type": "openml",
                                "dataset_id": int(dataset_arg),
                                "name": f"openml_dataset_{dataset_arg}",
                            }
                        )

    # Fetch tasks from OpenML suite if provided
    if args.suite:
        suite_datasets = fetch_suite_tasks(args.suite)
        datasets.extend(suite_datasets)

    # Build configuration
    optimization_config = {
        "method": args.optimization_method,
        "n_trials": args.n_trials,
        "random_seed": args.random_seed,
    }

    if args.cv_strategy == "nested":
        optimization_config["outer_cv_folds"] = args.cv_outer_folds
        optimization_config["inner_cv_folds"] = args.cv_inner_folds

    cv_config = {
        "strategy": args.cv_strategy,
        "train_split": args.cv_train_split,
        "simple_cv_folds": args.cv_simple_folds,
    }

    if args.cv_strategy == "nested":
        cv_config["outer_folds"] = args.cv_outer_folds
        cv_config["inner_folds"] = args.cv_inner_folds

    resources_config = {
        "n_jobs": args.n_jobs,
        "random_seed": args.random_seed,
    }

    # Validate we have datasets
    if not datasets:
        print("❌ No datasets were successfully parsed or fetched")
        sys.exit(1)

    print(f"\nTotal datasets: {len(datasets)}")

    # Parse hyperparameters if provided
    hyperparameters = None
    if args.hyperparams:
        # Try to parse as JSON string first
        try:
            hyperparameters = json.loads(args.hyperparams)
        except json.JSONDecodeError:
            # If not JSON, try as file path
            hyperparams_path = Path(args.hyperparams)
            if hyperparams_path.exists():
                with open(hyperparams_path, "r") as f:
                    hyperparameters = json.load(f)
            else:
                print(
                    f"Warning: Could not parse hyperparameters from: {args.hyperparams}"
                )
                print("Proceeding with default hyperparameters")

    # Generate config
    generate_config(
        datasets=datasets,
        models=args.models,
        optimization_config=optimization_config,
        cv_config=cv_config,
        resources_config=resources_config,
        hyperparameters=hyperparameters,
        output_path=args.output,
    )


if __name__ == "__main__":
    main()
