from typing import Tuple

import jax
import jax.numpy as jnp
from gymnasium import Env
from tqdm import tqdm

from algorithms.utils.replay_buffer import ReplayBuffer


class BaseEnv:
    def __init__(self, env_key: jax.random.PRNGKey, env: Env) -> None:
        self.reset_key = env_key
        self.env = env
        self.observation_space = env.observation_space
        self.action_space = env.action_space
        self.n_actions = self.env.action_space.n

    def reset(self, key=None) -> jnp.ndarray:
        if key is None:
            self.reset_key, key = jax.random.split(self.reset_key)
        self.state, _ = self.env.reset(seed=int(key[0]))
        self.n_steps = 0

        return self.state

    def step(
        self, action: jnp.int8
    ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, dict]:
        self.state, reward, terminated, truncated, _ = self.env.step(int(action))
        self.n_steps += 1
        return self.state, reward, terminated, truncated

    def collect_random_samples(
        self,
        sample_key: jax.random.PRNGKey,
        replay_buffer: ReplayBuffer,
        n_samples: int,
    ) -> None:
        state = self.reset()

        for _ in tqdm(range(n_samples)):
            sample_key, key = jax.random.split(sample_key)
            action = jax.random.choice(key, jnp.arange(self.n_actions))
            next_state, reward, terminated, truncated = self.step(action)

            replay_buffer.add(state, action, next_state, reward, terminated)

            if terminated or truncated:
                state = self.reset()
