from typing import List, Dict

import gymnasium as gym
import numpy as np

import heapq
from collections import deque, defaultdict

from minigrid.wrappers import FullyObsWrapper, ImgObsWrapper

from DataGenerators.DataGenerator import DataGenerator
from DataGenerators.MiniGridDoorKeyDataGenerator import observation_to_grid_representation
from config import PERSISTENT_DATA_PATH

# Define directions and their movement vectors
DIRECTIONS = {'^': (-1, 0), '>': (0, 1), 'V': (1, 0), '<': (0, -1)}
DIR_ORDER = ['^', '>', 'V', '<']  # Order for rotating directions


# Rotate right and left to get new directions
def rotate_right(direction):
    idx = (DIR_ORDER.index(direction) + 1) % 4
    return DIR_ORDER[idx]


def rotate_left(direction):
    idx = (DIR_ORDER.index(direction) - 1) % 4
    return DIR_ORDER[idx]


def find_all_optimal_paths(maze):
    rows, cols = len(maze), len(maze[0])

    # Locate starting position and goal
    start, start_dir = None, None
    goal = None
    for r in range(rows):
        for c in range(cols):
            if maze[r][c] in {'^', '>', 'V', '<'}:
                start = (r, c)
                start_dir = maze[r][c]
            elif maze[r][c] == 'G':
                goal = (r, c)

    if not start or not goal:
        return []  # No valid start or goal found

    # BFS setup
    queue = deque([(start, start_dir, [])])  # (position, direction, path)
    visited = defaultdict(lambda: float('inf'))  # Track min steps per state
    optimal_paths = []
    min_steps = float('inf')

    while queue:
        (x, y), direction, path = queue.popleft()

        # If we reach the goal, check if it's optimal
        if (x, y) == goal:
            if len(path) < min_steps:
                min_steps = len(path)
                optimal_paths = [path]
            elif len(path) == min_steps:
                optimal_paths.append(path)
            continue

        # Move Forward
        dx, dy = DIRECTIONS[direction]
        new_x, new_y = x + dx, y + dy
        if 0 <= new_x < rows and 0 <= new_y < cols and maze[new_x][new_y] != '#':
            if visited[(new_x, new_y, direction)] > len(path) + 1:
                visited[(new_x, new_y, direction)] = len(path) + 1
                queue.append(((new_x, new_y), direction, path + ['F']))

        # Turn Left
        new_direction = rotate_left(direction)
        if visited[(x, y, new_direction)] > len(path) + 1:
            visited[(x, y, new_direction)] = len(path) + 1
            queue.append(((x, y), new_direction, path + ['L']))

        # Turn Right
        new_direction = rotate_right(direction)
        if visited[(x, y, new_direction)] > len(path) + 1:
            visited[(x, y, new_direction)] = len(path) + 1
            queue.append(((x, y), new_direction, path + ['R']))

    return optimal_paths


class FourRoomDataGenerator(DataGenerator):
    def __init__(self, env="MiniGrid-FourRooms-v0", type="binary_feedback", distribution=1):

        """

        :param env:
        :param type:
        :param distribution:  1 full expert 0 fully random
        :param expert_path:
        """
        super().__init__(env, type, distribution)
        self.env = FullyObsWrapper(self.env)
        self.env = ImgObsWrapper(self.env)

    def sample_array_of_states(self, number=48, cutoff_length=20, seed_list=[]) -> np.array:
        """
        This function samples an array of states from the environment.
        :param number: number of states to sample
        :return: np.array of states
        """
        expert_action_percentage = self.distribution
        state_list = []
        episode_count = 0
        seed_count = 0
        obs, _ = self.env.reset(seed=seed_list[seed_count])
        for i in range(number):

            matrix_state = observation_to_grid_representation(self.env)
            expert_actions, paths = self.get_expert_actions(matrix_state, return_path=True)
            expert_action = expert_actions[0]
            # With action mask
            random_action = np.random.choice([0, 1, 2])
            if np.random.rand() < expert_action_percentage:
                action_taken = expert_action
            else:
                action_taken = random_action

            obs_dict = {"Observation": obs, "ObsString": matrix_state, "expert_actions": expert_actions,
                        "expert_paths": paths, "seed": seed_list[seed_count]}

            state_list.append(obs_dict)
            next_obs, reward, done, _, _ = self.env.step(action_taken)
            if done or episode_count > cutoff_length:
                seed_count += 1
                obs, _ = self.env.reset(seed=seed_list[seed_count])
                episode_count = 0

            else:
                obs = next_obs
        self.state_array = np.asarray(state_list)
        return self.state_array

    def sample_array_of_trajectories(self, number=48) -> List[Dict]:
        pass

    def sample_data(self, number=48, cutoff_length=100, seed_list=list(range(1919, 5000))) -> List[Dict]:
        ret_list = []
        masked_action_list = [0, 1, 2]
        self.sample_array_of_states(number, cutoff_length, seed_list)
        for i in self.state_array:
            # probs = self.get_expert_probabilities(i["Observation"])
            if self.type == "binary_feedback":
                for j in masked_action_list:
                    ret_list.append(
                        {"state": i, "action": j, "feedback": self.get_expert_binary_feedback(i["ObsString"], j), })
            elif self.type == "preference":
                for j in masked_action_list:
                    for k in masked_action_list:
                        if j != k:
                            ret_list.append({"state": i, "action1": j, "action2": k,
                                             "feedback": self.get_expert_preference(i["ObsString"], j, k)})
            elif self.type == "action_advising":
                ret_list.append({"state": i, "feedback": self.get_expert_action_advising(i["ObsString"])})

        return ret_list

    def get_expert_actions(self, state, return_path=False) -> List:
        """
        This function returns the expert actions for a given state.
        :param state:
        :return:
        """
        best_actions = []
        str_to_action_dict = {"L": 0, "R": 1, "F": 2}
        # test_matrix = observation_to_grid_representation(env)
        paths = find_all_optimal_paths(state)
        for path in paths:
            best_actions.append(str_to_action_dict[path[0]])
        if return_path:
            return best_actions, paths
        else:
            return best_actions

    def get_expert_qvalue(self, state, action) -> float:
        """
        This function returns the expert q-value for a given state-action pair.
        Not actual learned qvalues but a heuristic.
        :param state:
        :param action:
        :return:
        """
        expert_actions = self.get_expert_actions(state)
        if action in expert_actions:
            return 1
        else:
            return -1

    def get_expert_value(self, state) -> float:
        return NotImplemented



if __name__ == "__main__":
    # for i in [(0, "binary_feedback"), (0.5, "binary_feedback"), (1, "binary_feedback")]:
    #     mgdg = FourRoomDataGenerator("MiniGrid-FourRooms-v0", i[1], i[0])
    #     data = mgdg.sample_data(500)
    #     #print(data)
    #     np.save(f"{PERSISTENT_DATA_PATH}/FourRooms/FourRooms{i[1]}_{i[0]}.npy", data)


    # for i in [(0, "preference"), (0.5, "preference"), (1, "preference")]:
    #     mgdg = FourRoomDataGenerator("MiniGrid-FourRooms-v0", i[1], i[0])
    #     data = mgdg.sample_data(500)
    #     np.save(f"persistent_data/FourRooms/FourRooms{i[1]}_{i[0]}.npy", data)
    #
    # data_size = 100000
    # for i in [(1, "action_advising")]:
    #     mgdg = FourRoomDataGenerator("MiniGrid-FourRooms-v0", i[1], i[0])
    #     data = mgdg.sample_data(data_size, seed_list=list(range(1919, 1919+data_size)))
    #     np.save(f"{PERSISTENT_DATA_PATH}/FourRooms/FourRooms{i[1]}_{i[0]}_{data_size}.npy", data)
    data_size = 100000
    for i in [(0, "binary_feedback"), (0.5, "binary_feedback"), (1, "binary_feedback")]:
        mgdg = FourRoomDataGenerator("MiniGrid-FourRooms-v0", i[1], i[0])
        data = mgdg.sample_data(20000, seed_list=list(range(1919, 1919+20000)))
        #print(data)
        np.save(f"{PERSISTENT_DATA_PATH}/FourRooms/FourRooms{i[1]}_{i[0]}.npy", data)
