"""
Experiment: Cat-Mouse Chase with Termination Probability Commitment

Cat (agent) can move in 8 directions (including diagonals).
Mouse moves in its current direction EVERY step.
With probability p_turn, mouse turns to a direction AWAY from cat.

Instead of fixed duration, agent commits to TERMINATION PROBABILITY at option start.
Each step, option terminates with this probability (geometric distribution).

Key hypothesis: Higher p_turn → higher optimal termination probability
- p_turn=0: Mouse predictable → low term prob (stay in option longer)
- p_turn=1: Mouse unpredictable → high term prob (re-aim frequently)
"""

import numpy as np
import matplotlib.pyplot as plt
from itertools import product
from trainingProcedure.TrainingProcedureChase import (
    TrainingProcedureChase
)
from env.open_grid.open_grid_chase import (
    GRID_SIZE, ind2coord, DIR_UP, DIR_RIGHT, DIR_DOWN, DIR_LEFT
)
from MetaPolicies.NNMetaPolicyChase import (
    TERM_PROBS, ACTION_NAMES
)

# ============================================================================
# EXPERIMENT CONFIGURATION
# ============================================================================

P_TURN_VALUES = [1]
SWITCH_REWARDS = [-0.2]  # Multiple switch rewards for table

# Training settings
N_SEEDS = 1
NUM_EPISODES = 150000
MAX_EPISODE_STEPS = 50
HIDDEN_SIZE = 64
LR_POLICY = 0.001
EPSILON = 0.1
GAMMA = 0.99
STEP_REWARD = -0.2
TERMINAL_REWARD = 10.0
PRINT_EVERY = 1000

# ============================================================================
# RUN EXPERIMENTS
# ============================================================================

def run_experiments():
    """Run all experiment configurations. Returns results AND trainers for visualization."""
    results = []
    trainers = {}  # Store trainers by p_turn for visualization
    
    total_experiments = len(P_TURN_VALUES) * len(SWITCH_REWARDS) * N_SEEDS
    exp_idx = 0
    
    for seed in range(N_SEEDS):
        for p_turn, switch_reward in product(P_TURN_VALUES, SWITCH_REWARDS):
            exp_idx += 1
            print(f"\n{'='*60}")
            print(f"EXPERIMENT {exp_idx}/{total_experiments}: p_turn={p_turn}, seed={seed+1}")
            print(f"{'='*60}")
            
            np.random.seed(seed + 1)
            
            trainer = TrainingProcedureChase(
                p_turn=p_turn,
                switch_reward=switch_reward,
                step_reward=STEP_REWARD,
                terminal_reward=TERMINAL_REWARD,
                gamma=GAMMA,
                epsilon=EPSILON,
                num_episodes=NUM_EPISODES,
                max_episode_steps=MAX_EPISODE_STEPS,
                hidden_size=HIDDEN_SIZE,
                lr_policy=LR_POLICY,
                verbose=True,
                print_every=PRINT_EVERY,
            )
            
            result = trainer.train()
            result['seed'] = seed
            results.append(result)
            
            # Store trainer for visualization (only first seed per p_turn)
            if p_turn not in trainers:
                trainers[p_turn] = trainer
            
            print(f"Final average term prob: {result['final_avg_term_prob']:.3f}")
    
    return results, trainers


def plot_results(results):
    """Plot average termination probability vs p_turn for each switch_reward with std error bars."""
    # Organize by (p_turn, switch_reward)
    term_probs_by_config = {}
    for r in results:
        key = (r['p_turn'], r['switch_reward'])
        if key not in term_probs_by_config:
            term_probs_by_config[key] = []
        term_probs_by_config[key].append(r['final_avg_term_prob'])
    
    fig, ax = plt.subplots(figsize=(6, 3))
    
    colors = ['#1f77b4', '#2ca02c', '#ff7f0e']  # blue, green, orange
    
    for j, switch_reward in enumerate(SWITCH_REWARDS):
        means = []
        stds = []
        
        for p_turn in P_TURN_VALUES:
            values = term_probs_by_config.get((p_turn, switch_reward), [])
            if values:
                mean = np.mean(values)
                std = np.std(values, ddof=1) if len(values) > 1 else 0.0
            else:
                mean, std = 0, 0
            means.append(mean)
            stds.append(std)
        
        color = colors[j % len(colors)]
        ax.errorbar(P_TURN_VALUES, means, yerr=stds, fmt='o-', color=color,
                    label=rf"$\eta$ = {abs(switch_reward):.2f}", linewidth=2, markersize=8, capsize=4)
    
    ax.set_xlabel(r"$p_{turn}$", fontsize=16)
    ax.set_ylabel(r"Avg Term Prob $\beta$", fontsize=16)
    ax.tick_params(axis='both', labelsize=14)
    ax.legend(fontsize=12)
    ax.grid(True)
    ax.set_ylim(0, 1)
    
    plt.tight_layout()
    plt.savefig("chase_term_prob_vs_pturn.pdf", dpi=300, bbox_inches='tight')
    plt.savefig("chase_term_prob_vs_pturn.png", dpi=150, bbox_inches='tight')
    plt.show()


def print_summary(results):
    """Print summary table of average termination probability.
    
    Creates a 2D table: rows = p_turn, columns = switch_reward
    Shows mean ± std when multiple seeds are used.
    """
    # Organize results by (p_turn, switch_reward)
    table_data = {}
    for r in results:
        key = (r['p_turn'], r['switch_reward'])
        if key not in table_data:
            table_data[key] = []
        table_data[key].append(r['final_avg_term_prob'])
    
    # Get unique values
    p_turns = sorted(set(k[0] for k in table_data.keys()))
    switch_rewards = sorted(set(k[1] for k in table_data.keys()))
    
    # Check if we have multiple seeds
    has_multiple_seeds = any(len(v) > 1 for v in table_data.values())
    
    print("\n" + "="*90)
    print("SUMMARY TABLE: Average Termination Probability")
    print("="*90)
    print("\nRows = p_turn (mouse evasiveness)")
    print("Columns = switch_reward (cost of switching options)")
    print("\nHigher term prob = shorter expected duration (E[dur] = 1/p)")
    if has_multiple_seeds:
        print("Values shown as: mean ± std")
    print("-"*90)
    
    # Print header
    if has_multiple_seeds:
        header = "p_turn \\ sw_rew |"
        for sw in switch_rewards:
            header += f"    {sw:+.1f}     |"
    else:
        header = "p_turn \\ sw_rew |"
        for sw in switch_rewards:
            header += f"  {sw:+.1f}  |"
    print(header)
    print("-"*90)
    
    # Print rows
    for p in p_turns:
        if has_multiple_seeds:
            row = f"    {p:.2f}        |"
        else:
            row = f"    {p:.2f}        |"
        for sw in switch_rewards:
            key = (p, sw)
            if key in table_data:
                vals = table_data[key]
                mean = np.mean(vals)
                if len(vals) > 1:
                    std = np.std(vals)
                    row += f" {mean:.2f}±{std:.2f} |"
                else:
                    row += f"  {mean:.3f}  |"
            else:
                row += "    -    |"
        print(row)
    
    print("-"*90)
    
    # Summary
    print("\n→ Expected: Higher p_turn → Higher term prob (more re-aiming)")
    print("→ Expected: Higher switch_reward → Higher term prob (switching is rewarded)")
    print("→ Expected: Lower (negative) switch_reward → Lower term prob (switching is penalized)")
    
    # Check hypothesis for switch_reward=0
    if (p_turns[0], 0.0) in table_data and (p_turns[-1], 0.0) in table_data:
        first = np.mean(table_data[(p_turns[0], 0.0)])
        last = np.mean(table_data[(p_turns[-1], 0.0)])
        if last > first + 0.05:
            print(f"\n✓ HYPOTHESIS CONFIRMED (sw=0): term_prob {first:.3f} → {last:.3f}")
        else:
            print(f"\n✗ Hypothesis not confirmed (sw=0): term_prob {first:.3f} → {last:.3f}")


# ============================================================================
# PYGAME VISUALIZATION
# ============================================================================

def visualize_policies(trainers, delay=0.15):
    """
    Pygame visualization of trained policies.
    
    Shows:
    - Cat (red) and Mouse (blue) positions
    - Current option direction (arrow on cat)
    - Duration of current option
    - Steps remaining in current option
    - Mouse direction
    
    Controls:
    - 1, 2, 3: Switch between p_turn policies
    - SPACE: Pause/Resume
    - R: Reset episode
    - +/-: Change speed
    - Q: Quit
    """
    try:
        import pygame
    except ImportError:
        print("pygame not installed. Install with: pip install pygame")
        print("Falling back to text visualization...")
        visualize_policies_text(trainers)
        return
    
    pygame.init()
    
    CELL_SIZE = 60
    WINDOW_SIZE = GRID_SIZE * CELL_SIZE
    INFO_HEIGHT = 150
    
    screen = pygame.display.set_mode((WINDOW_SIZE, WINDOW_SIZE + INFO_HEIGHT))
    pygame.display.set_caption("Cat-Mouse Chase Policy Visualization")
    font = pygame.font.Font(None, 36)
    small_font = pygame.font.Font(None, 28)
    
    # Colors
    WHITE = (255, 255, 255)
    BLACK = (0, 0, 0)
    RED = (220, 80, 80)
    BLUE = (80, 80, 220)
    GREEN = (80, 200, 80)
    GRAY = (200, 200, 200)
    DARK_GRAY = (100, 100, 100)
    YELLOW = (255, 220, 100)
    
    # Direction arrows
    DIR_ARROWS = {DIR_UP: "↑", DIR_RIGHT: "→", DIR_DOWN: "↓", DIR_LEFT: "←"}
    OPTION_ARROWS = {
        0: "↑", 1: "↗", 2: "→", 3: "↘",
        4: "↓", 5: "↙", 6: "←", 7: "↖"
    }
    
    # State
    p_turn_list = sorted(trainers.keys())
    current_idx = 0
    current_p_turn = p_turn_list[current_idx]
    trainer = trainers[current_p_turn]
    
    # Reset env
    trainer.env._reset()
    
    clock = pygame.time.Clock()
    paused = False
    
    # Option tracking
    current_option = None
    current_term_prob = 0
    in_option = False
    total_reward = 0
    episode_steps = 0
    
    # Direction deltas for drawing trajectory
    OPTION_DELTAS = {
        0: (-1, 0), 1: (-1, 1), 2: (0, 1), 3: (1, 1),
        4: (1, 0), 5: (1, -1), 6: (0, -1), 7: (-1, -1)
    }
    MOUSE_DELTAS = {DIR_UP: (-1, 0), DIR_RIGHT: (0, 1), DIR_DOWN: (1, 0), DIR_LEFT: (0, -1)}
    
    def draw_grid():
        screen.fill(WHITE)
        for i in range(GRID_SIZE + 1):
            pygame.draw.line(screen, GRAY, (i * CELL_SIZE, 0), (i * CELL_SIZE, WINDOW_SIZE))
            pygame.draw.line(screen, GRAY, (0, i * CELL_SIZE), (WINDOW_SIZE, i * CELL_SIZE))
    
    def draw_trajectory(state, option_idx):
        """Draw the cat's trajectory line extending across the grid."""
        if state >= GRID_SIZE * GRID_SIZE or option_idx is None:
            return
        row, col = ind2coord(state)
        center_x = col * CELL_SIZE + CELL_SIZE // 2
        center_y = row * CELL_SIZE + CELL_SIZE // 2
        
        dr, dc = OPTION_DELTAS[option_idx]
        
        # Extend line to edge of grid
        end_row, end_col = row, col
        for _ in range(GRID_SIZE):
            next_row = end_row + dr
            next_col = end_col + dc
            if 0 <= next_row < GRID_SIZE and 0 <= next_col < GRID_SIZE:
                end_row, end_col = next_row, next_col
            else:
                break
        
        end_x = end_col * CELL_SIZE + CELL_SIZE // 2
        end_y = end_row * CELL_SIZE + CELL_SIZE // 2
        
        # Draw dashed trajectory line
        pygame.draw.line(screen, (255, 150, 150), (center_x, center_y), (end_x, end_y), 3)
    
    def draw_cat(state, option_idx):
        if state >= GRID_SIZE * GRID_SIZE:
            return
        row, col = ind2coord(state)
        center_x = col * CELL_SIZE + CELL_SIZE // 2
        center_y = row * CELL_SIZE + CELL_SIZE // 2
        
        # Cat body
        pygame.draw.circle(screen, RED, (center_x, center_y), CELL_SIZE // 3)
        pygame.draw.circle(screen, (180, 50, 50), (center_x, center_y), CELL_SIZE // 3, 3)
        
        # Option direction arrow
        if option_idx is not None:
            arrow = OPTION_ARROWS.get(option_idx, "?")
            text = font.render(arrow, True, WHITE)
            text_rect = text.get_rect(center=(center_x, center_y))
            screen.blit(text, text_rect)
            
            # Draw direction line extending from cat
            dr, dc = OPTION_DELTAS[option_idx]
            line_end_x = center_x + dc * CELL_SIZE * 0.6
            line_end_y = center_y + dr * CELL_SIZE * 0.6
            pygame.draw.line(screen, (255, 200, 200), (center_x, center_y), 
                           (int(line_end_x), int(line_end_y)), 4)
    
    def draw_mouse(state, direction):
        row, col = ind2coord(state)
        center_x = col * CELL_SIZE + CELL_SIZE // 2
        center_y = row * CELL_SIZE + CELL_SIZE // 2
        
        # Mouse body
        pygame.draw.circle(screen, BLUE, (center_x, center_y), CELL_SIZE // 3)
        pygame.draw.circle(screen, (50, 50, 180), (center_x, center_y), CELL_SIZE // 3, 3)
        
        # Direction arrow
        arrow = DIR_ARROWS.get(direction, "?")
        text = font.render(arrow, True, WHITE)
        text_rect = text.get_rect(center=(center_x, center_y))
        screen.blit(text, text_rect)
        
        # Draw direction line extending from mouse
        dr, dc = MOUSE_DELTAS[direction]
        line_end_x = center_x + dc * CELL_SIZE * 0.6
        line_end_y = center_y + dr * CELL_SIZE * 0.6
        pygame.draw.line(screen, (150, 150, 255), (center_x, center_y), 
                        (int(line_end_x), int(line_end_y)), 4)
    
    def draw_info():
        # Info panel background
        pygame.draw.rect(screen, DARK_GRAY, (0, WINDOW_SIZE, WINDOW_SIZE, INFO_HEIGHT))
        
        y = WINDOW_SIZE + 8
        
        # Current policy
        policy_text = f"Policy: p_turn={current_p_turn:.1f}  [Press 1/2/3 to switch]"
        text = small_font.render(policy_text, True, YELLOW)
        screen.blit(text, (10, y))
        y += 22
        
        # Cat option info
        if current_option is not None:
            option_name = ACTION_NAMES[current_option.idx]
            exp_dur = 1.0 / current_term_prob if current_term_prob > 0 else float('inf')
            cat_text = f"CAT: {option_name} {OPTION_ARROWS[current_option.idx]}  TermProb: {current_term_prob:.1f}  (E[dur]={exp_dur:.1f})"
        else:
            cat_text = "CAT: (selecting option...)"
        text = small_font.render(cat_text, True, (255, 180, 180))
        screen.blit(text, (10, y))
        y += 22
        
        # Mouse direction
        mouse_dir_name = {DIR_UP: "UP", DIR_RIGHT: "RIGHT", DIR_DOWN: "DOWN", DIR_LEFT: "LEFT"}
        mouse_arrow = {DIR_UP: "↑", DIR_RIGHT: "→", DIR_DOWN: "↓", DIR_LEFT: "←"}
        dir_name = mouse_dir_name.get(trainer.env.mouse_direction, "?")
        arrow = mouse_arrow.get(trainer.env.mouse_direction, "?")
        mouse_text = f"MOUSE: {dir_name} {arrow}"
        text = small_font.render(mouse_text, True, (180, 180, 255))
        screen.blit(text, (10, y))
        y += 22
        
        # Stats
        stats_text = f"Reward: {total_reward:.1f}  Steps: {episode_steps}"
        text = small_font.render(stats_text, True, WHITE)
        screen.blit(text, (10, y))
        y += 22
        
        # Controls
        controls = "SPACE=pause  R=reset  +/-=speed  Q=quit"
        text = small_font.render(controls, True, GRAY)
        screen.blit(text, (10, y))
        
        # Pause indicator
        if paused:
            pause_text = font.render("PAUSED", True, YELLOW)
            screen.blit(pause_text, (WINDOW_SIZE - 100, WINDOW_SIZE + 10))
    
    running = True
    while running:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                running = False
            elif event.type == pygame.KEYDOWN:
                if event.key == pygame.K_q:
                    running = False
                elif event.key == pygame.K_SPACE:
                    paused = not paused
                elif event.key == pygame.K_r:
                    trainer.env._reset()
                    current_option = None
                    in_option = False
                    total_reward = 0
                    episode_steps = 0
                elif event.key == pygame.K_1 and len(p_turn_list) > 0:
                    current_idx = 0
                    current_p_turn = p_turn_list[current_idx]
                    trainer = trainers[current_p_turn]
                    trainer.env._reset()
                    current_option = None
                    in_option = False
                    total_reward = 0
                    episode_steps = 0
                elif event.key == pygame.K_2 and len(p_turn_list) > 1:
                    current_idx = 1
                    current_p_turn = p_turn_list[current_idx]
                    trainer = trainers[current_p_turn]
                    trainer.env._reset()
                    current_option = None
                    in_option = False
                    total_reward = 0
                    episode_steps = 0
                elif event.key == pygame.K_3 and len(p_turn_list) > 2:
                    current_idx = 2
                    current_p_turn = p_turn_list[current_idx]
                    trainer = trainers[current_p_turn]
                    trainer.env._reset()
                    current_option = None
                    in_option = False
                    total_reward = 0
                    episode_steps = 0
                elif event.key == pygame.K_PLUS or event.key == pygame.K_EQUALS:
                    delay = max(0.02, delay - 0.03)
                elif event.key == pygame.K_MINUS:
                    delay = min(0.5, delay + 0.03)
        
        if not paused and not trainer.env.done:
            # Need new option?
            if not in_option:
                option, term_prob, _ = trainer.meta_policy.choose_option_and_term_prob(
                    trainer.env.state, trainer.env.mouse_state, trainer.env.mouse_direction
                )
                current_option = option
                current_term_prob = term_prob
                in_option = True
            
            # Execute one step
            action = current_option.intra_option_policy.get_action(
                trainer.env.state, trainer.env.mouse_state
            )
            _, reward, done, _ = trainer.env._step(action)
            total_reward += reward
            episode_steps += 1
            
            # Check termination
            if not done and np.random.random() < current_term_prob:
                in_option = False
            
            if done:
                # Reset after a brief pause
                import time
                draw_grid()
                draw_trajectory(trainer.env.state, current_option.idx if current_option else None)
                draw_mouse(trainer.env.mouse_state, trainer.env.mouse_direction)
                draw_cat(trainer.env.state, current_option.idx if current_option else None)
                draw_info()
                pygame.display.flip()
                time.sleep(0.5)
                trainer.env._reset()
                current_option = None
                in_option = False
                total_reward = 0
                episode_steps = 0
        
        # Draw
        draw_grid()
        draw_trajectory(trainer.env.state, current_option.idx if current_option else None)
        draw_mouse(trainer.env.mouse_state, trainer.env.mouse_direction)
        draw_cat(trainer.env.state, current_option.idx if current_option else None)
        draw_info()
        
        pygame.display.flip()
        
        import time
        time.sleep(delay)
        clock.tick(60)
    
    pygame.quit()


def visualize_policies_text(trainers, max_steps=30):
    """Text-based visualization fallback."""
    import time
    
    for p_turn, trainer in sorted(trainers.items()):
        print(f"\n{'='*60}")
        print(f"POLICY p_turn={p_turn}")
        print(f"{'='*60}")
        
        trainer.env._reset()
        total_reward = 0
        
        option = None
        term_prob = 0
        in_option = False
        
        for step in range(max_steps):
            if trainer.env.done:
                break
            
            # Need new option?
            if not in_option:
                option, term_prob, _ = trainer.meta_policy.choose_option_and_term_prob(
                    trainer.env.state, trainer.env.mouse_state, trainer.env.mouse_direction
                )
                in_option = True
                print(f"\n--- New option: {ACTION_NAMES[option.idx]}, term_prob={term_prob:.1f} ---")
            
            # Show state
            cat_row, cat_col = ind2coord(trainer.env.state)
            mouse_row, mouse_col = ind2coord(trainer.env.mouse_state)
            dir_name = {DIR_UP: "UP", DIR_RIGHT: "RIGHT", DIR_DOWN: "DOWN", DIR_LEFT: "LEFT"}
            print(f"Step {step}: Cat({cat_row},{cat_col}) Mouse({mouse_row},{mouse_col}) "
                  f"dir={dir_name[trainer.env.mouse_direction]}")
            
            # Execute step
            action = option.intra_option_policy.get_action(trainer.env.state, trainer.env.mouse_state)
            _, reward, done, _ = trainer.env._step(action)
            total_reward += reward
            
            # Check termination
            if not done and np.random.random() < term_prob:
                in_option = False
            
            time.sleep(0.1)
        
        print(f"\nTotal reward: {total_reward:.1f}")


# ============================================================================
# MAIN
# ============================================================================

if __name__ == "__main__":
    print("="*70)
    print("CAT-MOUSE CHASE EXPERIMENT - TERMINATION PROBABILITY COMMITMENT")
    print("="*70)
    print(f"Mouse turn probabilities: {P_TURN_VALUES}")
    print(f"Switch rewards: {SWITCH_REWARDS}")
    print(f"Available term probs: {TERM_PROBS}")
    print(f"Episodes: {NUM_EPISODES}")
    print()
    print("Setup:")
    print("  - Cat can move in 8 directions (including diagonals)")
    print("  - Mouse moves in its direction EVERY step")
    print("  - Mouse turns with probability p_turn (away from cat)")
    print("  - Agent commits to TERMINATION PROBABILITY at option start")
    print("  - Each step, option terminates with that probability")
    print()
    print("Hypothesis: Higher p_turn → higher optimal term prob (more re-aiming)")
    print("="*70)
    
    results, trainers = run_experiments()
    print_summary(results)
    
    np.save("chase_results.npy", results)
    print("\nResults saved to chase_results.npy")
    
    # Plot results
    plot_results(results)
    
    # Visualization (optional - skip if no trainers)
    if trainers:
        print("\n" + "="*60)
        print("LAUNCHING POLICY VISUALIZATION")
        print("="*60)
        print("Controls:")
        print("  1, 2, 3: Switch between p_turn policies")
        print("  SPACE: Pause/Resume")
        print("  R: Reset episode")
        print("  +/-: Change speed")
        print("  Q: Quit")
        print("="*60)
        
        visualize_policies(trainers)

