"""
utils/skill_storage.py
======================
Each item returned is a tuple

    (state, action, next_state, intrinsic_reward, skill_onehot)

which matches the ordering expected by
`vae_mixture.VaribadVAEMixture.pretrain_with_skill_data`.
"""

import pickle
import torch
from torch.utils.data import Dataset


class SkillReplayBuffer(Dataset):
    """
    Wrap a pickled list of tuples:

        [(s_t, a_t, s_{t+1}, r_int, y_onehot), ...]
    """

    def __init__(self, pkl_path):
        """
        Parameters
        ----------
        pkl_path : str or Path
            Path to the `.pkl` file saved by `PretrainerSDVT`.
        """
        with open(pkl_path, "rb") as f:
            data = pickle.load(f)

        # unzip into individual lists, then stack into tensors
        states, actions, next_states, rewards, skills, skill_means, skill_logvars = zip(*data)

        self.states      = torch.tensor(states,      dtype=torch.float32)
        self.actions     = torch.tensor(actions,     dtype=torch.float32)
        self.next_states = torch.tensor(next_states, dtype=torch.float32)
        # rewards need shape (N, 1) for consistency
        self.rewards     = torch.tensor(rewards,     dtype=torch.float32).unsqueeze(-1)
        self.skills      = torch.tensor(skills,      dtype=torch.float32)
        self.skill_means = torch.tensor(skill_means, dtype=torch.float32)
        self.skill_logvars = torch.tensor(skill_logvars, dtype=torch.float32)

    # -------- PyTorch Dataset interface ---------------------------------
    def __len__(self):
        return self.states.size(0)

    def __getitem__(self, idx):
        return (
            self.states[idx],
            self.actions[idx],
            self.next_states[idx],
            self.rewards[idx],
            self.skills[idx],
            self.skill_means[idx],
            self.skill_logvars[idx],
        )
