import itertools

import crafter
import gym
import numpy as np
from PIL import Image

from gym import spaces
# from balrog.environments import Strings

class Strings(spaces.Space):
    """A custom Gym space for managing discrete string-based actions."""

    def __init__(self, values, seed=None):
        super().__init__((len(values),), str, seed)
        self._dict = {value: i for i, value in enumerate(values)}
        self._values = values

    def sample(self):
        return self.np_random.choice(self._values)

    def map(self, action):
        return self._dict[action]

    def contains(self, value):
        return value in self._dict

    def __iter__(self):
        return self._values.__iter__()

ACTIONS = [
    "Noop",
    "Move West",
    "Move East",
    "Move North",
    "Move South",
    "Do",
    "Sleep",
    "Place Stone",
    "Place Table",
    "Place Furnace",
    "Place Plant",
    "Make Wood Pickaxe",
    "Make Stone Pickaxe",
    "Make Iron Pickaxe",
    "Make Wood Sword",
    "Make Stone Sword",
    "Make Iron Sword",
]

id_to_item = [0] * 19


# table, tree(W) / stone(R), sand(B), skeleton / cow(B), coal / player(@), plant

dummyenv = crafter.Env()
for name, ind in itertools.chain(dummyenv._world._mat_ids.items(), dummyenv._sem_view._obj_ids.items()):
    name = (
        str(name)[str(name).find("objects.") + len("objects.") : -2].lower() if "objects." in str(name) else str(name)
    )
    id_to_item[ind] = name
player_idx = id_to_item.index("player")

del dummyenv

vitals = [
    "health",
    "food",
    "drink",
    "energy",
]

rot = np.array([[0, -1], [1, 0]])
directions = ["front", "right", "back", "left"]


def describe_status_inventory(info):

    result = "Your status:\n"
    result += describe_special_status(info)
    status_str = "{}".format("\n".join(["- {}: {}/9".format(v, info["inventory"][v]) for v in vitals]))
    result += status_str + "\n\n"

    inventory_str = "\n".join(
        ["- {}: {}".format(i, num) for i, num in info["inventory"].items() if i not in vitals and num != 0]
    )
    inventory_str = (
        "Your inventory:\n{}".format(inventory_str) if inventory_str else "Your inventory: nothing"
    )
    result += inventory_str  # + "\n\n"

    return result.strip()


REF = np.array([0, 1])


def rotation_matrix(v1, v2):
    dot = np.dot(v1, v2)
    cross = np.cross(v1, v2)
    rotation_matrix = np.array([[dot, -cross], [cross, dot]])
    return rotation_matrix


def describe_loc(ref, P):
    desc = []
    if ref[1] > P[1]:
        desc.append("north")
    elif ref[1] < P[1]:
        desc.append("south")
    if ref[0] > P[0]:
        desc.append("west")
    elif ref[0] < P[0]:
        desc.append("east")

    return "-".join(desc)


def get_facing_direction(facing):
    """
    Convert facing vector to direction name.
    facing[0] = -1 -> west, +1 -> east
    facing[1] = -1 -> south, +1 -> north
    """
    if facing[0] == -1 and facing[1] == 0:
        return "west"
    elif facing[0] == 1 and facing[1] == 0:
        return "east"
    elif facing[0] == 0 and facing[1] == -1:
        return "north"
    elif facing[0] == 0 and facing[1] == 1:
        return "south"
    else:
        return "unknown"


def describe_env(info):
    assert info["semantic"][info["player_pos"][0], info["player_pos"][1]] == player_idx # always 13
    semantic = info["semantic"][ ### Observe up to 4 steps away
        max(0, info["player_pos"][0] - info["view"][0] // 2) : min(64, info["player_pos"][0] + info["view"][0] // 2 + 1),  # 9 grid
        max(0, info["player_pos"][1] - info["view"][1] // 2 + 1) : min(64, info["player_pos"][1] + info["view"][1] // 2),  # 7 grid
    ]
    center = np.array([info["view"][0] // 2, info["view"][1] // 2 - 1]) # 4,3
    result = ""

    facing = info["player_facing"]
    player_positions = np.where(semantic == player_idx)
    player_pos = np.array([player_positions[0][0], player_positions[1][0]]) if len(player_positions[0]) > 0 else center.copy()

    facing_direction = get_facing_direction(facing)
    
    target_x = player_pos[0] + facing[0]
    target_y = player_pos[1] + facing[1]

    max_x, max_y = semantic.shape

    if 0 <= target_x < max_x and 0 <= target_y < max_y:
        target_id = semantic[int(target_x), int(target_y)]
        target_item = id_to_item[target_id]
        obs = "You are facing {} at your front ({} direction).".format(target_item, facing_direction)

    else:
        obs = "You are facing the edge of the map ({} direction).".format(facing_direction)

    x = np.arange(semantic.shape[1])
    y = np.arange(semantic.shape[0])
    x1, y1 = np.meshgrid(x, y)
    loc = np.stack((y1, x1), axis=-1)
    dist = np.absolute(player_pos - loc).sum(axis=-1)
    obj_info_list = []

    for idx in np.unique(semantic):
        if idx == player_idx:
            continue

        smallest = np.unravel_index(np.argmin(np.where(semantic == idx, dist, np.inf)), semantic.shape)
        obj_info_list.append(
            (
                id_to_item[idx],
                dist[smallest],
                describe_loc(np.array([0, 0]), smallest - center),
            )
        )

    if len(obj_info_list) > 0:
        status_str = "You see:\n{}".format(
            "\n".join(["- {} {} steps to your {}".format(name, dist, loc) for name, dist, loc in obj_info_list])
        )
    else:
        status_str = "You see nothing away from you."
    result += status_str + "\n\n"
    result += obs.strip()

    return result.strip()

def describe_env_grid(info):
    # ['None' (.) , 'water' (~) , 'grass' (.) , 'stone'(R), 'path'(D), 'sand'(.), 'tree'(T), 'lava'(L), 'coal'(C), 'iron'(I), 'diamond'(D), 'table'(W), 'furnace'(F), 'player'(@), 'cow'(B), 'zombie'(Z), 'skeleton'(S), 'arrow'(A), 'plant'(P)]

    assert info["semantic"][info["player_pos"][0], info["player_pos"][1]] == player_idx
    semantic = info["semantic"][
        max(0, info["player_pos"][0] - info["view"][0] // 2) : min(64, info["player_pos"][0] + info["view"][0] // 2 + 1),
        max(0, info["player_pos"][1] - info["view"][1] // 2 + 1) : min(64, info["player_pos"][1] + info["view"][1] // 2),
    ]
    
    # Item to character mapping
    item_to_char = {
        'none': '.',
        'water': '~',
        'grass': '.',
        'stone': 'R',
        'path': 'G',
        'sand': 'S',
        'tree': 'T',
        'lava': 'L',
        'coal': 'C',
        'iron': 'I',
        'diamond': 'D',
        'table': 'W',
        'furnace': 'F',
        'player': '@',
        'cow': 'B',
        'zombie': 'Z',
        'skeleton': 'S',
        'arrow': 'A',
        'plant': 'P'
    }

    center = np.array([info["view"][0] // 2, info["view"][1] // 2 - 1]) # 4,3

    facing = info["player_facing"]
    player_positions = np.where(semantic == player_idx)
    player_pos = np.array([player_positions[0][0], player_positions[1][0]]) if len(player_positions[0]) > 0 else center.copy()

    facing_direction = get_facing_direction(facing)

    target_x = player_pos[0] + facing[0]
    target_y = player_pos[1] + facing[1]

    max_x, max_y = semantic.shape

    if 0 <= target_x < max_x and 0 <= target_y < max_y:
        target_id = semantic[int(target_x), int(target_y)]
        target_item = id_to_item[target_id]
        obs = "Facing info:\n- You are facing {} at your front.\n- You are facing {} direction.\n".format(target_item, facing_direction)

    else:
        obs = "Facing info:\n- You are facing the edge of the map.\n- You are facing {} direction.\n".format(facing_direction)
    
    # Create a grid of item symbols
    grid = []
    for row in semantic:
        grid_row = []
        for item_id in row:
            item_name = id_to_item[item_id]
            if item_id == player_idx:
                grid_row.append('@')  # Player
            elif item_name in item_to_char:
                grid_row.append(item_to_char[item_name])
            else:
                grid_row.append('.')  # Default for unknown items
        grid.append(grid_row)
    
    # Transpose the grid
    transposed_grid = list(map(list, zip(*grid)))
    
    # Convert to string representation
    grid_str = obs + "\n" + "Grid:\n"
    for row in transposed_grid:
        grid_str += " ".join(row) + "\n"
    
    return grid_str.strip()

def describe_act(action):
    result = ""

    action_str = action.replace("do_", "interact_")
    action_str = action_str.replace("move_up", "move_north")
    action_str = action_str.replace("move_down", "move_south")
    action_str = action_str.replace("move_left", "move_west")
    action_str = action_str.replace("move_right", "move_east")

    act = "You took action {}.".format(action_str)
    result += act

    return result.strip()


def describe_special_status(info):
    if info["sleeping"]:
        return "You are sleeping\n"
    elif info["dead"]:
        return "You died.\n"
    else:
        return ""


def describe_frame(info):
    try:
        result = ""

        result += describe_special_status(info)
        result += "\n\n"
        try:
            result += describe_env(info)
            # result += describe_env_grid(info)
        except Exception as e:
            import traceback
            error_msg = traceback.format_exc()
            print(f"Error in describe_env: {e}\n{error_msg}")
            result += "Failed to describe environment. Error encountered."
        result += "\n\n"

        return result.strip(), describe_status_inventory(info)
    except Exception as e:
        import traceback
        error_msg = traceback.format_exc()
        print(f"Error in describe_frame: {e}\n{error_msg}")
        return f"Error: {e}. Function: describe_frame"

def describe_frame_grid(info):
    try:
        result = describe_env_grid(info)
        result += "\n"
        return result.strip(), describe_status_inventory(info)

    except Exception as e:
        import traceback
        error_msg = traceback.format_exc()
        print(f"Error in describe_frame: {e}\n{error_msg}")
        return f"Error: {e}. Function: describe_frame"


class CrafterLanguageWrapper(gym.Wrapper):
    default_iter = 10
    default_steps = 10000

    def __init__(
        self,
        env,
        task="",
        max_episode_steps=2,
    ):
        super().__init__(env)

        self.score_tracker = 0
        self.language_action_space = Strings(ACTIONS)
        self.default_action = "Noop"
        self.max_steps = max_episode_steps
        self.achievements = None

    def get_text_action(self, action):
        return self.language_action_space._values[action]

    def _step_impl(self, action):
        obs, reward, done, info = super().step(action)
        # extra stuff for language wrapper
        aug_info = info.copy()
        aug_info["sleeping"] = self.env._player.sleeping
        aug_info["player_facing"] = self.env._player.facing
        aug_info["dead"] = self.env._player.health <= 0
        aug_info["unlocked"] = {
            name
            for name, count in self.env._player.achievements.items()
            if count > 0 and name not in self.env._unlocked
        }
        aug_info["view"] = self.env._view
        return obs, reward, done, aug_info

    def reset(self):
        self.env.reset()
        obs, reward, done, info = self._step_impl(0)
        self.score_tracker = 0
        self.achievements = None
        return self.process_obs(obs, info), reward, done, info
        # return self.process_obs_grid(obs, info), reward, done, info

    def step(self, action):
        # obs, reward, done, info = self._step_impl(self.language_action_space.map(action))
        obs, reward, done, info = self._step_impl(action)

        self.score_tracker = self.update_progress(info)
        obs = self.process_obs(obs, info)
        # obs = self.process_obs_grid(obs, info)
        return obs, reward, done, info

    def process_obs(self, obs, info):
        img = Image.fromarray(self.env.render()).convert("RGB")
        long_term_context, short_term_context = describe_frame(info)  # describe_status & describe_env / describe_inventory

        return {
            "text": {
                "long_term_context": long_term_context, # env info
                "short_term_context": short_term_context, # status & inventory info 
            },
            "image": img,
            "obs": obs,
        }

    def process_obs_grid(self, obs, info):
        img = Image.fromarray(self.env.render()).convert("RGB")
        grid_info, status_inventory_info = describe_frame_grid(info)  # describe_status & describe_env / describe_inventory

        return {
            "text": {
                "grid_info": grid_info,
                "status_inventory_info": status_inventory_info,
            },
            "image": img,
            "obs": obs,
        }

    def update_progress(self, info):
        self.score_tracker = 0 + sum([1.0 for k, v in info["achievements"].items() if v > 0])
        self.achievements = info["achievements"]
        return self.score_tracker

    def get_stats(self):
        return {
            "score": self.score_tracker,
            "progression": float(self.score_tracker) / 22.0,
            "achievements": self.achievements,
        }
    
    def is_sleeping(self):
        # Access the latest player state directly through env._player
        if hasattr(self.env, '_player') and self.env._player is not None:
            return self.env._player.sleeping
        elif hasattr(self, '_player') and self._player is not None:
            return self._player.sleeping
        else:
            print("Player is None in both self and self.env")
            return False
