#%%
import torch
from torch.utils.data import Dataset
import h5py
import os
from torchvision.datasets import ImageNet
import torchvision.transforms as transforms


class ImagenetResults(Dataset):
    def __init__(self, path, imagenet_path):
        super(ImagenetResults, self).__init__()

        self.path = os.path.join(path, 'results.hdf5')
        print("DB path: ", self.path)
        # self.path = os.path.join(path, 'imagenette2_results.hdf5')
        self.imagenet_path = imagenet_path
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
        ])


        self.imagenet = ImageNet(self.imagenet_path, split='val', transform=transform)
        self.data = None

        print('Reading dataset length...')
        with h5py.File(self.path , 'r') as f:
            self.data_length = len(f['/image'])
            if '/path' not in f:
                if (self.data_length != len(self.imagenet)):
                    print("Warning: Length of dataset does not match length of imagenet validation set")

    def __len__(self):
        return self.data_length

    def __getitem__(self, item):
        if self.data is None:
            self.data = h5py.File(self.path, 'r')

        vis = torch.tensor(self.data['vis'][item]) # (1, 14, 14)
        if '/path' in self.data:
            path = self.data['/path'][item].decode("ascii")
            imagenet_idx = next(i for i, x in enumerate(self.imagenet.imgs) if path in x[0])
        else:
            imagenet_idx = item
        # print(imagenet_idx)
        image, target = self.imagenet[imagenet_idx]
        # image = torch.tensor(self.data['image'][item])
        # target = torch.tensor(self.data['target'][item]).long()

        return image, vis, target
    
    def cleanup(self):
        if self.data is not None:
            self.data.close()


if __name__ == '__main__':
    from utils import render
    import imageio
    import numpy as np

    ds = ImagenetResults('../visualizations/fullgrad')
    
    sample_loader = torch.utils.data.DataLoader(
        ds,
        batch_size=5,
        shuffle=False)

    iterator = iter(sample_loader)
    image, vis, target = next(iterator)

    # maps = (render.hm_to_rgb(vis[0].data.cpu().numpy(), scaling=3, sigma=1, cmap='seismic') * 255).astype(np.uint8)

    # imageio.imsave('../delete_hm.jpg', maps)

    print(len(ds))


# %%
