import torch.multiprocessing as mp
import gym
import numpy as np

# -----------------------------------------------------------------------------
# Single worker process --------------------------------------------------------
# -----------------------------------------------------------------------------

def _worker(remote, seed, env_kwargs):
    """Core loop for a subprocess that hosts a *WebAgentTextEnv* instance.

    Commands sent from the main process are *(cmd, data)* tuples:

    - **'step'** *(str)*  → returns ``(obs, reward, done, info)`` where
      ``info['available_actions']`` has already been populated *after* the step.
    - **'reset'** *(int | None)* → returns ``(obs, info)`` with the same
      ``available_actions`` field (obtained immediately after reset).
    - **'render'** *(str)* → returns the value of ``env.render(mode)``.
    - **'available_actions'** *(None)* → returns the list from
      ``env.get_available_actions()``.
    - **'close'** → terminates the subprocess.
    """
    # Lazy import avoids CUDA initialisation issues under ``spawn``.
    import sys
    import os
    project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), 'webshop'))
    sys.path.append(project_root)
    from web_agent_site.envs import WebAgentTextEnv  # noqa: WPS433 (runtime import)
    # env_kwargs['seed'] = seed
    env = gym.make('WebAgentTextEnv-v0', **env_kwargs)

    try:
        while True:
            cmd, data = remote.recv()

            # -----------------------------------------------------------------
            # Environment interaction commands
            # -----------------------------------------------------------------
            if cmd == 'step':
                action = data
                obs, reward, done, info = env.step(action)
                info = dict(info or {})  # make a *copy* so we can mutate safely
                info['available_actions'] = env.get_available_actions()
                info['task_score'] = reward

                # Redefine reward. We only use rule-based reward - win for 10, lose for 0.
                if done and reward == 1.0:
                    info['won'] = True
                    reward = 10.0
                else:
                    info['won'] = False
                    reward = 0

                remote.send((obs, reward, done, info))

            elif cmd == 'reset':
                idx = data
                obs, info = env.reset(session=idx)
                info = dict(info or {})
                info['available_actions'] = env.get_available_actions()
                info['won'] = False
                remote.send((obs, info))

            elif cmd == 'render':
                mode_for_render = data
                rendered = env.render(mode=mode_for_render)
                remote.send(rendered)

            elif cmd == 'available_actions':
                remote.send(env.get_available_actions())
            elif cmd == 'get_goals':
                remote.send(env.server.goals)
            # -----------------------------------------------------------------
            # Book‑keeping
            # -----------------------------------------------------------------
            elif cmd == 'close':
                remote.close()
                break

            else:  # pragma: no cover – helps catch typos early
                raise NotImplementedError(f"Unknown command sent to worker: {cmd}")

    finally:  # Ensure the underlying environment *always* shuts down cleanly
        env.close()


# -----------------------------------------------------------------------------
# Vectorised multi‑process environment -----------------------------------------
# -----------------------------------------------------------------------------

class WebshopMultiProcessEnv(gym.Env):
    """A vectorised, multi‑process wrapper around *WebAgentTextEnv*.

    ``info`` dictionaries returned by :py:meth:`step` **and** :py:meth:`reset`
    automatically contain the key ``'available_actions'`` so downstream RL code
    can obtain the *legal* action set without extra IPC overhead.
    """
    def __init__(
        self,
        seed: int = 0,
        env_num: int = 1,
        group_n: int = 1,
        is_train: bool = True,
        env_kwargs: dict = None,
    ) -> None:
        super().__init__()

        self.group_n = group_n
        self.env_num = env_num
        self.num_processes = env_num * group_n
        self.is_train = is_train

        self._rng = np.random.RandomState(seed)

        self._env_kwargs = env_kwargs if env_kwargs is not None else {'observation_mode': 'text', 'num_products': None}

        # -------------------------- Multiprocessing setup --------------------
        self._parent_remotes: list[mp.connection.Connection] = []
        self._workers: list[mp.Process] = []

        ctx = mp.get_context('spawn')

        for i in range(self.num_processes):
            parent_remote, child_remote = mp.Pipe()
            worker = ctx.Process(
                target=_worker,
                args=(child_remote, seed + (i // self.group_n), self._env_kwargs),
            )
            worker.daemon = True  # auto‑kill if the main process crashes
            worker.start()
            child_remote.close()

            self._parent_remotes.append(parent_remote)
            self._workers.append(worker)


        self._parent_remotes[0].send(('get_goals', None))
        goals = self._parent_remotes[0].recv()

        # ------- original ----------#
        # if args.num is None:
        #     if split == 'test':
        #         self.goal_idxs = range(500)
        #     elif split == 'eval':
        #         self.goal_idxs = range(500, 1500)
        #     elif split == 'train':
        #         self.goal_idxs = range(1500, len(self.env.server.goals))
        # else:
        #     self.goal_idxs = range(len(self.env.server.goals))

        if not self.is_train:
            self.goal_idxs = range(500)
        else:
            self.goal_idxs = range(500, len(goals))
            
        print(self.goal_idxs)

    # ------------------------------------------------------------------
    # Base API ----------------------------------------------------------
    # ------------------------------------------------------------------

    def step(self, actions: list[str]):
        if len(actions) != self.num_processes:
            raise ValueError(
                f'Expected {self.num_processes} actions, got {len(actions)}',
            )

        for remote, action in zip(self._parent_remotes, actions):
            remote.send(('step', action))

        obs_list, reward_list, done_list, info_list = [], [], [], []
        for remote in self._parent_remotes:
            obs, reward, done, info = remote.recv()
            obs_list.append(obs)
            reward_list.append(reward)
            done_list.append(done)
            info_list.append(info)

        return obs_list, reward_list, done_list, info_list

    def reset(self):
        idx = self._rng.choice(self.goal_idxs, size=self.env_num, replace=False)
        idx = np.repeat(idx, self.group_n).tolist()

        for remote, i in zip(self._parent_remotes, idx):
            remote.send(('reset', i))

        obs_list, info_list = [], []
        for remote in self._parent_remotes:
            obs, info = remote.recv()
            obs_list.append(obs)
            info_list.append(info)

        return obs_list, info_list

    # ------------------------------------------------------------------
    # Convenience helpers ----------------------------------------------
    # ------------------------------------------------------------------

    def render(self, mode: str = 'text', env_idx: int = None):
        if env_idx is not None:
            self._parent_remotes[env_idx].send(('render', mode))
            return self._parent_remotes[env_idx].recv()

        for remote in self._parent_remotes:
            remote.send(('render', mode))
        return [remote.recv() for remote in self._parent_remotes]

    # ------------------------------------------------------------------
    # Clean‑up ----------------------------------------------------------
    # ------------------------------------------------------------------

    def close(self):
        if getattr(self, '_closed', False):
            return

        for remote in self._parent_remotes:
            remote.send(('close', None))
        for worker in self._workers:
            worker.join()
        self._closed = True

    def __del__(self):  # noqa: D401
        self.close()


# -----------------------------------------------------------------------------
# Factory helper --------------------------------------------------------------
# -----------------------------------------------------------------------------

def build_webshop_envs(
    seed: int = 0,
    env_num: int = 1,
    group_n: int = 1,
    is_train: bool = True,
    env_kwargs: dict = None,
):
    """Mirror *build_sokoban_envs* so higher‑level code can swap seamlessly."""
    return WebshopMultiProcessEnv(
        seed=seed,
        env_num=env_num,
        group_n=group_n,
        is_train=is_train,
        env_kwargs=env_kwargs,
    )

