from gridworld import *
import torch
import numpy as np
from collections import deque
import pandas as pd
from minigrid.core.constants import OBJECT_TO_IDX

torch.manual_seed(0)

action_name_to_index = {
    'left': 0,
    'right': 1,
    'forward': 2,
    'pickup': 3,
    'drop': 4,
    'toggle': 5,
    'done': 6
}

dir_names = ['right', 'down', 'left', 'up']
agent_dirs = [0, 1, 2, 3]  # right, down, left, up

def possible_blue_nodes():
    blue_row = [1, 2, 3]
    return [(6, row) for row in blue_row]



def possible_red_location():
    red_row = [1, 2, 3]
    return [(2, row) for row in red_row]


def replay_reference(start_pos, start_dir, actions):
    trajectory = []
    x, y = start_pos
    dir = start_dir

    for a in actions:
        trajectory.append(((x, y), dir, a))

        if a == action_name_to_index['left']:
            dir = (dir - 1) % 4
        elif a == action_name_to_index['right']:
            dir = (dir + 1) % 4
        elif a == action_name_to_index['forward']:
            dx, dy = [(1, 0), (0, 1), (-1, 0), (0, -1)][dir]
            x, y = x + dx, y + dy

    return trajectory 


def compute_per_step_utility_score(agent_trajectory, reference_trajectory):
    agent_tail = agent_trajectory[-len(reference_trajectory):]
    step_utilities = []

    for (agent_pos, agent_dir, agent_a), (ref_pos, ref_dir, ref_a) in zip(agent_tail, reference_trajectory):
        if (agent_pos == ref_pos) and (agent_dir == ref_dir) and (agent_a == ref_a): 
            step_utilities.append(1.0)
        elif (agent_pos == ref_pos) and (agent_a == ref_a): 
            step_utilities.append(0.7)
        elif agent_a == ref_a: 
            step_utilities.append(0.4)
        else:
            step_utilities.append(0.0)

    return np.array(step_utilities, dtype=np.float32)



def get_adjacent_facing_states(door_pos, color = 'blue'):
    x, y = door_pos
    if color == 'red':
        return {((x + 1, y), 2)}  # right of door, facing left
    else:    
        return {((x - 1, y), 0)} # left of door, facing right}


def bfs_trajectory(start_pos, start_dir, valid_target):
    visited = set()
    queue = deque()
    queue.append((start_pos, start_dir, []))
    while queue:
        (x, y), dir, path = queue.popleft()
        if ((x, y), dir) in valid_target: return path + ['toggle']
        if ((x, y), dir) in visited:  continue
        visited.add(((x, y), dir))

        queue.append(((x, y), (dir - 1) % 4, path + ['left']))
        queue.append(((x, y), (dir + 1) % 4, path + ['right']))

        dx, dy = [(1, 0), (0, 1), (-1, 0), (0, -1)][dir]
        nx, ny = x + dx, y + dy
        if 0 <= nx < 5 and 0 <= ny < 5: queue.append(((nx, ny), dir, path + ['forward']))

    return None 
    

def compute_utility_sequence(agent_trajectory, red_pos, blue_pos):
    U = np.zeros(len(agent_trajectory), dtype=np.float32)
    red_target_states = get_adjacent_facing_states(red_pos, color='red')

    red_path = bfs_trajectory(agent_trajectory[0][0], agent_trajectory[0][1], red_target_states)
    if red_path:
        red_ref = replay_reference(agent_trajectory[0][0], agent_trajectory[0][1], [action_name_to_index[a] for a in red_path])
        red_toggle_idx = None
        for i, ((x, y), d, a, red_open, blue_open) in enumerate(agent_trajectory):
            if (x, y) == (red_pos[0] + 1, red_pos[1]) and d == 2 and a == action_name_to_index['toggle']:
                red_toggle_idx = i
                break
        if red_toggle_idx is not None:
            agent_segment = [(x, d, a) for (x, d, a, _, _) in agent_trajectory[:red_toggle_idx + 1]]
            score_vector = compute_per_step_utility_score(agent_segment[-len(red_ref):], red_ref)
            U[red_toggle_idx - len(score_vector) + 1 : red_toggle_idx + 1] += score_vector
            
            
    for i, ((x, y), d, a, red_open, blue_open) in enumerate(agent_trajectory):
        if red_open == 0:
            if abs(x - red_pos[0]) + abs(y - red_pos[1]) <= 2:
                U[i] += 0.1
        # Shaping for movement toward the blue door after red is open but before blue is toggled
        if red_open == 1 and blue_open == 0:
            if abs(x - blue_pos[0]) + abs(y - blue_pos[1]) <= 2:
                U[i] += 0.15
    segment = []
    blue_toggle_idx = None
    in_phase2 = False
    # Phase 2: From red_open == 1 and blue_open == 0, compute matching to blue door trajectory
    for i, ((x, y), d, a, red_open, blue_open) in enumerate(agent_trajectory):
        if red_open == 1 and blue_open == 0:
            if not in_phase2:
                segment = []
                in_phase2 = True
            segment.append(((x, y), d, a))
            if (x, y) == (blue_pos[0] - 1, blue_pos[1]) and d == 0 and a == action_name_to_index['toggle']:
                blue_toggle_idx = i
                break
            
            
    if segment and blue_toggle_idx is not None:
        target_states = get_adjacent_facing_states(blue_pos, color='blue')

        path = bfs_trajectory(red_pos, 2, target_states)  # assuming dir = 2 after red
        best_refr = (segment[-len(best_refr):], best_refr)
        U[blue_toggle_idx - len(score_vector) + 1 : blue_toggle_idx + 1] += score_vector
    return U