import os
import torch
from chip.datasets.superres_dataset import SuperresolutionDS
from chip.models.iterative_model import TomographicReconstruction
import hdf5plugin
import lovely_tensors as lt
lt.monkey_patch()

import h5py
from tqdm import tqdm

sinogram_angles_256 = torch.linspace(0, 180 - 180/256, 256)

if __name__ == '__main__':
    files = os.listdir('data/imgs_synthetic')
    tomogram_ds = SuperresolutionDS(files, data_path="data")

    file_name = "data/synthetic_sinograms.h5"
    with torch.no_grad():
        with h5py.File(file_name, 'w') as h5f:
            # Set the chunk size to the size of one image
            chunk_size = (1, 256, 512)
            dataset = h5f.create_dataset(
                name='images',
                shape=(len(tomogram_ds), chunk_size[1], chunk_size[2]),
                chunks=chunk_size,
                **hdf5plugin.Bitshuffle(nelems=0, cname='lz4')
            )
            device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps')
            img_model = TomographicReconstruction(torch.zeros(chunk_size[2], chunk_size[2]).to(device), False).to(device)

            for i, (_, img, file) in tqdm(enumerate(tomogram_ds)):
                img_model.img *= 0
                img_model.img[0] += img.to(device)
                num_angles = chunk_size[1]
                sinogram = img_model.forward(torch.linspace(0, 180 - 180/num_angles, num_angles).to(device))
                dataset[i] = sinogram.cpu().numpy()

