# SPDX-License-Identifier: MIT
from __future__ import annotations
from typing import Dict
import torch
from torch.utils.data import Dataset
from itertools import product
from .embeddings import word_token_embeddings_hf, pad_or_truncate_seq


class PromptDataset(Dataset):
    def __init__(
        self,
        data: Dict[str, float],
        max_tokens: int,
        model,
        tokenizer,
        device: str,
        capture_mode: str = "token",
        target_dtype: torch.dtype = torch.float32,
    ):
        self.device = device
        self.max_tokens = max_tokens
        self.capture_mode = capture_mode
        self.model = model
        self.tokenizer = tokenizer
        self.target_dtype = target_dtype
        self.words = list(data.keys())
        self.labels = torch.tensor([float(data[w]) for w in self.words], dtype=torch.float32, device=device)
        seq_cache = []
        flat_cache = []
        with torch.no_grad():
            for w in self.words:
                seq = word_token_embeddings_hf(w, model=model, tokenizer=tokenizer)
                seq = pad_or_truncate_seq(seq, self.max_tokens)
                seq_cache.append(seq)
                flat_cache.append(seq.reshape(-1))
        self.seq_cache = torch.stack(seq_cache, dim=0).to(device=self.device, dtype=self.target_dtype)
        self.flat_cache = torch.stack(flat_cache, dim=0).to(device=self.device, dtype=self.target_dtype)
        self.pairs = list(product(range(len(self.words)), repeat=2))

    def __len__(self) -> int:
        return len(self.pairs)

    def __getitem__(self, idx):
        i, j = self.pairs[idx]
        seq_i = self.seq_cache[i]
        seq_j = self.seq_cache[j]
        flat_i = self.flat_cache[i]
        flat_j = self.flat_cache[j]
        y_i = self.labels[i]
        y_j = self.labels[j]
        return seq_i, y_i, seq_j, y_j, flat_i, flat_j