from typing import Dict, Callable, Optional
import os
import json
import yaml
import optuna
from optuna.trial import Trial
from optuna.samplers import TPESampler
from optuna.pruners import MedianPruner
import logging
from datetime import datetime

import warnings
warnings.filterwarnings('ignore', category=optuna.exceptions.ExperimentalWarning)

class CocktailOptimizer:
    """Hyperparameter optimizer for regularization cocktail."""
    def __init__(
        self,
        train_fn: Callable,
        task_id: int,
        max_epochs: int,
        n_trials: int,
        storage: Optional[str] = None,
        force_single_thread: bool = True,
        patience: int = 100,
        output_dir: str = './runs'
    ):
        self.train_fn = train_fn
        self.task_id = task_id
        self.max_epochs = max_epochs
        self.n_trials = n_trials
        self.storage = storage
        self.patience = patience
        self.output_dir = output_dir
        
        # Setup logging
        self.logger = logging.getLogger(__name__)
        self.logger.setLevel(logging.INFO)
        
        if force_single_thread:
            os.environ["OMP_NUM_THREADS"] = "1"
            optuna.logging.set_verbosity(optuna.logging.WARNING)
        
        if storage and storage.startswith('sqlite:///'):
            os.makedirs(os.path.dirname(storage.replace('sqlite:///', '')), exist_ok=True)
    
    def _suggest_params(self, trial: Trial) -> Dict:
        """Define hyperparameter search space."""
        config = {
            # 'learning_rate': trial.suggest_float('learning_rate', 1e-4, 1e-2, log=True),
            'use_batch_norm': trial.suggest_categorical('use_batch_norm', [True, False]),
            'use_swa': trial.suggest_categorical('use_swa', [True, False]),
            'use_weight_decay': trial.suggest_categorical('use_weight_decay', [True, False])
        }
        
        if config['use_weight_decay']:
            config['weight_decay'] = trial.suggest_float('weight_decay', 1e-5, 1e-1, log=True)
        
        config['use_dropout'] = trial.suggest_categorical('use_dropout', [True, False])
        if config['use_dropout']:
            config['dropout_shape'] = trial.suggest_categorical('dropout_shape', [
                'funnel', 'long_funnel', 'diamond', 'triangle'
            ])
            config['dropout_rate'] = trial.suggest_float('dropout_rate', 0.0, 0.8)
        
        config['use_skip'] = trial.suggest_categorical('use_skip', [True, False])
        if config['use_skip']:
            config['skip_type'] = trial.suggest_categorical('skip_type', 
                ['Standard', 'ShakeShake', 'ShakeDrop'])
            if config['skip_type'] == 'ShakeDrop':
                config['shakedrop_prob'] = trial.suggest_float('shakedrop_prob', 0.0, 1.0)
        
        config['augmentation'] = trial.suggest_categorical('augmentation', ['None', 'MixUp'])
        if config['augmentation'] != 'None':
            config['aug_magnitude'] = trial.suggest_float('aug_magnitude', 0.0, 1.0)
        
        config['use_amp'] = trial.suggest_categorical('use_amp', [True, False])
        config['max_grad_norm'] = trial.suggest_float('max_grad_norm', 0.1, 10.0, log=True)
        
        return config
    
    def _objective(self, trial: Trial) -> float:
        try:
            config = self._suggest_params(trial)
            result = self.train_fn(self.task_id, config, self.max_epochs)
            
            trial.report(result['val_balanced_accuracy'], step=result['epochs_trained'])
            if trial.should_prune():
                raise optuna.exceptions.TrialPruned()
            
            return 1.0 - result['val_balanced_accuracy']
        except optuna.exceptions.TrialPruned:
            self.logger.info(f"Trial {trial.number} pruned.")
            raise
        except Exception as e:
            # Add more detailed logging
            import traceback
            error_msg = f"Trial {trial.number} failed:\n{str(e)}\n{traceback.format_exc()}"
            self.logger.error(error_msg)
            raise optuna.exceptions.TrialPruned()
    
    def optimize(self) -> Dict:
        """Run hyperparameter optimization with early stopping."""
        sampler = TPESampler(
            seed=42,
            n_startup_trials=50,
            multivariate=True,
            constant_liar=True,
            warn_independent_sampling=False
        )
        
        pruner = MedianPruner(
            n_startup_trials=50,
            n_warmup_steps=50,
            interval_steps=1
        )
        
        study = optuna.create_study(
            study_name=f"cocktail_task_{self.task_id}",
            storage=self.storage,
            load_if_exists=True,
            direction='minimize',
            sampler=sampler,
            pruner=pruner
        )
        
        # Early stopping callback
        def early_stopping_callback(study: optuna.study.Study, trial: optuna.trial.FrozenTrial) -> None:
            completed_trials = study.get_trials(states=[optuna.trial.TrialState.COMPLETE])
            if len(completed_trials) < self.patience:
                return
            
            best_value = study.best_value
            recent_trials = completed_trials[-self.patience:]
            
            no_improvement = all(
                trial.value >= best_value 
                for trial in recent_trials 
                if trial.value is not None
            )
            
            if no_improvement:
                self.logger.info(f"No improvement in last {self.patience} trials. Stopping optimization.")
                study.stop()

        try:
            study.optimize(
                self._objective,
                n_trials=self.n_trials,
                n_jobs=1,
                show_progress_bar=True,
                gc_after_trial=True,
                callbacks=[early_stopping_callback]  # Add the callback
            )
        except KeyboardInterrupt:
            self.logger.info("Optimization interrupted by user.")
        except Exception as e:
            self.logger.error(f"Optimization error: {str(e)}")
            raise
        
        # Get best results
        best_config = study.best_trial.params
        best_result = self.train_fn(
            task_id=self.task_id,
            config=best_config,
            epochs=self.max_epochs
        )
        
        return {
            'best_config': best_config,
            'best_result': best_result,
            'task_id': self.task_id,
            'study': {
                'best_value': study.best_value,
                'n_trials': len(study.trials),
                'n_completed': len(study.get_trials(states=[optuna.trial.TrialState.COMPLETE])),
                'n_pruned': len(study.get_trials(states=[optuna.trial.TrialState.PRUNED])),
                'datetime_complete': datetime.now().isoformat()
            }
        }
    
    def save_results(self, results: Dict, filepath: str) -> None:
        """Save optimization results with YAML hyperparameter file."""
        try:
            os.makedirs(os.path.dirname(filepath), exist_ok=True)
            
            # Save full results as JSON
            with open(filepath, 'w') as f:
                json.dump(results, f, indent=4)
            
            # Save optimized hyperparameters as YAML
            yaml_path = os.path.join(
                os.path.dirname(filepath),
                f"task_{results['task_id']}_hyperparams.yml"
            )
            with open(yaml_path, 'w') as f:
                yaml.dump(results['best_config'], f, default_flow_style=False)
            
            self.logger.info(f"Results saved to {filepath}")
            self.logger.info(f"Hyperparameters saved to {yaml_path}")
        except Exception as e:
            self.logger.error(f"Error saving results: {str(e)}")
            with open(filepath + '.backup', 'w') as f:
                json.dump(results, f, indent=4)
            self.logger.info(f"Results saved to backup: {filepath}.backup")