#!/usr/bin/env python3
"""
Configuration loader utility for automatically loading best hyperparameters.
"""

import json
import os
from pathlib import Path
from typing import Dict, Any, Optional


def load_best_hyperparameters(project_root: str = None) -> Optional[Dict[str, Any]]:
    """
    Load best hyperparameters from the optimization results.
    
    Args:
        project_root: Path to project root directory
        
    Returns:
        Dictionary containing best hyperparameters, or None if not found
    """
    if project_root is None:
        # Try to find project root automatically
        current_dir = Path(__file__).parent
        project_root = current_dir.parent.parent
    
    config_path = Path(project_root) / "data" / "results" / "best_hyperparameters.json"
    
    if not config_path.exists():
        print(f"⚠️  Best hyperparameters file not found: {config_path}")
        print("   Using default configuration instead.")
        return None
    
    try:
        with open(config_path, 'r', encoding='utf-8') as f:
            config = json.load(f)
        
        print(f"✅ Loaded best hyperparameters from: {config_path}")
        print(f"   Configuration: {config}")
        return config
        
    except Exception as e:
        print(f"❌ Error loading best hyperparameters: {e}")
        print("   Using default configuration instead.")
        return None


def apply_best_hyperparameters_to_config(config_obj, best_params: Dict[str, Any], method: str = "rpit"):
    """
    Apply best hyperparameters to a configuration object.
    
    Args:
        config_obj: Configuration object to modify
        best_params: Best hyperparameters dictionary
        method: Method type ('rpit', 'bayesian', 'standard')
    """
    if best_params is None:
        return
    
    # Apply method-specific parameters
    if method == "rpit":
        # R-PIT specific parameters
        if 'lambda_sens' in best_params:
            config_obj.lambda_sens = best_params['lambda_sens']
        if 'lambda_var' in best_params:
            config_obj.lambda_var = best_params['lambda_var']
        if 'noise_std' in best_params:
            config_obj.noise_std = best_params['noise_std']
    
    elif method == "bayesian":
        # Bayesian specific parameters
        if 'n_ensemble' in best_params:
            config_obj.n_ensemble = best_params['n_ensemble']
        if 'dropout_rate' in best_params:
            config_obj.dropout_rate = best_params['dropout_rate']
        if 'weight_decay' in best_params:
            config_obj.weight_decay = best_params['weight_decay']
    
    # Common parameters
    if 'learning_rate' in best_params:
        config_obj.learning_rate = best_params['learning_rate']
    if 'hidden_layers' in best_params:
        config_obj.hidden_layers = best_params['hidden_layers']
    
    print(f"✅ Applied best hyperparameters for {method} method")


def get_optimized_config(problem_type: str, method: str = "rpit", **kwargs):
    """
    Get configuration with best hyperparameters applied.
    
    Args:
        problem_type: Type of problem ('lorenz', 'burgers', 'inverse_poisson')
        method: Method to use ('rpit', 'bayesian', 'standard')
        **kwargs: Additional configuration parameters
        
    Returns:
        Configuration object with best hyperparameters applied
    """
    # Import here to avoid circular imports
    if problem_type == "lorenz":
        from config.lorenz_config import LorenzConfig
        config = LorenzConfig(**kwargs)
    elif problem_type == "burgers":
        from config.burgers_config import BurgersConfig
        config = BurgersConfig(**kwargs)
    elif problem_type == "inverse_poisson":
        from config.inverse_config import InverseConfig
        config = InverseConfig(**kwargs)
    else:
        raise ValueError(f"Unknown problem type: {problem_type}")
    
    # Load and apply best hyperparameters
    best_params = load_best_hyperparameters()
    apply_best_hyperparameters_to_config(config, best_params, method)
    
    return config


def print_config_summary(config_obj, method: str):
    """Print a summary of the configuration being used."""
    print(f"\n📋 Configuration Summary for {method.upper()} method:")
    print("=" * 50)
    
    # Common parameters
    print(f"Learning Rate: {getattr(config_obj, 'learning_rate', 'N/A')}")
    print(f"Hidden Layers: {getattr(config_obj, 'hidden_layers', 'N/A')}")
    print(f"Device: {getattr(config_obj, 'device', 'N/A')}")
    print(f"Seed: {getattr(config_obj, 'seed', 'N/A')}")
    
    # Method-specific parameters
    if method == "rpit":
        print(f"Lambda Sens: {getattr(config_obj, 'lambda_sens', 'N/A')}")
        print(f"Lambda Var: {getattr(config_obj, 'lambda_var', 'N/A')}")
        print(f"Noise Std: {getattr(config_obj, 'noise_std', 'N/A')}")
    elif method == "bayesian":
        print(f"N Ensemble: {getattr(config_obj, 'n_ensemble', 'N/A')}")
        print(f"Dropout Rate: {getattr(config_obj, 'dropout_rate', 'N/A')}")
        print(f"Weight Decay: {getattr(config_obj, 'weight_decay', 'N/A')}")
    
    print("=" * 50)
