import copy
from abc import ABC, abstractmethod
from typing import List, Union, Dict
import gymnasium as gym
import numpy as np

from DataGenerators.DataGenerator import DataGenerator
from config import PERSISTENT_DATA_PATH
import os
import gymnasium as gym
import stable_baselines3
from stable_baselines3 import PPO
import numpy as np
import torch
import torch.nn as nn
from stable_baselines3 import A2C, PPO

from minigrid.wrappers import ImgObsWrapper, FullyObsWrapper, PositionBonus, ActionBonus
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor


def minigrid_full_render_without_mask(env):

    img = env.grid.render(
        env.tile_size,
        env.agent_pos,
        env.agent_dir,
        highlight_mask=None,
    )

    return img

def observation_to_grid_representation(env):
    """
    Converts a MiniGrid environment observation into a grid representation.

    Args:
        env: The MiniGrid environment object.
    Returns:
        A 2D NumPy array representing the grid, where different objects are shown as characters.
    """
    grid = np.full((env.width, env.height), ' ')  # Create an empty grid filled with spaces

    # Iterate through each cell in the grid and map it to a corresponding symbol
    for x in range(env.width):
        for y in range(env.height):
            obj = env.grid.get(x, y)
            if obj is None:
                grid[x, y] = '.'
            else:
                grid[x, y] = obj_to_char(obj)

    # Add the agent's position to the grid representation
    agent_pos = env.agent_pos
    agent_dir = env.agent_dir  # Agent's direction: 0=right, 1=down, 2=left, 3=up
    grid[agent_pos[0], agent_pos[1]] = agent_dir_to_char(agent_dir)

    return grid.T

def obj_to_char(obj):
    """
    Maps MiniGrid objects to characters for the grid representation.

    Args:
        obj: The object from the MiniGrid environment.

    Returns:
        A single character representing the object in the grid.
    """
    if obj.type == 'wall':
        return '#'
    elif obj.type == 'goal':
        return 'G'
    elif obj.type == 'lava':
        return 'L'
    elif obj.type == 'door':
        return 'D'
    elif obj.type == 'key':
        return 'K'
    elif obj.type == 'ball':
        return 'B'
    elif obj.type == 'box':
        return 'X'
    elif obj.type == 'water':
        return 'W'
    else:
        return '?'  # Unknown objects

def agent_dir_to_char(agent_dir):
    """
    Maps the agent's direction to a character for the grid representation.

    Args:
        agent_dir: The direction the agent is facing (0=right, 1=down, 2=left, 3=up).

    Returns:
        A single character representing the agent's facing direction.
    """
    #     AGENT_DIR_TO_STR = {0: ">", 1: "V", 2: "<", 3: "^"}
    AGENT_DIR_TO_STR = {0: ">", 1: "V", 2: "<", 3: "^"}
    if agent_dir in AGENT_DIR_TO_STR:
        return AGENT_DIR_TO_STR[agent_dir]
    else:
        return 'A'  # Default character for the agent if direction is unknown


def matrix_rot_90_counterclockwise(m):
    tmp = np.rot90(m)
    agent_pos = None
    for i, v in np.ndenumerate(tmp):
        if v in ["V", "^", "<", ">"]:
            agent_pos = i
            break
    rotation_dict = {
        "V": ">",
        ">": "^",
        "^": "<",
        "<": "V"
    }
    if agent_pos != None:
        tmp[agent_pos] = rotation_dict[tmp[agent_pos]]
    return tmp


def rotate_to_egocentric(m):
    rotation_number_dict = {
        "V": 2,
        ">": 1,
        "^": 0,
        "<": 3
    }
    agent_pos = None
    for i, v in np.ndenumerate(m):
        if v in ["V", "^", "<", ">"]:
            agent_pos = i
            break

    if agent_pos == None:
        return m

    tmp = m
    rotation_number = rotation_number_dict[m[agent_pos]]
    tmp[agent_pos] = "^"

    for i in range(rotation_number):
        tmp = np.rot90(tmp)
    return tmp


class MinigridFeaturesExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.Space, features_dim: int = 512,
                 normalized_image: bool = False) -> None:
        super().__init__(observation_space, features_dim)
        n_input_channels = observation_space.shape[0]
        self.cnn = nn.Sequential(
            nn.Conv2d(n_input_channels, 16, (2, 2)),
            nn.ReLU(),
            nn.Conv2d(16, 32, (2, 2)),
            nn.ReLU(),
            nn.Conv2d(32, 64, (2, 2)),
            nn.ReLU(),
            nn.Flatten(),
        )

        # Compute shape by doing one forward pass
        with torch.no_grad():
            n_flatten = self.cnn(torch.as_tensor(observation_space.sample()[None]).float()).shape[1]

        self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())

    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        return self.linear(self.cnn(observations))


class DoorkeyDataGenerator(DataGenerator):
    def __init__(self, env, type, distribution,
                 expert_path=PERSISTENT_DATA_PATH + "/expert/doorkey_fully_observable_feature_dim5seed2performance0.9602200000000002.pkl"):

        """

        :param env:
        :param type:
        :param distribution:  1: full expert, 0 fully random
        :param expert_path:
        Abandoned, bug in holding key flag
        """
        super().__init__(env, type, distribution)
        policy_kwargs = dict(
            features_extractor_class=MinigridFeaturesExtractor,
            features_extractor_kwargs=dict(features_dim=5)
        )

        self.expert = PPO.load(expert_path, custom_objects={"policy_kwargs": policy_kwargs, "learning_rate": 2.5e-4,
                                                            "clip_range": 0.4}, device="cpu")
        self.env = FullyObsWrapper(self.env)
        self.env = ImgObsWrapper(self.env)

    def sample_array_of_states(self, number=48, cutoff_length=20) -> np.array:
        """
        This function samples an array of states from the environment.
        :param number: number of states to sample
        :return: np.array of states
        """
        expert_action_percentage = self.distribution
        state_list = []
        episode_count = 0
        episode_hold_key_flag = False
        episode_unlock_door_flag = False

        obs, _ = self.env.reset()
        for i in range(number):

            expert_actions = self.get_expert_actions(obs)
            expert_action = expert_actions[0]
            # With action mask
            random_action = np.random.choice([0, 1, 2, 3, 5])
            if np.random.rand() < expert_action_percentage:
                action_taken = expert_action
            else:
                action_taken = random_action

            obs_string = observation_to_grid_representation(self.env)
            obs_dict = {"Observation": obs, "ObsString": obs_string, "HoldingKey": episode_hold_key_flag,
                        "DoorUnlocked": episode_unlock_door_flag, "expert_actions": expert_actions}

            if expert_action == 3 and action_taken == 3:
                # Pick up the key
                episode_hold_key_flag = True
            if expert_action == 5 and action_taken == 5:
                # Unlock the door
                episode_unlock_door_flag = True

            state_list.append(obs_dict)
            next_obs, reward, done, _, _ = self.env.step(action_taken)
            if done or episode_count > cutoff_length:
                obs, _ = self.env.reset()
                episode_count = 0
                episode_hold_key_flag = False
                episode_unlock_door_flag = False
            else:
                obs = next_obs
        self.state_array = np.asarray(state_list)
        return self.state_array

    def sample_array_of_trajectories(self, number=48) -> List[Dict]:
        pass

    def sample_data(self, number=48) -> List[Dict]:
        ret_list = []
        masked_action_list = [0, 1, 2, 3, 5]
        self.sample_array_of_states(number)
        for i in self.state_array:
            probs = self.get_expert_probabilities(i["Observation"])
            if self.type == "binary_feedback":
                for j in masked_action_list:
                    ret_list.append(
                        {"state": i, "action": j, "feedback": self.get_expert_binary_feedback(i["Observation"], i),
                         "probs": probs})
            elif self.type == "preference":
                for j in masked_action_list:
                    for k in masked_action_list:
                        if j != k:
                            ret_list.append({"state": i, "action1": j, "action2": k,
                                             "feedback": self.get_expert_preference(i["Observation"], j, k),
                                             "probs": probs})
            elif self.type == "action_advising":
                ret_list.append(
                    {"state": i, "feedback": self.get_expert_action_advising(i["Observation"]), "probs": probs})

        return ret_list

    def get_expert_actions(self, state) -> List:
        """
        This function returns the expert actions for a given state.
        :param state:
        :return:
        """
        # index_to_action = {
        #     0 : 0,
        #     1 : 1,
        #     2 : 2,
        #     3 : 3,
        #     4 : 5
        # }
        # print(state)
        probs = self.get_expert_probabilities(state)
        sorted_values, sorted_index = torch.sort(probs, descending=True)
        # print(probs, sorted_values, sorted_index)
        if (sorted_values[0] - sorted_values[1]) < 0.2 and sorted_values[0] + sorted_values[1] > 0.8:
            return [(sorted_index[0]), (sorted_index[1])]
        else:
            return [(sorted_index[0])]

    def get_expert_probabilities(self, state) -> torch.Tensor:
        """
        This function returns the expert actions for a given state.
        :param state:
        :return:
        """
        tmp = self.expert.policy.extract_features(torch.from_numpy(state).float().unsqueeze(0).permute(0, 3, 1, 2))

        tmp = self.expert.policy.mlp_extractor.policy_net(tmp)
        probs = self.expert.policy.action_net(tmp)
        probs = torch.nn.functional.softmax(probs, dim=-1)[0]
        return probs

    def get_expert_qvalue(self, state, action) -> float:
        """
        This function returns the expert q-value for a given state-action pair.
        Not actual learned qvalues but a heuristic.
        :param state:
        :param action:
        :return:
        """
        expert_actions = self.get_expert_actions(state)
        if action in expert_actions:
            return 1
        else:
            return -1

    def test_expert_policy(self):
        """
        This function tests the expert policy on the environment.
        :return:
        """
        state, _ = self.env.reset()
        done = False
        steps = 0
        while not done and steps < 100:
            actions = self.get_expert_actions(state)
            action = actions[0]
            next_state, _, done, _, _ = self.env.step(action)
            print(state, action, next_state)
            state = next_state
            steps += 1

    def get_expert_value(self, state) -> float:
        return NotImplemented


from collections import deque


def find_all_optimal_paths_doorkey(maze, expected_maximum_length=15, holding_key=False, door_unlocked=False):
    # Directions and their corresponding deltas
    DIRS = {'^': (-1, 0), 'V': (1, 0), '<': (0, -1), '>': (0, 1)}
    DIR_ORDER = ['^', '>', 'V', '<']  # Clockwise rotation

    # Helper function to turn left
    def turn_left(direction):
        return DIR_ORDER[(DIR_ORDER.index(direction) - 1) % 4]

    # Helper function to turn right
    def turn_right(direction):
        return DIR_ORDER[(DIR_ORDER.index(direction) + 1) % 4]

    def is_facing(agent_pos, agent_dir, target_pos):
        dr, dc = DIRS[agent_dir]
        return (agent_pos[0] + dr, agent_pos[1] + dc) == target_pos

    # Find initial position, goal, key, and door positions
    start_pos, start_dir = None, None
    goal_pos, key_pos, door_pos = None, None, None
    for r, row in enumerate(maze):
        for c, cell in enumerate(row):
            if cell in DIRS:
                start_pos = (r, c)
                start_dir = cell
            elif cell == 'G':
                goal_pos = (r, c)
            elif cell == 'K':
                key_pos = (r, c)
            elif cell == 'D':
                door_pos = (r, c)

    # BFS setup
    visited = set()  # To avoid revisiting states
    optimal_paths = []
    min_steps = float('inf')
    queue = deque([(start_pos, start_dir, holding_key, door_unlocked, [], 0,
                    visited)])  # (pos, dir, has_key, door_unlocked, path, steps)

    while queue:
        pos, direction, has_key, door_unlocked, path, steps, visited = queue.popleft()

        if len(path) > expected_maximum_length:
            continue
        # print(len(queue))

        # Check if we reached the goal
        if pos == goal_pos and has_key and door_unlocked:
            if steps < min_steps:
                min_steps = steps
                optimal_paths = [path]
            elif steps == min_steps:
                optimal_paths.append(path)
            continue
        r, c = pos

        # Pick up key (if adjacent)
        if not has_key and is_facing(pos, direction, key_pos):
            new_state = (pos, direction, True, door_unlocked)
            if new_state not in visited:
                new_visted = copy.deepcopy(visited)
                new_visted.add(new_state)
                queue.append((pos, direction, True, door_unlocked, path + ['pick_key'], steps + 1, new_visted))

        # Unlock door (if adjacent)
        if has_key and not door_unlocked and is_facing(pos, direction, door_pos):

            new_state = (pos, direction, has_key, True)
            if new_state not in visited:
                new_visted = copy.deepcopy(visited)
                new_visted.add(new_state)
                queue.append((pos, direction, has_key, True, path + ['unlock_door'], steps + 1, new_visted))

        # Turn left
        new_dir = turn_left(direction)
        new_state = (pos, new_dir, has_key, door_unlocked)
        if new_state not in visited:
            new_visted = copy.deepcopy(visited)
            new_visted.add(new_state)
            queue.append((pos, new_dir, has_key, door_unlocked, path + ['turn_left'], steps + 1, new_visted))

        # Turn right
        new_dir = turn_right(direction)
        new_state = (pos, new_dir, has_key, door_unlocked)
        if new_state not in visited:
            new_visted = copy.deepcopy(visited)
            new_visted.add(new_state)
            queue.append((pos, new_dir, has_key, door_unlocked, path + ['turn_right'], steps + 1, new_visted))

        # Move forward
        dr, dc = DIRS[direction]
        new_pos = (r + dr, c + dc)
        if 0 <= new_pos[0] < len(maze) and 0 <= new_pos[1] < len(maze[0]) and maze[new_pos[0]][new_pos[1]] != '#':
            new_state = (new_pos, direction, has_key, door_unlocked)
            if new_state not in visited:
                new_visted = copy.deepcopy(visited)
                new_visted.add(new_state)
                queue.append((new_pos, direction, has_key, door_unlocked, path + ['move'], steps + 1, new_visted))
    return optimal_paths


class DoorkeyDataGeneratorBFS(DataGenerator):
    def __init__(self, env, type, distribution):

        """
        :param env:
        :param type:
        :param distribution:  1: full expert, 0 fully random
        """
        super().__init__(env, type, distribution)
        self.env = FullyObsWrapper(self.env)


    def sample_array_of_states(self, number=48, cutoff_length=20, seed_list=[]) -> np.array:
        """
        This function samples an array of states from the environment.
        :param number: number of states to sample
        :return: np.array of states
        """
        expert_action_percentage = self.distribution
        state_list = []
        episode_count = 0
        seed_count = 0
        obs, _ = self.env.reset(seed=seed_list[seed_count])
        holding_key_flag = False
        door_unlocked_flag = False
        history_list = []
        for i in range(number):

            matrix_state = observation_to_grid_representation(self.env)
            expert_actions, paths = self.get_expert_actions(matrix_state, return_path=True,
                                                            holding_key=holding_key_flag,
                                                            door_unlocked=door_unlocked_flag)

            # if len(expert_actions) == 0:
            #     print(matrix_state)
            #     print(expert_actions, paths)

            expert_action = expert_actions[0]

            # With action mask
            random_action = np.random.choice([0, 1, 2, 3, 5])
            if np.random.rand() < expert_action_percentage:
                action_taken = expert_action
            else:
                action_taken = random_action

            image = minigrid_full_render_without_mask(self.env)


            obs_dict = {"Observation": obs, "ObsString": matrix_state, "HoldingKey": holding_key_flag,
                        "DoorUnlocked": door_unlocked_flag, "expert_actions": expert_actions, "expert_paths": paths,
                        "seed": seed_list[seed_count], "action_taken": action_taken, "history": history_list.copy(), "image": image}

            state_list.append(obs_dict)

            if expert_action == 3 and action_taken == 3:
                holding_key_flag = True
            if expert_action == 5 and action_taken == 5:
                door_unlocked_flag = True

            next_obs, reward, done, _, _ = self.env.step(action_taken)
            next_image = minigrid_full_render_without_mask(self.env)

            extra_info = ""
            if holding_key_flag:
                extra_info += "You have the key."
            else:
                extra_info += "You don't have the key."
            if door_unlocked_flag:
                extra_info += "The door is unlocked."
            else:
                extra_info += "The door is locked."
            history_list.append({"state": matrix_state, "action": action_taken, "next_state": observation_to_grid_representation(self.env), "done": done, "extra":extra_info, "image": image, "next_image": next_image})


            if done or episode_count > cutoff_length:
                seed_count += 1
                obs, _ = self.env.reset(seed=seed_list[seed_count])
                episode_count = 0
                holding_key_flag = False
                door_unlocked_flag = False

                history_list = []
            else:
                obs = next_obs
        self.state_array = np.asarray(state_list)
        return self.state_array

    def sample_data(self, number=48, cutoff_length=20, seed_list=[]) -> List[Dict]:
        ret_list = []
        masked_action_list = [0, 1, 2, 3, 5]
        self.sample_array_of_states(number, cutoff_length, seed_list)
        for i in self.state_array:
            if self.type == "binary_feedback":
                for j in masked_action_list:
                    feedback = j in i["expert_actions"]
                    ret_list.append({"state": i, "action": j, "feedback": feedback})
            elif self.type == "preference":
                for j in masked_action_list:
                    for k in masked_action_list:
                        if j != k:
                            preference = 0
                            if j in i["expert_actions"] and k in i["expert_actions"]:
                                preference = 0
                            if (not j in i["expert_actions"]) and (not k in i["expert_actions"]):
                                preference = 0
                            if j in i["expert_actions"] and (not k in i["expert_actions"]):
                                preference = 1
                            if (not j in i["expert_actions"]) and (k in i["expert_actions"]):
                                preference = -1

                            ret_list.append({"state": i, "action1": j, "action2": k, "feedback": preference})
            elif self.type == "action_advising":
                ret_list.append({"state": i, "feedback": i["expert_actions"]})

        return ret_list

    def sample_array_of_trajectories(self, number) -> List[Dict]:
        pass

    def get_expert_actions(self, env, return_path=False, holding_key=False, door_unlocked=False) -> List:
        """
        This function returns the expert actions for a given state.
        :param state:
        :return:
        """
        doorkey_str_to_action_dict = {
            "turn_left": 0,
            "turn_right": 1,
            "move": 2,
            "pick_key": 3,
            "unlock_door": 5
        }
        paths = find_all_optimal_paths_doorkey(env, holding_key=holding_key, door_unlocked=door_unlocked)
        actions = [doorkey_str_to_action_dict[path[0]] for path in paths]
        # print(
        #     paths
        # )
        if return_path:
            return actions, paths
        else:
            return actions

    def get_expert_probabilities(self, state) -> torch.Tensor:
        """
        This function returns the expert actions for a given state.
        :param state:
        :return:
        """
        return None

    def get_expert_qvalue(self, state, action) -> float:
        """
        This function returns the expert q-value for a given state-action pair.
        Not actual learned qvalues but a heuristic.
        :param state:
        :param action:
        :return:
        """
        expert_actions = self.get_expert_actions(state)
        if action in expert_actions:
            return 1
        else:
            return -1

    def test_expert_policy(self):
        """
        This function tests the expert policy on the environment.
        :return:
        """
        state, _ = self.env.reset()
        done = False
        steps = 0
        while not done and steps < 100:
            actions = self.get_expert_actions(state)
            action = actions[0]
            next_state, _, done, _, _ = self.env.step(action)
            print(state, action, next_state)
            state = next_state
            steps += 1

    def get_expert_value(self, state) -> float:
        return NotImplemented



if __name__ == "__main__":
    # for i in [(1, "action_advising")]:
    #     mgdg = DoorkeyDataGeneratorBFS("MiniGrid-DoorKey-5x5-v0", i[1], i[0])
    #     data = mgdg.sample_data(100, 20, seed_list=list(range(1919, 1919 + 100)))
    #     # np.save(f"persistent_data/Doorkey/doorkey_{i[1]}_{i[0]}.npy", data)
    #     print(data)



    # for i in [(0, "preference"), (0.5, "preference"), (1, "preference")]:
    #     mgdg = DoorkeyDataGeneratorBFS("MiniGrid-DoorKey-5x5-v0", i[1], i[0])
    #     data = mgdg.sample_data(100, 20, seed_list=list(range(1919, 1919 + 100)))
    #     np.save(f"{PERSISTENT_DATA_PATH}/Doorkey/doorkey_{i[1]}_{i[0]}.npy", data)

    # for i in [(0, "action_advising"), (0.5, "action_advising"), (1, "action_advising")]:
    #     mgdg = DoorkeyDataGeneratorBFS("MiniGrid-DoorKey-5x5-v0", i[1], i[0])
    #     data = mgdg.sample_data(100, 20, seed_list=list(range(1919, 1919 + 100)))
    #     np.save(f"{PERSISTENT_DATA_PATH}/Doorkey/doorkey_{i[1]}_{i[0]}.npy", data)
    #     # print(data)
    #
    # for i in [(0, "binary_feedback"), (0.5, "binary_feedback"), (1, "binary_feedback")]:
    #     mgdg = DoorkeyDataGeneratorBFS("MiniGrid-DoorKey-5x5-v0", i[1], i[0])
    #     data = mgdg.sample_data(100, 20, seed_list=list(range(1919, 1919 + 100)))
    #     np.save(f"{PERSISTENT_DATA_PATH}/Doorkey/doorkey_{i[1]}_{i[0]}.npy", data)
    #


    # for i in [(1, "action_advising"), (0.5, "action_advising")]:
    #     data_size = 100000
    #     mgdg = DoorkeyDataGeneratorBFS("MiniGrid-DoorKey-5x5-v0", i[1], i[0])
    #     data = mgdg.sample_data(data_size, 40, seed_list=list(range(1919, 1919 + data_size)))
    #     np.save(f"{PERSISTENT_DATA_PATH}/Doorkey/doorkey_{i[1]}_{i[0]}_{data_size}_full_image.npy", data)
        # print(data)

    # data_size = 60000
    # for i in [(1, "preference"), (0.5, "preference"), (0, "preference")]:
    #     mgdg = DoorkeyDataGeneratorBFS("MiniGrid-DoorKey-5x5-v0", i[1], i[0])
    #     data = mgdg.sample_data(data_size, 20, seed_list=list(range(1919, 1919 + data_size)))
    #     np.save(f"{PERSISTENT_DATA_PATH}/Doorkey/doorkey_{i[1]}_{i[0]}_{data_size}_full_image.npy", data)


    data_size = 1000
    for i in [(1, "action_advising"), (0.5, "action_advising"), (0, "action_advising"), ]:
        mgdg = DoorkeyDataGeneratorBFS("MiniGrid-DoorKey-5x5-v0", i[1], i[0])
        data = mgdg.sample_data(data_size, 20, seed_list=list(range(1919, 1919 + data_size)))
        np.save(f"{PERSISTENT_DATA_PATH}/Doorkey/doorkey_{i[1]}_{i[0]}_{data_size}_with_history_image.npy", data)

    data_size = 1000
    for i in [(1, "binary_feedback"), (0.5, "binary_feedback"), (0, "binary_feedback"), ]:
        mgdg = DoorkeyDataGeneratorBFS("MiniGrid-DoorKey-5x5-v0", i[1], i[0])
        data = mgdg.sample_data(data_size, 20, seed_list=list(range(1919, 1919 + data_size)))
        np.save(f"{PERSISTENT_DATA_PATH}/Doorkey/doorkey_{i[1]}_{i[0]}_{data_size}_with_history_image.npy", data)

    data_size = 1000
    for i in [(1, "preference"), (0.5, "preference"), (0, "preference"), ]:
        mgdg = DoorkeyDataGeneratorBFS("MiniGrid-DoorKey-5x5-v0", i[1], i[0])
        data = mgdg.sample_data(data_size, 20, seed_list=list(range(1919, 1919 + data_size)))
        np.save(f"{PERSISTENT_DATA_PATH}/Doorkey/doorkey_{i[1]}_{i[0]}_{data_size}_with_history_image.npy", data)
