import unittest, torch

from random import seed, sample

from torchvision import datasets
from torchvision.transforms import ToTensor

from tql import Database, Query, Table
from tql.sqrl import SQRL

import time

test_data = datasets.MNIST(
        root = 'data',
        train = False,
        transform = ToTensor(),
        download = True,
    )
train_data = datasets.MNIST(
        root = 'data',
        train = True,
        transform = ToTensor(),
        download = True,
    )

db = Database("mnist")
db.register_dataset(test_data, "test")
db.register_dataset(train_data, "train")

class row_sum_batch(torch.nn.Module):
    def __init__(self, row):
        super().__init__()
        self.row = row
        
    def forward(self, image_tensor, label):
        return image_tensor[:,0,self.row].sum(dim=1)
    
q = Query('satisfies_rule_sum10', base='train')\
    .project(lambda imgs, labels: zip(imgs, labels, row_sum_batch(10)(imgs, labels))) \
    .project(lambda imgs, label, sum_row_10: zip(imgs, label, sum_row_10, torch.logical_and(label == 5.0, torch.logical_and(0.7984313845634461 <= sum_row_10, sum_row_10 <= 15.972551059722871))))

res = q(db)

res.sample()
