import os
from tqdm import tqdm
from spaghettini import quick_register
import torch
import random

import numpy as np
from scipy.stats import bernoulli
from scipy.signal import convolve2d
from torch.utils.data import Dataset

import matplotlib.pyplot as plt

WHITE_PLUS = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]])
BLACK_PLUS = np.array([[1, 0, 1], [0, 0, 0], [1, 0, 1]])


@quick_register
class BinaryErasureChannelIO(Dataset):
    def __init__(self, zero_prob=0.5, epoch_len=50000):
        super().__init__()
        self.zero_prob = zero_prob
        self.epoch_len = epoch_len

    def __len__(self):
        return self.epoch_len

    def __getitem__(self, idx):
        bit = np.random.binomial(n=1, p=self.zero_prob, size=(1,))
        oh_bit = (bit[..., None] == np.arange(2)[None, ...]).astype(np.float32)  # One-hot encode.
        return oh_bit[0], bit[0]


if __name__ == "__main__":
    """
    Run from root. 
    python -m src.data.datasets.binary_erasure
    """
    test_num = 0

    if test_num == 0:
        # Test if samples are generated properly.
        num_samples = 1000
        bec = BinaryErasureChannelIO(zero_prob=0.5)
        oh_bits = np.zeros((num_samples, 2))
        for i in range(num_samples):
            oh_bits[i, :] = bec[i]
