import torch
import numpy as np


class QueryBudgetDataset(torch.utils.data.Dataset):
    def __init__(self, embeds, budgets=None, max_budget=None) -> None:
        """_summary_

        Args:
            texts (_type_): list of text for each query
            embeds (_type_): list of embedding vectors for each query, each with shape (dim,)
            budgets (List[List[int]]): list of per client ICE budget number.
        """
        # self.texts = texts
        self.embeds = embeds

        if budgets is not None:
            if max_budget is None:
                max_budget = sum(budgets[0])

            self.max_budget = max_budget
            self.budgets = (np.array(budgets) / max_budget).astype(np.float32)
        else:
            self.budgets = None

    def __len__(self):
        return len(self.embeds)

    def __getitem__(self, index):
        embed = self.embeds[index]
        if self.budgets is not None:
            budget = self.budgets[index]
            return embed, budget
        else:
            return embed


class QueryBudgetDatasetNew(torch.utils.data.Dataset):
    def __init__(self, embeds, targets=None):
        self.embeds = embeds
        self.targets = targets

    def __len__(self):
        return len(self.embeds)

    def __getitem__(self, index):
        embed = self.embeds[index]
        if self.targets is not None:
            target = self.targets[index]
            return embed, target
        return embed
