"""GPU-friendly ring buffer for 84×84 frame stacks and rewards.

Stores environment rewards (scalar) and auxiliary random reward vectors
used for training V and VPS. `sample` returns (s, s', a, r_env, r_rand,
done) tuples; `sample_nstep` additionally returns n-step sequences for
option-Q training.
"""

import torch, numpy as np


class Memory:
    """Shard-aware ring buffer across one or multiple CUDA devices."""

    # --------------------------------------------------------------------- #
    # 1. Initialization                                                     #
    # --------------------------------------------------------------------- #
    def __init__(
        self,
        max_size: int,
        storage_devices="cuda:0",
        target_device="cuda:0",
        frame_stack_num: int = 1,
    ):
        self.max_size = max_size
        self.full = False
        self.curr = 0  # next write slot
        self.tgt_dev = torch.device(target_device)
        self.F = frame_stack_num

        # Device splitting
        if isinstance(storage_devices, str):
            self.dev_lst = [storage_devices]
        else:
            self.dev_lst = list(storage_devices)
        self.n_dev = len(self.dev_lst)
        self.size_per = self.max_size // self.n_dev

        # Fixed-structure buffers
        self.states, self.env_rewards, self.actions, self.dones = [], [], [], []
        # rand_rewards is created lazily because we need its dim from first write
        self.rand_rewards = None
        self.rand_dim = None

        for dev in self.dev_lst:
            self.states.append(
                torch.zeros((self.size_per, 1, 84, 84), dtype=torch.uint8, device=dev)
            )
            self.env_rewards.append(torch.zeros((self.size_per, 1), device=dev))
            self.actions.append(
                torch.zeros((self.size_per, 1), dtype=torch.int64, device=dev)
            )
            self.dones.append(torch.zeros((self.size_per, 1), dtype=torch.uint8, device=dev))

    # -------- private: global idx ↔ shard idx ---------------------------- #
    def _dev_id(self, idx: int):
        return idx // self.size_per

    def _local_idx(self, idx: int):
        return idx % self.size_per

    def size(self):
        return self.max_size if self.full else self.curr

    # --------------------------------------------------------------------- #
    # 2. Insertion                                                          #
    # --------------------------------------------------------------------- #
    @torch.no_grad()
    def store(
        self,
        frame,  # (1,84,84) uint8 / float32
        action,  # int
        env_reward,  # float / int
        rand_reward_vec,  # 1-D sequence / None
        terminal,  # 0 / 1
    ):
        """Insert one transition; lazily allocate random-reward storage.

        rand_reward_vec:
          - Arbitrary length; the first insert fixes its dimension (k_rand).
          - All subsequent inserts must use the same length.
          - If None, a zero vector of shape (k_rand,) is used.
        """
        # ------ First encounter of random reward: create the buffer --------
        if self.rand_dim is None:
            if rand_reward_vec is None:
                self.rand_dim = 1
                rand_reward_vec = [0.0]
            else:
                self.rand_dim = int(np.asarray(rand_reward_vec).size)

            # Allocate rand_rewards for each device shard
            self.rand_rewards = []
            for dev in self.dev_lst:
                self.rand_rewards.append(
                    torch.zeros(
                        (self.size_per, self.rand_dim), dtype=torch.float32, device=dev
                    )
                )

        # Ensure consistent input dimension
        if rand_reward_vec is None:
            rand_reward_vec = np.zeros(self.rand_dim, dtype=np.float32)
        else:
            rand_reward_vec = np.asarray(rand_reward_vec, dtype=np.float32).reshape(-1)
            assert (
                rand_reward_vec.size == self.rand_dim
            ), f"rand_reward dim mismatch: expect {self.rand_dim}, got {rand_reward_vec.size}"

        # ------ Write ------------------------------------------------------
        if frame.dtype != np.uint8:
            frame = (frame * 255).astype(np.uint8)

        did = self._dev_id(self.curr)
        loc = self._local_idx(self.curr)
        dev = self.dev_lst[did]

        self.states[did][loc] = torch.from_numpy(frame).to(dev, non_blocking=True)
        self.actions[did][loc] = int(action)
        self.env_rewards[did][loc] = float(env_reward)
        self.rand_rewards[did][loc] = torch.from_numpy(rand_reward_vec).to(dev)
        self.dones[did][loc] = int(terminal)

        self.full = self.full or (self.curr + 1 >= self.max_size)
        self.curr = (self.curr + 1) % self.max_size

    # --------------------------------------------------------------------- #
    # 3. 1-step sampling                                                    #
    # --------------------------------------------------------------------- #
    @torch.no_grad()
    def sample(self, N: int):
        tail = self.F - 1  # =3 when F=4
        size = self.size()
        assert size >= tail + 2, "buffer too small"

        # Concatenate shards
        states = torch.cat([x.to(self.tgt_dev, non_blocking=True) for x in self.states], 0)
        env_r = torch.cat(
            [x.to(self.tgt_dev, non_blocking=True) for x in self.env_rewards], 0
        )
        rand_r = torch.cat(
            [x.to(self.tgt_dev, non_blocking=True) for x in self.rand_rewards], 0
        )
        acts = torch.cat([x.to(self.tgt_dev, non_blocking=True) for x in self.actions], 0)
        dones = torch.cat(
            [x.to(self.tgt_dev, non_blocking=True) for x in self.dones], 0
        ).squeeze(1)

        win = torch.arange(tail, -1, -1, device=self.tgt_dev)
        idx = torch.randint(tail, size - 1, (N * 2,), device=self.tgt_dev)

        ok = (dones[idx.unsqueeze(1) - win[:-1]].sum(1) == 0)
        idx = idx[ok][:N]
        while idx.numel() < N:
            extra = torch.randint(tail, size - 1, (N // 2,), device=self.tgt_dev)
            idx = torch.cat(
                [
                    idx,
                    extra[(dones[extra.unsqueeze(1) - win[:-1]].sum(1) == 0)],
                ]
            )[:N]

        frame_idx = idx.unsqueeze(1) - win  # (N,4)
        next_idx = frame_idx + 1

        s = states[frame_idx].squeeze(2).float().div_(255.0)
        sp = states[next_idx].squeeze(2).float().div_(255.0)

        a = acts[idx].long()
        r_env = env_r[idx]
        r_rand = rand_r[idx]
        d = dones[idx].unsqueeze(1)

        return s, sp, a, r_env, r_rand, d

    # --------------------------------------------------------------------- #
    # 4. n-step sampling                                                    #
    # --------------------------------------------------------------------- #
    @torch.no_grad()
    def sample_nstep(self, N: int, n_step: int):
        """Return n-step sequences for option-Q training.

        Returns:
            S_seq, A_seq, R_env_seq, R_rand_seq, D_seq, S_n
        """
        assert n_step >= 2
        tail = self.F - 1
        size = self.size()
        assert size >= tail + n_step, "buffer too small"

        # Gather
        states = torch.cat([x.to(self.tgt_dev, non_blocking=True) for x in self.states], 0)
        env_r = torch.cat(
            [x.to(self.tgt_dev, non_blocking=True) for x in self.env_rewards], 0
        )
        rand_r = torch.cat(
            [x.to(self.tgt_dev, non_blocking=True) for x in self.rand_rewards], 0
        )
        acts = torch.cat([x.to(self.tgt_dev, non_blocking=True) for x in self.actions], 0)
        dones = torch.cat(
            [x.to(self.tgt_dev, non_blocking=True) for x in self.dones], 0
        ).squeeze(1)

        win = torch.arange(tail, -1, -1, device=self.tgt_dev)
        step = torch.arange(0, n_step, device=self.tgt_dev)
        front = torch.arange(0, n_step - 2, device=self.tgt_dev)

        idx = torch.randint(tail, size - n_step, (N * 2,), device=self.tgt_dev)

        def valid(idxs):
            ok_hist = (dones[idxs.unsqueeze(1) - win[:-1]].sum(1) == 0)
            ok_traj = ok_hist if n_step <= 2 else (
                dones[idxs.unsqueeze(1) + front].sum(1) == 0
            )
            return ok_hist & ok_traj

        idx = idx[valid(idx)][:N]
        while idx.numel() < N:
            extra = torch.randint(tail, size - n_step, (N // 2,), device=self.tgt_dev)
            idx = torch.cat([idx, extra[valid(extra)]])[:N]

        base_idx = idx.unsqueeze(1) + step
        frame_idx = base_idx.unsqueeze(2) - win
        frame_idxN = (idx + n_step).unsqueeze(1) - win

        S_seq = states[frame_idx].squeeze(3).float().div_(255.0)
        S_n = states[frame_idxN].squeeze(2).float().div_(255.0)
        A_seq = acts[base_idx].long()
        R_env_seq = env_r[base_idx]
        R_rand_seq = rand_r[base_idx]
        D_seq = dones[base_idx].unsqueeze(-1)

        return (
            S_seq.contiguous(),
            A_seq,
            R_env_seq,
            R_rand_seq,
            D_seq,
            S_n.contiguous(),
        )
