import argparse
import random

import pickle
import torch


def generate_data(n, train_size, val_size, test_size):
    X_train = torch.randint(0, 2, (train_size, n), dtype=torch.float32)
    X_val = torch.randint(0, 2, (val_size, n), dtype=torch.float32)
    X_test = torch.randint(0, 2, (test_size, n), dtype=torch.float32)
    return {"train": X_train, "val": X_val, "test": X_test}


class StandardDataGenerator:
    """
    Samples length n in [n_min, n_max] with probability proportional to 2^(beta * n),
    then samples a random 0/1 sequence of length n.
    """
    def __init__(self,
                 data = None,
                 n_bits = 50,
                 d_size=10000,
                 batch_size=32,
                 shuffle=False,
                 is_causal=True,
                 hint=False):
        self.data = data
        self.n_bits = n_bits
        self.d_size = d_size
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.is_causal = is_causal
        self.hint = hint

        # Precompute probabilities p(n) proportional to 2^(beta * n)
        if data is not None:
            seq_lengths = torch.tensor([len(seq) for seq in data])
        else:
            seq_lengths = torch.tensor([n_bits for _ in range(d_size)])

        self.X, self.y, self.lens = self._create_dataset(seq_lengths, data)

    def __iter__(self):
        self.step = 0
        if self.shuffle:
            perm = torch.randperm(self.d_size)
            self.X, self.y, self.lens = self.X[perm], self.y[perm], self.lens[perm]
        return self

    def _create_dataset(self, lengths, data=None):
        if self.is_causal and self.hint:
            X, y = self._generate_batch_serial_hint(lengths, data)
        elif self.hint:
            X, y = self._generate_batch_parallel_hint(lengths, data)
        else:
            X, y = self._generate_batch(lengths, data)

        return X, y, lengths


    def __next__(self):
        if self.step >= self.d_size // self.batch_size:
            raise StopIteration

        start = self.step * self.batch_size
        end = min((self.step + 1) * self.batch_size, self.X.shape[0])
        sequences = self.X[start:end]
        labels = self.y[start:end]
        lengths = self.lens[start:end]

        self.step += 1
        return sequences, labels, lengths

    def _threshold_terms(self, x):
        thresh_terms = x.sum()
        range_tensor = torch.arange(1, x.shape[0] + 1, dtype=torch.float32)
        return (range_tensor <= thresh_terms).long()

    def _subparities(self, x):
        parities = torch.zeros(x.shape[0], dtype=torch.long)
        for i in range(1, x.shape[0] + 1):
            parities[i - 1] = torch.sum(x[:i], dim=0) % 2
        return parities.long()

    def _generate_batch(self, lengths, data=None):
        max_len = torch.max(lengths).item()
        batch_size = len(lengths)

        # We'll store all sequences in a [batch_size, max_len] padded tensor
        sequences = -torch.ones((batch_size, max_len + 1), dtype=torch.long)
        labels = torch.zeros((batch_size, 1), dtype=torch.long)

        for i, length in enumerate(lengths):
            if data is None:
                seq = torch.randint(0, 2, (length,), dtype=torch.long)
            else:
                seq = torch.tensor(data[i], dtype=torch.long)
            parity = seq.sum().item() % 2
            sequences[i, :length] = seq
            sequences[i, length] = 2
            labels[i] = parity

        return sequences.unsqueeze(-1), labels.unsqueeze(-1)

    def _generate_batch_serial_hint(self, lengths, data=None):
        max_len = torch.max(lengths).item()
        batch_size = len(lengths)

        # We'll store all sequences in a [batch_size, max_len] padded tensor
        sequences = -torch.ones((batch_size, max_len + 1), dtype=torch.long)
        labels = torch.zeros((batch_size, max_len + 1), dtype=torch.long)

        for i, length in enumerate(lengths):
            if data is None:
                seq = torch.randint(0, 2, (length,), dtype=torch.long)
            else:
                seq = torch.tensor(data[i], dtype=torch.long)
            hint = self._subparities(seq)
            parity = seq.sum().item() % 2
            sequences[i, :length] = seq
            sequences[i, length] = 2
            labels[i, :length] = hint
            labels[i, length] = parity

        return sequences.unsqueeze(-1), labels.unsqueeze(-1)

    def _generate_batch_parallel_hint(self, lengths, data=None):
        max_len = torch.max(lengths).item()
        batch_size = len(lengths)

        # We'll store all sequences in a [batch_size, max_len] padded tensor
        sequences = -torch.ones((batch_size, max_len + 1), dtype=torch.long)
        labels = torch.zeros((batch_size, max_len + 1), dtype=torch.long)

        for i, length in enumerate(lengths):
            if data is None:
                seq = torch.randint(0, 2, (length,), dtype=torch.long)
            else:
                seq = torch.tensor(data[i], dtype=torch.long)
            hint = self._threshold_terms(seq)
            parity = seq.sum().item() % 2
            sequences[i, :length] = seq
            sequences[i, length] = 2
            labels[i, :length] = hint
            labels[i, length] = parity

        return sequences.unsqueeze(-1), labels.unsqueeze(-1)


class PauseDataGenerator(StandardDataGenerator):
    """
    Samples length n in [n_min, n_max] with probability proportional to 2^(beta * n),
    then samples a random 0/1 sequence of length n.
    """
    def __init__(self, mask_percentage=0.0, **kwargs):
        super().__init__(**kwargs)
        self.mask_percentage = mask_percentage

    def _create_dataset(self, lengths, data=None):
        X, y = self._generate_batch_cot(lengths, data)

        return X, y, lengths

    def __next__(self):
        if self.step >= self.d_size // self.batch_size:
            raise StopIteration

        start = self.step * self.batch_size
        end = min((self.step + 1) * self.batch_size, self.X.shape[0])
        sequences = self.X[start:end]
        labels = self.y[start:end]
        lengths = self.lens[start:end]

        for i in range(sequences.size(0)):
            if random.random() < self.mask_percentage:
                sequences[i, lengths[i]+1:2 * lengths[i]+1] = -1

        self.step += 1
        return sequences, labels, lengths

    def _generate_batch_cot(self, lengths, data=None):
        max_len = torch.max(lengths).item()
        batch_size = len(lengths)

        # We'll store all sequences in a [batch_size, max_len] padded tensor
        sequences = -torch.ones((batch_size, 2 * max_len + 1), dtype=torch.long)
        labels = torch.zeros((batch_size, 2 * max_len + 1), dtype=torch.long)

        for i, length in enumerate(lengths):
            if data is None:
                seq = torch.randint(0, 2, (length,), dtype=torch.long)
            else:
                seq = torch.tensor(data[i], dtype=torch.long)
            cot = self._threshold_terms(seq)
            parity = seq.sum().item() % 2
            sequences[i, :length] = seq
            sequences[i, length] = 2
            if not self.is_train:
                sequences[i, length+1:2*length+1] = -torch.ones_like(seq)
            else:
                sequences[i, length+1:2*length+1] = cot
            labels[i, length:2*length] = cot
            labels[i, 2*length] = parity

        return sequences.unsqueeze(-1), labels.unsqueeze(-1)
