"""GPU-friendly ring buffer for 84×84 frame stacks and rewards.

Stores environment rewards (scalar) and auxiliary random reward vectors
used for training V and VPS. `sample` returns (s, s', a, r_env, r_rand,
done) tuples; `sample_nstep` additionally returns n-step sequences for
option-Q training.
"""

import torch, numpy as np


class Memory:
    """Shard-aware ring buffer across one or multiple CUDA devices."""

    # --------------------------------------------------------------------- #
    # 1. Initialization                                                     #
    # --------------------------------------------------------------------- #
    def __init__(
        self,
        max_size: int,
        storage_devices="cuda:0",
        target_device="cuda:0",
        frame_stack_num: int = 1,
    ):
        self.max_size = max_size
        self.full = False
        self.curr = 0  # next write slot
        self.tgt_dev = torch.device(target_device)
        self.F = frame_stack_num

        # Device splitting
        if isinstance(storage_devices, str):
            self.dev_lst = [storage_devices]
        else:
            self.dev_lst = list(storage_devices)
        self.n_dev = len(self.dev_lst)
        self.size_per = self.max_size // self.n_dev

        # Fixed-structure buffers
        self.states, self.env_rewards, self.actions, self.dones = [], [], [], []
        # rand_rewards is created lazily because we need its dim from first write
        self.rand_rewards = None
        self.rand_dim = None

        for dev in self.dev_lst:
            self.states.append(
                torch.zeros((self.size_per, 1, 84, 84), dtype=torch.uint8, device=dev)
            )
            self.env_rewards.append(torch.zeros((self.size_per, 1), device=dev))
            self.actions.append(
                torch.zeros((self.size_per, 1), dtype=torch.int64, device=dev)
            )
            self.dones.append(torch.zeros((self.size_per, 1), dtype=torch.uint8, device=dev))

    # -------- private: global idx ↔ shard idx ---------------------------- #
    def _dev_id(self, idx: int):
        return idx // self.size_per

    def _local_idx(self, idx: int):
        return idx % self.size_per

    def size(self):
        return self.max_size if self.full else self.curr

    # --------------------------------------------------------------------- #
    # 2.5  Random-reward normalization                                       #
    # --------------------------------------------------------------------- #
    @torch.no_grad()
    def normalize_rand_rewards_(self, eps: float = 1e-6):
        """In-place standardization of stored random reward vectors.

        After data collection, it is often beneficial to normalize the random
        reward vectors to have zero mean and unit variance (per reward dim)
        across the entire replay buffer. This reduces large constant offsets
        in the learned value functions when γ≈1.

        Returns:
            (mean, std): two 1-D CPU tensors of shape (k,)
        """
        if self.rand_rewards is None or self.rand_dim is None:
            raise RuntimeError("rand_rewards not initialized; call store() at least once.")

        size = int(self.size())
        if size <= 0:
            raise RuntimeError("buffer is empty; cannot normalize.")

        # Concatenate shards onto target device (buffer sizes here are typically modest).
        rr = torch.cat([x.to(self.tgt_dev, non_blocking=True) for x in self.rand_rewards], 0)
        rr = rr[:size]  # handle non-full buffers safely

        mean = rr.mean(0)  # (k,)
        var = rr.var(0, unbiased=False)  # (k,)
        std = torch.sqrt(var).clamp_min(eps)

        # Apply normalization shard-wise in-place
        for shard in self.rand_rewards:
            dev = shard.device
            shard.sub_(mean.to(dev)).div_(std.to(dev))

        return mean.detach().cpu(), std.detach().cpu()

    # --------------------------------------------------------------------- #
    # 2. Insertion                                                          #
    # --------------------------------------------------------------------- #
    @torch.no_grad()
    def store(
        self,
        frame,  # (1,84,84) uint8 / float32
        action,  # int
        env_reward,  # float / int
        rand_reward_vec,  # 1-D sequence / None
        terminal,  # 0 / 1
    ):
        """Insert one transition; lazily allocate random-reward storage.

        rand_reward_vec:
          - Arbitrary length; the first insert fixes its dimension (k_rand).
          - All subsequent inserts must use the same length.
          - If None, a zero vector of shape (k_rand,) is used.
        """
        # ------ First encounter of random reward: create the buffer --------
        if self.rand_dim is None:
            if rand_reward_vec is None:
                self.rand_dim = 1
                rand_reward_vec = [0.0]
            else:
                self.rand_dim = int(np.asarray(rand_reward_vec).size)

            # Allocate rand_rewards for each device shard
            self.rand_rewards = []
            for dev in self.dev_lst:
                self.rand_rewards.append(
                    torch.zeros(
                        (self.size_per, self.rand_dim), dtype=torch.float32, device=dev
                    )
                )

        # Ensure consistent input dimension
        if rand_reward_vec is None:
            rand_reward_vec = np.zeros(self.rand_dim, dtype=np.float32)
        else:
            rand_reward_vec = np.asarray(rand_reward_vec, dtype=np.float32).reshape(-1)
            assert (
                rand_reward_vec.size == self.rand_dim
            ), f"rand_reward dim mismatch: expect {self.rand_dim}, got {rand_reward_vec.size}"

        # ------ Write ------------------------------------------------------
        if frame.dtype != np.uint8:
            frame = (frame * 255).astype(np.uint8)

        did = self._dev_id(self.curr)
        loc = self._local_idx(self.curr)
        dev = self.dev_lst[did]

        self.states[did][loc] = torch.from_numpy(frame).to(dev, non_blocking=True)
        self.actions[did][loc] = int(action)
        self.env_rewards[did][loc] = float(env_reward)
        self.rand_rewards[did][loc] = torch.from_numpy(rand_reward_vec).to(dev)
        self.dones[did][loc] = int(terminal)

        self.full = self.full or (self.curr + 1 >= self.max_size)
        self.curr = (self.curr + 1) % self.max_size

    # --------------------------------------------------------------------- #
    # 3. 1-step sampling                                                    #
    # --------------------------------------------------------------------- #
    @torch.no_grad()
    def sample(self, N: int):
        tail = self.F - 1  # =3 when F=4
        size = self.size()
        assert size >= tail + 2, "buffer too small"

        # Concatenate shards
        states = torch.cat([x.to(self.tgt_dev, non_blocking=True) for x in self.states], 0)
        env_r = torch.cat(
            [x.to(self.tgt_dev, non_blocking=True) for x in self.env_rewards], 0
        )
        rand_r = torch.cat(
            [x.to(self.tgt_dev, non_blocking=True) for x in self.rand_rewards], 0
        )
        acts = torch.cat([x.to(self.tgt_dev, non_blocking=True) for x in self.actions], 0)
        dones = torch.cat(
            [x.to(self.tgt_dev, non_blocking=True) for x in self.dones], 0
        ).squeeze(1)

        win = torch.arange(tail, -1, -1, device=self.tgt_dev)
        idx = torch.randint(tail, size - 1, (N * 2,), device=self.tgt_dev)

        ok = (dones[idx.unsqueeze(1) - win[:-1]].sum(1) == 0)
        idx = idx[ok][:N]
        while idx.numel() < N:
            extra = torch.randint(tail, size - 1, (N // 2,), device=self.tgt_dev)
            idx = torch.cat(
                [
                    idx,
                    extra[(dones[extra.unsqueeze(1) - win[:-1]].sum(1) == 0)],
                ]
            )[:N]

        frame_idx = idx.unsqueeze(1) - win  # (N,4)
        next_idx = frame_idx + 1

        s = states[frame_idx].squeeze(2).float().div_(255.0)
        sp = states[next_idx].squeeze(2).float().div_(255.0)

        a = acts[idx].long()
        r_env = env_r[idx]
        r_rand = rand_r[idx]
        d = dones[idx].unsqueeze(1)

        return s, sp, a, r_env, r_rand, d

    # --------------------------------------------------------------------- #
    # 4. n-step sampling                                                    #
    # --------------------------------------------------------------------- #
    @torch.no_grad()
    def sample_nstep(self, N: int, n_step: int):
        """Return n-step sequences for option-Q training.

        Returns:
            S_seq, A_seq, R_env_seq, R_rand_seq, D_seq, S_n
        """
        assert n_step >= 2
        tail = self.F - 1
        size = self.size()
        assert size >= tail + n_step, "buffer too small"

        # Gather
        states = torch.cat([x.to(self.tgt_dev, non_blocking=True) for x in self.states], 0)
        env_r = torch.cat(
            [x.to(self.tgt_dev, non_blocking=True) for x in self.env_rewards], 0
        )
        rand_r = torch.cat(
            [x.to(self.tgt_dev, non_blocking=True) for x in self.rand_rewards], 0
        )
        acts = torch.cat([x.to(self.tgt_dev, non_blocking=True) for x in self.actions], 0)
        dones = torch.cat(
            [x.to(self.tgt_dev, non_blocking=True) for x in self.dones], 0
        ).squeeze(1)

        win = torch.arange(tail, -1, -1, device=self.tgt_dev)
        step = torch.arange(0, n_step, device=self.tgt_dev)
        front = torch.arange(0, n_step - 2, device=self.tgt_dev)

        idx = torch.randint(tail, size - n_step, (N * 2,), device=self.tgt_dev)

        def valid(idxs):
            ok_hist = (dones[idxs.unsqueeze(1) - win[:-1]].sum(1) == 0)
            ok_traj = ok_hist if n_step <= 2 else (
                dones[idxs.unsqueeze(1) + front].sum(1) == 0
            )
            return ok_hist & ok_traj

        idx = idx[valid(idx)][:N]
        while idx.numel() < N:
            extra = torch.randint(tail, size - n_step, (N // 2,), device=self.tgt_dev)
            idx = torch.cat([idx, extra[valid(extra)]])[:N]

        base_idx = idx.unsqueeze(1) + step
        frame_idx = base_idx.unsqueeze(2) - win
        frame_idxN = (idx + n_step).unsqueeze(1) - win

        S_seq = states[frame_idx].squeeze(3).float().div_(255.0)
        S_n = states[frame_idxN].squeeze(2).float().div_(255.0)
        A_seq = acts[base_idx].long()
        R_env_seq = env_r[base_idx]
        R_rand_seq = rand_r[base_idx]
        D_seq = dones[base_idx].unsqueeze(-1)

        return (
            S_seq.contiguous(),
            A_seq,
            R_env_seq,
            R_rand_seq,
            D_seq,
            S_n.contiguous(),
        )
