from gridworld import *
import torch
import numpy as np
from collections import deque
import pandas as pd
from minigrid.core.constants import OBJECT_TO_IDX
from util import trajectory_m, path_to_goal

torch.manual_seed(0)

action_name_to_index = {
    'left': 0,
    'right': 1,
    'forward': 2,
    'pickup': 3,
    'drop': 4,
    'toggle': 5,
    'done': 6
}

dir_names = ['right', 'down', 'left', 'up']
agent_positions = [(1, 1), (1, 2), (1, 3)]
agent_dirs = [0, 1, 2, 3]  # right, down, left, up


def get_adjacent_facing_states(door_pos):
    x, y = door_pos
    return {((x - 1, y), 0)} # left of door, facing right}
      

def generate_reference_action_dict(door_pos):
    valid_target = get_adjacent_facing_states(door_pos)
    reference_action_dict = {}
    for pos in agent_positions:
        for d in agent_dirs:
            traj = trajectory_m(pos, d, valid_target)
            if traj is not None:
                reference_action_dict[(pos, d)] = [action_name_to_index[a] for a in traj]
    return reference_action_dict


def door_nodes(env_size = 5):
    if env_size == 5: return [(2, 1), (2, 2)]
    
    
#These are “plans” to reach and toggle the door from various positions/directions.
nodes = { (2, 1): generate_reference_action_dict((2, 1)),
          (2, 2): generate_reference_action_dict((2, 2)) }



    
def replay_reference(start_pos, start_dir, actions):
    """
    Simulate executing a sequence of actions from a given starting state
    and return the full trajectory as ((x, y), dir, action) tuples.
    """
    trajectory = []
    x, y = start_pos
    dir = start_dir

    for a in actions:
        trajectory.append(((x, y), dir, a))

        if a == action_name_to_index['left']:
            dir = (dir - 1) % 4
        elif a == action_name_to_index['right']:
            dir = (dir + 1) % 4
        elif a == action_name_to_index['forward']:
            dx, dy = [(1, 0), (0, 1), (-1, 0), (0, -1)][dir]
            x, y = x + dx, y + dy

    return trajectory    
    
  
def get_best_matching_reference(agent_segment, reference_action_dict):
    """
    From the full set of reference trajectories (keyed by (pos, dir)),
    find the one that best matches the end of agent_segment.
    """
    best_score = -1
    best_ref = []
    for (start_pos, start_dir), actions in reference_action_dict.items():
        ref_traj = replay_reference(start_pos, start_dir, actions)
        score = compute_per_step_utility_score(agent_segment[-len(ref_traj):], ref_traj).sum()
        if score > best_score:
            best_score = score
            best_ref = ref_traj
    return best_ref
  
  

def compute_per_step_utility_score(agent_trajectory, reference_trajectory):
    agent_tail = agent_trajectory[-len(reference_trajectory):]
    step_utilities = []

    for (agent_pos, agent_dir, agent_a), (ref_pos, ref_dir, ref_a) in zip(agent_tail, reference_trajectory):
        if (agent_pos == ref_pos) and (agent_dir == ref_dir) and (agent_a == ref_a): 
            step_utilities.append(1.0)
        elif (agent_pos == ref_pos) and (agent_a == ref_a): 
            step_utilities.append(0.7)
        elif agent_a == ref_a: 
            step_utilities.append(0.4)
        else:
            step_utilities.append(0.0)

    return np.array(step_utilities, dtype=np.float32)



def compute_keyseg(agent_trajectory):
    segment = []
    for (pos, dir, a, carrying_key) in agent_trajectory:
        segment.append((pos, dir, a))
        if carrying_key:
            break
    return segment

def compute_doorseg(agent_trajectory, door_pos):
    key_index = next((i for i, (_, _, _, k) in enumerate(agent_trajectory) if k), None)
    if key_index is None: return 0.0  # key never picked up
    door_segment = agent_trajectory[key_index+1:]
    valid_target = get_adjacent_facing_states(door_pos)
    for j, (pos, dir, a, _) in enumerate(door_segment):
        if a == action_name_to_index['toggle'] and ((pos, dir) in valid_target):
            door_segment = door_segment[:j+1]
            break
    return [(pos, dir, a) for (pos, dir, a, _) in door_segment]


def compute_goalseg(agent_trajectory, door_pos, goal_pos = (3,3)):
    for i, (pos, dir, _, _) in enumerate(agent_trajectory):
        if pos == door_pos:
            start, enterance_dir = i, dir
            break
    
    segment = []
    for j in range(start, len(agent_trajectory)):
        segment.append((agent_trajectory[j][0], agent_trajectory[j][1], agent_trajectory[j][2]))
        if agent_trajectory[j][0] == goal_pos:
            break
    return segment, enterance_dir
    
    
def compute_utility_sequence(agent_trajectory, door_pos,obs_seq):
    # just door utility for now
    door_segment = compute_doorseg(agent_trajectory, door_pos)
    goal_segment, enterance_dir = compute_goalseg(agent_trajectory, door_pos)
    
    reference_action_dict = nodes.get(door_pos)
    if reference_action_dict is None:
        reference_action_dict = generate_reference_action_dict(door_pos)
        nodes[door_pos] = reference_action_dict

    key_segment = compute_keyseg(agent_trajectory)
    key_reference_dict = generate_reference_action_dict((1, 3))  # Replace with actual key pos if dynamic
    key_reference = get_best_matching_reference(key_segment, key_reference_dict)
    U_key = compute_per_step_utility_score(key_segment, key_reference)

    reference = get_best_matching_reference(door_segment, reference_action_dict)
    goal_reference = path_to_goal((door_pos[0], door_pos[1]), enterance_dir)

    # print("Reference trajectory:", reference)
    # print("Agent trajectory segment:", door_segment)
    U_door = compute_per_step_utility_score(door_segment, reference)
    U_goal = compute_per_step_utility_score(goal_segment, goal_reference)
    # print("--------------------------------------")
    # print("Reference trajectory:", reference)
    print("Door segment:", door_segment)
    print("key segment:", key_segment)

    # Return a padded vector to match agent_trajectory length
    U = np.zeros(len(agent_trajectory), dtype=np.float32)
    if door_segment:
        start_idx = next((i for i, (_, _, _, k) in enumerate(agent_trajectory) if k), None)
        if start_idx is not None:
            U[start_idx + 1 : start_idx + 1 + len(U_door)] += U_door
    return U


