"""
Pygame visualization for Cat-Mouse Chase environment.
Watch the game play out step by step to debug the mouse behavior.
"""

import pygame
import numpy as np
import time
from env.open_grid.open_grid_chase import (
    OpenGridChase, GRID_SIZE, ind2coord,
    DIR_UP, DIR_RIGHT, DIR_DOWN, DIR_LEFT,
    ACT_UP, ACT_DOWN, ACT_LEFT, ACT_RIGHT,
    ACT_UP_LEFT, ACT_UP_RIGHT, ACT_DOWN_LEFT, ACT_DOWN_RIGHT
)

# Colors
WHITE = (255, 255, 255)
BLACK = (0, 0, 0)
RED = (255, 100, 100)      # Cat
BLUE = (100, 100, 255)     # Mouse
GREEN = (100, 255, 100)    # Mouse direction indicator
GRAY = (200, 200, 200)     # Grid lines

CELL_SIZE = 60
WINDOW_SIZE = GRID_SIZE * CELL_SIZE

# Direction arrows for mouse
DIR_ARROWS = {
    DIR_UP: "↑",
    DIR_RIGHT: "→",
    DIR_DOWN: "↓",
    DIR_LEFT: "←",
}


def draw_grid(screen):
    """Draw the grid lines."""
    for i in range(GRID_SIZE + 1):
        pygame.draw.line(screen, GRAY, (i * CELL_SIZE, 0), (i * CELL_SIZE, WINDOW_SIZE))
        pygame.draw.line(screen, GRAY, (0, i * CELL_SIZE), (WINDOW_SIZE, i * CELL_SIZE))


def draw_cat(screen, state, font):
    """Draw the cat (red circle with C)."""
    if state >= GRID_SIZE * GRID_SIZE:
        return  # Absorbing state
    row, col = ind2coord(state)
    center_x = col * CELL_SIZE + CELL_SIZE // 2
    center_y = row * CELL_SIZE + CELL_SIZE // 2
    
    pygame.draw.circle(screen, RED, (center_x, center_y), CELL_SIZE // 3)
    text = font.render("C", True, BLACK)
    text_rect = text.get_rect(center=(center_x, center_y))
    screen.blit(text, text_rect)


def draw_mouse(screen, state, direction, font):
    """Draw the mouse (blue circle with direction arrow)."""
    row, col = ind2coord(state)
    center_x = col * CELL_SIZE + CELL_SIZE // 2
    center_y = row * CELL_SIZE + CELL_SIZE // 2
    
    pygame.draw.circle(screen, BLUE, (center_x, center_y), CELL_SIZE // 3)
    
    # Draw direction arrow
    arrow = DIR_ARROWS.get(direction, "?")
    text = font.render(arrow, True, WHITE)
    text_rect = text.get_rect(center=(center_x, center_y))
    screen.blit(text, text_rect)
    
    # Draw direction line
    dr, dc = {DIR_UP: (-1, 0), DIR_RIGHT: (0, 1), DIR_DOWN: (1, 0), DIR_LEFT: (0, -1)}[direction]
    end_x = center_x + dc * CELL_SIZE // 2
    end_y = center_y + dr * CELL_SIZE // 2
    pygame.draw.line(screen, GREEN, (center_x, center_y), (end_x, end_y), 3)


def draw_info(screen, font, info):
    """Draw info text at top."""
    y = 10
    for line in info:
        text = font.render(line, True, BLACK)
        screen.blit(text, (10, y))
        y += 25


def get_greedy_action(cat_state, mouse_state):
    """Simple greedy policy: move toward mouse."""
    cat_row, cat_col = ind2coord(cat_state)
    mouse_row, mouse_col = ind2coord(mouse_state)
    
    dr = mouse_row - cat_row
    dc = mouse_col - cat_col
    
    # Choose diagonal if both differ
    if dr < 0 and dc < 0:
        return ACT_UP_LEFT
    elif dr < 0 and dc > 0:
        return ACT_UP_RIGHT
    elif dr > 0 and dc < 0:
        return ACT_DOWN_LEFT
    elif dr > 0 and dc > 0:
        return ACT_DOWN_RIGHT
    elif dr < 0:
        return ACT_UP
    elif dr > 0:
        return ACT_DOWN
    elif dc < 0:
        return ACT_LEFT
    elif dc > 0:
        return ACT_RIGHT
    else:
        return ACT_UP  # Already on mouse


def run_visualization(p_turn=1.0, delay=0.3, max_steps=100):
    """Run the visualization."""
    pygame.init()
    screen = pygame.display.set_mode((WINDOW_SIZE, WINDOW_SIZE + 100))
    pygame.display.set_caption(f"Cat-Mouse Chase (p_turn={p_turn})")
    font = pygame.font.Font(None, 36)
    small_font = pygame.font.Font(None, 24)
    clock = pygame.time.Clock()
    
    env = OpenGridChase(p_turn=p_turn, step_reward=-0.1, terminal_reward=10)
    cat_state, mouse_state = env._reset()
    
    step = 0
    total_reward = 0
    running = True
    paused = False
    
    print(f"\nStarting visualization with p_turn={p_turn}")
    print("Controls: SPACE=pause/resume, Q=quit, R=reset, +/-=speed")
    print("-" * 50)
    
    while running and step < max_steps:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                running = False
            elif event.type == pygame.KEYDOWN:
                if event.key == pygame.K_q:
                    running = False
                elif event.key == pygame.K_SPACE:
                    paused = not paused
                elif event.key == pygame.K_r:
                    cat_state, mouse_state = env._reset()
                    step = 0
                    total_reward = 0
                elif event.key == pygame.K_PLUS or event.key == pygame.K_EQUALS:
                    delay = max(0.05, delay - 0.05)
                elif event.key == pygame.K_MINUS:
                    delay = min(1.0, delay + 0.05)
        
        if paused:
            # Draw "PAUSED" and wait
            screen.fill(WHITE)
            draw_grid(screen)
            draw_mouse(screen, env.mouse_state, env.mouse_direction, font)
            draw_cat(screen, env.state, font)
            
            info = [
                f"Step: {step}  Reward: {total_reward:.1f}  PAUSED",
                f"p_turn: {p_turn}  delay: {delay:.2f}s",
                "SPACE=resume, Q=quit, R=reset"
            ]
            
            # Draw info panel at bottom
            pygame.draw.rect(screen, WHITE, (0, WINDOW_SIZE, WINDOW_SIZE, 100))
            y = WINDOW_SIZE + 10
            for line in info:
                text = small_font.render(line, True, BLACK)
                screen.blit(text, (10, y))
                y += 20
            
            pygame.display.flip()
            clock.tick(30)
            continue
        
        # Get action (simple greedy toward mouse)
        action = get_greedy_action(env.state, env.mouse_state)
        
        # Print state before step
        cat_row, cat_col = ind2coord(env.state)
        mouse_row, mouse_col = ind2coord(env.mouse_state)
        mouse_dir_name = {DIR_UP: "UP", DIR_RIGHT: "RIGHT", DIR_DOWN: "DOWN", DIR_LEFT: "LEFT"}[env.mouse_direction]
        
        print(f"Step {step}: Cat({cat_row},{cat_col}) Mouse({mouse_row},{mouse_col}) dir={mouse_dir_name}")
        
        # Take step
        next_state, reward, done, info = env._step(action)
        total_reward += reward
        step += 1
        
        # Draw
        screen.fill(WHITE)
        draw_grid(screen)
        draw_mouse(screen, env.mouse_state, env.mouse_direction, font)
        draw_cat(screen, env.state, font)
        
        # Info panel
        info_lines = [
            f"Step: {step}  Reward: {total_reward:.1f}  Done: {done}",
            f"p_turn: {p_turn}  delay: {delay:.2f}s",
            f"Mouse dir: {mouse_dir_name}",
        ]
        
        pygame.draw.rect(screen, WHITE, (0, WINDOW_SIZE, WINDOW_SIZE, 100))
        y = WINDOW_SIZE + 10
        for line in info_lines:
            text = small_font.render(line, True, BLACK)
            screen.blit(text, (10, y))
            y += 20
        
        pygame.display.flip()
        
        if done:
            print(f"CAUGHT! Total reward: {total_reward:.1f} in {step} steps")
            time.sleep(1)
            # Reset
            cat_state, mouse_state = env._reset()
            step = 0
            total_reward = 0
        
        time.sleep(delay)
        clock.tick(60)
    
    pygame.quit()
    print(f"\nVisualization ended. Final reward: {total_reward:.1f}")


if __name__ == "__main__":
    import sys
    
    p_turn = 1.0
    delay = 0.3
    
    if len(sys.argv) > 1:
        p_turn = float(sys.argv[1])
    if len(sys.argv) > 2:
        delay = float(sys.argv[2])
    
    print(f"Running visualization with p_turn={p_turn}, delay={delay}s")
    print("Usage: python visualize_chase.py [p_turn] [delay]")
    
    run_visualization(p_turn=p_turn, delay=delay)


