import hydra
import torch
from einops import rearrange
from torch.utils.data import Dataset
from torchvision.transforms import v2

from experiments.neural_datasets.inr_utils import make_image_grid


class ImageFitting(Dataset):
    def __init__(
        self,
        train_set,
        test_set,
        img_size,
        train=True,
        debug=False,
        use_test_transform=True,
    ):
        super().__init__()
        self.im_size = list(img_size)
        self.num_pixels = self.im_size[0] * self.im_size[1]

        transform = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])

        self.finetune_transform = v2.Compose(
            [
                v2.RandomHorizontalFlip(),
                v2.RandomResizedCrop(img_size, scale=(0.8, 1.0), ratio=(0.9, 1.1)),
                v2.ToImage(),
                v2.ToDtype(torch.float32, scale=True),
            ]
        )

        self.use_test_transform = use_test_transform

        self._dset = hydra.utils.instantiate(
            train_set if train else test_set,
            transform=transform if use_test_transform else self.finetune_transform,
        )
        try:
            dset_labels = self._dset.targets
        except AttributeError:
            dset_labels = self._dset.labels
        self.num_images = len(dset_labels)

        self.transform_dataset()
        # Save labels for later
        self.labels = torch.LongTensor(dset_labels)

        # Generate and repeat coordinates and indices
        self.coords = make_image_grid(self.im_size, 1)[0].repeat(self.num_images, 1)
        self.indices = torch.arange(self.num_images).repeat_interleave(self.num_pixels)

        self.debug = debug
        self._num_debug = 128

    def transform_dataset(self):
        # NOTE: This can be called at the beginning of each epoch
        loader = torch.utils.data.DataLoader(
            self._dset,
            batch_size=64,
            num_workers=16,
            pin_memory=True,
            shuffle=False,
        )
        self.dset = rearrange(
            torch.cat([d[0] for d in loader], 0), "b c h w -> (b h w) c"
        )

    def __len__(self):
        if self.debug:
            return self._num_debug * self.num_pixels
        return self.dset.shape[0]

    def __getitem__(self, idx):
        image_index = idx // self.num_pixels
        return self.coords[idx], self.dset[idx], image_index
