"""
Common Atari wrappers shared by training and evaluation scripts.
"""

from __future__ import annotations

import gymnasium as gym
import numpy as np


class AtariMaxPoolWrapper(gym.Wrapper):
    """
    Atari env wrapper for flicker reduction via max-pooling over the last two raw frames
    during action repeats.

    Notes:
    - Use with an underlying ALE env created with frameskip=1
    - This wrapper repeats actions `frame_skip` times and max-pools the last two frames.
    """

    def __init__(self, env: gym.Env, *, frame_skip: int = 5):
        super().__init__(env)
        self.frame_skip = int(frame_skip)
        if self.frame_skip < 1:
            raise ValueError(f"frame_skip must be >= 1, got {self.frame_skip}")

    def step(self, action):
        total_r = 0.0
        term = False
        trunc = False
        info = {}
        obs = None

        obs_buf = []
        for _ in range(self.frame_skip):
            obs, r, term, trunc, info = self.env.step(action)
            total_r += float(r)
            if obs is not None:
                obs_buf.append(obs)
                if len(obs_buf) > 2:
                    obs_buf.pop(0)
            if term or trunc:
                break

        if len(obs_buf) == 2:
            pooled = np.maximum(obs_buf[0], obs_buf[1])
        elif len(obs_buf) == 1:
            pooled = obs_buf[0]
        else:
            pooled = obs

        return pooled, total_r, term, trunc, info


class NoopResetEnv(gym.Wrapper):
    """
    On every reset, perform a random number of NOOP actions in [0, noop_max].

    This is the common DQN-style random-start protocol (random no-ops on reset) to
    inject stochasticity and avoid overfitting to a fixed initial state.
    """

    def __init__(self, env: gym.Env, *, noop_max: int = 30):
        super().__init__(env)
        self.noop_max = int(noop_max)
        if self.noop_max < 0:
            raise ValueError(f"noop_max must be >= 0, got {self.noop_max}")

        # In full action space, action 0 is expected to be NOOP.
        if not hasattr(self.env.unwrapped, "get_action_meanings"):
            raise TypeError(
                "Env does not expose `get_action_meanings()`. Cannot strictly verify the NOOP action."
            )
        meanings = list(self.env.unwrapped.get_action_meanings())
        if len(meanings) == 0 or meanings[0] != "NOOP":
            raise ValueError(
                f"Expected action 0 to be 'NOOP', got meanings[0]={meanings[0] if meanings else None}. "
                f"Action meanings={meanings}"
            )

        self.noop_action = 0

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)

        if self.noop_max == 0:
            return obs, info

        # Prefer the env's RNG (seeded via env.reset(seed=...)) if available.
        rng = getattr(self.env.unwrapped, "np_random", None)
        if rng is None:
            n_noops = int(np.random.randint(0, self.noop_max + 1))
        else:
            n_noops = int(rng.integers(0, self.noop_max + 1))

        for _ in range(n_noops):
            step_result = self.env.step(self.noop_action)
            if len(step_result) == 5:
                obs, _reward, terminated, truncated, _info = step_result
                done = bool(terminated or truncated)
            else:
                obs, _reward, done, _info = step_result
                done = bool(done)

            if done:
                obs, info = self.env.reset(**kwargs)

        return obs, info

