from typing import Union, Optional
from functools import partial
from random import randrange

import gymnasium as gym
import torch
import numpy as np

PRIORITIZED_EPS = 1e-6


class VecTrajectoryBuffer:

    def __init__(
        self,
        env: gym.Env,
        size: int,
        n_step: int,
        device: Union[torch.device, str] = "cpu"
    ) -> None:

        self.size = size
        self.n_step = n_step
        self.num_envs = env.num_envs
        self.device = device
        self.total_frames = 0

        self.trajectories = []

        _obs_space = env.single_observation_space
        self.observations = np.zeros((n_step, self.num_envs,) + _obs_space.shape, dtype=_obs_space.dtype)
        self.actions = np.zeros((n_step, self.num_envs,), dtype=np.long)
        self.rewards = np.zeros((n_step, self.num_envs,), dtype=np.float32)

        self.storage_device = device if self.observations.nbytes < 1e10 else "cpu"

        self.dummy = torch.as_tensor(_obs_space.low, device=self.storage_device)

        self.idx = [0] * self.num_envs

    def add(self, obs, act, rew, done, next_obs):

        for i in range(self.num_envs):
            idx = self.idx[i]
            self.observations[idx, i] = obs[i]
            self.actions[idx, i] = act[i]
            self.rewards[idx, i] = rew[i]

            self.idx[i] += 1
            if done[i] or self.idx[i] == self.n_step:
                self.finalize(i, done[i], next_obs[i])
                self.idx[i] = 0

    def finalize(self, env_idx: int, done: bool, next_obs: Optional[np.ndarray]):

        end_idx = self.idx[env_idx]

        while self.total_frames + end_idx > self.size:
            _t = self.trajectories.pop(0)
            self.total_frames -= len(_t[0])

        self.trajectories.append([
                torch.tensor(self.observations[:end_idx, env_idx], device=self.storage_device),
                torch.tensor(self.actions[:end_idx, env_idx], device=self.storage_device),
                torch.tensor(self.rewards[:end_idx, env_idx], device=self.storage_device),
                done,
                self.dummy if done else torch.tensor(next_obs, device=self.storage_device),
        ])
        self.total_frames += end_idx

    def sample(self, batch_size: int):

        _obs = []
        _act = []
        _rew = []
        splits = []
        dones = []
        next_states = []

        n_trajs = len(self.trajectories)
        for t_id in np.random.choice(n_trajs, size=min(batch_size // self.n_step, n_trajs), replace=False):

            _o, _a, _r, _d, _n = self.trajectories[t_id]

            _obs.append(_o)
            _act.append(_a)
            _rew.append(_r)
            dones.append(_d)
            next_states.append(_n)
            splits.append(len(_o))

        trajs = [_obs, _act, _rew]
        return ([torch.cat(tensors).to(device=self.device) for tensors in trajs],
                dones, torch.stack(next_states).to(device=self.device),
                splits)
