#!/usr/bin/env python3
"""
Hypergrid Comparison for DATE-GFN

This script compares DATE-GFN with baseline methods on the Hypergrid environment
using GFlowNet implementations.
"""

import sys
import os
import torch
import numpy as np
import time
from pathlib import Path

# Add src to path
sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src"))

from date_gfn.core.date_gfn import DATEGFN
from date_gfn.baselines.baselines import GFNBaseline, EGFNBaseline
from date_gfn.environments.environments import HypergridEnvironment


def set_seed(seed: int):
    """Set random seeds for reproducibility."""
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)


def evaluate_method(method, env, num_samples=200):
    """Evaluate a method on the environment."""
    trajectories = method.sample(env, num_samples=num_samples)
    
    # Calculate metrics
    coverage, num_modes = env.calculate_mode_coverage(trajectories)
    diversity = env.calculate_diversity(trajectories)
    l1_error = env.calculate_l1_error(trajectories)
    
    # Reward statistics
    rewards = [env.get_reward(traj[-1]) for traj in trajectories]
    mean_reward = np.mean(rewards)
    max_reward = np.max(rewards)
    
    return {
        'mode_coverage': coverage,
        'num_modes_found': num_modes,
        'diversity': diversity,
        'l1_error': l1_error,
        'mean_reward': mean_reward,
        'max_reward': max_reward,
        'num_trajectories': len(trajectories)
    }


def train_method(method, env, num_steps=500):
    """Train a method for specified number of steps."""
    print(f"Training {method.__class__.__name__} for {num_steps} steps...")
    
    for step in range(num_steps):
        if hasattr(method, 'train_step'):
            # For baselines
            if isinstance(method, GFNBaseline):
                optimizer = torch.optim.Adam(method.parameters(), lr=5e-4)
                method.train_step(env, optimizer)
            elif isinstance(method, EGFNBaseline):
                optimizer = torch.optim.Adam(method.student_gfn.parameters(), lr=5e-4)
                method.train_step(env, optimizer)
        else:
            # For DATE-GFN
            method.evolutionary_phase(env)
            method.distillation_phase(env)
        
        if (step + 1) % 100 == 0:
            print(f"  Step {step + 1}/{num_steps}")


def run_comparison():
    """Run comprehensive comparison between methods."""
    
    print("=" * 60)
    print("🔬 Hypergrid Method Comparison")
    print("=" * 60)
    
    # Set seed for reproducibility
    set_seed(42)
    
    # Create environment configurations
    configs = [
        {"height": 8, "ndim": 2, "name": "8x8^2"},
        {"height": 6, "ndim": 3, "name": "6x6^3"},
        {"height": 5, "ndim": 4, "name": "5x5^4"},
    ]
    
    methods_config = {
        "TB-GFN": {"class": GFNBaseline, "color": "blue"},
        "EGFN": {"class": EGFNBaseline, "color": "green"},
        "DATE-GFN": {"class": DATEGFN, "color": "red"}
    }
    
    all_results = {}
    
    for config in configs:
        env_name = config["name"]
        print(f"\n{'='*20} Environment: {env_name} {'='*20}")
        
        # Create environment
        env = HypergridEnvironment(
            height=config["height"],
            ndim=config["ndim"],
            reward_beta=10.0,
            reward_at_corners=1e-4
        )
        
        print(f"Total modes: {len(env.corner_states)}")
        
        state_dim = env.ndim
        action_dim = 2 * env.ndim + 1
        
        env_results = {}
        
        for method_name, method_config in methods_config.items():
            print(f"\n--- {method_name} ---")
            
            # Create method
            if method_name == "TB-GFN":
                method = method_config["class"](
                    state_dim=state_dim,
                    action_dim=action_dim,
                    hidden_dim=256
                )
            elif method_name == "EGFN":
                method = method_config["class"](
                    state_dim=state_dim,
                    action_dim=action_dim,
                    hidden_dim=256,
                    population_size=50
                )
            elif method_name == "DATE-GFN":
                method = method_config["class"](
                    state_dim=state_dim,
                    action_dim=action_dim,
                    hidden_dim=256,
                    population_size=50,
                    teachability_weight=0.1
                )
                method.initialize_population()
            
            # Train method
            start_time = time.time()
            train_method(method, env, num_steps=300)
            training_time = time.time() - start_time
            
            # Evaluate method
            results = evaluate_method(method, env, num_samples=200)
            results['training_time'] = training_time
            
            env_results[method_name] = results
            
            print(f"Results:")
            print(f"  Mode Coverage: {results['mode_coverage']:.3f}")
            print(f"  Modes Found: {results['num_modes_found']}/{len(env.corner_states)}")
            print(f"  Diversity: {results['diversity']:.3f}")
            print(f"  L1 Error: {results['l1_error']:.3f}")
            print(f"  Mean Reward: {results['mean_reward']:.6f}")
            print(f"  Max Reward: {results['max_reward']:.6f}")
            print(f"  Training Time: {results['training_time']:.1f}s")
        
        all_results[env_name] = env_results
    
    # Print summary comparison
    print("\n" + "="*60)
    print("📊 SUMMARY COMPARISON")
    print("="*60)
    
    for env_name, env_results in all_results.items():
        print(f"\n{env_name}:")
        print(f"{'Method':<12} {'Coverage':<10} {'Diversity':<10} {'L1_Error':<10} {'Time(s)':<8}")
        print("-" * 50)
        
        for method_name, results in env_results.items():
            print(f"{method_name:<12} {results['mode_coverage']:<10.3f} "
                  f"{results['diversity']:<10.3f} {results['l1_error']:<10.3f} "
                  f"{results['training_time']:<8.1f}")
    
    # Find best performing method overall
    print(f"\n{'='*60}")
    print("🏆 OVERALL PERFORMANCE RANKING")
    print("="*60)
    
    # Calculate average performance across environments
    method_averages = {}
    for method_name in methods_config.keys():
        coverage_scores = []
        diversity_scores = []
        l1_scores = []
        
        for env_results in all_results.values():
            if method_name in env_results:
                coverage_scores.append(env_results[method_name]['mode_coverage'])
                diversity_scores.append(env_results[method_name]['diversity'])
                l1_scores.append(1 - env_results[method_name]['l1_error'])  # Higher is better
        
        method_averages[method_name] = {
            'avg_coverage': np.mean(coverage_scores),
            'avg_diversity': np.mean(diversity_scores),
            'avg_l1_score': np.mean(l1_scores)
        }
    
    # Sort by average coverage (primary metric)
    sorted_methods = sorted(method_averages.items(), 
                          key=lambda x: x[1]['avg_coverage'], reverse=True)
    
    for rank, (method_name, scores) in enumerate(sorted_methods, 1):
        print(f"{rank}. {method_name}: Coverage={scores['avg_coverage']:.3f}, "
              f"Diversity={scores['avg_diversity']:.3f}, "
              f"L1_Score={scores['avg_l1_score']:.3f}")
    
    print(f"\n✅ Comparison completed successfully!")


if __name__ == "__main__":
    run_comparison()
