#!/usr/bin/env python3
"""
Science World Environment Test Script

Tests the complete Science World environment integration including:
- Basic imports and initialization
- Ray remote actors
- Environment operations (reset, step)
- Projection functions
- Environment manager
- Multi-process environments
"""

import sys
import os
import traceback
import json
import random
from typing import List, Dict, Any

# Test results tracking
test_results = {
    "passed": [],
    "failed": [],
    "warnings": []
}

def test_basic_imports():
    """Test 1: Basic imports"""
    print("\n" + "="*60)
    print("Test 1: Basic Imports")
    print("="*60)

    try:
        # Test ScienceWorld import
        print("  Testing ScienceWorld import...")
        from agent_system.environments.env_package.sciworld.ScienceWorld.scienceworld import ScienceWorldEnv

        # Test Ray import
        print("  Testing Ray import...")
        import ray

        # Test projection function import
        print("  Testing projection function import...")
        from agent_system.environments.env_package.sciworld import sciworld_projection

        # Test environment manager import
        print("  Testing environment manager import...")
        from agent_system.environments.env_manager import SciWorldEnvironmentManager

        test_results["passed"].append("test_basic_imports")
        print("✓ Test 1: Basic imports - PASSED")
        return True
    except Exception as e:
        test_results["failed"].append(f"test_basic_imports: {str(e)}")
        print(f"✗ Test 1: Basic imports - FAILED: {str(e)}")
        traceback.print_exc()
        return False


def test_single_environment():
    """Test 2: Single environment operations"""
    print("\n" + "="*60)
    print("Test 2: Single Environment Operations")
    print("="*60)

    try:
        import sys
        import os

        # Add ScienceWorld to path
        sciworld_path = os.path.join(
            os.path.dirname(__file__),
            'agent_system/environments/env_package/sciworld/ScienceWorld'
        )
        sys.path.insert(0, sciworld_path)

        from scienceworld import ScienceWorldEnv

        # Get JAR path
        jar_path = os.path.join(
            os.path.dirname(__file__),
            'agent_system/environments/env_package/sciworld/ScienceWorld/scienceworld/scienceworld.jar'
        )

        print(f"  JAR path: {jar_path}")
        print(f"  JAR exists: {os.path.exists(jar_path)}")

        # Create environment
        print("  Creating ScienceWorld environment...")
        env = ScienceWorldEnv("", jar_path, envStepLimit=30)

        # Get task names
        print("  Getting task names...")
        task_names = env.get_task_names()
        print(f"  Found {len(task_names)} tasks")
        assert len(task_names) > 0, "No task names found"

        # Load a task
        print(f"  Loading task: {task_names[0]}")
        env.load(task_names[0], 0, "easy")

        # Reset
        print("  Testing reset...")
        obs, info = env.reset()
        assert isinstance(obs, str), "Observation should be string"
        assert isinstance(info, dict), "Info should be dict"
        print(f"  Reset successful. Observation length: {len(obs)}")

        # Step
        print("  Testing step with action 'look around'...")
        obs, reward, done, info = env.step("look around")
        assert isinstance(obs, str), "Observation should be string"
        assert isinstance(reward, (int, float)), "Reward should be numeric"
        assert isinstance(done, bool), "Done should be boolean"
        print(f"  Step successful. Reward: {reward}, Done: {done}")

        # Close
        print("  Closing environment...")
        env.close()

        test_results["passed"].append("test_single_environment")
        print("✓ Test 2: Single environment - PASSED")
        return True
    except Exception as e:
        test_results["failed"].append(f"test_single_environment: {str(e)}")
        print(f"✗ Test 2: Single environment - FAILED: {str(e)}")
        traceback.print_exc()
        return False


def test_ray_remote_actor():
    """Test 3: Ray remote actor functionality"""
    print("\n" + "="*60)
    print("Test 3: Ray Remote Actor")
    print("="*60)

    try:
        import ray
        from agent_system.environments.env_package.sciworld.envs import SciWorldWorker

        # Initialize Ray
        print("  Initializing Ray...")
        if not ray.is_initialized():
            ray.init(ignore_reinit_error=True)
        print(f"  Ray initialized: {ray.is_initialized()}")

        # Prepare env_kwargs
        jar_path = os.path.join(
            os.path.dirname(__file__),
            'agent_system/environments/env_package/sciworld/ScienceWorld/scienceworld/scienceworld.jar'
        )

        variations_idx_path = os.path.join(
            os.path.dirname(__file__),
            'agent_system/environments/env_package/sciworld/variations_idx/L0_idx.json'
        )

        print(f"  Loading variations from: {variations_idx_path}")
        with open(variations_idx_path, 'r') as f:
            variations_idx = json.load(f)
        print(f"  Loaded {len(variations_idx)} variations")

        env_kwargs = {
            'jar_path': jar_path,
            'env_step_limit': 30,
            'simplifications_preset': 'easy',
            'variations_idx': variations_idx
        }

        # Create remote actor
        print("  Creating Ray remote actor...")
        worker_class = ray.remote(num_cpus=0.1)(SciWorldWorker)
        worker = worker_class.remote(seed=0, env_kwargs=env_kwargs)

        # Test reset
        print("  Testing remote reset...")
        obs, info = ray.get(worker.reset.remote(0))
        assert isinstance(obs, str), "Observation should be string"
        assert 'task_description' in info, "Info should contain task_description"
        print(f"  Reset successful. Task: {info['task_description'][:50]}...")

        # Test step
        print("  Testing remote step...")
        obs, reward, done, info = ray.get(worker.step.remote("look around"))
        assert isinstance(obs, str), "Observation should be string"
        print(f"  Step successful. Reward: {reward}, Done: {done}")

        # Cleanup
        print("  Cleaning up Ray actor...")
        ray.kill(worker)

        test_results["passed"].append("test_ray_remote_actor")
        print("✓ Test 3: Ray remote actor - PASSED")
        return True
    except Exception as e:
        test_results["failed"].append(f"test_ray_remote_actor: {str(e)}")
        print(f"✗ Test 3: Ray remote actor - FAILED: {str(e)}")
        traceback.print_exc()
        return False


def test_multi_process_env():
    """Test 4: Multi-process environment"""
    print("\n" + "="*60)
    print("Test 4: Multi-Process Environment")
    print("="*60)

    try:
        from agent_system.environments.env_package.sciworld import build_sciworld_envs

        # Prepare env_kwargs
        jar_path = os.path.join(
            os.path.dirname(__file__),
            'agent_system/environments/env_package/sciworld/ScienceWorld/scienceworld/scienceworld.jar'
        )

        variations_idx_path = os.path.join(
            os.path.dirname(__file__),
            'agent_system/environments/env_package/sciworld/variations_idx/L0_idx.json'
        )

        with open(variations_idx_path, 'r') as f:
            variations_idx = json.load(f)

        env_kwargs = {
            'jar_path': jar_path,
            'env_step_limit': 30,
            'simplifications_preset': 'easy',
            'variations_idx': variations_idx
        }

        # Create multi-process environment (small batch for testing)
        print("  Creating multi-process environment (2 envs x 2 group_n = 4 workers)...")
        envs = build_sciworld_envs(
            seed=0,
            env_num=2,
            group_n=2,
            resources_per_worker={'num_cpus': 0.1},
            is_train=True,
            env_kwargs=env_kwargs
        )

        # Test reset
        print("  Testing vectorized reset...")
        obs_list, info_list = envs.reset()
        assert len(obs_list) == 4, f"Expected 4 observations, got {len(obs_list)}"
        assert len(info_list) == 4, f"Expected 4 infos, got {len(info_list)}"
        print(f"  Reset successful. Got {len(obs_list)} observations")

        # Test step
        print("  Testing vectorized step...")
        actions = ["look around"] * 4
        obs_list, reward_list, done_list, info_list = envs.step(actions)
        assert len(obs_list) == 4, "Expected 4 observations"
        assert len(reward_list) == 4, "Expected 4 rewards"
        print(f"  Step successful. Rewards: {reward_list}")

        # Cleanup
        print("  Cleaning up environments...")
        envs.close()

        test_results["passed"].append("test_multi_process_env")
        print("✓ Test 4: Multi-process environment - PASSED")
        return True
    except Exception as e:
        test_results["failed"].append(f"test_multi_process_env: {str(e)}")
        print(f"✗ Test 4: Multi-process environment - FAILED: {str(e)}")
        traceback.print_exc()
        return False


def test_projection_function():
    """Test 5: Projection function"""
    print("\n" + "="*60)
    print("Test 5: Projection Function")
    print("="*60)

    try:
        from agent_system.environments.env_package.sciworld import sciworld_projection

        # Test valid action
        print("  Testing valid action format...")
        valid_action = "<think>I should look around</think><action>look around</action>"
        actions, valids = sciworld_projection([valid_action])
        assert actions[0] == "look around", f"Expected 'look around', got '{actions[0]}'"
        assert valids[0] == 1, "Valid action should have valid=1"
        print(f"  Valid action test passed. Action: '{actions[0]}', Valid: {valids[0]}")

        # Test invalid action (no think tags)
        print("  Testing invalid action (no think tags)...")
        invalid_action = "<action>look around</action>"
        actions, valids = sciworld_projection([invalid_action])
        assert valids[0] == 0, "Invalid action should have valid=0"
        print(f"  Invalid action test passed. Valid: {valids[0]}")

        # Test invalid action (no action tags)
        print("  Testing invalid action (no action tags)...")
        invalid_action2 = "<think>I should look around</think>"
        actions, valids = sciworld_projection([invalid_action2])
        assert valids[0] == 0, "Invalid action should have valid=0"
        print(f"  Invalid action test passed. Valid: {valids[0]}")

        # Test Chinese characters
        print("  Testing Chinese character detection...")
        chinese_action = "<think>我应该看看周围</think><action>look around</action>"
        actions, valids = sciworld_projection([chinese_action])
        assert valids[0] == 0, "Action with Chinese should have valid=0"
        print(f"  Chinese character test passed. Valid: {valids[0]}")

        test_results["passed"].append("test_projection_function")
        print("✓ Test 5: Projection function - PASSED")
        return True
    except Exception as e:
        test_results["failed"].append(f"test_projection_function: {str(e)}")
        print(f"✗ Test 5: Projection function - FAILED: {str(e)}")
        traceback.print_exc()
        return False


def test_environment_manager():
    """Test 6: Environment manager"""
    print("\n" + "="*60)
    print("Test 6: Environment Manager")
    print("="*60)

    try:
        from agent_system.environments.env_manager import SciWorldEnvironmentManager
        from agent_system.environments.env_package.sciworld import build_sciworld_envs, sciworld_projection
        from functools import partial
        from omegaconf import OmegaConf

        # Prepare config
        config = OmegaConf.create({
            'env': {
                'history_length': 2,
                'max_steps': 30,
                'sciworld': {
                    'simplifications_preset': 'easy'
                }
            }
        })

        # Prepare env_kwargs
        jar_path = os.path.join(
            os.path.dirname(__file__),
            'agent_system/environments/env_package/sciworld/ScienceWorld/scienceworld/scienceworld.jar'
        )

        variations_idx_path = os.path.join(
            os.path.dirname(__file__),
            'agent_system/environments/env_package/sciworld/variations_idx/L0_idx.json'
        )

        with open(variations_idx_path, 'r') as f:
            variations_idx = json.load(f)

        env_kwargs = {
            'jar_path': jar_path,
            'env_step_limit': 30,
            'simplifications_preset': 'easy',
            'variations_idx': variations_idx
        }

        # Create environments
        print("  Creating environments...")
        _envs = build_sciworld_envs(
            seed=0,
            env_num=2,
            group_n=1,
            resources_per_worker={'num_cpus': 0.1},
            is_train=True,
            env_kwargs=env_kwargs
        )

        # Create manager
        print("  Creating environment manager...")
        projection_f = partial(sciworld_projection)
        env_manager = SciWorldEnvironmentManager(_envs, projection_f, config)

        # Test reset
        print("  Testing manager reset...")
        observations, infos = env_manager.reset(kwargs={})
        assert 'text' in observations, "Observations should contain 'text'"
        assert 'image' in observations, "Observations should contain 'image'"
        assert 'anchor' in observations, "Observations should contain 'anchor'"
        assert len(observations['text']) == 2, "Expected 2 text observations"
        print(f"  Reset successful. Got {len(observations['text'])} text observations")

        # Test step
        print("  Testing manager step...")
        text_actions = [
            "<think>I should look around</think><action>look around</action>",
            "<think>I should examine something</think><action>examine table</action>"
        ]
        next_observations, rewards, dones, infos = env_manager.step(text_actions)
        assert len(rewards) == 2, "Expected 2 rewards"
        assert len(dones) == 2, "Expected 2 dones"
        print(f"  Step successful. Rewards: {rewards}, Dones: {dones}")

        # Cleanup
        print("  Cleaning up...")
        _envs.close()

        test_results["passed"].append("test_environment_manager")
        print("✓ Test 6: Environment manager - PASSED")
        return True
    except Exception as e:
        test_results["failed"].append(f"test_environment_manager: {str(e)}")
        print(f"✗ Test 6: Environment manager - FAILED: {str(e)}")
        traceback.print_exc()
        return False


def test_integration():
    """Test 7: Integration test with mini training loop"""
    print("\n" + "="*60)
    print("Test 7: Integration Test (Mini Training Loop)")
    print("="*60)

    try:
        from agent_system.environments.env_manager import SciWorldEnvironmentManager
        from agent_system.environments.env_package.sciworld import build_sciworld_envs, sciworld_projection
        from functools import partial
        from omegaconf import OmegaConf

        # Prepare config
        config = OmegaConf.create({
            'env': {
                'history_length': 2,
                'max_steps': 30,
                'sciworld': {
                    'simplifications_preset': 'easy'
                }
            }
        })

        # Prepare env_kwargs
        jar_path = os.path.join(
            os.path.dirname(__file__),
            'agent_system/environments/env_package/sciworld/ScienceWorld/scienceworld/scienceworld.jar'
        )

        variations_idx_path = os.path.join(
            os.path.dirname(__file__),
            'agent_system/environments/env_package/sciworld/variations_idx/L0_idx.json'
        )

        with open(variations_idx_path, 'r') as f:
            variations_idx = json.load(f)

        env_kwargs = {
            'jar_path': jar_path,
            'env_step_limit': 30,
            'simplifications_preset': 'easy',
            'variations_idx': variations_idx
        }

        # Create environments
        print("  Creating environments...")
        _envs = build_sciworld_envs(
            seed=0,
            env_num=2,
            group_n=1,
            resources_per_worker={'num_cpus': 0.1},
            is_train=True,
            env_kwargs=env_kwargs
        )

        # Create manager
        projection_f = partial(sciworld_projection)
        env_manager = SciWorldEnvironmentManager(_envs, projection_f, config)

        # Reset
        print("  Resetting environments...")
        observations, infos = env_manager.reset(kwargs={})

        # Run mini training loop (10 steps)
        print("  Running 10-step training loop...")
        possible_actions = [
            "look around",
            "inventory",
            "examine table",
            "open door",
            "go north"
        ]

        for step in range(10):
            # Generate random actions with proper format
            text_actions = []
            for i in range(2):
                action = random.choice(possible_actions)
                text_action = f"<think>Step {step}: I should try {action}</think><action>{action}</action>"
                text_actions.append(text_action)

            # Step
            next_observations, rewards, dones, infos = env_manager.step(text_actions)

            # Verify structure
            assert len(rewards) == 2, f"Step {step}: Expected 2 rewards"
            assert len(dones) == 2, f"Step {step}: Expected 2 dones"
            assert 'text' in next_observations, f"Step {step}: Missing 'text' in observations"

            print(f"  Step {step}: rewards={rewards}, dones={dones}")

        # Cleanup
        print("  Cleaning up...")
        _envs.close()

        test_results["passed"].append("test_integration")
        print("✓ Test 7: Integration test - PASSED")
        return True
    except Exception as e:
        test_results["failed"].append(f"test_integration: {str(e)}")
        print(f"✗ Test 7: Integration test - FAILED: {str(e)}")
        traceback.print_exc()
        return False


def print_test_summary():
    """Print summary of all test results"""
    print("\n" + "="*60)
    print("TEST SUMMARY")
    print("="*60)

    total_tests = len(test_results["passed"]) + len(test_results["failed"])
    passed_count = len(test_results["passed"])
    failed_count = len(test_results["failed"])

    print(f"\nTotal Tests: {total_tests}")
    print(f"Passed: {passed_count}/{total_tests}")
    print(f"Failed: {failed_count}/{total_tests}")

    if test_results["passed"]:
        print("\n✓ Passed Tests:")
        for test in test_results["passed"]:
            print(f"  - {test}")

    if test_results["failed"]:
        print("\n✗ Failed Tests:")
        for test in test_results["failed"]:
            print(f"  - {test}")

    if test_results["warnings"]:
        print("\n⚠ Warnings:")
        for warning in test_results["warnings"]:
            print(f"  - {warning}")

    print("\n" + "="*60)
    if failed_count == 0:
        print("✓ ALL TESTS PASSED! Science World environment is ready for training.")
    else:
        print(f"✗ {failed_count} test(s) failed. Please review the errors above.")
    print("="*60)


if __name__ == "__main__":
    print("="*60)
    print("Science World Environment Test Script")
    print("="*60)
    print("\nThis script will test the complete Science World environment integration.")
    print("Tests include: imports, single env, Ray actors, multi-process, projection,")
    print("environment manager, and integration with mini training loop.\n")

    # Add Gym warning to warnings list
    test_results["warnings"].append("Gym version v0.24.0 warning (can be safely ignored)")

    # Run all tests
    tests = [
        ("Basic Imports", test_basic_imports),
        ("Single Environment", test_single_environment),
        ("Ray Remote Actor", test_ray_remote_actor),
        ("Multi-Process Environment", test_multi_process_env),
        ("Projection Function", test_projection_function),
        ("Environment Manager", test_environment_manager),
        ("Integration Test", test_integration)
    ]

    for test_name, test_func in tests:
        try:
            success = test_func()
            if not success:
                print(f"\n⚠ Warning: {test_name} failed, but continuing with remaining tests...")
        except Exception as e:
            print(f"\n⚠ Warning: {test_name} raised an exception, but continuing with remaining tests...")
            print(f"Exception: {str(e)}")

    # Print summary
    print_test_summary()

    # Exit with appropriate code
    if len(test_results["failed"]) > 0:
        sys.exit(1)
    else:
        sys.exit(0)
