"""
RQ5: Mitigation of Teacher Mode Collapse

Scientific Question: Does the co-evolutionary dynamic resist premature convergence 
of the critic population to a single mode?

Compares critic population diversity in DATE-GFN vs standard EA.
"""

import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))

import yaml
import torch
import numpy as np
from shared import (
    DATEGFN, HypergridEnvironment, WandbLogger, set_seed,
    save_results, create_experiment_directory
)

class RQ5Experiment:
    """Population diversity analysis experiment."""
    
    def __init__(self, config):
        self.config = config
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.wandb_logger = WandbLogger("date_gfn_rq5")
    
    def run_diversity_analysis(self):
        """Run population diversity analysis."""
        print("RQ5: Teacher Mode Collapse Mitigation")
        print("Analyzing critic population diversity...")
        
        # Create environment
        env_config = yaml.safe_load(open('configs/hypergrid_config.yaml'))['environment']
        environment = HypergridEnvironment(
            height=env_config['height'],
            ndim=env_config['ndim'],
            reward_beta=env_config['reward_beta'],
            reward_at_corners=env_config['reward_at_corners']
        )
        
        diversity_results = {}
        
        # Test DATE-GFN vs standard EA
        for method_type in ['DATE-GFN', 'Standard-EA']:
            print(f"Testing {method_type}...")
            
            seed_results = []
            for seed in range(3):  # Reduced for time
                set_seed(seed)
                
                # Create method
                if method_type == 'DATE-GFN':
                    teachability_weight = 0.1
                else:
                    teachability_weight = 0.0  # No teachability penalty
                
                method = DATEGFN(
                    state_dim=env_config['state_dim'],
                    action_dim=env_config['action_dim'],
                    hidden_dim=self.config['model']['hidden_dim'],
                    teachability_weight=teachability_weight,
                    device=self.device
                )
                
                # Track diversity over generations
                diversity_history = []
                
                for generation in range(20):  # Reduced for time
                    method.evolutionary_phase(environment)
                    diversity = method._calculate_population_diversity()
                    diversity_history.append(diversity)
                
                seed_results.append({
                    'diversity_history': diversity_history,
                    'final_diversity': diversity_history[-1],
                    'avg_diversity': np.mean(diversity_history)
                })
            
            diversity_results[method_type] = seed_results
        
        # Analyze results
        analysis = self.analyze_diversity_results(diversity_results)
        
        exp_dir = create_experiment_directory("results", "RQ5_diversity")
        save_results(analysis, exp_dir / "rq5_analysis.json")
        
        print("✓ RQ5 analysis completed")
        return analysis
    
    def analyze_diversity_results(self, results):
        """Analyze diversity maintenance results."""
        analysis = {
            'research_question': 'RQ5 - Teacher Mode Collapse Mitigation',
            'hypothesis': 'DATE-GFN maintains higher population diversity',
            'key_findings': {},
            'status': 'completed'
        }
        
        for method_type, seed_results in results.items():
            avg_diversity = np.mean([r['avg_diversity'] for r in seed_results])
            final_diversity = np.mean([r['final_diversity'] for r in seed_results])
            
            analysis['key_findings'][method_type] = {
                'avg_diversity': avg_diversity,
                'final_diversity': final_diversity,
                'diversity_maintenance': 'high' if avg_diversity > 0.5 else 'low'
            }
        
        # Compare DATE-GFN vs Standard EA
        date_gfn_diversity = analysis['key_findings']['DATE-GFN']['avg_diversity']
        standard_ea_diversity = analysis['key_findings']['Standard-EA']['avg_diversity']
        
        improvement = (date_gfn_diversity - standard_ea_diversity) / (standard_ea_diversity + 1e-8)
        
        analysis['diversity_improvement'] = improvement
        analysis['hypothesis_confirmed'] = improvement > 0.2  # >20% improvement
        
        return analysis

def main():
    with open('configs/base_config.yaml', 'r') as f:
        config = yaml.safe_load(f)
    
    experiment = RQ5Experiment(config)
    return experiment.run_diversity_analysis()

if __name__ == "__main__":
    main()
