from dataclasses import dataclass
from typing import Literal

import gymnasium as gym
import numpy as np
from gymnasium import spaces
from PIL import Image, ImageDraw


@dataclass
class Node:
    """A single room in the tree."""

    depth: int
    sibling_idx: int
    """ index within its parent's children list """
    children: list[int]
    """ IDs of child nodes (ordered) """
    lateral: int | None
    """ ID of right sibling reachable laterally """
    coord: tuple[float, float]
    """ normalized (x, y) for visualization """


class TreeTeleportEnv(gym.Env):
    """Teleport-Tree environment compatible with *gymnasium* interface."""

    metadata = {"render_modes": ["rgb_array"], "render_fps": 4}

    # ---------------------------------------------------------------------
    # Construction --------------------------------------------------------
    # ---------------------------------------------------------------------
    def __init__(
        self,
        depth: int = 4,
        branching: int = 2,
        lateral_prob: float = 1.0,
        reward_interval: int = 2,
        reward_value: float = 1.0,
        step_penalty: float = 0.1,
        obs_mode: Literal["id", "compact", "image", "pov_image"] = "id",
        image_size: int = 64,
        pov_size: int = 7,
    ) -> None:
        """Create the environment.

        Parameters
        ----------
        depth : int
            Number of levels including the root (>= 1).
        branching : int
            Children per internal node.
        lateral_prob : float
            Probability of inserting a one-way *right-sibling* teleport.
        reward_interval : int
            Positive reward awarded whenever depth % reward_interval == 0.
        reward_value : float
            Magnitude of incremental reward.
        step_penalty : float
            Negative reward for not descending.
        obs_mode : Literal["id", "compact", "image", "pov_image"]
            Observation type: "id", "compact", "image", or "pov_image".
        image_size : int
            Height/width of square RGB observation (only for obs_mode="image").
        pov_size : int
            Size of the partial observation window (only for obs_mode="pov_image").
        """
        assert depth >= 1, "depth must be >= 1"
        assert branching >= 1, "branching must be >= 1"
        assert obs_mode in {"id", "compact", "image", "pov_image"}, "invalid obs_mode"
        if obs_mode in {"image", "pov_image"} and Image is None:
            raise RuntimeError("Pillow is required for image observations. pip install pillow")

        self.depth_limit = depth
        self.branching = branching
        self.lateral_prob = lateral_prob
        self.reward_interval = reward_interval
        self.reward_value = reward_value
        self.step_penalty = step_penalty
        self.obs_mode = obs_mode
        self.image_size = image_size
        self.pov_size = pov_size

        # populated in reset()
        self.rng: np.random.Generator | None = None
        self.nodes: dict[int, Node] = {}
        self.n_nodes: int = 0
        self.state_id: int = 0  # current node index

        # background for image obs
        self._bg: np.ndarray | None = None  # shape (H, W, 3)

        # observation + action space depend on parameters but not on RNG, so
        # we can initialise them immediately.
        if obs_mode == "id":
            self.observation_space: spaces.Space = spaces.Discrete(1)  # placeholder, fixed in reset()
        elif obs_mode == "compact":
            self.observation_space = spaces.Box(0.0, 1.0, (3,), dtype=np.float32)
        elif obs_mode == "image":
            H = W = image_size
            self.observation_space = spaces.Box(0, 255, (H, W, 3), np.uint8)
        else:  # pov_image
            H = W = pov_size
            self.observation_space = spaces.Box(0, 255, (H, W, 3), np.uint8)

        # action: 0…branching-1 → choose child, branching → lateral
        self.action_space = spaces.Discrete(branching + 1)

    # ------------------------------------------------------------------
    # Public API -------------------------------------------------------
    # ------------------------------------------------------------------
    def reset(self, *, seed: int | None = None, options=None):
        super().reset(seed=seed)
        # seeding ------------------------------------------------------
        self.rng, seed = gym.utils.seeding.np_random(seed)

        # build tree ---------------------------------------------------
        self.nodes.clear()
        self._build_tree()
        self.state_id = 0  # root
        obs = self._get_obs()

        # adjust Discrete size now that we know n_nodes
        if self.obs_mode == "id":
            self.observation_space = spaces.Discrete(self.n_nodes)

        return obs, {}

    def step(self, action: int):
        assert self.action_space.contains(action), "invalid action"
        node = self.nodes[self.state_id]
        current_depth = node.depth
        next_id = self.state_id  # default: stay

        if action < self.branching:  # vertical child teleport
            if action < len(node.children):
                next_id = node.children[action]
        else:  # lateral
            if node.lateral is not None:
                next_id = node.lateral

        self.state_id = next_id
        new_node = self.nodes[self.state_id]

        # reward if hitting depth multiple
        reward = 0.0
        if new_node.depth > current_depth:
            if new_node.depth % self.reward_interval == 0 and new_node.depth != 0:
                reward = self.reward_value
        else:
            reward = -self.step_penalty

        # episode ends when at max depth and no children
        terminated = new_node.depth == self.depth_limit - 1
        truncated = False
        obs = self._get_obs()
        return obs, reward, terminated, truncated, {}

    # ------------------------------------------------------------------
    # Internal helpers -------------------------------------------------
    # ------------------------------------------------------------------
    def _build_tree(self):
        """Populate self.nodes with a full tree and stochastic lateral links."""

        def recurse(node_id: int, depth: int, x: float, gap: float):
            """Recursively create nodes with normalized coords."""
            n = Node(
                depth=depth,
                sibling_idx=node_id % self.branching if depth > 0 else 0,
                children=[],
                lateral=None,
                coord=(x, depth),
            )
            self.nodes[node_id] = n

            if depth < self.depth_limit - 1:
                for i in range(self.branching):
                    child_id = node_id * self.branching + i + 1
                    child_x = x + (i - (self.branching - 1) / 2) * gap
                    n.children.append(child_id)
                    recurse(child_id, depth + 1, child_x, gap / self.branching)

        recurse(0, 0, 0.5, 0.5)  # x in [0,1]
        self.n_nodes = len(self.nodes)

        # lateral links between *consecutive* siblings with probability p
        for _, parent_node in list(self.nodes.items()):
            for idx, child_id in enumerate(parent_node.children[:-1]):
                if self.rng.random() < self.lateral_prob:
                    self.nodes[child_id].lateral = parent_node.children[idx + 1]

        # pre-draw background for image obs
        if self.obs_mode == "image":
            self._bg = self._draw_background()
        elif self.obs_mode == "pov_image":
            self._bg = self._draw_background()

    # ------------------------------------------------------------------
    def _get_obs(self):
        if self.obs_mode == "id":
            return self.state_id
        elif self.obs_mode == "compact":
            node = self.nodes[self.state_id]
            d_norm = node.depth / (self.depth_limit - 1)
            phase = (node.sibling_idx + 0.5) / self.branching
            rflag = float(node.depth % self.reward_interval == 0 and node.depth != 0)
            return np.array([d_norm, phase, rflag], dtype=np.float32)
        else:  # image or pov_image
            full_render = self._draw_agent_on_bg()
            if self.obs_mode == "image":
                return full_render
            else:  # pov_image
                H = W = self.image_size
                node = self.nodes[self.state_id]
                px = int(node.coord[0] * (W - 1))
                py = int(node.coord[1] / (self.depth_limit - 1) * (H - 1))
                half = self.pov_size // 2
                y1, y2 = py - half, py + half + 1
                x1, x2 = px - half, px + half + 1

                # Pad if out of bounds
                pad_top = max(0, -y1)
                pad_bottom = max(0, y2 - H)
                pad_left = max(0, -x1)
                pad_right = max(0, x2 - W)

                crop = full_render[max(0, y1) : min(H, y2), max(0, x1) : min(W, x2)]

                if any([pad_top, pad_bottom, pad_left, pad_right]):
                    crop = np.pad(
                        crop, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), mode="constant", constant_values=0
                    )
                return crop

    # ------------------------------------------------------------------
    # Image-obs rendering ---------------------------------------------
    # ------------------------------------------------------------------
    def _draw_background(self) -> np.ndarray:
        H = W = self.image_size
        img = Image.new("RGB", (W, H), (0, 0, 0))
        draw = ImageDraw.Draw(img)

        # scale helper
        def to_px(x: float, y: float) -> tuple[int, int]:
            return int(x * (W - 1)), int(y / (self.depth_limit - 1) * (H - 1))

        # edges
        for _, node in self.nodes.items():
            x1, y1 = node.coord
            for child_id in node.children:
                x2, y2 = self.nodes[child_id].coord
                draw.line([to_px(x1, y1), to_px(x2, y2)], fill=(32, 32, 32))
            if node.lateral is not None:
                x2, y2 = self.nodes[node.lateral].coord
                draw.line([to_px(x1, y1), to_px(x2, y2)], fill=(32, 32, 32), width=1)

        # nodes as small grey dots
        for node in self.nodes.values():
            px, py = to_px(*node.coord)
            draw.rectangle([px - 1, py - 1, px + 1, py + 1], fill=(64, 64, 64))

        return np.asarray(img, dtype=np.uint8)

    def _draw_agent_on_bg(self) -> np.ndarray:
        assert self._bg is not None, "background not initialised"
        H, W, _ = self._bg.shape
        img = self._bg.copy()
        img = img.copy()
        node = self.nodes[self.state_id]
        px = int(node.coord[0] * (W - 1))
        py = int(node.coord[1] / (self.depth_limit - 1) * (H - 1))
        # draw red 2x2 block
        img[max(py - 1, 0) : min(py + 2, H), max(px - 1, 0) : min(px + 2, W)] = [255, 0, 0]
        return img

    def get_current_depth(self) -> tuple[int, int]:
        """Get the current node's depth as well as the maximum depth."""
        return self.nodes[self.state_id].depth, self.depth_limit - 1

    def render(self, mode="rgb_array"):
        if self.obs_mode in {"image", "pov_image"}:
            return self._draw_agent_on_bg()
        raise NotImplementedError("render only available for image obs")

    def close(self):
        pass


if __name__ == "__main__":
    # TEST
    import matplotlib.pyplot as plt

    env = TreeTeleportEnv(
        depth=4,
        reward_interval=2,
        reward_value=1.0,
        branching=2,
        lateral_prob=0.8,
        obs_mode="pov_image",
        image_size=64,
        pov_size=7,
    )
    obs, _ = env.reset(seed=42)
    img = env.render()

    done = False
    total_reward = 0.0
    i = 0
    while not done:
        plt.imshow(img)
        plt.savefig(f"tree_env_test_{i}_full.png")
        plt.imshow(obs)
        plt.savefig(f"tree_env_test_{i}_pov.png")
        action = env.action_space.sample()
        obs, reward, term, trunc, _ = env.step(action)
        img = env.render()
        i += 1
        if reward > 0:
            print(f"Got reward {reward} at state {obs} (level {env.get_current_depth()})")
        total_reward += reward
        done = term or trunc
    plt.imshow(img)
    plt.savefig("tree_env_test_final_full.png")
    plt.imshow(obs)
    plt.savefig("tree_env_test_final_pov.png")
    print("Episode finished with total reward:", total_reward)
