#!/usr/bin/env python
"""Random walk test for KeyLockEnv to verify reward function.

This script runs random walk episodes and prints information when
the agent successfully reaches the goal (receives +1 reward).
"""

import time
import numpy as np
import random
from key_lock_env import KeyLockEnv


def random_walk_test(
    num_episodes=100,
    max_steps_per_episode=500,
    seed=None,
    verbose=True,
):
    """
    Run random walk episodes and track when goal is reached.
    
    Args:
        num_episodes: Number of episodes to run
        max_steps_per_episode: Maximum steps per episode
        seed: Random seed
        verbose: Print detailed information
    """
    if seed is not None:
        np.random.seed(seed)
        random.seed(seed)
    
    # Create environment
    env = KeyLockEnv(
        size=15,
        agent_start_pos=(1, 1),
        agent_start_dir=0,
        yellow_key_pos=(3, 3),
        yellow_door_pos=(7, 7),
        blue_key_pos=(12, 3),
        blue_door_pos=(10, 10),
        goal_pos=(13, 13),
        render_mode=None,  # No rendering for batch testing
    )
    
    # Statistics
    total_episodes = 0
    successful_episodes = 0
    total_steps = 0
    total_reward = 0.0
    episode_rewards = []
    episode_lengths = []
    successful_episode_indices = []  # Track which episodes were successful
    
    print("=" * 70)
    print("KeyLockEnv Random Walk Test")
    print("=" * 70)
    print(f"Configuration:")
    print(f"  Yellow key position: {env.yellow_key_pos}")
    print(f"  Yellow door position: {env.yellow_door_pos}")
    print(f"  Blue key position: {env.blue_key_pos}")
    print(f"  Blue door position: {env.blue_door_pos}")
    print(f"  Goal position: {env.goal_pos}")
    print(f"  Agent start: (1, 1), direction: 0 (east)")
    print(f"  Number of episodes: {num_episodes}")
    print(f"  Max steps per episode: {max_steps_per_episode}")
    print(f"  Actions: 0=up, 1=down, 2=left, 3=right, 4=pickup, 5=toggle")
    print("=" * 70)
    print()
    
    for episode in range(num_episodes):
        obs, info = env.reset()
        episode_reward = 0.0
        episode_steps = 0
        has_yellow_key = False
        has_blue_key = False
        goal_reached = False
        
        if verbose and episode % 10 == 0:
            print(f"Episode {episode}...", end=" ", flush=True)
        
        for step in range(max_steps_per_episode):
            # Random action
            action = random.randint(0, 5)
            
            # Execute action
            obs, reward, terminated, truncated, info = env.step(action)
            
            # Update statistics
            episode_reward += reward
            episode_steps += 1
            total_steps += 1
            
            # Check if yellow key was picked up
            if obs[3] == 1 and not has_yellow_key:
                has_yellow_key = True
                if verbose:
                    print(f"\n  [Episode {episode}, Step {step}] ✓ Yellow key picked up!")
            
            # Check if blue key was picked up
            if obs[4] == 1 and not has_blue_key:
                has_blue_key = True
                if verbose:
                    print(f"\n  [Episode {episode}, Step {step}] ✓ Blue key picked up!")
            
            # Check if goal was reached (reward == 1.0)
            if reward == 1.0:
                goal_reached = True
                successful_episodes += 1
                successful_episode_indices.append(episode)
                print("\n" + "=" * 70)
                print(f"🎉 SUCCESS! Episode {episode}, Step {step}")
                print("=" * 70)
                print(f"  Final state:")
                print(f"    Position: ({obs[0]}, {obs[1]})")
                print(f"    Direction: {obs[2]} (0=east, 1=south, 2=west, 3=north)")
                print(f"    Has yellow key: {obs[3]}")
                print(f"    Has blue key: {obs[4]}")
                print(f"    Yellow door open: {obs[5]}")
                print(f"    Blue door open: {obs[6]}")
                print(f"  Episode statistics:")
                print(f"    Total steps: {episode_steps}")
                print(f"    Total reward: {episode_reward:.2f}")
                print(f"    Final reward: {reward:.2f} (goal reached!)")
                print("=" * 70)
                print()
            
            # Check termination
            if terminated or truncated:
                break
        
        # Update statistics
        total_episodes += 1
        total_reward += episode_reward
        episode_rewards.append(episode_reward)
        episode_lengths.append(episode_steps)
        
        if verbose and episode % 10 == 9:
            print(f"Done. Success rate so far: {successful_episodes}/{episode+1} = {100*successful_episodes/(episode+1):.1f}%")
    
    # Print final statistics
    print("\n" + "=" * 70)
    print("Final Statistics")
    print("=" * 70)
    print(f"Total episodes: {total_episodes}")
    print(f"Successful episodes (goal reached): {successful_episodes}")
    print(f"Success rate: {100 * successful_episodes / total_episodes:.2f}%")
    print(f"Total steps: {total_steps}")
    print(f"Average steps per episode: {total_steps / total_episodes:.1f}")
    print(f"Average reward per episode: {total_reward / total_episodes:.3f}")
    
    if successful_episodes > 0:
        # Use the tracked successful episode indices
        successful_rewards = [episode_rewards[i] for i in successful_episode_indices]
        successful_lengths = [episode_lengths[i] for i in successful_episode_indices]
        
        print(f"\nSuccessful episodes statistics:")
        print(f"  Number of successful episodes: {len(successful_rewards)}")
        print(f"  Average steps to success: {np.mean(successful_lengths):.1f}")
        print(f"  Average reward: {np.mean(successful_rewards):.3f}")
        if len(successful_lengths) > 0:
            print(f"  Min steps: {min(successful_lengths)}")
            print(f"  Max steps: {max(successful_lengths)}")
    else:
        print("\n⚠️  WARNING: No successful episodes!")
        print("   The agent did not manage to reach the goal.")
        print("   This could indicate:")
        print("   - The task is too difficult for random walk")
        print("   - Need more episodes or steps per episode")
        print("   - There might be an issue with the environment")
    
    print("=" * 70)
    
    return {
        'total_episodes': total_episodes,
        'successful_episodes': successful_episodes,
        'success_rate': successful_episodes / total_episodes,
        'total_steps': total_steps,
        'avg_reward': total_reward / total_episodes,
        'episode_rewards': episode_rewards,
        'episode_lengths': episode_lengths,
    }


def main():
    """Main function to run the random walk test."""
    import argparse
    
    parser = argparse.ArgumentParser(description="Random walk test for KeyLockEnv")
    parser.add_argument(
        "--num_episodes",
        type=int,
        default=100,
        help="Number of episodes to run (default: 100)"
    )
    parser.add_argument(
        "--max_steps",
        type=int,
        default=500,
        help="Maximum steps per episode (default: 500)"
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=None,
        help="Random seed (default: None)"
    )
    parser.add_argument(
        "--quiet",
        action="store_true",
        help="Reduce output verbosity"
    )
    
    args = parser.parse_args()
    
    # Run test
    stats = random_walk_test(
        num_episodes=args.num_episodes,
        max_steps_per_episode=args.max_steps,
        seed=args.seed,
        verbose=not args.quiet,
    )
    
    # Return success status
    return 0 if stats['successful_episodes'] > 0 else 1


if __name__ == "__main__":
    exit(main())
