"""
RQ6: Scalability and Efficiency Analysis

Scientific Question: How does DATE-GFN's time-to-solution scale with increasing 
problem difficulty compared to baselines?

Tests on Easy (H=10,D=3), Medium (H=20,D=4), Hard (H=30,D=5) Hypergrid configurations.
"""

import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))

import yaml
import torch
import numpy as np
import time
from shared import (
    DATEGFN, GFNBaseline, HypergridEnvironment, WandbLogger, set_seed,
    save_results, create_experiment_directory
)

class RQ6Experiment:
    """Scalability analysis experiment."""
    
    def __init__(self, config):
        self.config = config
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.wandb_logger = WandbLogger("date_gfn_rq6")
    
    def create_environment(self, difficulty):
        """Create environment of specified difficulty."""
        difficulties = {
            'easy': {'height': 10, 'ndim': 3, 'state_dim': 3, 'action_dim': 7},
            'medium': {'height': 20, 'ndim': 4, 'state_dim': 4, 'action_dim': 9},
            'hard': {'height': 30, 'ndim': 5, 'state_dim': 5, 'action_dim': 11}
        }
        
        config = difficulties[difficulty]
        return HypergridEnvironment(
            height=config['height'],
            ndim=config['ndim'],
            reward_beta=10.0,
            reward_at_corners=1e-5
        ), config
    
    def run_scalability_test(self):
        """Run scalability analysis across difficulty levels."""
        print("RQ6: Scalability Analysis")
        print("Testing scaling across problem difficulties...")
        
        results = {}
        
        for difficulty in ['easy', 'medium', 'hard']:
            print(f"\nTesting {difficulty} difficulty...")
            
            environment, env_config = self.create_environment(difficulty)
            difficulty_results = {}
            
            # Test DATE-GFN and GFN-TB
            for method_name in ['DATE-GFN', 'GFN-TB']:
                print(f"  {method_name}...")
                
                times_to_solution = []
                
                for seed in range(3):  # Reduced for time
                    set_seed(seed)
                    
                    # Create method
                    if method_name == 'DATE-GFN':
                        method = DATEGFN(
                            state_dim=env_config['state_dim'],
                            action_dim=env_config['action_dim'],
                            hidden_dim=self.config['model']['hidden_dim'],
                            teachability_weight=0.1,
                            device=self.device
                        )
                        optimizer_f = torch.optim.Adam(method.forward_policy.parameters(), lr=1e-3)
                        optimizer_b = torch.optim.Adam(method.backward_policy.parameters(), lr=1e-3)
                    else:
                        method = GFNBaseline(
                            state_dim=env_config['state_dim'],
                            action_dim=env_config['action_dim'],
                            hidden_dim=self.config['model']['hidden_dim']
                        )
                        optimizer = torch.optim.Adam(
                            list(method.forward_policy.parameters()) +
                            list(method.backward_policy.parameters()) +
                            list(method.state_value.parameters()), lr=1e-3
                        )
                    
                    # Measure time to reach performance threshold
                    start_time = time.time()
                    target_reached = False
                    
                    for step in range(1000):  # Reduced for time
                        if method_name == 'DATE-GFN':
                            method.train_step(environment, optimizer_f, optimizer_b)
                        else:
                            method.train_step(environment, optimizer)
                        
                        # Check performance every 50 steps
                        if step % 50 == 0:
                            trajectories = method.sample(environment, 50)
                            coverage, modes = environment.calculate_mode_coverage(trajectories)
                            
                            # Target: 80% mode coverage
                            if coverage >= 0.8:
                                time_to_solution = time.time() - start_time
                                times_to_solution.append(time_to_solution)
                                target_reached = True
                                break
                    
                    if not target_reached:
                        times_to_solution.append(float('inf'))  # Did not converge
                
                difficulty_results[method_name] = {
                    'times_to_solution': times_to_solution,
                    'mean_time': np.mean([t for t in times_to_solution if t != float('inf')]),
                    'success_rate': sum(1 for t in times_to_solution if t != float('inf')) / len(times_to_solution)
                }
            
            results[difficulty] = difficulty_results
        
        # Analyze scaling
        analysis = self.analyze_scaling_results(results)
        
        exp_dir = create_experiment_directory("results", "RQ6_scalability")
        save_results(analysis, exp_dir / "rq6_analysis.json")
        
        print("✓ RQ6 analysis completed")
        return analysis
    
    def analyze_scaling_results(self, results):
        """Analyze scaling behavior."""
        analysis = {
            'research_question': 'RQ6 - Scalability Analysis',
            'hypothesis': 'DATE-GFN scales better than baselines',
            'scaling_results': results,
            'scaling_analysis': {},
            'status': 'completed'
        }
        
        # Analyze scaling for each method
        for method_name in ['DATE-GFN', 'GFN-TB']:
            times = []
            difficulties = ['easy', 'medium', 'hard']
            
            for difficulty in difficulties:
                mean_time = results[difficulty][method_name]['mean_time']
                if mean_time != float('inf'):
                    times.append(mean_time)
                else:
                    times.append(None)
            
            # Calculate scaling factor
            if times[0] and times[-1]:
                scaling_factor = times[-1] / times[0]  # hard / easy
            else:
                scaling_factor = float('inf')
            
            analysis['scaling_analysis'][method_name] = {
                'times_by_difficulty': dict(zip(difficulties, times)),
                'scaling_factor': scaling_factor,
                'scales_well': scaling_factor < 10  # Less than 10x increase
            }
        
        # Compare scaling
        date_gfn_scaling = analysis['scaling_analysis']['DATE-GFN']['scaling_factor']
        gfn_tb_scaling = analysis['scaling_analysis']['GFN-TB']['scaling_factor']
        
        analysis['date_gfn_better_scaling'] = date_gfn_scaling < gfn_tb_scaling
        analysis['scaling_improvement'] = (gfn_tb_scaling - date_gfn_scaling) / gfn_tb_scaling if gfn_tb_scaling != float('inf') else 0
        
        return analysis

def main():
    with open('configs/base_config.yaml', 'r') as f:
        config = yaml.safe_load(f)
    
    experiment = RQ6Experiment(config)
    return experiment.run_scalability_test()

if __name__ == "__main__":
    main()
