#!/usr/bin/env python3

import argparse
import warnings
warnings.filterwarnings('ignore')

from config import DEFAULT_MODEL_CONFIGS, TrainingConfig, FeedbackConfig, HybridRLConfig
from experiments import ExperimentRunner
from visualization import (
    plot_training_curves, plot_performance_comparison, plot_reward_accumulation,
    plot_ablation_study, plot_alpha_evolution, print_performance_table,
    plot_ppo_training_stats, print_ppo_performance_table, generate_comprehensive_report
)
from utils import set_random_seeds, save_results, load_results, check_system_requirements, ExperimentLogger

def setup_experiment():
    print("HYBRID REINFORCEMENT LEARNING FRAMEWORK")
    print("Implementation of: 'Mitigating Hallucinations in Large Language Models via HRL'")
    print("="*80)
    
    set_random_seeds(42)
    check_system_requirements()
    
    try:
        runner = ExperimentRunner(DEFAULT_MODEL_CONFIGS)
        print("Experiment runner initialized successfully")
        return runner, True
    except Exception as e:
        print(f"Error initializing experiment runner: {e}")
        return None, False

def run_main_experiments(runner: ExperimentRunner, epochs: int = 20):
    logger = ExperimentLogger()
    logger.log("Starting main experiments")
    
    try:
        results = runner.run_experiments(epochs=epochs)
        logger.log("Main experiments completed successfully")
        
        save_results(results, "results/main_experiments.json")
        logger.log("Results saved to results/main_experiments.json")
        
        return results
    except Exception as e:
        logger.log(f"Error in main experiments: {e}", "ERROR")
        return None

def run_ablation_experiments(runner: ExperimentRunner):
    logger = ExperimentLogger()
    logger.log("Starting ablation studies")
    
    ablation_results = {}
    
    for model_config in runner.model_configs:
        model_name = model_config.model_name
        logger.log(f"Running ablation study for {model_name}")
        
        try:
            results = runner.run_ablation_study(model_config)
            ablation_results[model_name] = results
            logger.log(f"Ablation study completed for {model_name}")
        except Exception as e:
            logger.log(f"Error in ablation study for {model_name}: {e}", "ERROR")
            continue
    
    if ablation_results:
        save_results(ablation_results, "results/ablation_studies.json")
        logger.log("Ablation results saved")
    
    return ablation_results

def run_domain_experiments(runner: ExperimentRunner):
    logger = ExperimentLogger()
    logger.log("Starting domain experiments")
    
    domain_results = {}
    
    for model_config in runner.model_configs:
        model_name = model_config.model_name
        logger.log(f"Running domain experiments for {model_name}")
        
        try:
            results = runner.run_domain_experiments(model_config)
            domain_results[model_name] = results
            logger.log(f"Domain experiments completed for {model_name}")
        except Exception as e:
            logger.log(f"Error in domain experiments for {model_name}: {e}", "ERROR")
            continue
    
    if domain_results:
        save_results(domain_results, "results/domain_experiments.json")
        logger.log("Domain results saved")
    
    return domain_results

def generate_all_visualizations(results, ablation_results=None, domain_results=None):
    print("\nGenerating visualizations...")
    
    for model_name in results.keys():
        print(f"\nVisualizations for {model_name}:")
        
        plot_training_curves(results, model_name)
        plot_performance_comparison(results, model_name)
        plot_reward_accumulation(results, model_name)
        plot_alpha_evolution(results, model_name)
        print_performance_table(results, model_name)
        
        # Plot PPO-specific stats for RL methods
        for method in ['RLHF', 'RLAIF', 'Static_Hybrid', 'HRL']:
            if method in results[model_name]:
                plot_ppo_training_stats(results, model_name, method)
        
        if ablation_results and model_name in ablation_results:
            plot_ablation_study(ablation_results[model_name], model_name)
    
    generate_comprehensive_report(results, ablation_results, domain_results)

def main():
    parser = argparse.ArgumentParser(description="Run HRL experiments")
    parser.add_argument('--epochs', type=int, default=20, help='Number of training epochs')
    parser.add_argument('--skip-ablation', action='store_true', help='Skip ablation studies')
    parser.add_argument('--skip-domain', action='store_true', help='Skip domain experiments')
    parser.add_argument('--skip-viz', action='store_true', help='Skip visualization generation')
    parser.add_argument('--load-results', type=str, help='Load existing results file')
    parser.add_argument('--quick-test', action='store_true', help='Run quick test with reduced epochs')
    
    args = parser.parse_args()
    
    if args.quick_test:
        args.epochs = 5
        print("Quick test mode: using 5 epochs")
    
    if args.load_results:
        print(f"Loading existing results from {args.load_results}")
        try:
            results = load_results(args.load_results)
            ablation_results = load_results("results/ablation_studies.json") if not args.skip_ablation else None
            domain_results = load_results("results/domain_experiments.json") if not args.skip_domain else None
            
            if not args.skip_viz:
                generate_all_visualizations(results, ablation_results, domain_results)
            
            return
        except Exception as e:
            print(f"Error loading results: {e}")
            print("Proceeding with new experiments...")
    
    runner, success = setup_experiment()
    if not success:
        print("Failed to initialize experiment. Exiting.")
        return
    
    results = run_main_experiments(runner, args.epochs)
    if not results:
        print("Main experiments failed. Exiting.")
        return
    
    ablation_results = None
    if not args.skip_ablation:
        ablation_results = run_ablation_experiments(runner)
    
    domain_results = None
    if not args.skip_domain:
        domain_results = run_domain_experiments(runner)
    
    if not args.skip_viz:
        generate_all_visualizations(results, ablation_results, domain_results)
    
    print("\n" + "="*80)
    print("ALL EXPERIMENTS COMPLETED SUCCESSFULLY")
    print("="*80)
    print("Results saved in the 'results/' directory")
    print("Use --load-results to reload and re-visualize existing results")

if __name__ == "__main__":
    main()