#!/usr/bin/env python3
"""
DATE-GFN Hypergrid Experiment

This script runs a complete DATE-GFN experiment on the Hypergrid environment
using GFlowNet implementations with proper hyperparameters.
"""

import sys
import os
import argparse
import torch
import numpy as np
import time
from pathlib import Path

# Add src to path
sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src"))

from date_gfn.core.date_gfn import DATEGFN
from date_gfn.baselines.baselines import GFNBaseline, EGFNBaseline
from date_gfn.environments.environments import HypergridEnvironment


def set_seed(seed: int):
    """Set random seeds for reproducibility."""
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)


def run_experiment(method, env, num_steps=1000, eval_interval=100):
    """Run training experiment for a method."""
    print(f"\n🔥 Training {method.__class__.__name__}...")
    
    results = {
        'steps': [],
        'mode_coverage': [],
        'diversity': [],
        'l1_error': [],
        'best_reward': []
    }
    
    start_time = time.time()
    
    for step in range(num_steps):
        if hasattr(method, 'train_step'):
            # For baselines
            if isinstance(method, GFNBaseline):
                optimizer = torch.optim.Adam(method.parameters(), lr=5e-4)
                metrics = method.train_step(env, optimizer)
            elif isinstance(method, EGFNBaseline):
                optimizer = torch.optim.Adam(method.student_gfn.parameters(), lr=5e-4)
                metrics = method.train_step(env, optimizer)
        else:
            # For DATE-GFN
            method.evolutionary_phase(env)
            method.distillation_phase(env)
        
        # Evaluation
        if step % eval_interval == 0:
            trajectories = method.sample(env, num_samples=100)
            
            # Calculate metrics
            coverage, num_modes = env.calculate_mode_coverage(trajectories)
            diversity = env.calculate_diversity(trajectories)
            l1_error = env.calculate_l1_error(trajectories)
            
            # Best reward
            rewards = [env.get_reward(traj[-1]) for traj in trajectories]
            best_reward = max(rewards) if rewards else 0.0
            
            results['steps'].append(step)
            results['mode_coverage'].append(coverage)
            results['diversity'].append(diversity)
            results['l1_error'].append(l1_error)
            results['best_reward'].append(best_reward)
            
            print(f"Step {step:4d}: Coverage={coverage:.3f}, Diversity={diversity:.3f}, "
                  f"L1_Error={l1_error:.3f}, Best_Reward={best_reward:.6f}")
    
    elapsed = time.time() - start_time
    print(f"✓ Completed in {elapsed:.1f}s")
    
    return results


def main():
    parser = argparse.ArgumentParser(description="DATE-GFN Hypergrid Experiment")
    parser.add_argument("--steps", type=int, default=1000, help="Number of training steps")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("--height", type=int, default=10, help="Hypergrid height")
    parser.add_argument("--ndim", type=int, default=4, help="Hypergrid dimensions")
    parser.add_argument("--methods", nargs="+", default=["DATE-GFN", "TB-GFN", "EGFN"], 
                        choices=["DATE-GFN", "TB-GFN", "EGFN"],
                        help="Methods to compare")
    
    args = parser.parse_args()
    
    # Set random seed
    set_seed(args.seed)
    
    print("=" * 60)
    print("DATE-GFN Hypergrid Experiment")
    print("=" * 60)
    print(f"Environment: Hypergrid {args.height}^{args.ndim}")
    print(f"Training steps: {args.steps}")
    print(f"Methods: {', '.join(args.methods)}")
    print(f"Random seed: {args.seed}")
    print("=" * 60)
    
    # Create environment
    env = HypergridEnvironment(
        height=args.height, 
        ndim=args.ndim, 
        reward_beta=10.0, 
        reward_at_corners=1e-4
    )
    
    state_dim = env.ndim
    action_dim = 2 * env.ndim + 1  # movements + terminate
    
    all_results = {}
    
    # Run experiments for each method
    for method_name in args.methods:
        print(f"\n{'='*20} {method_name} {'='*20}")
        
        if method_name == "DATE-GFN":
            method = DATEGFN(
                state_dim=state_dim,
                action_dim=action_dim,
                hidden_dim=256,
                population_size=50,
                teachability_weight=0.1
            )
            method.initialize_population()
            
        elif method_name == "TB-GFN":
            method = GFNBaseline(
                state_dim=state_dim,
                action_dim=action_dim,
                hidden_dim=256
            )
            
        elif method_name == "EGFN":
            method = EGFNBaseline(
                state_dim=state_dim,
                action_dim=action_dim,
                hidden_dim=256,
                population_size=50
            )
        
        # Run experiment
        results = run_experiment(method, env, args.steps, eval_interval=100)
        all_results[method_name] = results
    
    # Print final comparison
    print("\n" + "="*60)
    print("📊 FINAL RESULTS COMPARISON")
    print("="*60)
    
    for method_name, results in all_results.items():
        if results['steps']:
            final_coverage = results['mode_coverage'][-1]
            final_diversity = results['diversity'][-1]
            final_l1_error = results['l1_error'][-1]
            final_reward = results['best_reward'][-1]
            
            print(f"{method_name:12s}: Coverage={final_coverage:.3f}, "
                  f"Diversity={final_diversity:.3f}, L1_Error={final_l1_error:.3f}, "
                  f"Best_Reward={final_reward:.6f}")
    
    print("\n✅ Experiment completed successfully!")


if __name__ == "__main__":
    main()
