import os
os.environ.setdefault("MUJOCO_GL", "egl")
os.environ.setdefault("PYOPENGL_PLATFORM", "egl")


from pathlib import Path
from typing import Iterable, Tuple, List
import numpy as np
from PIL import Image

import metaworld              # ensure registry exists
import gymnasium as gym

from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import torch

from utils.data_utils import TPDataset  # your implementation

# -------- FS helpers --------
def _ensure_dir(p: Path):
    p.mkdir(parents=True, exist_ok=True)

def _save_episode(base_dir: Path, ep_idx: int, frames_np: Iterable[np.ndarray], desc: str):
    traj_dir = base_dir / f"traj_{ep_idx}"
    img_dir = traj_dir / "images0"
    _ensure_dir(img_dir)
    with open(traj_dir / "lang.txt", "w") as f:
        f.write(desc.strip() + "\n")
    for t, frame in enumerate(frames_np):
        Image.fromarray(frame).save(img_dir / f"im_{t:05d}.jpg", format="JPEG", quality=90, optimize=True)

# -------- Expert policy map --------

_POLICY_MAP = {
    # MT10 benchmark tasks
    "reach-v3": "metaworld.policies.sawyer_reach_v3_policy.SawyerReachV3Policy",
    "push-v3": "metaworld.policies.sawyer_push_v3_policy.SawyerPushV3Policy",
    "pick-place-v3": "metaworld.policies.sawyer_pick_place_v3_policy.SawyerPickPlaceV3Policy",
    "door-open-v3": "metaworld.policies.sawyer_door_open_v3_policy.SawyerDoorOpenV3Policy",
    "drawer-open-v3": "metaworld.policies.sawyer_drawer_open_v3_policy.SawyerDrawerOpenV3Policy",
    "drawer-close-v3": "metaworld.policies.sawyer_drawer_close_v3_policy.SawyerDrawerCloseV3Policy",
    "button-press-topdown-v3": "metaworld.policies.sawyer_button_press_topdown_v3_policy.SawyerButtonPressTopdownV3Policy",
    "peg-insert-side-v3": "metaworld.policies.sawyer_peg_insertion_side_v3_policy.SawyerPegInsertionSideV3Policy",
    "window-open-v3": "metaworld.policies.sawyer_window_open_v3_policy.SawyerWindowOpenV3Policy",
    "window-close-v3": "metaworld.policies.sawyer_window_close_v3_policy.SawyerWindowCloseV3Policy", 
}

def _make_expert(env_name: str):
    if env_name not in _POLICY_MAP:
        raise NotImplementedError(f"No expert policy mapped for env '{env_name}'. Add it to _POLICY_MAP.")
    module_path, cls_name = _POLICY_MAP[env_name].rsplit(".", 1)
    mod = __import__(module_path, fromlist=[cls_name])
    return getattr(mod, cls_name)()

# -------- Collection --------
def collect_expert_dataset(env_name: str, desc: str, success_episodes: int, max_steps: int, seed: int, camera_name: str, render_mode: str, out_root: str | Path = "exp_metaworld/data", ):
    out_root = Path(out_root)                      
    base_dir = out_root / f"{env_name}_{success_episodes}"
    _ensure_dir(base_dir)

    env = gym.make("Meta-World/MT1", env_name=env_name, seed=seed, render_mode=render_mode, camera_name=camera_name)
    expert = _make_expert(env_name)

    obs_buf, act_buf, next_buf, rew_buf, done_buf = [], [], [], [], []
    successes, episodes, max_traj_len = 0, 0, 0

    while successes < success_episodes:
        obs, info = env.reset()
        done, steps = False, 0
        frames_ep, traj_rews = [], []
        while not done and steps < max_steps:
            act = expert.get_action(obs)
            next_obs, rew, term, trunc, info = env.step(act)
            success = bool(info.get("success", False))
            done = success or term or trunc
            frames_ep.append(np.flipud(env.render()))
            obs_buf.append(obs.astype(np.float32))
            act_buf.append(np.asarray(act, dtype=np.float32))
            next_buf.append(next_obs.astype(np.float32))
            traj_rews.append(np.float32(rew))
            done_buf.append(bool(done))
            obs = next_obs
            steps += 1
        
        if len(traj_rews) > 0:
            rmin, rmax = min(traj_rews), max(traj_rews)
            if abs(rmax - rmin) < 1e-8:
                normed = [0.0 for _ in traj_rews]
            else:
                normed = [(r - rmin) / (rmax - rmin) for r in traj_rews]
            rew_buf.extend([np.float32(r) for r in normed])

        _save_episode(base_dir, episodes, frames_ep, desc)
        episodes += 1
        if success:
            successes += 1
            max_traj_len = max(max_traj_len, len(frames_ep))
            print(f"[collect:{env_name}] success {successes}/{success_episodes} (steps={steps})")

    env.close()
    (base_dir / "max_traj_len.txt").write_text(str(max_traj_len) + "\n")

    observations = np.asarray(obs_buf, dtype=np.float32)
    actions = np.asarray(act_buf, dtype=np.float32)
    rewards = np.asarray(rew_buf, dtype=np.float32)
    terminals = np.asarray(done_buf, dtype=np.bool_)

    print(f"[collect:{env_name}] wrote episodes under: {base_dir.resolve()}")
    print("[collect] transitions:", len(observations))
    return observations, actions, rewards, terminals, base_dir, max_traj_len

# -------- Loader for reward computation --------
def create_loader(root_dir: Path, split: str, num_workers: int):
    root_dir = Path( "exp_metaworld/data")
    split_path = root_dir / split               # FIX: look under root_dir/split
    max_file = split_path / "max_traj_len.txt"
    if not max_file.exists():
        raise FileNotFoundError(f"{max_file} missing. Did collection finish?")
    max_traj_len = int(max_file.read_text().strip())

    dataset = TPDataset(root_dir=str(root_dir), split=split, traj_len=max_traj_len)
    loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=True)
    return loader, max_traj_len
