import torch
from torch.utils.data import DataLoader
import random

class AssociativeRecallDataLoader(DataLoader):
    def __init__(self, batch_size, num_batch, vocab_size, start_id=0, max_seq_length=None):
        assert (vocab_size - start_id) % 2 == 0, "vocab_size - start_id should be an even number"

        self.batch_size = batch_size
        self.num_batch = num_batch
        self.start_id = start_id
        self.num_pairs = (vocab_size - start_id) // 2

        if max_seq_length is None:
            self.min_num_pairs = max(self.num_pairs, self.num_pairs - 5)
            self.max_num_pairs = self.num_pairs
        else:
            self.min_num_pairs = max(max_seq_length // 4, max_seq_length // 2 - 5)
            self.max_num_pairs = max_seq_length // 2
            
        assert self.min_num_pairs <= self.max_num_pairs
        assert self.max_num_pairs <= self.num_pairs

    def __iter__(self):
        for _ in range(self.num_batch):
            # determine the number of key-value pairs
            seq_num_pairs = random.randint(self.min_num_pairs, self.max_num_pairs)

            # generate the key tokens and value tokens
            keys = [self.start_id + torch.randperm(self.num_pairs)[:seq_num_pairs]
                    for _ in range(self.batch_size)]
            key_tensor = torch.stack(keys, dim=0)
            values = [self.start_id + self.num_pairs + torch.randperm(self.num_pairs)[:seq_num_pairs]
                      for _ in range(self.batch_size)]
            value_tensor = torch.stack(values, dim=0)

            # create a Tensor with key and query alternately arranged
            key_value_tensor = torch.stack([key_tensor, value_tensor], dim=2)
            seq_tensor = key_value_tensor.reshape(self.batch_size, -1)

            # create query tokens by randomly permuting the keys
            query_perm_idx = [torch.randperm(seq_num_pairs) for _ in range(self.batch_size)]
            query_perm_idx = torch.stack(query_perm_idx, dim=0) # (batch_size, seq_num_pairs)
            query_tensor = torch.gather(key_tensor, 1, query_perm_idx)
            query_label = torch.gather(value_tensor, 1, query_perm_idx)
            
            # yield the concatenated sequence and the query token
            yield {"input_ids": torch.cat([seq_tensor, query_tensor], dim=1),
                   "labels": query_label, "mask": None}

    def __len__(self):
        return self.num_batch

if __name__ == "__main__":
    dataloader = AssociativeRecallDataLoader(batch_size=3, num_batch=3, start_id=0, vocab_size=16)
    for batch in dataloader:
        print("input_ids:")
        print(batch["input_ids"], batch["input_ids"].shape)
        print("labels:")
        print(batch["labels"], batch["labels"].shape)
        