#!/usr/bin/env python3
"""
Hyperparameter optimization utilities for PINN experiments.
"""

import numpy as np
import pandas as pd
from typing import Dict, List, Any, Optional, Tuple, Callable
from pathlib import Path
import json
import time
import itertools
from dataclasses import dataclass, asdict
import warnings

# Try to import optimization libraries
try:
    from sklearn.model_selection import ParameterGrid, ParameterSampler
    from sklearn.metrics import make_scorer
    SKLEARN_AVAILABLE = True
except ImportError:
    SKLEARN_AVAILABLE = False
    warnings.warn("scikit-learn not available. Using basic optimization methods.")

try:
    import optuna
    OPTUNA_AVAILABLE = True
except ImportError:
    OPTUNA_AVAILABLE = False
    warnings.warn("Optuna not available. Install with 'pip install optuna' for advanced optimization. Using basic optimization methods.")


@dataclass
class OptimizationResult:
    """Result of hyperparameter optimization."""
    best_params: Dict[str, Any]
    best_score: float
    all_results: List[Dict[str, Any]]
    optimization_time: float
    n_trials: int
    method: str


class HyperparameterOptimizer:
    """Hyperparameter optimization for PINN methods."""
    
    def __init__(self, device: str = "auto"):
        """
        Initialize hyperparameter optimizer.
        
        Args:
            device: Device to use for optimization
        """
        self.device = device
        self.results = {}
        
        # Initialize CUDA context early if using GPU
        if device == "cuda" or (device == "auto" and torch.cuda.is_available()):
            import torch
            if torch.cuda.is_available():
                # Create a dummy tensor to initialize CUDA context
                dummy_tensor = torch.tensor([1.0], device='cuda')
                del dummy_tensor
                torch.cuda.empty_cache()
                torch.cuda.synchronize()
        
        # Default parameter grids for each method
        # Note: hidden_layers are stored as strings for Optuna compatibility
        self.default_param_grids = {
            'standard': {
                'learning_rate': [1e-4, 5e-4, 1e-3, 5e-3, 1e-2],
                'batch_size': [500, 1000, 2000],
                'hidden_layers': ['32,32', '50,50', '64,64', '50,50,50', '64,64,64']
            },
            'rpit': {
                'learning_rate': [1e-4, 5e-4, 1e-3, 5e-3, 1e-2],
                'batch_size': [500, 1000, 2000],
                'hidden_layers': ['32,32', '50,50', '64,64', '50,50,50', '64,64,64'],
                'lambda_sens': [0.01, 0.05, 0.1, 0.2, 0.5],
                'lambda_var': [0.1, 0.5, 1.0, 2.0, 5.0],
                'noise_std': [0.01, 0.05, 0.1, 0.2, 0.5]
            },
            'bayesian': {
                'learning_rate': [1e-4, 5e-4, 1e-3],  # More conservative learning rates
                'batch_size': [1000, 2000],  # Larger batch sizes for stability
                'hidden_layers': ['32,32', '50,50'],  # Smaller networks for stability
                'n_ensemble': [3, 5],  # Smaller ensemble for stability
                'dropout_rate': [0.05, 0.1]  # Very conservative dropout rates
            }
        }
    
    def optimize_method(
        self,
        problem_type: str,
        method: str,
        optimization_method: str = "grid",
        n_trials: int = 50,
        n_epochs: int = 500,
        n_seeds: int = 3,
        custom_param_grid: Optional[Dict[str, List]] = None,
        objective_metric: str = "final_train_loss"
    ) -> OptimizationResult:
        """
        Optimize hyperparameters for a specific method and problem.
        
        Args:
            problem_type: Type of problem ('lorenz', 'burgers', 'inverse_poisson')
            method: Method to optimize ('standard', 'rpit', 'bayesian')
            optimization_method: Optimization method ('grid', 'random', 'optuna')
            n_trials: Number of trials for random/optuna optimization
            n_epochs: Number of epochs for each trial
            n_seeds: Number of random seeds for each parameter combination
            custom_param_grid: Custom parameter grid (overrides default)
            objective_metric: Metric to optimize ('final_train_loss', 'training_time')
            
        Returns:
            OptimizationResult object
        """
        print(f"🔧 Optimizing {method.upper()} for {problem_type} problem...")
        print(f"   Method: {optimization_method}")
        print(f"   Trials: {n_trials}")
        print(f"   Epochs per trial: {n_epochs}")
        
        start_time = time.time()
        
        # Get parameter grid
        param_grid = custom_param_grid or self.default_param_grids.get(method, {})
        
        if not param_grid:
            raise ValueError(f"No parameter grid defined for method: {method}")
        
        # Choose optimization method
        if optimization_method == "grid":
            result = self._grid_search(
                problem_type, method, param_grid, n_epochs, n_seeds, objective_metric
            )
        elif optimization_method == "random":
            result = self._random_search(
                problem_type, method, param_grid, n_trials, n_epochs, n_seeds, objective_metric
            )
        elif optimization_method == "optuna":
            if OPTUNA_AVAILABLE:
                result = self._optuna_optimization(
                    problem_type, method, param_grid, n_trials, n_epochs, n_seeds, objective_metric
                )
            else:
                print(f"⚠️  Optuna not available, falling back to random search...")
                result = self._random_search(
                    problem_type, method, param_grid, n_trials, n_epochs, n_seeds, objective_metric
                )
        else:
            raise ValueError(f"Unsupported optimization method: {optimization_method}")
        
        optimization_time = time.time() - start_time
        result.optimization_time = optimization_time
        
        print(f"✅ Optimization completed in {optimization_time:.2f} seconds")
        print(f"   Best score: {result.best_score:.6f}")
        print(f"   Best parameters: {result.best_params}")
        
        return result
    
    def _grid_search(
        self,
        problem_type: str,
        method: str,
        param_grid: Dict[str, List],
        n_epochs: int,
        n_seeds: int,
        objective_metric: str
    ) -> OptimizationResult:
        """Perform grid search optimization."""
        if not SKLEARN_AVAILABLE:
            # Fallback to basic grid search
            param_combinations = list(itertools.product(*param_grid.values()))
            param_names = list(param_grid.keys())
        else:
            param_combinations = list(ParameterGrid(param_grid))
            param_names = list(param_grid.keys())
        
        all_results = []
        best_score = float('inf')
        best_params = None
        
        total_combinations = len(param_combinations)
        print(f"   Grid search: {total_combinations} combinations")
        
        for i, params in enumerate(param_combinations):
            if not SKLEARN_AVAILABLE:
                param_dict = dict(zip(param_names, params))
            else:
                param_dict = params
            
            print(f"   Trial {i+1}/{total_combinations}: {param_dict}")
            
            # Run multiple seeds for this parameter combination
            scores = []
            for seed in range(n_seeds):
                try:
                    score = self._evaluate_parameters(
                        problem_type, method, param_dict, n_epochs, seed, objective_metric
                    )
                    scores.append(score)
                except Exception as e:
                    print(f"     Seed {seed} failed: {e}")
                    scores.append(float('inf'))
            
            # Average score across seeds
            avg_score = np.mean(scores)
            std_score = np.std(scores)
            
            result = {
                'params': param_dict.copy(),
                'score': avg_score,
                'score_std': std_score,
                'scores': scores,
                'n_seeds': n_seeds
            }
            all_results.append(result)
            
            if avg_score < best_score:
                best_score = avg_score
                best_params = param_dict.copy()
            
            print(f"     Score: {avg_score:.6f} ± {std_score:.6f}")
        
        return OptimizationResult(
            best_params=best_params,
            best_score=best_score,
            all_results=all_results,
            optimization_time=0,  # Will be set by caller
            n_trials=total_combinations,
            method="grid_search"
        )
    
    def _random_search(
        self,
        problem_type: str,
        method: str,
        param_grid: Dict[str, List],
        n_trials: int,
        n_epochs: int,
        n_seeds: int,
        objective_metric: str
    ) -> OptimizationResult:
        """Perform random search optimization."""
        if SKLEARN_AVAILABLE:
            param_combinations = list(ParameterSampler(param_grid, n_iter=n_trials, random_state=42))
        else:
            # Fallback to basic random sampling
            param_combinations = []
            for _ in range(n_trials):
                param_dict = {}
                for param_name, param_values in param_grid.items():
                    param_dict[param_name] = np.random.choice(param_values)
                param_combinations.append(param_dict)
        
        all_results = []
        best_score = float('inf')
        best_params = None
        
        print(f"   Random search: {n_trials} trials")
        
        for i, params in enumerate(param_combinations):
            print(f"   Trial {i+1}/{n_trials}: {params}")
            
            # Run multiple seeds for this parameter combination
            scores = []
            for seed in range(n_seeds):
                try:
                    score = self._evaluate_parameters(
                        problem_type, method, params, n_epochs, seed, objective_metric
                    )
                    scores.append(score)
                except Exception as e:
                    print(f"     Seed {seed} failed: {e}")
                    scores.append(float('inf'))
            
            # Average score across seeds
            avg_score = np.mean(scores)
            std_score = np.std(scores)
            
            result = {
                'params': params.copy(),
                'score': avg_score,
                'score_std': std_score,
                'scores': scores,
                'n_seeds': n_seeds
            }
            all_results.append(result)
            
            if avg_score < best_score:
                best_score = avg_score
                best_params = params.copy()
            
            print(f"     Score: {avg_score:.6f} ± {std_score:.6f}")
        
        return OptimizationResult(
            best_params=best_params,
            best_score=best_score,
            all_results=all_results,
            optimization_time=0,  # Will be set by caller
            n_trials=n_trials,
            method="random_search"
        )
    
    def _optuna_optimization(
        self,
        problem_type: str,
        method: str,
        param_grid: Dict[str, List],
        n_trials: int,
        n_epochs: int,
        n_seeds: int,
        objective_metric: str
    ) -> OptimizationResult:
        """Perform Optuna-based optimization."""
        def objective(trial):
            # Suggest parameters
            params = {}
            for param_name, param_values in param_grid.items():
                if isinstance(param_values[0], int):
                    params[param_name] = trial.suggest_int(param_name, min(param_values), max(param_values))
                elif isinstance(param_values[0], float):
                    params[param_name] = trial.suggest_float(param_name, min(param_values), max(param_values))
                elif isinstance(param_values[0], list):
                    # For hidden_layers, suggest from available options
                    params[param_name] = trial.suggest_categorical(param_name, param_values)
                else:
                    params[param_name] = trial.suggest_categorical(param_name, param_values)
            
            # Evaluate parameters
            scores = []
            for seed in range(n_seeds):
                try:
                    score = self._evaluate_parameters(
                        problem_type, method, params, n_epochs, seed, objective_metric
                    )
                    
                    # Check for NaN or infinite values
                    if np.isnan(score) or np.isinf(score):
                        print(f"⚠️  Trial failed: NaN/Inf score detected (score={score})")
                        scores.append(float('inf'))
                    else:
                        scores.append(score)
                        
                except Exception as e:
                    print(f"⚠️  Trial failed with exception: {e}")
                    scores.append(float('inf'))
            
            avg_score = np.mean(scores)
            
            # Final check for NaN
            if np.isnan(avg_score) or np.isinf(avg_score):
                print(f"⚠️  Trial failed: Final score is NaN/Inf (score={avg_score})")
                return float('inf')
            
            return avg_score
        
        # Create study
        study = optuna.create_study(direction='minimize')
        study.optimize(objective, n_trials=n_trials)
        
        # Convert results
        all_results = []
        for trial in study.trials:
            result = {
                'params': trial.params,
                'score': trial.value,
                'score_std': 0,  # Optuna doesn't provide std
                'scores': [trial.value],
                'n_seeds': 1
            }
            all_results.append(result)
        
        return OptimizationResult(
            best_params=study.best_params,
            best_score=study.best_value,
            all_results=all_results,
            optimization_time=0,  # Will be set by caller
            n_trials=n_trials,
            method="optuna"
        )
    
    def _evaluate_parameters(
        self,
        problem_type: str,
        method: str,
        params: Dict[str, Any],
        n_epochs: int,
        seed: int,
        objective_metric: str
    ) -> float:
        """
        Evaluate a set of parameters by running a single experiment.
        
        Args:
            problem_type: Type of problem
            method: Method to evaluate
            params: Parameter dictionary
            n_epochs: Number of training epochs
            seed: Random seed
            objective_metric: Metric to return
            
        Returns:
            Objective metric value
        """
        # Import here to avoid circular imports
        import torch
        import numpy as np
        import sys
        import os
        
        # Add the project root to the path for absolute imports
        current_dir = os.path.dirname(os.path.abspath(__file__))
        project_root = os.path.dirname(os.path.dirname(current_dir))
        if project_root not in sys.path:
            sys.path.insert(0, project_root)
        
        from src.models.problem_models import create_problem_model
        from src.utils.data_generator import generate_problem_data
        from src.training.trainer import PINNTrainer
        
        # Set random seed
        torch.manual_seed(seed)
        np.random.seed(seed)
        
        # Get device
        device = self.device
        if device == "auto":
            device = "cuda" if torch.cuda.is_available() else "cpu"
        
        # Initialize CUDA context early if using GPU
        if device == "cuda" and torch.cuda.is_available():
            # Create a dummy tensor to initialize CUDA context
            dummy_tensor = torch.tensor([1.0], device='cuda')
            del dummy_tensor
            torch.cuda.empty_cache()  # Clear cache
            torch.cuda.synchronize()  # Ensure CUDA context is initialized
        
        # Convert string parameters back to appropriate types
        processed_params = {}
        for key, value in params.items():
            if key == 'hidden_layers' and isinstance(value, str):
                # Convert string like "32,32" to list [32, 32]
                processed_params[key] = [int(x) for x in value.split(',')]
            else:
                processed_params[key] = value
        
        # Separate model parameters from training parameters
        model_params = {}
        training_params = {}
        
        # Define which parameters belong to the model vs trainer
        model_param_keys = ['hidden_layers', 'lambda_sens', 'lambda_var', 'noise_std', 'n_ensemble', 'dropout_rate']
        training_param_keys = ['learning_rate', 'batch_size']
        
        for key, value in processed_params.items():
            if key in model_param_keys:
                model_params[key] = value
            elif key in training_param_keys:
                training_params[key] = value
            else:
                # Default to model parameters for unknown keys
                model_params[key] = value
        
        # Add default parameters for each problem type
        problem_defaults = {
            'lorenz': {
                'input_dim': 1,
                'output_dim': 3,
                'sigma': 10.0,
                'rho': 28.0,
                'beta': 8.0/3.0
            },
            'burgers': {
                'input_dim': 2,
                'output_dim': 1
            },
            'inverse_poisson': {
                'input_dim': 1,
                'output_dim': 2
            }
        }
        
        # Add problem-specific defaults
        if problem_type in problem_defaults:
            for key, value in problem_defaults[problem_type].items():
                if key not in model_params:
                    model_params[key] = value
        
        # Create model with model parameters only
        model = create_problem_model(
            problem_type=problem_type,
            method=method,
            device=device,
            **model_params
        )
        
        # Generate data
        data = generate_problem_data(problem_type, n_points=1000, device=device)
        
        # Create trainer with training parameters
        trainer = PINNTrainer(
            model=model,
            device=device,
            epochs=n_epochs,
            lr=training_params.get('learning_rate', 1e-3),
            batch_size=training_params.get('batch_size', 1000)
        )
        
        # Train model
        train_losses = trainer.train(data)
        
        # Check for NaN in training losses
        if train_losses:
            final_loss = train_losses[-1]
            if np.isnan(final_loss) or np.isinf(final_loss):
                print(f"⚠️  Training failed: Final loss is NaN/Inf (loss={final_loss})")
                return float('inf')
        
        # Return the objective metric
        if objective_metric == "final_train_loss":
            result = train_losses[-1] if train_losses else float('inf')
            # Final check for NaN
            if np.isnan(result) or np.isinf(result):
                print(f"⚠️  Objective metric failed: Result is NaN/Inf (result={result})")
                return float('inf')
            return result
        elif objective_metric == "training_time":
            result = trainer.losses_history[-1] if hasattr(trainer, 'losses_history') else 0.0
            return result
        else:
            raise ValueError(f"Unknown objective metric: {objective_metric}")
    
    def optimize_all_methods(
        self,
        problem_types: List[str] = None,
        methods: List[str] = None,
        optimization_method: str = "random",
        n_trials: int = 20,
        n_epochs: int = 300,
        n_seeds: int = 2
    ) -> Dict[str, Dict[str, OptimizationResult]]:
        """
        Optimize hyperparameters for all methods and problems.
        
        Args:
            problem_types: List of problem types to optimize
            methods: List of methods to optimize
            optimization_method: Optimization method to use
            n_trials: Number of trials per optimization
            n_epochs: Number of epochs per trial
            n_seeds: Number of seeds per parameter combination
            
        Returns:
            Dictionary of optimization results
        """
        if problem_types is None:
            problem_types = ['lorenz', 'burgers', 'inverse_poisson']
        
        if methods is None:
            methods = ['standard', 'rpit', 'bayesian']
        
        all_results = {}
        
        for problem_type in problem_types:
            all_results[problem_type] = {}
            
            for method in methods:
                try:
                    result = self.optimize_method(
                        problem_type=problem_type,
                        method=method,
                        optimization_method=optimization_method,
                        n_trials=n_trials,
                        n_epochs=n_epochs,
                        n_seeds=n_seeds
                    )
                    all_results[problem_type][method] = result
                    
                except Exception as e:
                    print(f"❌ Optimization failed for {method} on {problem_type}: {e}")
                    all_results[problem_type][method] = None
        
        return all_results
    
    def save_optimization_results(
        self,
        results: Dict[str, Dict[str, OptimizationResult]],
        save_path: Path
    ) -> None:
        """Save optimization results to files."""
        save_path.mkdir(parents=True, exist_ok=True)
        
        # Save detailed results
        detailed_results = {}
        for problem_type, problem_results in results.items():
            detailed_results[problem_type] = {}
            for method, result in problem_results.items():
                if result is not None:
                    detailed_results[problem_type][method] = {
                        'best_params': result.best_params,
                        'best_score': result.best_score,
                        'optimization_time': result.optimization_time,
                        'n_trials': result.n_trials,
                        'method': result.method,
                        'all_results': result.all_results
                    }
        
        with open(save_path / "optimization_results.json", 'w') as f:
            json.dump(detailed_results, f, indent=2, default=str)
        
        # Save best parameters summary
        best_params_summary = {}
        for problem_type, problem_results in results.items():
            best_params_summary[problem_type] = {}
            for method, result in problem_results.items():
                if result is not None:
                    best_params_summary[problem_type][method] = result.best_params
        
        with open(save_path / "best_parameters.json", 'w') as f:
            json.dump(best_params_summary, f, indent=2)
        
        print(f"💾 Optimization results saved to {save_path}")
    
    def load_optimization_results(self, load_path: Path) -> Dict[str, Dict[str, OptimizationResult]]:
        """Load optimization results from files."""
        with open(load_path / "optimization_results.json", 'r') as f:
            data = json.load(f)
        
        results = {}
        for problem_type, problem_results in data.items():
            results[problem_type] = {}
            for method, result_data in problem_results.items():
                if result_data is not None:
                    result = OptimizationResult(
                        best_params=result_data['best_params'],
                        best_score=result_data['best_score'],
                        all_results=result_data['all_results'],
                        optimization_time=result_data['optimization_time'],
                        n_trials=result_data['n_trials'],
                        method=result_data['method']
                    )
                    results[problem_type][method] = result
                else:
                    results[problem_type][method] = None
        
        return results
