#!/usr/bin/env python3
"""
Standalone experiment runner for cluster execution.

This script runs a single dataset x model combination based on a configuration file
and a run ID. It extracts the necessary functionality from the hyperparameter
tuning dashboard to run without Flask/RQ dependencies.

Usage:
    python cluster_scripts/run_cluster_experiment.py --config config.json --run-id 0 --output results/
"""

import json
import argparse
import sys
import traceback
import shutil
from pathlib import Path
from datetime import datetime
import numpy as np

# Add parent directory to path to import dashboard modules
sys.path.insert(0, str(Path(__file__).parent.parent))

from experiment_runner import default_models
from experiment_runner.execution import run_optuna_benchmark, SafeEstimatorWrapper, Task
from services.datasets import load_dataset
from sklearn.model_selection import KFold, train_test_split
from sklearn.base import clone
import joblib


def save_fitted_model(
    model_name, best_params, fixed_params, X_train, y_train, output_path
):
    """
    Save a fitted model to disk using the appropriate method for the model type.

    Args:
        model_name: Name of the model
        best_params: Best hyperparameters found during optimization
        fixed_params: Fixed hyperparameters
        X_train: Training features
        y_train: Training targets
        output_path: Path where the model should be saved (without extension)

    Returns:
        Path to the saved model file(s), or None if saving failed.
        For MPFRegressor, returns "model_path; visualdb_path" format.
        For other models, returns the model path as a string.
    """
    if model_name not in default_models:
        print(f"⚠️  Warning: Unknown model {model_name}, skipping model save")
        return None

    base_estimator, _ = default_models[model_name]

    # Create model instance
    model = clone(base_estimator)

    # Apply fixed params first, then best params
    if fixed_params:
        model = model.set_params(**fixed_params)
    if best_params:
        model = model.set_params(**best_params)

    # For MPFRegressor, also create visualdb_path and set it before fitting
    visualdb_path = None
    if model_name == "MPFRegressor":
        visualdb_path = Path(f"{output_path}.sqlite")
        # Set visualdb parameter before fitting (use absolute path)
        model.set_params(visualdb=str(visualdb_path.absolute()))
        print(f"Set visualdb_path for {model_name}: {visualdb_path.absolute()}")

    # Fit the model
    print(f"Fitting {model_name} for saving...")
    model.fit(X_train, y_train)

    # Determine file extension and save method based on model type
    try:
        if model_name == "MPFRegressor":
            file_ext = ".bin"
            model_path = Path(f"{output_path}{file_ext}")
            # MPFRegressor has .save() method
            model.save(str(model_path))
            # Return both paths in the format used by the dashboard
            if visualdb_path and visualdb_path.exists():
                result_path = f"{model_path}; {visualdb_path}"
                print(f"✅ Model saved to: {model_path}")
                print(f"✅ VisualDB saved to: {visualdb_path}")
                return result_path
            else:
                print(f"✅ Model saved to: {model_path}")
                return str(model_path)
        elif model_name == "XGBRegressor":
            file_ext = ".json"
            model_path = Path(f"{output_path}{file_ext}")
            # XGBRegressor has .save_model() method
            model.save_model(str(model_path))
            print(f"✅ Model saved to: {model_path}")
            return str(model_path)
        elif model_name == "LGBMRegressor":
            file_ext = ".txt"
            model_path = Path(f"{output_path}{file_ext}")
            # LGBMRegressor uses .booster_.save_model() method
            model.booster_.save_model(str(model_path))
            print(f"✅ Model saved to: {model_path}")
            return str(model_path)
        else:  # RandomForestRegressor and others
            file_ext = ".pkl"
            model_path = Path(f"{output_path}{file_ext}")
            # Use joblib for sklearn models
            joblib.dump(model, model_path)
            print(f"✅ Model saved to: {model_path}")
            return str(model_path)
    except Exception as e:
        print(f"❌ Failed to save {model_name}: {e}")
        import traceback

        print(traceback.format_exc())
        return None


def prepare_model(model_name, custom_hyperparams=None):
    """
    Prepare a model for execution.

    Args:
        model_name: Name of the model to prepare
        custom_hyperparams: Optional dict of custom hyperparameter ranges for this model

    Returns:
        Tuple of (estimator, fixed_params, param_distributions)
    """
    if model_name not in default_models:
        raise ValueError(
            f"Unknown model: {model_name}. Available: {list(default_models.keys())}"
        )

    estimator, default_param_distributions = default_models[model_name]

    # Use custom hyperparameters if provided, otherwise use defaults
    if custom_hyperparams and model_name in custom_hyperparams:
        param_distributions_raw = custom_hyperparams[model_name]
        print(f"Using custom hyperparameters for {model_name}")

        # Convert JSON list format to tuple format for Optuna
        # Lists like ["uniform", 0.1, 0.9] should become tuples ("uniform", 0.1, 0.9)
        # Lists with single values [value] or multiple non-distribution values stay as lists
        param_distributions = {}
        distribution_types = {"randint", "uniform", "loguniform", "1-loguniform"}

        for param_name, param_spec in param_distributions_raw.items():
            if isinstance(param_spec, list) and len(param_spec) >= 2:
                # Check if first element is a distribution type
                if param_spec[0] in distribution_types:
                    # Convert to tuple: ["uniform", 0.1, 0.9] -> ("uniform", 0.1, 0.9)
                    param_distributions[param_name] = tuple(param_spec)
                else:
                    # Categorical or fixed value list - keep as list
                    param_distributions[param_name] = param_spec
            elif isinstance(param_spec, list) and len(param_spec) == 1:
                # Single value list - keep as list (will be treated as fixed value)
                param_distributions[param_name] = param_spec
            else:
                # Already a tuple or other format - use as is
                param_distributions[param_name] = param_spec
    else:
        param_distributions = default_param_distributions
        print(f"Using default hyperparameters for {model_name}")

    # No fixed params for cluster execution (all params are tunable)
    fixed_params = {}

    return estimator, fixed_params, param_distributions


def run_single_experiment(run_config, output_dir, custom_hyperparams=None):
    """
    Run a single dataset x model experiment.

    Args:
        run_config: Configuration dictionary for this run
        output_dir: Directory to save results
        custom_hyperparams: Optional dict of custom hyperparameter ranges per model

    Returns:
        Dictionary with results
    """
    # Record start time
    start_timestamp = datetime.utcnow().isoformat()

    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    dataset_config = run_config["dataset"]
    model_name = run_config["model"]
    optimization_config = run_config.get("optimization", {})
    cv_config = run_config.get("cv", {})
    resources_config = run_config.get("resources", {})

    print(f"\n{'=' * 80}")
    print(f"Running experiment: {dataset_config.get('name', 'unknown')} x {model_name}")
    print(f"{'=' * 80}\n")

    # Load dataset
    print("Loading dataset...")
    try:
        dataset_info = load_dataset(dataset_config)
        X = dataset_info["X"]
        y = dataset_info["y"]
        dataset_name = dataset_info["name"]
        print(f"✅ Dataset loaded: {dataset_name}, shape: X{X.shape}, y{y.shape}")
    except Exception as e:
        error_msg = f"Failed to load dataset: {str(e)}\n{traceback.format_exc()}"
        print(f"❌ {error_msg}")
        return {
            "success": False,
            "error": error_msg,
            "dataset": dataset_config.get("name", "unknown"),
            "model": model_name,
            "dataset_config": dataset_config,
        }

    # Create subdirectory for this dataset/model combination
    # Sanitize dataset name for filesystem
    safe_dataset_name = dataset_name.replace(" ", "_").replace("/", "_")
    safe_model_name = model_name.replace(" ", "_").replace("/", "_")
    run_output_dir = output_dir / safe_dataset_name / safe_model_name
    run_output_dir.mkdir(parents=True, exist_ok=True)

    # Prepare model
    print(f"Preparing model: {model_name}...")
    try:
        estimator, fixed_params, param_distributions = prepare_model(
            model_name, custom_hyperparams
        )
        estimator._name = model_name
        print(f"✅ Model prepared: {model_name}")
        print(f"   Parameters to tune: {list(param_distributions.keys())}")
    except Exception as e:
        error_msg = f"Failed to prepare model: {str(e)}\n{traceback.format_exc()}"
        print(f"❌ {error_msg}")
        return {
            "success": False,
            "error": error_msg,
            "dataset": dataset_name,
            "model": model_name,
        }

    # Prepare cross-validation splits
    cv_strategy = cv_config.get("strategy", "simple")
    random_seed = optimization_config.get("random_seed", 42)
    n_trials = optimization_config.get("n_trials", 100)
    optimization_method = optimization_config.get("method", "optuna")
    n_jobs = resources_config.get("n_jobs", 1)

    all_results = []
    # Store training data separately (can't be JSON serialized)
    training_data_by_fold = {}

    if cv_strategy == "nested":
        # Nested CV: run optimization for each fold
        outer_cv_folds = optimization_config.get("outer_cv_folds", 5)
        inner_cv_folds = optimization_config.get("inner_cv_folds", 3)

        kf = KFold(n_splits=outer_cv_folds, shuffle=True, random_state=random_seed)

        for fold_idx, (train_idx, test_idx) in enumerate(kf.split(X)):
            print(f"\n--- Fold {fold_idx + 1}/{outer_cv_folds} ---")
            X_train, X_test = X[train_idx], X[test_idx]
            y_train, y_test = y[train_idx], y[test_idx]

            # Create Optuna storage path for this fold
            storage_path = (
                run_output_dir
                / f"{safe_dataset_name}_{safe_model_name}_fold{fold_idx}.log"
            )

            study_name = f"{model_name}_fold{fold_idx}"

            # Create task
            task = Task(
                estimator=SafeEstimatorWrapper(estimator),
                param_distributions=param_distributions,
                fixed_params=fixed_params,
                X_train=X_train,
                y_train=y_train,
                X_test=X_test,
                y_test=y_test,
                n_iter=n_trials,
                cv=inner_cv_folds,
                random_state=random_seed + fold_idx,
                n_jobs=n_jobs,
                study_name=study_name,
                storage_name=str(storage_path),
                experiment_id=None,  # Not needed for standalone execution
                optimization_method=optimization_method,
            )

            # Run optimization
            try:
                result = run_optuna_benchmark(task)
                if result:
                    result["fold"] = fold_idx
                    result["dataset"] = dataset_name
                    result["model"] = model_name
                    # Store training data separately (can't be JSON serialized)
                    training_data_by_fold[fold_idx] = {
                        "X_train": X_train,
                        "y_train": y_train,
                    }
                    all_results.append(result)
                    print(
                        f"✅ Fold {fold_idx + 1} completed: RMSE={result.get('test_rmse', 'N/A'):.4f}"
                    )
                else:
                    print(f"⚠️  Fold {fold_idx + 1} returned no results")
            except Exception as e:
                error_msg = (
                    f"Failed on fold {fold_idx + 1}: {str(e)}\n{traceback.format_exc()}"
                )
                print(f"❌ {error_msg}")
                all_results.append(
                    {
                        "fold": fold_idx,
                        "success": False,
                        "error": error_msg,
                    }
                )

    else:  # simple CV
        # Simple CV: single train/test split, then CV for optimization
        train_split = cv_config.get("train_split", 0.8)
        simple_cv_folds = cv_config.get("simple_cv_folds", 3)
        test_size = 1.0 - train_split

        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=test_size, random_state=random_seed
        )

        # Create Optuna storage path
        storage_path = run_output_dir / f"{safe_dataset_name}_{safe_model_name}.log"

        study_name = f"{model_name}_simple"

        # Create task
        task = Task(
            estimator=SafeEstimatorWrapper(estimator),
            param_distributions=param_distributions,
            fixed_params=fixed_params,
            X_train=X_train,
            y_train=y_train,
            X_test=X_test,
            y_test=y_test,
            n_iter=n_trials,
            cv=simple_cv_folds,
            random_state=random_seed,
            n_jobs=n_jobs,
            study_name=study_name,
            storage_name=str(storage_path),
            experiment_id=None,
            optimization_method=optimization_method,
        )

        # Run optimization
        try:
            result = run_optuna_benchmark(task)
            if result:
                result["dataset"] = dataset_name
                result["model"] = model_name
                # Store training data separately (can't be JSON serialized)
                training_data_by_fold["simple"] = {
                    "X_train": X_train,
                    "y_train": y_train,
                }
                all_results.append(result)
                print(
                    f"✅ Experiment completed: RMSE={result.get('test_rmse', 'N/A'):.4f}"
                )
            else:
                print("⚠️  Experiment returned no results")
        except Exception as e:
            error_msg = f"Failed to run experiment: {str(e)}\n{traceback.format_exc()}"
            print(f"❌ {error_msg}")
            all_results.append(
                {
                    "success": False,
                    "error": error_msg,
                }
            )

    # Aggregate results
    if all_results:
        # Calculate summary statistics
        successful_results = [r for r in all_results if r.get("test_rmse") is not None]

        if successful_results:
            test_rmses = [r["test_rmse"] for r in successful_results]
            summary = {
                "success": True,
                "dataset": dataset_name,
                "model": model_name,
                "dataset_config": dataset_config,
                "n_folds": len(successful_results),
                "mean_test_rmse": np.mean(test_rmses),
                "std_test_rmse": np.std(test_rmses),
                "min_test_rmse": np.min(test_rmses),
                "max_test_rmse": np.max(test_rmses),
                "results": all_results,
                "start_timestamp": start_timestamp,
                "end_timestamp": datetime.utcnow().isoformat(),
            }

            # Save the best fitted model
            # For nested CV: use the fold with best test_rmse
            # For simple CV: use the single result
            best_result = min(successful_results, key=lambda r: r["test_rmse"])
            best_params = best_result.get("best_params", {})
            best_fixed_params = best_result.get("fixed_params", {})

            # Get training data for the best result
            if cv_strategy == "nested":
                best_fold = best_result.get("fold")
                training_data = training_data_by_fold.get(best_fold)
            else:
                training_data = training_data_by_fold.get("simple")

            if training_data:
                best_X_train = training_data["X_train"]
                best_y_train = training_data["y_train"]
                model_output_path = (
                    run_output_dir / f"{safe_dataset_name}_{safe_model_name}_best"
                )
                model_path = save_fitted_model(
                    model_name,
                    best_params,
                    best_fixed_params,
                    best_X_train,
                    best_y_train,
                    str(model_output_path),
                )
                if model_path:
                    # Store path(s) relative to output_dir for anonymity (no host paths in JSON)
                    parts = [p.strip() for p in str(model_path).split(";")]
                    rel_parts = []
                    for p in parts:
                        try:
                            rel_parts.append(str(Path(p).relative_to(output_dir)))
                        except (ValueError, TypeError):
                            rel_parts.append(p)
                    summary["fitted_model_path"] = "; ".join(rel_parts)
        else:
            summary = {
                "success": False,
                "dataset": dataset_name,
                "model": model_name,
                "dataset_config": dataset_config,
                "error": "All folds failed",
                "results": all_results,
                "start_timestamp": start_timestamp,
                "end_timestamp": datetime.utcnow().isoformat(),
            }
    else:
        summary = {
            "success": False,
            "dataset": dataset_name,
            "model": model_name,
            "dataset_config": dataset_config,
            "error": "No results generated",
            "start_timestamp": start_timestamp,
            "end_timestamp": datetime.utcnow().isoformat(),
        }

    return summary


def main():
    parser = argparse.ArgumentParser(
        description="Run a single experiment from cluster configuration"
    )
    parser.add_argument(
        "--config",
        type=str,
        required=True,
        help="Path to cluster configuration JSON file",
    )
    parser.add_argument(
        "--run-id",
        type=int,
        required=True,
        help="Run ID (0-indexed) to execute",
    )
    parser.add_argument(
        "--output",
        type=str,
        default="cluster_results",
        help="Output directory for results",
    )

    args = parser.parse_args()

    # Load configuration
    config_path = Path(args.config)
    if not config_path.exists():
        print(f"❌ Configuration file not found: {config_path}")
        sys.exit(1)

    with open(config_path, "r") as f:
        config = json.load(f)

    runs = config.get("runs", [])
    total_runs = len(runs)

    if args.run_id < 0 or args.run_id >= total_runs:
        print(
            f"❌ Invalid run ID: {args.run_id}. Must be between 0 and {total_runs - 1}"
        )
        sys.exit(1)

    # Get the specific run configuration
    run_config = runs[args.run_id]

    # Get custom hyperparameters if provided
    custom_hyperparams = config.get("hyperparameters")

    print("\n" + "=" * 80)
    print("Cluster Experiment Runner")
    print(f"{'=' * 80}")
    print(f"Configuration: {config_path}")
    print(f"Run ID: {args.run_id + 1}/{total_runs}")
    print(f"Output directory: {args.output}")
    if custom_hyperparams:
        print(f"Using custom hyperparameters for: {list(custom_hyperparams.keys())}")
    print(f"{'=' * 80}\n")

    # Create output directory
    output_dir = Path(args.output)
    output_dir.mkdir(parents=True, exist_ok=True)

    # Copy config file to output directory for record-keeping (only if it doesn't exist)
    # Note: This is typically done once in submit_cluster_jobs.sh, but we check here
    # as a safety measure in case the script is run directly
    config_filename = config_path.name
    config_copy_path = output_dir / config_filename
    if not config_copy_path.exists():
        try:
            shutil.copy2(config_path, config_copy_path)
            print(f"📋 Configuration copied to: {config_copy_path}")
        except Exception as e:
            print(f"⚠️  Warning: Failed to copy config file: {e}")
    else:
        print(f"📋 Configuration already exists at: {config_copy_path}")

    # Run the experiment
    try:
        results = run_single_experiment(run_config, args.output, custom_hyperparams)

        # Save results

        # Create subdirectory for this dataset/model combination
        dataset_name = results.get("dataset", "unknown")
        model_name = results.get("model", "unknown")
        safe_dataset_name = dataset_name.replace(" ", "_").replace("/", "_")
        safe_model_name = model_name.replace(" ", "_").replace("/", "_")
        run_output_dir = output_dir / safe_dataset_name / safe_model_name
        run_output_dir.mkdir(parents=True, exist_ok=True)

        # Create result filename
        result_filename = (
            f"run_{args.run_id:04d}_{safe_dataset_name}_{safe_model_name}.json"
        )
        result_path = run_output_dir / result_filename
        with open(result_path, "w") as f:
            json.dump(results, f, indent=2, default=str)

        print(f"\n{'=' * 80}")
        print(f"✅ Results saved to: {result_path}")
        if results.get("success"):
            if "mean_test_rmse" in results:
                print(
                    f"   Mean Test RMSE: {results['mean_test_rmse']:.4f} ± {results['std_test_rmse']:.4f}"
                )
            else:
                test_rmse = results.get("test_rmse", "N/A")
                print(f"   Test RMSE: {test_rmse}")
        else:
            error_msg = results.get("error", "Unknown error")
            print(f"   ⚠️  Experiment failed: {error_msg}")
        print(f"{'=' * 80}\n")

    except Exception as e:
        error_msg = f"Fatal error: {str(e)}\n{traceback.format_exc()}"
        print(f"❌ {error_msg}")

        # Save error to file
        # Try to get dataset and model from run_config if available
        output_dir = Path(args.output)
        output_dir.mkdir(parents=True, exist_ok=True)

        # Try to determine dataset/model from run_config for error file location
        try:
            dataset_config = run_config.get("dataset", {})
            dataset_name = dataset_config.get("name", "unknown")
            model_name = run_config.get("model", "unknown")
            safe_dataset_name = dataset_name.replace(" ", "_").replace("/", "_")
            safe_model_name = model_name.replace(" ", "_").replace("/", "_")
            error_output_dir = output_dir / safe_dataset_name / safe_model_name
        except Exception:
            # Fallback to root if we can't determine dataset/model
            error_output_dir = output_dir

        error_output_dir.mkdir(parents=True, exist_ok=True)
        error_path = error_output_dir / f"run_{args.run_id:04d}_ERROR.json"
        with open(error_path, "w") as f:
            json.dump(
                {
                    "success": False,
                    "error": error_msg,
                    "run_id": args.run_id,
                    "timestamp": datetime.utcnow().isoformat(),
                },
                f,
                indent=2,
            )

        sys.exit(1)


if __name__ == "__main__":
    main()
