#!/usr/bin/env python3
"""
Script to save an MPFRegressor model from an Optuna log file.

This script loads the best parameters from an Optuna study log file,
fits the model with those parameters, and saves it.

Usage:
    python cluster_scripts/save_model_from_log.py \
        --log-file data/cluster_results/blackbox/mpf/california_housing/MPFRegressor/california_housing_MPFRegressor.log \
        --dataset-config '{"type": "openml_task", "task_id": 361255}' \
        --output-path data/cluster_results/blackbox/mpf/california_housing/MPFRegressor/california_housing_MPFRegressor_best
"""

import json
import argparse
import sys
from pathlib import Path
from datetime import datetime
import time
import numpy as np
import optuna
from optuna.storages import JournalStorage
from optuna.storages.journal import JournalFileBackend
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from sklearn.base import clone

# Add parent directory to path to import dashboard modules
sys.path.insert(0, str(Path(__file__).parent.parent))

from experiment_runner import default_models
from services.datasets import load_dataset
from cluster_scripts.run_cluster_experiment import save_fitted_model


def load_best_params_from_log(log_file_path):
    """
    Load the best parameters from an Optuna study log file.
    
    Args:
        log_file_path: Path to the Optuna journal log file
        
    Returns:
        Tuple of (best_params dict, study_name)
    """
    log_path = Path(log_file_path)
    if not log_path.exists():
        raise FileNotFoundError(f"Log file not found: {log_file_path}")
    
    # Create storage from log file
    storage = JournalStorage(JournalFileBackend(str(log_path)))
    
    # Try to load the study - we need to know the study name
    # From the log file, we can see it's typically "MPFRegressor_simple" for simple CV
    # Let's try common study names
    study_names = ["MPFRegressor_simple", "MPFRegressor"]
    
    study = None
    study_name = None
    for name in study_names:
        try:
            study = optuna.load_study(study_name=name, storage=storage)
            study_name = name
            break
        except Exception:
            continue
    
    if study is None:
        # If we can't find it by name, try to get all study names from storage
        # This is a workaround - we'll read the log file directly to find the study name
        with open(log_path, 'r') as f:
            first_line = f.readline()
            if first_line:
                first_entry = json.loads(first_line)
                if 'study_name' in first_entry:
                    study_name = first_entry['study_name']
                    study = optuna.load_study(study_name=study_name, storage=storage)
                else:
                    # Default to MPFRegressor_simple
                    study_name = "MPFRegressor_simple"
                    study = optuna.load_study(study_name=study_name, storage=storage)
            else:
                raise ValueError(f"Log file is empty: {log_file_path}")
    
    if len(study.trials) == 0:
        raise ValueError(f"No trials found in study: {study_name}")
    
    try:
        best_params = study.best_params
        best_value = study.best_value
        print(f"✅ Found best trial with value: {best_value}")
        print(f"   Best parameters: {best_params}")
        return best_params, study_name, best_value
    except ValueError as e:
        raise ValueError(f"Could not get best parameters from study: {e}")


def main():
    parser = argparse.ArgumentParser(
        description="Save MPFRegressor model from Optuna log file"
    )
    parser.add_argument(
        "--log-file",
        type=str,
        required=True,
        help="Path to Optuna log file (.log)",
    )
    parser.add_argument(
        "--dataset-config",
        type=str,
        required=True,
        help="JSON string with dataset configuration (e.g., '{\"type\": \"openml_task\", \"task_id\": 361255}')",
    )
    parser.add_argument(
        "--output-path",
        type=str,
        required=True,
        help="Output path for saved model (without extension, e.g., 'model_best')",
    )
    parser.add_argument(
        "--train-split",
        type=float,
        default=0.8,
        help="Train/test split ratio (default: 0.8)",
    )
    parser.add_argument(
        "--random-seed",
        type=int,
        default=42,
        help="Random seed for train/test split (default: 42)",
    )
    parser.add_argument(
        "--run-id",
        type=int,
        default=0,
        help="Run ID for the results JSON filename (default: 0)",
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        default=None,
        help="Output directory for results JSON (default: same as log file directory)",
    )
    
    args = parser.parse_args()
    
    print("=" * 80)
    print("Save MPFRegressor from Optuna Log")
    print("=" * 80)
    print(f"Log file: {args.log_file}")
    print(f"Output path: {args.output_path}")
    print()
    
    # Parse dataset config
    try:
        dataset_config = json.loads(args.dataset_config)
    except json.JSONDecodeError as e:
        print(f"❌ Error parsing dataset config: {e}")
        sys.exit(1)
    
    # Record start time
    start_timestamp = datetime.utcnow().isoformat()
    
    # Load best parameters from log
    print("Loading best parameters from Optuna log...")
    try:
        best_params, study_name, best_cv_score = load_best_params_from_log(args.log_file)
        # best_cv_score is negative MSE, so we need to negate it
        best_cv_score = -best_cv_score
    except Exception as e:
        print(f"❌ Failed to load best parameters: {e}")
        sys.exit(1)
    
    # Load dataset
    print("\nLoading dataset...")
    try:
        dataset_info = load_dataset(dataset_config)
        X = dataset_info["X"]
        y = dataset_info["y"]
        dataset_name = dataset_info["name"]
        print(f"✅ Dataset loaded: {dataset_name}, shape: X{X.shape}, y{y.shape}")
    except Exception as e:
        print(f"❌ Failed to load dataset: {e}")
        import traceback
        traceback.print_exc()
        sys.exit(1)
    
    # Split data (same as in run_cluster_experiment.py)
    print(f"\nSplitting data (train_split={args.train_split})...")
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=1.0 - args.train_split, random_state=args.random_seed
    )
    print(f"✅ Train set: X{X_train.shape}, y{y_train.shape}")
    print(f"✅ Test set: X{X_test.shape}, y{y_test.shape}")
    
    # Save the model
    print("\nFitting and saving model...")
    model_name = "MPFRegressor"
    fixed_params = {}  # No fixed params for cluster execution
    
    output_path = Path(args.output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    # Fit the model and measure time
    fit_start_time = time.time()
    
    try:
        # Create model instance
        base_estimator, _ = default_models[model_name]
        model = clone(base_estimator)
        
        # Apply fixed params first, then best params
        if fixed_params:
            model = model.set_params(**fixed_params)
        if best_params:
            model = model.set_params(**best_params)
        
        # For MPFRegressor, also create visualdb_path and set it before fitting
        visualdb_path = None
        if model_name == "MPFRegressor":
            visualdb_path = Path(f"{output_path}.sqlite")
            # Set visualdb parameter before fitting (use absolute path)
            model.set_params(visualdb=str(visualdb_path.absolute()))
            print(f"Set visualdb_path for {model_name}: {visualdb_path.absolute()}")
        
        # Fit the model
        print(f"Fitting {model_name}...")
        model.fit(X_train, y_train)
        
        fit_end_time = time.time()
        fit_time = fit_end_time - fit_start_time
        
        # Calculate test RMSE
        print("Calculating test RMSE...")
        y_pred = model.predict(X_test)
        test_mse = mean_squared_error(y_test, y_pred)
        test_rmse = np.sqrt(test_mse)
        print(f"✅ Test RMSE: {test_rmse:.4f}")
        
        # Save the model
        file_ext = ".bin"
        model_path = Path(f"{output_path}{file_ext}")
        model.save(str(model_path))
        
        # Return both paths in the format used by the dashboard
        if visualdb_path and visualdb_path.exists():
            fitted_model_path = f"{model_path}; {visualdb_path}"
            print(f"✅ Model saved to: {model_path}")
            print(f"✅ VisualDB saved to: {visualdb_path}")
        else:
            fitted_model_path = str(model_path)
            print(f"✅ Model saved to: {model_path}")
        
    except Exception as e:
        print(f"❌ Failed to fit/save model: {e}")
        import traceback
        traceback.print_exc()
        sys.exit(1)
    
    # Create results JSON file
    print("\nCreating results JSON file...")
    
    # Determine output directory for JSON
    if args.output_dir:
        json_output_dir = Path(args.output_dir)
    else:
        # Use the same directory as the log file
        json_output_dir = Path(args.log_file).parent
    
    json_output_dir.mkdir(parents=True, exist_ok=True)
    
    # Create safe names for filename
    safe_dataset_name = dataset_name.replace(" ", "_").replace("/", "_")
    safe_model_name = model_name.replace(" ", "_").replace("/", "_")
    
    # Create result filename
    result_filename = (
        f"run_{args.run_id:04d}_{safe_dataset_name}_{safe_model_name}.json"
    )
    result_path = json_output_dir / result_filename
    
    # Create results structure matching the benchmark format
    result_entry = {
        "model": model_name,
        "best_params": best_params,
        "fixed_params": fixed_params,
        "best_cv_score": best_cv_score,
        "test_mse": test_mse,
        "test_rmse": test_rmse,
        "fit_time": fit_time,
        "dataset": dataset_name,
    }
    
    results = {
        "success": True,
        "dataset": dataset_name,
        "model": model_name,
        "dataset_config": dataset_config,
        "n_folds": 1,
        "mean_test_rmse": test_rmse,
        "std_test_rmse": 0.0,
        "min_test_rmse": test_rmse,
        "max_test_rmse": test_rmse,
        "results": [result_entry],
        "start_timestamp": start_timestamp,
        "end_timestamp": datetime.utcnow().isoformat(),
        "fitted_model_path": fitted_model_path,
    }
    
    # Save results JSON
    with open(result_path, "w") as f:
        json.dump(results, f, indent=2, default=str)
    
    print(f"\n{'=' * 80}")
    print(f"✅ Model saved successfully!")
    print(f"   Model path: {fitted_model_path}")
    print(f"✅ Results JSON saved to: {result_path}")
    print(f"   Test RMSE: {test_rmse:.4f}")
    print(f"{'=' * 80}\n")


if __name__ == "__main__":
    main()
