import torch as t


RECALL = 0 # index for the special "RECALL" token.


class Dataset():

    def __init__(self, sequence_length=10, tokens=10, recall_probability=.3, reminder_probability=.0, device="cuda"):
        self.sequence_length = sequence_length
        self.tokens = tokens
        if isinstance(recall_probability, float):
            recall_probability = t.FloatTensor([recall_probability])
        elif isinstance(recall_probability, tuple):
            idx = recall_probability[0]
            recall_probability = t.FloatTensor([recall_probability[1]] * sequence_length)
            for i in range(idx):
                recall_probability[i] = 0


        if isinstance(reminder_probability, float):
            reminder_probability = t.FloatTensor([reminder_probability])
        elif isinstance(reminder_probability, tuple):
            idx = reminder_probability[0]
            reminder_probability = t.FloatTensor([reminder_probability[1]] * sequence_length)
            for i in range(idx):
                reminder_probability[i] = 0

        self.recall_probability = recall_probability.to(device)
        self.device = device

        self.reminder_probability=reminder_probability.to(device)

    def get_data(self, batch_size):
        seq = 1 + t.randint(self.tokens, size=(batch_size, 2 + self.sequence_length), device=self.device)
        recall_mask = t.rand(batch_size, self.sequence_length, device=self.device) > self.recall_probability

        # Positions where noizy (unpredictable) reminders occur
        reminder_mask = t.rand(batch_size, self.sequence_length, device=self.device) < self.reminder_probability

        seq[:, 1:-1] *= recall_mask
        seq[:, -1] = 0
        target = seq.detach().clone()
        recall_value = target[:, 0]
        target[:, 1:] += (target[:, 1:] == 0) * recall_value.unsqueeze(1)


        target[:, 1:-1] = target[:, 1:-1] - target[:, 1:-1] * reminder_mask + reminder_mask * recall_value.unsqueeze(1)

        return seq, target


class SeparateDataset(Dataset):

    # Generating data for the case where querying and providing instructions is done separately.

    def get_data(self, batch_size):
        instruction = 1 + t.randint(self.tokens, size=(batch_size, 2 + self.sequence_length), device=self.device)
        query = t.zeros_like(instruction)
        query[:, 1:-1] = t.rand(batch_size, self.sequence_length, device=self.device) > self.recall_probability
        target = instruction.detach().clone()
        target *= query
        recall_value = instruction[:, 0]
        target[:, 1:] += (target[:, 1:] == 0) * recall_value.unsqueeze(1)
        return instruction, query, target

