import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import random

class RegressionDataloader(DataLoader):
    def __init__(self, batch_size, num_batch, vocab_size, max_seq_length, non_linear=True):
        self.batch_size = batch_size
        self.num_batch = num_batch
        self.vocab_size = vocab_size
        self.max_seq_length = max_seq_length
        self.min_seq_length = max(max_seq_length // 2, max_seq_length - 10)
        self.non_linear = non_linear

        # assign real value for each word
        self.word_values = torch.arange(vocab_size, dtype=torch.float32) / vocab_size

    def __iter__(self):
        for _ in range(self.num_batch):
            # determine the sequence length
            seq_length = random.randint(self.min_seq_length, self.max_seq_length)

            # generate the input sequence
            input_seq = torch.randint(0, self.vocab_size, 
                                      (self.batch_size, seq_length))

            # label: cummax of word values corresponding to the input sequence
            seq_word_values = self.word_values[input_seq]
            labels = torch.cummax(seq_word_values, dim=-1).values

            # if non-linear, apply the non-linear network to the labels
            if self.non_linear:
                labels = labels ** 2

            # yield the input sequence and the target sequence
            yield {"input_ids": input_seq, "labels": labels, "masks": None}
    
    def __len__(self):
        return self.num_batch

if __name__ == "__main__":
    dataloader = RegressionDataloader(batch_size=2, num_batch=3, vocab_size=8, max_seq_length=20)
    for batch in dataloader:
        print("input_ids:")
        print(batch["input_ids"])
        print("labels:")
        print(batch["labels"])

    layer = nn.Linear(8, 16)
    x = torch.randn(2, 3, 8)
    print(layer(x).shape)