"""
Script to run Optuna hyperparameter tuning with a custom dataset.

This script uses the existing Optuna tuning infrastructure to run
hyperparameter optimization on a custom dataset with MPFRegressor.

Usage:
    python cluster_scripts/run_custom_dataset_experiment.py --output results/
"""

import json
import argparse
import sys
from pathlib import Path
from datetime import datetime
from typing import Optional
import numpy as np
from sklearn.model_selection import train_test_split

# Add parent directory to path to import dashboard modules
sys.path.insert(0, str(Path(__file__).parent.parent))

from experiment_runner import default_models
from experiment_runner.execution import run_optuna_benchmark, SafeEstimatorWrapper, Task


def make_dataset(
    n: int,
    seed: int,
    noise_std: float = 0.25,
    constant_offset: float = 0.0,
    std_x1: float = 1.0,
    std_x2: float = 2.0,
    std_x3: float = 3.0,
) -> tuple[np.ndarray, np.ndarray]:
    """
    Generate synthetic dataset with PD cancellation property.

    Parameters
    ----------
    n : int
        Number of samples
    seed : int
        Random seed
    noise_std : float, default=0.25
        Standard deviation of noise term
    constant_offset : float, default=0.0
        Constant offset added to y (for compatibility with different scripts)
    std_x1 : float, default=1.0
        Standard deviation for X1
    std_x2 : float, default=1.0
        Standard deviation for X2
    std_x3 : float, default=1.0
        Standard deviation for X3

    Returns
    -------
    X : np.ndarray of shape (n, 3)
        Feature matrix
    y : np.ndarray of shape (n,)
        Target vector
    """
    rng = np.random.default_rng(seed)
    # Generate each feature with its own variance
    x1 = rng.normal(loc=0.0, scale=std_x1, size=n).astype(np.float64)
    x2 = rng.normal(loc=0.0, scale=std_x2, size=n).astype(np.float64)
    x3 = rng.normal(loc=0.0, scale=std_x3, size=n).astype(np.float64)
    X = np.column_stack([x1, x2, x3])
    y = x1**2 * x2 + x1**2 * x2 * x3 + constant_offset
    y = y + noise_std * rng.standard_normal(size=n)
    return X, y.astype(np.float64)


def load_hyperparameters(hyperparams_path: Path) -> dict:
    """Load hyperparameters from JSON file."""
    with open(hyperparams_path, "r") as f:
        hyperparams = json.load(f)
    return hyperparams


def prepare_model(model_name: str, custom_hyperparams: Optional[dict] = None):
    """
    Prepare a model for execution.

    Args:
        model_name: Name of the model to prepare
        custom_hyperparams: Optional dict of custom hyperparameter ranges for this model

    Returns:
        Tuple of (estimator, fixed_params, param_distributions)
    """
    if model_name not in default_models:
        raise ValueError(
            f"Unknown model: {model_name}. Available: {list(default_models.keys())}"
        )

    estimator, default_param_distributions = default_models[model_name]

    # Use custom hyperparameters if provided, otherwise use defaults
    if custom_hyperparams and model_name in custom_hyperparams:
        param_distributions_raw = custom_hyperparams[model_name]
        print(f"Using custom hyperparameters for {model_name}")

        # Convert JSON list format to tuple format for Optuna
        # Lists like ["uniform", 0.1, 0.9] should become tuples ("uniform", 0.1, 0.9)
        # Lists with single values [value] or multiple non-distribution values stay as lists
        param_distributions = {}
        distribution_types = {"randint", "uniform", "loguniform", "1-loguniform"}

        for param_name, param_spec in param_distributions_raw.items():
            if isinstance(param_spec, list) and len(param_spec) >= 2:
                # Check if first element is a distribution type
                if param_spec[0] in distribution_types:
                    # Convert to tuple: ["uniform", 0.1, 0.9] -> ("uniform", 0.1, 0.9)
                    param_distributions[param_name] = tuple(param_spec)
                else:
                    # Categorical or fixed value list - keep as list
                    param_distributions[param_name] = param_spec
            elif isinstance(param_spec, list) and len(param_spec) == 1:
                # Single value list - keep as list (will be treated as fixed value)
                param_distributions[param_name] = param_spec
            else:
                # Already a tuple or other format - use as is
                param_distributions[param_name] = param_spec
    else:
        param_distributions = default_param_distributions
        print(f"Using default hyperparameters for {model_name}")

    # No fixed params (all params are tunable)
    fixed_params = {}

    return estimator, fixed_params, param_distributions


def main():
    parser = argparse.ArgumentParser(
        description="Run Optuna hyperparameter tuning with custom dataset"
    )
    parser.add_argument(
        "--hyperparams",
        type=str,
        default="cluster_scripts/hyperparams/interpretable.json",
        help="Path to hyperparameters JSON file",
    )
    parser.add_argument(
        "--output",
        type=str,
        default="custom_dataset_results",
        help="Output directory for results",
    )
    parser.add_argument(
        "--n-trials",
        type=int,
        default=200,
        help="Number of optimization trials",
    )
    parser.add_argument(
        "--n-samples",
        type=int,
        default=5000,
        help="Number of samples in dataset",
    )
    parser.add_argument(
        "--noise-std",
        type=float,
        default=0.25,
        help="Standard deviation of noise",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Random seed",
    )
    parser.add_argument(
        "--train-split",
        type=float,
        default=0.8,
        help="Train/test split ratio",
    )
    parser.add_argument(
        "--cv-folds",
        type=int,
        default=3,
        help="Number of CV folds for optimization",
    )
    parser.add_argument(
        "--n-jobs",
        type=int,
        default=1,
        help="Number of parallel jobs",
    )
    parser.add_argument(
        "--constant-offset",
        type=float,
        default=0.0,
        help="Constant offset added to y",
    )
    parser.add_argument(
        "--std-x1",
        type=float,
        default=1.0,
        help="Standard deviation for X1",
    )
    parser.add_argument(
        "--std-x2",
        type=float,
        default=2.0,
        help="Standard deviation for X2",
    )
    parser.add_argument(
        "--std-x3",
        type=float,
        default=3.0,
        help="Standard deviation for X3",
    )

    args = parser.parse_args()

    # Create output directory
    output_dir = Path(args.output)
    output_dir.mkdir(parents=True, exist_ok=True)

    print("\n" + "=" * 80)
    print("Custom Dataset Experiment Runner")
    print(f"{'=' * 80}")
    print(f"Hyperparameters: {args.hyperparams}")
    print(f"Output directory: {args.output}")
    print(f"Number of trials: {args.n_trials}")
    print(f"Dataset samples: {args.n_samples}")
    print(f"Noise std: {args.noise_std}")
    print(f"Random seed: {args.seed}")
    print(f"{'=' * 80}\n")

    # Load hyperparameters
    hyperparams_path = Path(args.hyperparams)
    if not hyperparams_path.exists():
        print(f"❌ Hyperparameters file not found: {hyperparams_path}")
        sys.exit(1)

    custom_hyperparams = load_hyperparameters(hyperparams_path)
    print(f"✅ Loaded hyperparameters from: {hyperparams_path}")

    # Generate dataset
    print("\nGenerating dataset...")
    X, y = make_dataset(
        n=args.n_samples,
        seed=args.seed,
        noise_std=args.noise_std,
        constant_offset=args.constant_offset,
        std_x1=args.std_x1,
        std_x2=args.std_x2,
        std_x3=args.std_x3,
    )
    print(f"✅ Dataset generated: X{X.shape}, y{y.shape}")

    # Split into train/test
    print(f"\nSplitting dataset (train_split={args.train_split})...")
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=1.0 - args.train_split, random_state=args.seed
    )
    print(f"✅ Train set: X{X_train.shape}, y{y_train.shape}")
    print(f"✅ Test set: X{X_test.shape}, y{y_test.shape}")

    # Prepare model
    model_name = "MPFRegressor"
    print(f"\nPreparing model: {model_name}...")
    try:
        estimator, fixed_params, param_distributions = prepare_model(
            model_name, custom_hyperparams
        )
        estimator._name = model_name
        print(f"✅ Model prepared: {model_name}")
        print(f"   Parameters to tune: {list(param_distributions.keys())}")
    except Exception as e:
        print(f"❌ Failed to prepare model: {e}")
        import traceback
        traceback.print_exc()
        sys.exit(1)

    # Create Optuna storage path
    storage_path = output_dir / "optuna_study.log"
    study_name = f"{model_name}_custom_dataset"

    # Create task
    task = Task(
        estimator=SafeEstimatorWrapper(estimator),
        param_distributions=param_distributions,
        fixed_params=fixed_params,
        X_train=X_train,
        y_train=y_train,
        X_test=X_test,
        y_test=y_test,
        n_iter=args.n_trials,
        cv=args.cv_folds,
        random_state=args.seed,
        n_jobs=args.n_jobs,
        study_name=study_name,
        storage_name=str(storage_path),
        experiment_id=None,
        optimization_method="optuna",
    )

    # Run optimization
    print(f"\n{'=' * 80}")
    print(f"Starting optimization with {args.n_trials} trials...")
    print(f"{'=' * 80}\n")

    start_time = datetime.utcnow()
    try:
        result = run_optuna_benchmark(task)
        end_time = datetime.utcnow()

        if result:
            print(f"\n{'=' * 80}")
            print("✅ Optimization completed!")
            print(f"{'=' * 80}")
            test_rmse = result.get('test_rmse', 'N/A')
            train_rmse = result.get('train_rmse', 'N/A')
            print(f"Test RMSE: {test_rmse:.4f}" if isinstance(test_rmse, (int, float)) else f"Test RMSE: {test_rmse}")
            print(f"Train RMSE: {train_rmse:.4f}" if isinstance(train_rmse, (int, float)) else f"Train RMSE: {train_rmse}")
            print(f"Best parameters: {result.get('best_params', {})}")
            print(f"Duration: {end_time - start_time}")
            print(f"{'=' * 80}\n")

            # Save results
            result_filename = output_dir / "results.json"
            result_data = {
                "success": True,
                "dataset": "custom_dataset",
                "model": model_name,
                "dataset_config": {
                    "type": "custom",
                    "n_samples": args.n_samples,
                    "noise_std": args.noise_std,
                    "seed": args.seed,
                    "constant_offset": args.constant_offset,
                    "std_x1": args.std_x1,
                    "std_x2": args.std_x2,
                    "std_x3": args.std_x3,
                },
                "optimization_config": {
                    "n_trials": args.n_trials,
                    "cv_folds": args.cv_folds,
                    "random_seed": args.seed,
                },
                "start_timestamp": start_time.isoformat(),
                "end_timestamp": end_time.isoformat(),
                **result,
            }

            with open(result_filename, "w") as f:
                json.dump(result_data, f, indent=2, default=str)

            print(f"✅ Results saved to: {result_filename}")
        else:
            print("⚠️  Optimization returned no results")
            sys.exit(1)

    except Exception as e:
        error_msg = f"Fatal error: {str(e)}"
        print(f"❌ {error_msg}")
        import traceback
        traceback.print_exc()

        # Save error to file
        error_path = output_dir / "ERROR.json"
        with open(error_path, "w") as f:
            json.dump(
                {
                    "success": False,
                    "error": error_msg,
                    "timestamp": datetime.utcnow().isoformat(),
                },
                f,
                indent=2,
            )

        sys.exit(1)


if __name__ == "__main__":
    main()
