import torch
import h5py
import numpy as np

from chip.utils.sinogram import Sinogram
from torch.utils.data import Dataset


class SinogramsDS(torch.utils.data.Dataset):
    def __init__(self, h5_file_path):
        super().__init__()
        sinograms_ds = h5py.File(h5_file_path, 'r')
        self.sinograms = sinograms_ds.get('images')

    def __getitem__(self, idx):
        return torch.tensor(self.sinograms[idx])

    def __len__(self):
        return len(self.sinograms)


class SinogramDataset(Dataset):
    def __init__(self, file):
        super().__init__()
        self.data = h5py.File(file, 'r')
        self.sinograms = np.array(self.data.get('sinogram'))  # load into memory
        self.theta = torch.Tensor(np.array(self.data.get('theta'))).flatten()

    def __getitem__(self, idx):
        hr_sinogram = torch.Tensor(self.sinograms[:, :, idx])
        return Sinogram(sinogram=hr_sinogram, angles=self.theta), idx

    def __len__(self):
        return self.sinograms.shape[2]

