# import optuna
# import json
# import os
# import logging
# from datetime import datetime
# import torch
# from transformers import TrainingArguments, Trainer
# from lora import (  # Import from your enhanced script
#     set_reproducible_seed, get_system_info, compute_metrics,
#     EnhancedTrainMetricsCallback, get_exp_data_hf
# )

# class LoRAHyperparameterTuner:
#     def __init__(self, base_args, n_trials=50, study_name=None):
#         self.base_args = base_args
#         self.n_trials = n_trials
#         self.study_name = study_name or f"lora_tuning_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
        
#         # Setup logging
#         self.setup_logging()
        
#         # Load dataset once
#         self.dataset = get_exp_data_hf("ag_news", val_size=0.1, seed=base_args.seed)
        
#         # Setup tokenizer
#         from transformers import BertTokenizerFast
#         self.tokenizer = BertTokenizerFast.from_pretrained(base_args.model_name)
        
#         # Tokenize dataset
#         def tokenize_fn(ex):
#             return self.tokenizer(ex["text"], truncation=True, max_length=base_args.max_length)
        
#         self.tokenized_dataset = self.dataset.map(tokenize_fn, batched=True)
        
#         # Results tracking
#         self.results = []
        
#     def setup_logging(self):
#         """Setup logging for hyperparameter tuning"""
#         os.makedirs(f"hyperparameter_tuning/{self.study_name}", exist_ok=True)
        
#         logging.basicConfig(
#             level=logging.INFO,
#             format='%(asctime)s - %(levelname)s - %(message)s',
#             handlers=[
#                 logging.FileHandler(f"hyperparameter_tuning/{self.study_name}/tuning.log"),
#                 logging.StreamHandler()
#             ]
#         )
        
#     def objective(self, trial):
#         """Objective function for Optuna optimization"""
        
#         # Suggest hyperparameters
#         r = trial.suggest_categorical('r', [4, 8, 16, 32, 64])
#         alpha = trial.suggest_categorical('alpha', [8, 16, 32, 64, 128])
        
#         # Optional: suggest other hyperparameters
#         learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-3, log=True)
#         dropout = trial.suggest_float('lora_dropout', 0.0, 0.3)
        
#         # Log trial info
#         logging.info(f"Trial {trial.number}: r={r}, alpha={alpha}, lr={learning_rate:.2e}, dropout={dropout:.3f}")
        
#         try:
#             # Train model with these hyperparameters
#             val_accuracy = self.train_model(r, alpha, learning_rate, dropout, trial.number)
            
#             # Store results
#             self.results.append({
#                 'trial': trial.number,
#                 'r': r,
#                 'alpha': alpha,
#                 'learning_rate': learning_rate,
#                 'dropout': dropout,
#                 'val_accuracy': val_accuracy,
#                 'alpha_r_ratio': alpha / r
#             })
            
#             return val_accuracy
            
#         except Exception as e:
#             logging.error(f"Trial {trial.number} failed: {str(e)}")
#             return 0.0  # Return poor score for failed trials
    
#     def train_model(self, r, alpha, learning_rate, dropout, trial_number):
#         """Train model with given hyperparameters"""
        
#         # Set seed for reproducibility
#         set_reproducible_seed(self.base_args.seed + trial_number)
        
#         # Import here to avoid circular imports
#         from transformers import BertForSequenceClassification, DataCollatorWithPadding
#         from peft import get_peft_model, LoraConfig, TaskType
        
#         # Create model
#         base_model = BertForSequenceClassification.from_pretrained(
#             self.base_args.model_name,
#             num_labels=4
#         )
        
#         # Apply LoRA with trial hyperparameters
#         peft_config = LoraConfig(
#             task_type=TaskType.SEQ_CLS,
#             inference_mode=False,
#             r=r,
#             lora_alpha=alpha,
#             lora_dropout=dropout,
#             target_modules=['query', 'value'],
#         )
#         model = get_peft_model(base_model, peft_config)
        
#         # Data collator
#         data_collator = DataCollatorWithPadding(self.tokenizer)
        
#         # Training arguments (shorter training for hyperparameter search)
#         training_args = TrainingArguments(
#             output_dir=f"hyperparameter_tuning/{self.study_name}/trial_{trial_number}",
#             num_train_epochs=2,  # Shorter training for efficiency
#             per_device_train_batch_size=self.base_args.batch_size,
#             per_device_eval_batch_size=self.base_args.batch_size,
#             learning_rate=learning_rate,
#             eval_strategy="epoch",
#             save_strategy="no",  # Don't save checkpoints during tuning
#             logging_steps=100,
#             seed=self.base_args.seed + trial_number,
#             fp16=True,
#             report_to=None,
#             disable_tqdm=True,  # Reduce output clutter
#         )
        
#         # Trainer
#         trainer = Trainer(
#             model=model,
#             args=training_args,
#             train_dataset=self.tokenized_dataset["train"],
#             eval_dataset=self.tokenized_dataset["val"],
#             tokenizer=self.tokenizer,
#             data_collator=data_collator,
#             compute_metrics=compute_metrics,
#             label_names=["label"],
#         )
        
#         # Train
#         trainer.train()
        
#         # Get validation accuracy
#         val_metrics = trainer.evaluate()
#         val_accuracy = val_metrics['eval_accuracy']
        
#         logging.info(f"Trial {trial_number} completed: val_accuracy = {val_accuracy:.4f}")
        
#         # Cleanup
#         del model, trainer, base_model
#         torch.cuda.empty_cache()
        
#         return val_accuracy
    
#     def run_optimization(self):
#         """Run hyperparameter optimization"""
        
#         logging.info(f"Starting hyperparameter optimization with {self.n_trials} trials")
        
#         # Create Optuna study
#         study = optuna.create_study(
#             direction='maximize',
#             study_name=self.study_name,
#             storage=f'sqlite:///hyperparameter_tuning/{self.study_name}/study.db',
#             load_if_exists=True
#         )
        
#         # Run optimization
#         study.optimize(self.objective, n_trials=self.n_trials)
        
#         # Save results
#         self.save_results(study)
        
#         return study
    
#     def save_results(self, study):
#         """Save optimization results"""
        
#         # Best parameters
#         best_params = study.best_params
#         best_value = study.best_value
        
#         logging.info(f"Best parameters: {best_params}")
#         logging.info(f"Best validation accuracy: {best_value:.4f}")
        
#         # Save detailed results
#         results_summary = {
#             'best_params': best_params,
#             'best_value': best_value,
#             'n_trials': len(study.trials),
#             'study_name': self.study_name,
#             'timestamp': datetime.now().isoformat(),
#             'system_info': get_system_info(),
#             'all_results': self.results
#         }
        
#         with open(f"hyperparameter_tuning/{self.study_name}/results.json", 'w') as f:
#             json.dump(results_summary, f, indent=2)
        
#         # Create analysis plots
#         self.create_analysis_plots(study)
        
#         # Print summary
#         self.print_summary(study)
    
#     def create_analysis_plots(self, study):
#         """Create visualization plots"""
#         try:
#             import matplotlib.pyplot as plt
#             import optuna.visualization as vis
            
#             # Parameter importance
#             fig1 = vis.plot_param_importances(study)
#             fig1.write_html(f"hyperparameter_tuning/{self.study_name}/param_importance.html")
            
#             # Optimization history
#             fig2 = vis.plot_optimization_history(study)
#             fig2.write_html(f"hyperparameter_tuning/{self.study_name}/optimization_history.html")
            
#             # Parameter relationships
#             fig3 = vis.plot_parallel_coordinate(study)
#             fig3.write_html(f"hyperparameter_tuning/{self.study_name}/parallel_coordinate.html")
            
#             # Slice plot
#             fig4 = vis.plot_slice(study)
#             fig4.write_html(f"hyperparameter_tuning/{self.study_name}/slice_plot.html")
            
#             logging.info("Analysis plots saved as HTML files")
            
#         except ImportError:
#             logging.warning("plotly not available, skipping visualization plots")
    
#     def print_summary(self, study):
#         """Print optimization summary"""
#         print("\n" + "="*50)
#         print("HYPERPARAMETER TUNING SUMMARY")
#         print("="*50)
#         print(f"Study: {self.study_name}")
#         print(f"Total trials: {len(study.trials)}")
#         print(f"Best validation accuracy: {study.best_value:.4f}")
#         print("\nBest parameters:")
#         for key, value in study.best_params.items():
#             print(f"  {key}: {value}")
        
#         # Show top 5 trials
#         print("\nTop 5 trials:")
#         sorted_trials = sorted(study.trials, key=lambda t: t.value or 0, reverse=True)[:5]
#         for i, trial in enumerate(sorted_trials):
#             print(f"  {i+1}. Trial {trial.number}: {trial.value:.4f} - {trial.params}")
        
#         print("="*50)


# # Usage example
# def main():
#     import argparse
    
#     # Parse base arguments (similar to your main script)
#     parser = argparse.ArgumentParser()
#     parser.add_argument("--model_name", default="bert-base-uncased")
#     parser.add_argument("--batch_size", type=int, default=16)
#     parser.add_argument("--max_length", type=int, default=128)
#     parser.add_argument("--seed", type=int, default=42)
#     parser.add_argument("--n_trials", type=int, default=50)
#     parser.add_argument("--study_name", type=str, default=None)
    
#     args = parser.parse_args()
    
#     # Run hyperparameter tuning
#     tuner = LoRAHyperparameterTuner(args, n_trials=args.n_trials, study_name=args.study_name)
#     study = tuner.run_optimization()
    
#     # Train final model with best parameters
#     best_params = study.best_params
#     print(f"\nTraining final model with best parameters: {best_params}")
    
#     # You can now use these parameters in your main training script
#     # python your_main_script.py --lora_r {best_params['r']} --lora_alpha {best_params['alpha']} --learning_rate {best_params['learning_rate']}


# if __name__ == "__main__":
#     main()