# take example by https://github.com/HazyResearch/state-spaces
"""Long Range Arena datasets"""
from pathlib import Path

import torch
import torchvision
from einops.layers.torch import Rearrange, Reduce
from PIL import Image  # Only used for Pathfinder
from torch.utils.data import DataLoader


class PathFinderDataset(torch.utils.data.Dataset):
    """Path Finder dataset."""

    # There's an empty file in the dataset
    blacklist = {"pathfinder32/curv_baseline/imgs/0/sample_172.png"}

    def __init__(self, data_dir, transform=None):
        """
        Args:
            data_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        # note: temporarily constant for transform
        self.pool = 1
        self.tokenize = True
        self.sequential = True

        self.data_dir = Path(data_dir).expanduser()
        assert self.data_dir.is_dir(), f"data_dir {str(self.data_dir)} does not exist"
        self.transform = transform if transform is not None else self.default_transforms()
        samples = []
        # for diff_level in ['curv_baseline', 'curv_contour_length_9', 'curv_contour_length_14']:
        for diff_level in ["curv_contour_length_14"]:
            path_list = sorted(
                list((self.data_dir / diff_level / "metadata").glob("*.npy")),
                key=lambda path: int(path.stem),
            )
            assert path_list, "No metadata found"
            for metadata_file in path_list:
                with open(metadata_file, "r") as f:
                    for metadata in f.read().splitlines():
                        metadata = metadata.split()
                        image_path = Path(diff_level) / metadata[0] / metadata[1]
                        if str(Path(self.data_dir.stem) / image_path) not in self.blacklist:
                            label = int(metadata[3])
                            samples.append((image_path, label))
        self.samples = samples

    def default_transforms(self):
        transform_list = [torchvision.transforms.ToTensor()]
        if self.pool > 1:
            transform_list.append(
                Reduce(
                    "1 (h h2) (w w2) -> 1 h w",
                    "mean",
                    h2=self.pool,
                    w2=self.pool,
                )
            )
        if self.tokenize:
            transform_list.append(torchvision.transforms.Lambda(lambda x: (x * 255).long()))
        else:
            transform_list.append(torchvision.transforms.Normalize(mean=0.5, std=0.5))
        if self.sequential:
            # If tokenize, it makes more sense to get rid of the channel dimension
            transform_list.append(Rearrange("1 h w -> (h w)") if self.tokenize else Rearrange("1 h w -> (h w) 1"))
        else:
            transform_list.append(Rearrange("1 h w -> h w 1"))
        return torchvision.transforms.Compose(transform_list)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, target = self.samples[idx]
        # https://github.com/pytorch/vision/blob/9b29f3f22783112406d9c1a6db47165a297c3942/torchvision/datasets/folder.py#L247
        with open(self.data_dir / path, "rb") as f:
            sample = Image.open(f).convert("L")  # Open in grayscale
        if self.transform is not None:
            sample = self.transform(sample)
        return {"input_ids": sample, "labels": target}


def get_pathx_dataloader(
    datasets_path=None,
    validation_split_percentage=10,
    test_split_percentage=10,
    seed=1234,
    per_device_train_batch_size=10,
    per_device_eval_batch_size=10,
):
    dataset = PathFinderDataset(datasets_path)
    len_dataset = len(dataset)
    val_len = int(validation_split_percentage / 100 * len_dataset)
    test_len = int(test_split_percentage / 100 * len_dataset)
    train_len = len_dataset - val_len - test_len
    (train_dataset, eval_dataset, test_dataset,) = torch.utils.data.random_split(
        dataset,
        [train_len, val_len, test_len],
        generator=torch.Generator().manual_seed(seed),
    )
    train_dataloader = DataLoader(
        train_dataset,
        shuffle=True,
        batch_size=per_device_train_batch_size,
        pin_memory=True,
    )

    eval_dataloader = DataLoader(
        eval_dataset,
        batch_size=per_device_eval_batch_size,
        pin_memory=True,
    )
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=per_device_eval_batch_size,
        pin_memory=True,
    )
    return train_dataloader, eval_dataloader, test_dataloader, None
