#!/usr/bin/env python3
"""
Integration test for the GFlowNet implementation.
"""

import torch
import numpy as np
import sys
import os

# Add src to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))

from date_gfn.core.gfn_base import GFlowNetBase, GFlowNetPolicy
from date_gfn.core.date_gfn import DATEGFN
from date_gfn.baselines.baselines import GFNBaseline, EGFNBaseline
from date_gfn.environments.environments import HypergridEnvironment


def test_gfn_base():
    """Test basic GFlowNet functionality."""
    print("Testing GFlowNetBase...")
    
    # Create environment
    env = HypergridEnvironment(height=5, ndim=2)
    state_dim = env.ndim
    action_dim = 2 * env.ndim + 1  # movements + terminate
    
    # Create GFlowNet
    gfn = GFlowNetBase(state_dim, action_dim, hidden_dim=64)
    print(f"✓ Created GFlowNetBase with state_dim={state_dim}, action_dim={action_dim}")
    
    # Test sampling
    trajectories = gfn.sample(env, num_samples=5)
    print(f"✓ Sampled {len(trajectories)} trajectories")
    
    # Test trajectory balance loss
    loss = gfn.calculate_trajectory_balance_loss(trajectories, env)
    print(f"✓ Calculated TB loss: {loss.item():.4f}")
    
    # Test action probabilities
    state = torch.tensor(env.reset(), dtype=torch.float32)
    valid_actions = env.get_valid_actions(env.reset())
    action_mask = torch.zeros(action_dim, dtype=torch.bool)
    action_mask[valid_actions] = True
    
    action_probs = gfn.get_action_probs(state, action_mask)
    print(f"✓ Got action probabilities: {action_probs.shape}")
    
    print("GFlowNetBase test passed!\n")


def test_date_gfn():
    """Test DATE-GFN functionality."""
    print("Testing DATE-GFN...")
    
    # Create environment
    env = HypergridEnvironment(height=5, ndim=2)
    state_dim = env.ndim
    action_dim = 2 * env.ndim + 1
    
    # Create DATE-GFN
    date_gfn = DATEGFN(
        state_dim=state_dim,
        action_dim=action_dim,
        hidden_dim=64,
        population_size=50,
        teachability_weight=0.1
    )
    print(f"✓ Created DATE-GFN")
    
    # Initialize population
    date_gfn.initialize_population()
    print(f"✓ Initialized population of size {len(date_gfn.critic_population)}")
    
    # Test evolutionary phase
    date_gfn.evolutionary_phase(env)
    print(f"✓ Completed evolutionary phase")
    
    # Test distillation phase
    date_gfn.distillation_phase(env)
    print(f"✓ Completed distillation phase")
    
    # Test sampling
    trajectories = date_gfn.sample(env, num_samples=3)
    print(f"✓ Sampled {len(trajectories)} trajectories")
    
    print("DATE-GFN test passed!\n")


def test_baselines():
    """Test baseline methods."""
    print("Testing Baselines...")
    
    # Create environment
    env = HypergridEnvironment(height=5, ndim=2)
    state_dim = env.ndim
    action_dim = 2 * env.ndim + 1
    
    # Test GFNBaseline
    gfn_baseline = GFNBaseline(state_dim, action_dim, hidden_dim=64)
    optimizer = torch.optim.Adam(gfn_baseline.parameters(), lr=1e-3)
    
    metrics = gfn_baseline.train_step(env, optimizer, batch_size=4)
    print(f"✓ GFNBaseline training step: {metrics}")
    
    trajectories = gfn_baseline.sample(env, num_samples=3)
    print(f"✓ GFNBaseline sampled {len(trajectories)} trajectories")
    
    # Test EGFNBaseline
    egfn_baseline = EGFNBaseline(state_dim, action_dim, hidden_dim=64, population_size=50)
    optimizer = torch.optim.Adam(egfn_baseline.student_gfn.parameters(), lr=1e-3)
    
    metrics = egfn_baseline.train_step(env, optimizer)
    print(f"✓ EGFNBaseline training step: {metrics}")
    
    trajectories = egfn_baseline.sample(env, num_samples=3)
    print(f"✓ EGFNBaseline sampled {len(trajectories)} trajectories")
    
    print("Baselines test passed!\n")


def test_environment_metrics():
    """Test environment-specific metrics."""
    print("Testing Environment Metrics...")
    
    env = HypergridEnvironment(height=5, ndim=2)
    
    # Generate some test trajectories
    trajectories = env.sample_batch(batch_size=10, max_length=20)
    print(f"✓ Generated {len(trajectories)} test trajectories")
    
    # Test mode coverage
    coverage, num_modes = env.calculate_mode_coverage(trajectories)
    print(f"✓ Mode coverage: {coverage:.3f}, Modes found: {num_modes}")
    
    # Test diversity
    diversity = env.calculate_diversity(trajectories)
    print(f"✓ Diversity: {diversity:.3f}")
    
    # Test L1 error
    l1_error = env.calculate_l1_error(trajectories)
    print(f"✓ L1 error: {l1_error:.3f}")
    
    print("Environment metrics test passed!\n")


def main():
    """Run all integration tests."""
    print("=" * 50)
    print("DATE-GFN Integration Tests")
    print("=" * 50)
    
    try:
        test_gfn_base()
        test_date_gfn()
        test_baselines()
        test_environment_metrics()
        
        print("=" * 50)
        print("🎉 ALL TESTS PASSED!")
        print("The GFlowNet implementation is working correctly.")
        print("=" * 50)
        
    except Exception as e:
        print(f" Test failed with error: {e}")
        import traceback
        traceback.print_exc()
        return 1
    
    return 0


if __name__ == "__main__":
    exit(main())
