# utils.py
import os
from typing import Optional, Dict, Any, Tuple, List

import gymnasium as gym
import torch
import torch.nn as nn
import numpy as np
import minigrid
from minigrid.wrappers import FullyObsWrapper
from minigrid.core.world_object import Door as MGDoor
from minigrid.core.world_object import Wall as MGWall
from gymnasium.wrappers import TimeLimit

HIDDEN_DIM = 256
MAX_STEPS_PER_EPISODE = 100

GAME_CODE_TO_ENV_ID = {
    "pickup": "BabyAI-Pickup-v0",
    "goto":   "BabyAI-GoToLocal-v0",
    "synthseq": "BabyAI-SynthSeq-v0",
    "bosslevel": "BabyAI-BossLevel-v0",
}

def get_model_paths(game_code: str, root: str = "tasks/minigrid") -> Dict[str, str]:
    """
    Returns standard checkpoint paths for a given game_code.
    Example: tasks/minigrid/open/pretrain_policy.pt
    """
    sub = os.path.join(root, game_code)
    os.makedirs(sub, exist_ok=True)
    return {
        "pretrain": os.path.join(sub, "pretrain_policy.pt"),
        "finetune": os.path.join(sub, "ppo_finetune_policy.pt"),
        "rollouts": os.path.join(sub, "rollouts"),
    }

def make_env(
    game_code: str,
    seed: Optional[int] = None,
    full_obs: bool = False,
    max_steps: int = MAX_STEPS_PER_EPISODE,
) -> gym.Env:
    """
    Construct a BabyAI/MiniGrid env by short code (e.g., 'open').
    Set full_obs=True to use FullyObsWrapper (for Minari full-obs datasets).
    """
    if game_code not in GAME_CODE_TO_ENV_ID:
        raise ValueError(f"Unknown game_code '{game_code}'. Known: {list(GAME_CODE_TO_ENV_ID.keys())}")
    env_id = GAME_CODE_TO_ENV_ID[game_code]
    env = gym.make(env_id)

    if full_obs:
        env = FullyObsWrapper(env)

    if max_steps is not None and max_steps > 0:
        env = TimeLimit(env, max_episode_steps=max_steps)

    if seed is not None:
        try:
            env.reset(seed=seed)
        except TypeError:
            env.seed(seed)
        np.random.seed(seed % (2**32 - 1))
        torch.manual_seed(seed % (2**32 - 1))

    return env

def get_action_dim(env: gym.Env) -> int:
    return env.action_space.n

def has_text(obs: Dict[str, Any]) -> bool:
    return "mission_tokens" in obs or "mission" in obs

def default_tokenize(mission_str: str, max_len: int = 32) -> torch.LongTensor:
    """
    Super-simple whitespace tokenizer with a tiny dynamic vocab.
    For production, replace with your BabyAI vocab + tokenizer.
    """
    if not hasattr(default_tokenize, "_vocab"):
        default_tokenize._vocab = {"<pad>": 0, "<unk>": 1}
    vocab = default_tokenize._vocab

    tokens = mission_str.lower().strip().split()
    ids = []
    for t in tokens[:max_len]:
        if t not in vocab:
            vocab[t] = len(vocab)
        ids.append(vocab[t])
    if len(ids) < max_len:
        ids += [vocab["<pad>"]] * (max_len - len(ids))
    return torch.LongTensor(ids).unsqueeze(0)  # (1, L)

def _pad_or_trunc_1d(x: torch.LongTensor, L: int) -> torch.LongTensor:
    x = x.long()
    if x.numel() >= L:
        return x[:L]
    out = torch.zeros(L, dtype=torch.long, device=x.device)  # PAD=0
    out[:x.numel()] = x
    return out


def prepare_obs(
    obs: Dict[str, Any],
    device: torch.device,
    use_text: bool = False,
    text_tokens_key: str = "mission_tokens",
    max_text_len: int = 32,
) -> Dict[str, torch.Tensor]:
    """
    Convert env observation dict to model input tensors (batched=1 by default).
    - Keeps categorical IDs as integers (no normalization).
    - If use_text and tokens not present, builds them from 'mission'.
    """
    img = torch.as_tensor(obs["image"], dtype=torch.long, device=device)  
    direction = torch.as_tensor(obs["direction"], dtype=torch.long, device=device)  
    out = {
        "image": img.unsqueeze(0),         
        "direction": direction.unsqueeze(0)  
    }

    if use_text:
        if text_tokens_key in obs:
            toks = torch.as_tensor(obs[text_tokens_key], dtype=torch.long, device=device)
            if toks.ndim == 1:
                toks = toks.unsqueeze(0)   
            toks = torch.stack([_pad_or_trunc_1d(row, max_text_len) for row in toks], dim=0)
        else:
            toks = default_tokenize(obs.get("mission", ""), max_len=max_text_len).to(device)
        out["mission_tokens"] = toks  
    return out

def cluster_seq_diversity(
    trajectories: List[List[Tuple[int, int]]],
    grid_size: Tuple[int, int] = (19, 19),
    centers: Optional[List[Tuple[float, float]]] = None,
    distance: str = "manhattan",
    compress_repeats: bool = True,
    return_sequences: bool = False,     
    grid_dims: Tuple[int, int] = (3, 3),
) -> Any:
    """
    Compute diversity of cluster *sets* for MiniGrid trajectories.

    This version considers two trajectories different only if the set of rooms
    they visit is different. The order of visitation does not matter.

    Args:
        trajectories: list of n trajectories, each a list of (x, y) integer positions.
        grid_size: (width, height) in grid cells.
        centers: optional list of (x, y) centers. If None, evenly space centers in a grid.
        distance: "manhattan" or "euclidean" to assign a position to its nearest center.
        compress_repeats: if True, collapse consecutive duplicates in the room sequence.
        return_sequences: if True, also return the list of per-trajectory sets (as frozensets).
        grid_dims: (nx, ny) number of room partitions.

    Returns:
        diversity: float in [0, 1], normalized unique set count (0 if only one unique set).
        (optionally) room_sequences: list of frozensets of room IDs per trajectory.
    """
    if not trajectories:
        return (0.0, []) if return_sequences else 0.0

    w, h = grid_size
    nx, ny = grid_dims
    n_rooms = nx * ny

    def _default_centers() -> List[Tuple[float, float]]:
        xs = [ (2*i + 1) * w / (2*nx) for i in range(nx) ]
        ys = [ (2*j + 1) * h / (2*ny) for j in range(ny) ]
        return [(x, y) for j, y in enumerate(ys) for i, x in enumerate(xs)]

    if centers is None:
        centers = _default_centers()
    else:
        if len(centers) != n_rooms:
            raise ValueError(f"centers must have length {n_rooms} for grid_dims={grid_dims}, got {len(centers)}.")

    centers_arr = np.array(centers, dtype=float)

    def nearest_center_id(x: int, y: int) -> int:
        p = np.array([x, y], dtype=float)
        if distance == "manhattan":
            d = np.abs(centers_arr - p).sum(axis=1)
        elif distance == "euclidean":
            d = np.linalg.norm(centers_arr - p, axis=1)
        else:
            raise ValueError("distance must be 'manhattan' or 'euclidean'")
        return int(np.argmin(d))

    def _compress(seq: List[int]) -> List[int]:
        if not seq: return []
        out = [seq[0]]
        for z in seq[1:]:
            if z != out[-1]: out.append(z)
        return out

    sequences: List[Tuple[int, ...]] = []
    for traj in trajectories:
        if not traj:
            sequences.append(tuple())
            continue
        room_seq = [nearest_center_id(int(x), int(y)) for (x, y) in traj[:-1]]
        if compress_repeats:
            room_seq = _compress(room_seq)
        sequences.append(tuple(room_seq))

    room_sequences = [frozenset(s) for s in sequences]
    unique_count = len(set(room_sequences))

    n = len(trajectories)
    if n == 0 or unique_count <= 1:
        diversity = 0.0
    else:
        diversity = unique_count / float(n)

    if return_sequences:
        return diversity, room_sequences
    return diversity

class BabyAI_BC(nn.Module):
    """
    BC policy for BabyAI-style dict observations.
    Image: (H,W,3) with channels = (object, color, state) as categorical IDs.
    Direction: Discrete(4). Text is optional (mission_tokens).
    Output: logits over actions.
    """
    def __init__(self, n_actions: int, use_text: bool = False, vocab_size: int = 200):
        super().__init__()
        self.obj_emb = nn.Embedding(32, 8)
        self.col_emb = nn.Embedding(8, 4)
        self.sta_emb = nn.Embedding(8, 4)
        self.dir_emb = nn.Embedding(4, 8)

        self.conv = nn.Sequential(
            nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1), nn.ReLU(),
        )

        self.use_text = use_text
        if use_text:
            self.tok_emb = nn.Embedding(vocab_size, 32, padding_idx=0)
            self.gru = nn.GRU(32, 128, batch_first=True)

        in_dim = 64 + 8 + (128 if use_text else 0)
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, HIDDEN_DIM),
            nn.ReLU(),
            nn.Dropout(0.1),
        )
        self.pi = nn.Linear(HIDDEN_DIM, n_actions)

    def encode_img(self, img: torch.Tensor) -> torch.Tensor:
        """
        img: (B,H,W,3) int/long
        returns: (B,64) pooled conv features
        """
        o = self.obj_emb(img[..., 0].long())
        c = self.col_emb(img[..., 1].long())
        s = self.sta_emb(img[..., 2].long())
        x = torch.cat([o, c, s], dim=-1).permute(0, 3, 1, 2) 
        x = self.conv(x).mean(dim=(-2, -1))                   
        return x

    def forward(self, obs: Dict[str, torch.Tensor]) -> torch.Tensor:
        """
        obs keys: 'image' (B,H,W,3), 'direction' (B,), optional 'mission_tokens' (B,L)
        returns: logits (B, n_actions)
        """
        x = self.encode_img(obs["image"])
        d = self.dir_emb(obs["direction"].long())
        feats = [x, d]
        if self.use_text:
            emb = self.tok_emb(obs["mission_tokens"].long())
            _, h = self.gru(emb) 
            feats.append(h.squeeze(0))
        h = self.mlp(torch.cat(feats, dim=-1))
        return self.pi(h)

    def features(self, obs: Dict[str, torch.Tensor]) -> torch.Tensor:
        """
        Run encoder up to (and including) self.mlp. Output shape: (B, HIDDEN_DIM).
        """
        x = self.encode_img(obs["image"])               
        d = self.dir_emb(obs["direction"].long())       
        feats = [x, d]
        if self.use_text:
            emb = self.tok_emb(obs["mission_tokens"].long())  
            _, h = self.gru(emb)                              
            feats.append(h.squeeze(0))
        h = self.mlp(torch.cat(feats, dim=-1))       
        return h

def build_env_and_model(
    game_code: str,
    device: Optional[torch.device] = None,
    seed: Optional[int] = None,
    full_obs: bool = False,
    max_steps: int = MAX_STEPS_PER_EPISODE,
    use_text: bool = False,
) -> Tuple[gym.Env, BabyAI_BC]:
    """
    One-liner to create the env and a matching policy.
    """
    env = make_env(game_code, seed=seed, full_obs=full_obs, max_steps=max_steps)
    n_actions = get_action_dim(env)
    model = BabyAI_BC(n_actions=n_actions, use_text=use_text)
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    return env, model