from functools import partial
import os
import io
from pathlib import Path

import logging
import pickle


import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
import torchvision
from PIL import Image  # Only used for Pathfinder
from einops.layers.torch import Rearrange, Reduce
from einops import rearrange
import torchtext
from omegaconf import DictConfig
from datasets import load_dataset, DatasetDict, Value
from pytorch_lightning import LightningDataModule
from sklearn.utils import shuffle
from src.utils import permutations, is_list
from .cifar_augmentations import Cutout, RandomErasing
from sklearn.model_selection import train_test_split
import src.tasks.defaults as tasks

# Default data path is environment variable or hippo/data
if (default_data_path := os.getenv('DATA_PATH')) is None:
    default_data_path = Path(__file__).parent.parent.parent.absolute()
    default_data_path = default_data_path / "data"
else:
    default_data_path = Path(default_data_path).absolute()


class SequenceDataset:
    registry = {}
    _name_ = NotImplementedError("Dataset must have shorthand name")


    # Since subclasses do not specify __init__ which is instead handled by this class
    # Subclasses can provide a list of default arguments which are automatically registered as attributes
    init_defaults = {}

    # https://www.python.org/dev/peps/pep-0487/#subclass-registration
    def __init_subclass__(cls, **kwargs):
        super().__init_subclass__(**kwargs)
        cls.registry[cls._name_] = cls

    def __init__(self, _name_, data_dir=None, **dataset_cfg):
        assert _name_ == self._name_
        self.data_dir = Path(data_dir).absolute() if data_dir is not None else None

        # Add all arguments to self
        # self.dataset_cfg = DictConfig(dataset_cfg)
        init_args = self.init_defaults
        init_args.update(dataset_cfg) # TODO this overrides the default dict which is bad
        for k, v in init_args.items():
            setattr(self, k, v)

        # Construct mandatory shape parameters
        self.init() # TODO replace this with @property for each important parameter

        # train, val, test datasets must be set by class instantiation
        self.train = None
        self.val = None
        self.test = None

        # # Loader arguments
        # self.collate_fn = None


    def init(self):
        """ Hook for other __init__ actions. Currently used mainly for setting data_dir """
        pass

    def setup(self):
        """ This method should set self.train, self.val, and self.test """
        raise NotImplementedError

    def split_train_val(self, val_split):
        train_len = int(len(self.train) * (1.0-val_split))
        self.train, self.val = torch.utils.data.random_split(
            self.train,
            (train_len, len(self.train) - train_len),
            generator=torch.Generator().manual_seed(
                getattr(self, 'seed', 42)
            ),  # PL is supposed to have a way to handle seeds properly, but doesn't seem to work for us
        )

    @staticmethod
    def collate_fn(resolution, batch):
        """ batch: list of (x, y) pairs """
        x, y = zip(*batch)
        # Drop every nth sample
        x = torch.stack(x, dim=0)[:, ::resolution]
        y = torch.LongTensor(y)
        return x, y

    # def train_dataloader(self, batch_size, **kwargs):
    def train_dataloader(self, train_resolution, eval_resolutions, **kwargs):
        if train_resolution is None: train_resolution = [1]
        if not is_list(train_resolution): train_resolution = [train_resolution]
        assert len(train_resolution) == 1, "Only one train resolution supported for now"

        return self._dataloader(
            self.train,
            resolutions=train_resolution,
            shuffle=True,
            **kwargs,
        )[0]


    def val_dataloader(self, **kwargs):
        return self._eval_dataloader(self.val, **kwargs)

    def test_dataloader(self, **kwargs):
        return self._eval_dataloader(self.test, **kwargs)

    def _eval_dataloader(self, dataset, train_resolution, eval_resolutions, **kwargs):
        if eval_resolutions is None: eval_resolutions = [1]
        if not is_list(eval_resolutions): eval_resolutions = [eval_resolutions]

        dataloaders = self._dataloader(
            dataset,
            resolutions=eval_resolutions,
            shuffle=False,
            **kwargs,
        )

        return {
            str(res) if res > 1 else None: dl for res, dl in
            zip(eval_resolutions, dataloaders)
        } if dataloaders is not None else None

    def _dataloader(self, dataset, resolutions, **loader_args):
        if dataset is None: return None

        return [
            torch.utils.data.DataLoader(
                dataset, 
                collate_fn=partial(self.collate_fn, resolution) if self.collate_fn is not None else None,
                **loader_args,
                )
            for resolution in resolutions
        ]

    def __str__(self):
        return self._name_ # if hasattr(self, "_name_") else self.__name__


class MNIST(SequenceDataset):
    _name_ = "mnist"
    # task = MulticlassClassification()
    default_task = tasks.multiclass_classification
    d_input = 1
    d_output = 10
    l_output = 0
    L = 784

    init_defaults = {
        'permute': True,
        'val_split': 0.1,
        'seed': 42, # For train/val split
    }

    def setup(self):
        self.data_dir = self.data_dir or default_data_path / self._name_

        transform_list = [
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Lambda(
                lambda x: x.view(self.d_input, self.L).t()
            ),
        ]  # (L, d_input)
        if self.permute:
            # below is another permutation that other works have used
            # permute = np.random.RandomState(92916)
            # permutation = torch.LongTensor(permute.permutation(784))
            permutation = permutations.bitreversal_permutation(self.L)
            transform_list.append(
                torchvision.transforms.Lambda(lambda x: x[permutation])
            )
        # TODO does MNIST need normalization?
        # torchvision.transforms.Normalize((0.1307,), (0.3081,)) # normalize inputs
        transform = torchvision.transforms.Compose(transform_list)
        self.train = torchvision.datasets.MNIST(
            # f"{default_data_path}/{self._name_}",
            self.data_dir,
            train=True,
            download=True,
            transform=transform,
        )
        self.test = torchvision.datasets.MNIST(
            self.data_dir,
            train=False,
            transform=transform,
        )
        self.split_train_val(self.val_split)

    def __str__(self):
        return f"{'p' if self.permute else 's'}{self._name_}"


class CIFAR10(SequenceDataset):
    _name_ = "cifar"
    d_output = 10
    l_output = 0
    default_task = tasks.multiclass_classification

    init_defaults = { # TODO handle embedding and grayscale arguments
        'permute': None,
        'grayscale': False,
        # 'd_embed': 0,
        'tokenize': False,
        'augment': False,
        'cutout': False,
        'random_erasing': False,
        'val_split': 0.1,
        'seed': 42, # For validation split
    }

    @property
    def d_input(self):
        if self.grayscale:
            if self.tokenize:
                return 256
            else:
                return 1
        else:
            assert not self.tokenize
            return 3


    def setup(self):
        if self.grayscale:
            preprocessors = [
                torchvision.transforms.Grayscale(),
                torchvision.transforms.ToTensor(),
            ]
            permutations = [
                torchvision.transforms.Lambda(
                    lambda x: x.view(1, 1024).t()
                )  # (L, d_input)
            ]

            if self.tokenize:
                preprocessors.append(
                    torchvision.transforms.Lambda(lambda x: (x * 255).long())
                )
                permutations.append(Rearrange("l 1 -> l"))
            else:
                preprocessors.append(
                    torchvision.transforms.Normalize(
                        mean=122.6 / 255.0, std=61.0 / 255.0
                    )
                )
        else:
            preprocessors = [
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize(
                    (0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)
                ),
            ]
            permutations = [
                torchvision.transforms.Lambda(
                    Rearrange("z h w -> (h w) z", z=3, h=32, w=32)
                )  # (L, d_input)
            ]

        # Permutations and reshaping
        if self.permute == "br":
            permutation = permutations.bitreversal_permutation(1024)
            print("bit reversal", permutation)
            permutations.append(
                torchvision.transforms.Lambda(lambda x: x[permutation])
            )
        elif self.permute == "snake":
            permutation = permutations.snake_permutation(32, 32)
            print("snake", permutation)
            permutations.append(
                torchvision.transforms.Lambda(lambda x: x[permutation])
            )
        elif self.permute == "hilbert":
            permutation = permutations.hilbert_permutation(32)
            print("hilbert", permutation)
            permutations.append(
                torchvision.transforms.Lambda(lambda x: x[permutation])
            )
        elif self.permute == "transpose":
            permutation = permutations.transpose_permutation(32, 32)
            transform = torchvision.transforms.Lambda(
                lambda x: torch.cat([x, x[permutation]], dim=-1)
            )
            permutations.append(transform)

        # Augmentation
        # if "augment" in self.dataset_cfg and self.dataset_cfg.augment:
        if self.augment:
            augmentations = [
                torchvision.transforms.RandomCrop(32, padding=4, padding_mode='symmetric'),
                torchvision.transforms.RandomHorizontalFlip(),
            ]

            post_augmentations = []
            if self.cutout:
                post_augmentations.append(Cutout(1, 16))
                pass
            if self.random_erasing:
                # augmentations.append(RandomErasing())
                pass
        else:
            augmentations, post_augmentations = [], []
        torchvision.transforms_train = (
            augmentations + preprocessors + post_augmentations + permutations
        )
        torchvision.transforms_eval = preprocessors + permutations

        transform_train = torchvision.transforms.Compose(torchvision.transforms_train)
        transform_eval = torchvision.transforms.Compose(torchvision.transforms_eval)
        self.train = torchvision.datasets.CIFAR10(
            f"{default_data_path}/{self._name_}",
            train=True,
            download=True,
            transform=transform_train,
        )
        self.test = torchvision.datasets.CIFAR10(
            f"{default_data_path}/{self._name_}", train=False, transform=transform_eval
        )
        self.split_train_val(self.val_split)

    def __str__(self):
        return f"{'p' if self.permute else 's'}{self._name_}"

class CIFAR10Generation(SequenceDataset):
    """ TODO there should be a way to combine this with main CIFAR class. the issue is making sure the torchvision.transforms are applied to output in the same way. """
    _name_ = "cifargen"

    init_defaults = {
        'transpose': False,
        'tokenize': True,
        'mixture': 0,
        'val_split': 0.02,
        'seed': 42,
    }

    @property
    def d_input(self):
        if not self.tokenize: return 1 # Returns None otherwise

    @property
    def d_output(self):
        return 256 if self.mixture == 0 else 3 * self.mixture

    @property
    def n_tokens(self):
        if self.tokenize: return 3*256+1

    @property
    def n_classes(self):
        return 10

    @property
    def default_task(self): # TODO
        encoder = 'embedding' if self.tokenize else 'linear'
        loss = 'cross_entropy' if self.mixture == 0 else 'mixture'
        task = {
            '_target_': 'tasks.tasks.GeneralTask',
            'encoder': {'_name_': encoder},
            'decoder': {'_name_': 'sequence'},
            'loss': loss,
            'metrics': ['bpb'],
        }
        return task

    @property
    def permute(self):
        if self.transpose: # R R ... G G ... B B ...
            return lambda x: rearrange(x, "... h w c -> ... (c h w) 1")
        else: # R G B R G B ...
            return lambda x: rearrange(x, "... h w c -> ... (h w c) 1")

    @property
    def transforms0(self):
        """ Transforms applied before permutation """
        if self.tokenize:
            return torchvision.transforms.Lambda(lambda x: x + 1 + torch.arange(3) * 256)
        else:
            # return torchvision.transforms.Normalize(mean=127.5, std=127.5)
            return torchvision.transforms.Lambda(lambda x: (x.float() - 127.5) / 127.5)

    @property
    def transforms1(self):
        """ Transforms applied after permutation """
        if self.tokenize:
            return torchvision.transforms.Lambda(lambda x: x.squeeze(-1))
        else:
            return torchvision.transforms.Compose([])

    def setup(self):
        transforms = [
            torchvision.transforms.ToTensor(),  # (B, C, H, W)
            Rearrange("c h w -> h w c"),  # (B, H, W, C)
            torchvision.transforms.Lambda(lambda x: (x * 255).long()), # Convert back to ints
        ]
        transform = torchvision.transforms.Compose(transforms)

        self.train = torchvision.datasets.CIFAR10(
            f"{default_data_path}/cifar",
            train=True,
            download=True,
            transform=transform,
        )
        self.test = torchvision.datasets.CIFAR10(
            f"{default_data_path}/cifar", train=False, transform=transform
        )
        self.split_train_val(self.val_split)

        def collate_batch(resolution, batch):
            """ batch: list of (x, y) pairs """
            inputs, labels = zip(*batch)
            x = torch.stack(inputs, dim=0)
            z = torch.LongTensor(labels)
            y = self.permute(x)
            x = self.transforms0(x)
            x = self.permute(x)
            x = F.pad(x[:, :-1, :], (0, 0, 1, 0))
            x = self.transforms1(x)
            return x, y, z
        self.collate_fn = collate_batch



    def __str__(self):  # TODO not updated
        return f"{self._name_}"


class CIFAR10GenerationFactored(CIFAR10Generation):
    """ Version of CIFAR-10 Density Estimation that keeps the sequence of length 1024 and factors the distribution over the 3 channels """
    _name_ = "cifargenf"
    l_output = 1024 # Leaving this out or setting to None also works, to indicate that the entire length dimension is kept

    init_defaults = {
        'mixture': 0,
        'val_split': 0.02,
        'seed': 42,
    }

    @property
    def d_input(self): return 3
    @property
    def d_output(self):
        return 3*256 if self.mixture == 0 else 10 * self.mixture

    @property
    def default_task(self): # TODO
        loss = 'cross_entropy' if self.mixture == 0 else 'mixture_kd'
        task = {
            '_target_': 'tasks.tasks.GeneralTask',
            'encoder': {'_name_': 'linear'},
            'decoder': {'_name_': 'sequence'},
            'loss': loss,
            'metrics': ['bpb'],
        }
        return task

    @property
    def permute(self):
        return lambda x: rearrange(x, "... h w c -> ... (h w) c")

    @property
    def transforms0(self):
        return torchvision.transforms.Lambda(lambda x: (x.float() - 127.5) / 127.5)
        # return torchvision.transforms.Normalize(mean=0.5, std=0.5)
    @property
    def transforms1(self):
        return torchvision.transforms.Compose([])


class Copying(SequenceDataset):
    _name_ = "copying"
    # task = MulticlassClassification()
    default_task = tasks.multiclass_classification
    default_task.update({'decoder': {'_name_': 'sequence', 'mode': 'last'}})


    init_defaults = {
        'l_noise': 100, # number of padding tokens
        'l_memorize': 10,  # number of tokens to memorize
        'n_tokens': 10,  # alphabet size
        'variable': False, # Randomly distribute memorization tokens throughout sequence instead of frontloading them
        'n_samples': 50000,
        'val_split': 0.1,
    }

    @property
    def d_input(self): return self.n_tokens
    @property
    def d_output(self): return self.n_tokens
    @property
    def l_output(self): return self.l_memorize


    def setup(self):
        from .copying import copying_static_dataset
        self.train = copying_static_dataset(
            self.l_noise,
            self.l_memorize,
            self.n_tokens,
            self.variable,
            self.n_samples,
        )
        self.test = None
        self.split_train_val(self.val_split)

    def __str__(self):
        return f"{self._name_}{self.l_noise}{'v' if self.variable else ''}"


class Adding(SequenceDataset):
    _name_ = "adding"
    d_input = 2
    d_output = 1
    l_output = 0
    # L = self.l_max

    default_task = tasks.mse_regression

    init_defaults = {
        'l_max': 1000,
        'n_samples': 50000,
        'val_split': 0.1,
    }

    def setup(self):
        from .adding import adding_static_dataset
        self.train = adding_static_dataset(self.l_max, self.n_samples)
        self.test = None
        self.split_train_val(self.val_split)

    def __str__(self):
        return f"{self._name_}{self.l_max}"


# Wrap the data loader with callback function
class IMDB(SequenceDataset):
    _name_ = "imdb"
    # d_input = 0  # Gets overridden by actual vocab size
    d_output = 2
    l_output = 0

    init_defaults = {
        'l_max': 4096,
        'level': 'char',
        'min_freq': 15,
        'seed': 42,
        'val_split': 0.0,
        'append_bos': False,
        'append_eos': True,
        'd_embed': 128,
        'n_workers': 4, # Only used for tokenizing dataset before caching
    }

    @property
    def n_tokens(self):
        return len(self.vocab)

    def init(self):
        """If cache_dir is not None, we'll cache the processed dataset there."""
        self.data_dir = self.data_dir or default_data_path / self._name_
        self.cache_dir = self.data_dir / "cache"
        assert self.level in [
            "word",
            "char",
        ], f"level {self.level} not supported"


    def prepare_data(self):
        if self.cache_dir is None:  # Just download the dataset
            load_dataset(self._name_, cache_dir=self.data_dir)
        else:  # Process the dataset and save it
            self.process_dataset()

    def setup(self, stage=None):
        if stage == "test" and hasattr(self, "dataset_test"):
            return
        dataset, self.vocab = self.process_dataset()
        print(
            f"IMDB {self.level} level | min_freq {self.min_freq} | vocab size {len(self.vocab)}"
        )
        dataset.set_format(type="torch", columns=["input_ids", "label"])

        # Create all splits
        dataset_train, dataset_test = dataset["train"], dataset["test"]
        if self.val_split == 0.0:
            # Use test set as val set, as done in the LRA paper
            self.dataset_train, self.dataset_val = dataset_train, None
        else:
            train_val = dataset_train.train_test_split(
                test_size=self.val_split, seed=self.seed
            )
            self.dataset_train, self.dataset_val = (
                train_val["train"],
                train_val["test"],
            )
        self.dataset_test = dataset_test

        # Rename for this repo's naming conventions
        self.train = self.dataset_train
        self.val = self.dataset_val
        self.test = self.dataset_test

        def collate_batch(resolution, batch):
            xs, ys = zip(
                *[(data["input_ids"], data["label"]) for data in batch]
            )
            lengths = torch.tensor([len(x) for x in xs])
            xs = nn.utils.rnn.pad_sequence(
                xs, padding_value=self.vocab["<pad>"], batch_first=True
            )
            ys = torch.tensor(ys)
            return xs, ys, lengths

        self.collate_fn = collate_batch

    def process_dataset(self):
        cache_dir = (
            None
            if self.cache_dir is None
            else self.cache_dir / self._cache_dir_name
        )
        if cache_dir is not None:
            if cache_dir.is_dir():
                return self._load_from_cache(cache_dir)

        dataset = load_dataset(self._name_, cache_dir=self.data_dir)
        dataset = DatasetDict(train=dataset["train"], test=dataset["test"])
        if self.level == "word":
            self.tokenizer = torchtext.data.utils.get_tokenizer(
                "spacy", language="en_core_web_sm"
            )
        else: # self.level == 'char'
            self.tokenizer = list  # Just convert a string to a list of chars
        # Account for <bos> and <eos> tokens
        l_max = (
            self.l_max - int(self.append_bos) - int(self.append_eos)
        )
        tokenize = lambda example: {
            "tokens": self.tokenizer(example["text"])[:l_max]
        }
        dataset = dataset.map(
            tokenize,
            remove_columns=["text"],
            keep_in_memory=True,
            load_from_cache_file=False,
            num_proc=max(self.n_workers, 1),
        )
        vocab = torchtext.vocab.build_vocab_from_iterator(
            dataset["train"]["tokens"],
            min_freq=self.min_freq,
            specials=(
                ["<pad>", "<unk>"]
                + (["<bos>"] if self.append_bos else [])
                + (["<eos>"] if self.append_eos else [])
            ),
        )
        vocab.set_default_index(vocab["<unk>"])

        numericalize = lambda example: {
            "input_ids": vocab(
                (["<bos>"] if self.append_bos else [])
                + example["tokens"]
                + (["<eos>"] if self.append_eos else [])
            )
        }
        dataset = dataset.map(
            numericalize,
            remove_columns=["tokens"],
            keep_in_memory=True,
            load_from_cache_file=False,
            num_proc=max(self.n_workers, 1),
        )

        if cache_dir is not None:
            self._save_to_cache(dataset, vocab, cache_dir)
        return dataset, vocab

    def _save_to_cache(self, dataset, vocab, cache_dir):
        import logging, pickle

        cache_dir = self.cache_dir / self._cache_dir_name
        logger = logging.getLogger(__name__)
        logger.info(f"Saving to cache at {str(cache_dir)}")
        dataset.save_to_disk(str(cache_dir))
        with open(cache_dir / "vocab.pkl", "wb") as f:
            pickle.dump(vocab, f)

    def _load_from_cache(self, cache_dir):
        import logging, pickle

        assert cache_dir.is_dir()
        logger = logging.getLogger(__name__)
        logger.info(f"Load from cache at {str(cache_dir)}")
        dataset = DatasetDict.load_from_disk(str(cache_dir))
        with open(cache_dir / "vocab.pkl", "rb") as f:
            vocab = pickle.load(f)
        return dataset, vocab

    @property
    def _cache_dir_name(self):
        return f"l_max-{self.l_max}-level-{self.level}-min_freq-{self.min_freq}-append_bos-{self.append_bos}-append_eos-{self.append_eos}"


class SpeechCommands(SequenceDataset):
    _name_ = "sc"
    default_task = tasks.multiclass_classification

    init_defaults = {
        'mfcc': False,
        'dropped_rate': 0.,
        'length': 16000,
    }

    def init(self):
        if self.mfcc:
            self.d_input = 20
            self.L = 161
        else:
            self.d_input = 1
            self.L = self.length
            # self.L = 16000
            # self.L = 16384

        if self.dropped_rate > 0.0:
            self.d_input += 1

        self.d_output = 10
        self.l_output = 0


    def setup(self):
        from src.dataloaders.sc import _SpeechCommands

        # TODO refactor with data_dir argument
        self.train = _SpeechCommands(
            partition="train",
            length=16000,# self.L,
            mfcc=self.mfcc,
            sr=1,
            dropped_rate=self.dropped_rate,
            path=default_data_path,
        )

        self.val = _SpeechCommands(
            partition="val",
            length=16000,# self.L,
            mfcc=self.mfcc,
            sr=1,
            dropped_rate=self.dropped_rate,
            path=default_data_path,
        )

        self.test = _SpeechCommands(
            partition="test",
            length=16000,# self.L,
            mfcc=self.mfcc,
            sr=1,
            dropped_rate=self.dropped_rate,
            path=default_data_path,
        )

    #     if self.val_resampled:
    #         self.val_2 = _SpeechCommands(
    #             partition="val",
    #             length=self.L,
    #             mfcc=self.mfcc,
    #             sr=2,
    #             dropped_rate=self.dropped_rate,
    #             path=default_data_path,
    #         )

    #         self.val_4 = _SpeechCommands(
    #             partition="val",
    #             length=self.L,
    #             mfcc=self.mfcc,
    #             sr=4,
    #             dropped_rate=self.dropped_rate,
    #             path=default_data_path,
    #         )

    #         self.val_8 = _SpeechCommands(
    #             partition="val",
    #             length=self.L,
    #             mfcc=self.mfcc,
    #             sr=8,
    #             dropped_rate=self.dropped_rate,
    #             path=default_data_path,
    #         )

    #         self.val_16 = _SpeechCommands(
    #             partition="val",
    #             length=self.L,
    #             mfcc=self.mfcc,
    #             sr=16,
    #             dropped_rate=self.dropped_rate,
    #             path=default_data_path,
    #         )

    #         self.test_2 = _SpeechCommands(
    #             partition="test",
    #             length=self.L,
    #             mfcc=self.mfcc,
    #             sr=2,
    #             dropped_rate=self.dropped_rate,
    #             path=default_data_path,
    #         )

    #         self.test_4 = _SpeechCommands(
    #             partition="test",
    #             length=self.L,
    #             mfcc=self.mfcc,
    #             sr=4,
    #             dropped_rate=self.dropped_rate,
    #             path=default_data_path,
    #         )

    #         self.test_8 = _SpeechCommands(
    #             partition="test",
    #             length=self.L,
    #             mfcc=self.mfcc,
    #             sr=8,
    #             dropped_rate=self.dropped_rate,
    #             path=default_data_path,
    #         )

    #         self.test_16 = _SpeechCommands(
    #             partition="test",
    #             length=self.L,
    #             mfcc=self.mfcc,
    #             sr=16,
    #             dropped_rate=self.dropped_rate,
    #             path=default_data_path,
    #         )

    # def val_dataloader(self, batch_size, **kwargs):
    #     if self.val_resampled:
    #         return {
    #             "val": torch.utils.data.DataLoader(
    #                 self.val, batch_size=batch_size, shuffle=False, **kwargs
    #             ),
    #             "val_2": torch.utils.data.DataLoader(
    #                 self.val_2, batch_size=batch_size, shuffle=False, **kwargs
    #             ),
    #             "val_4": torch.utils.data.DataLoader(
    #                 self.val_4, batch_size=batch_size, shuffle=False, **kwargs
    #             ),
    #             "val_8": torch.utils.data.DataLoader(
    #                 self.val_8, batch_size=batch_size, shuffle=False, **kwargs
    #             ),
    #             "val_16": torch.utils.data.DataLoader(
    #                 self.val_16, batch_size=batch_size, shuffle=False, **kwargs
    #             ),
    #         }
    #     else:
    #         return {
    #             "val": torch.utils.data.DataLoader(
    #                 self.val, batch_size=batch_size, shuffle=False, **kwargs
    #             )
    #         }

    # def test_dataloader(self, batch_size, **kwargs):
    #     if self.val_resampled:
    #         return {
    #             "test": torch.utils.data.DataLoader(
    #                 self.test, batch_size=batch_size, shuffle=False, **kwargs
    #             ),
    #             "test_2": torch.utils.data.DataLoader(
    #                 self.test_2, batch_size=batch_size, shuffle=False, **kwargs
    #             ),
    #             "test_4": torch.utils.data.DataLoader(
    #                 self.test_4, batch_size=batch_size, shuffle=False, **kwargs
    #             ),
    #             "test_8": torch.utils.data.DataLoader(
    #                 self.test_8, batch_size=batch_size, shuffle=False, **kwargs
    #             ),
    #             "test_16": torch.utils.data.DataLoader(
    #                 self.test_16,
    #                 batch_size=batch_size,
    #                 shuffle=False,
    #                 **kwargs,
    #             ),
    #         }
    #     else:
    #         return {
    #             "test": torch.utils.data.DataLoader(
    #                 self.test, batch_size=batch_size, shuffle=False, **kwargs
    #             )
    #         }


class TabularDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        path,
        format,
        col_idx=None,
        skip_header=False,
        csv_reader_params=None,
    ):
        """
        col_idx: the indices of the columns.
        """
        if csv_reader_params is None:
            csv_reader_params = {}
        format = format.lower()
        assert format in ["tsv", "csv"]
        with io.open(os.path.expanduser(path), encoding="utf8") as f:
            if format == "csv":
                reader = torchtext.utils.unicode_csv_reader(
                    f, **csv_reader_params
                )
            elif format == "tsv":
                reader = torchtext.utils.unicode_csv_reader(
                    f, delimiter="\t", **csv_reader_params
                )
            else:
                reader = f
            if skip_header:
                next(reader)
            self._data = [
                line if col_idx is None else [line[c] for c in col_idx]
                for line in reader
            ]

    def __len__(self):
        return len(self._data)

    def __getitem__(self, idx):
        return self._data[idx]


class ListOps(SequenceDataset):
    _name_ = "listops"
    d_output = 10
    default_task = tasks.text_classification
    l_output = 0

    init_defaults = {
        'l_max': 2048,
        'append_bos': False,
        'append_eos': True,
        # 'max_vocab': 20, # Actual size 18
        'n_workers': 4, # Only used for tokenizing dataset
    }

    @property
    def n_tokens(self):
        return len(self.vocab)

    def setup(self, stage=None):
        if self.data_dir is None: self.data_dir = default_data_path / self._name_

        if stage == "test" and hasattr(self, "test"):
            return

        dataset = load_dataset(
            "csv",
            data_files={
                "train": str(self.data_dir / "basic_train.tsv"),
                "val": str(self.data_dir / "basic_val.tsv"),
                "test": str(self.data_dir / "basic_test.tsv"),
            },
            delimiter="\t",
            keep_in_memory=True,
        )

        # LRA tokenizer renames ']' to 'X' and delete parentheses as their tokenizer removes
        # non-alphanumeric characters.
        # https://github.com/google-research/long-range-arena/blob/264227cbf9591e39dd596d2dc935297a2070bdfe/lra_benchmarks/listops/input_pipeline.py#L46
        self.tokenizer = lambda s: s.translate(
            {ord("]"): ord("X"), ord("("): None, ord(")"): None}
        ).split()

        # Account for <bos> and <eos> tokens
        l_max = self.l_max - int(self.append_bos) - int(self.append_eos)
        tokenize = lambda example: {
            "tokens": self.tokenizer(example["Source"])[:l_max]
        }
        dataset = dataset.map(
            tokenize,
            keep_in_memory=True,
            load_from_cache_file=False,
            num_proc=self.n_workers,
        )
        self.vocab = torchtext.vocab.build_vocab_from_iterator(
            dataset["train"]["tokens"],
            specials=(
                ["<pad>", "<unk>"]
                + (["<bos>"] if self.append_bos else [])
                + (["<eos>"] if self.append_eos else [])
            ),
        )
        self.vocab.set_default_index(self.vocab["<unk>"])
        self.d_input = len(self.vocab)
        print(f"ListOps vocab size {self.d_input}")

        def _vocab(
            tokens,
        ):  # TODO duplicated from other vocab datasets like IMDB, AAN
            bos = ["<bos>"] if self.append_bos else []
            eos = ["<eos>"] if self.append_eos else []
            # stoi = (
            #     lambda token: self.vocab[token]
            #     if self.vocab[token] < self.max_vocab
            #     else self.vocab["<unk>"]
            # )
            stoi = lambda token: self.vocab[token]
            indices = [stoi(token) for token in bos + tokens + eos]
            return indices

        numericalize = lambda example: {"input_ids": _vocab(example["tokens"])}
        dataset = dataset.map(
            numericalize,
            keep_in_memory=True,
            load_from_cache_file=False,
            num_proc=self.n_workers,
        )

        dataset.set_format(type="torch", columns=["input_ids", "Target"])
        self.train, self.val, self.test = (
            dataset["train"],
            dataset["val"],
            dataset["test"],
        )

        def collate_batch(resolution, batch):
            xs, ys = zip(
                *[(data["input_ids"], data["Target"]) for data in batch]
            )
            lengths = torch.tensor([len(x) for x in xs])
            xs = nn.utils.rnn.pad_sequence(
                xs, padding_value=self.vocab["<pad>"], batch_first=True
            )
            ys = torch.tensor(ys)
            return xs, ys, lengths

        self.collate_fn = collate_batch


class PathFinderDataset(torch.utils.data.Dataset):
    """Path Finder dataset."""

    # There's an empty file in the dataset
    blacklist = {"pathfinder32/curv_baseline/imgs/0/sample_172.png"}

    def __init__(self, data_dir, transform=None):
        """
        Args:
            data_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.data_dir = Path(data_dir).expanduser()
        assert (
            self.data_dir.is_dir()
        ), f"data_dir {str(self.data_dir)} does not exist"
        self.transform = transform
        samples = []
        # for diff_level in ['curv_baseline', 'curv_contour_length_9', 'curv_contour_length_14']:
        for diff_level in ["curv_contour_length_14"]:
            path_list = sorted(
                list((self.data_dir / diff_level / "metadata").glob("*.npy")),
                key=lambda path: int(path.stem),
            )
            assert path_list, "No metadata found"
            for metadata_file in path_list:
                with open(metadata_file, "r") as f:
                    for metadata in f.read().splitlines():
                        metadata = metadata.split()
                        image_path = (
                            Path(diff_level) / metadata[0] / metadata[1]
                        )
                        if (
                            str(Path(self.data_dir.stem) / image_path)
                            not in self.blacklist
                        ):
                            label = int(metadata[3])
                            samples.append((image_path, label))
        self.samples = samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, target = self.samples[idx]
        # https://github.com/pytorch/vision/blob/9b29f3f22783112406d9c1a6db47165a297c3942/torchvision/datasets/folder.py#L247
        with open(self.data_dir / path, "rb") as f:
            sample = Image.open(f).convert("L")  # Open in grayscale
        if self.transform is not None:
            sample = self.transform(sample)
        return sample, target


class PathFinder(SequenceDataset):
    _name_ = "pathfinder"
    d_input = 1
    d_output = 2
    l_output = 0
    default_task = tasks.binary_classification

    @property
    def n_tokens(self):
        if self.tokenize: return 256

    init_defaults = {
        'resolution': 32,
        'sequential': True,
        # 'd_embed': None,
        'tokenize': False,
        'pool': 1,
        'val_split': 0.1,
        'test_split': 0.1,
        'seed': 42, # Controls the train/val/test split
    }


    def init(self):
        if self.data_dir is None:
            self.data_dir = default_data_path / self._name_ / f"pathfinder{self.resolution}"


    def default_transforms(self):
        transform_list = [torchvision.transforms.ToTensor()]
        if self.pool > 1:
            transform_list.append(
                Reduce(
                    "1 (h h2) (w w2) -> 1 h w",
                    "mean",
                    h2=self.pool,
                    w2=self.pool,
                )
            )
        if self.tokenize:
            transform_list.append(
                torchvision.transforms.Lambda(lambda x: (x * 255).long())
            )
        else:
            transform_list.append(
                torchvision.transforms.Normalize(mean=0.5, std=0.5)
            )
        if self.sequential:
            # If tokenize, it makes more sense to get rid of the channel dimension
            transform_list.append(
                Rearrange("1 h w -> (h w)")
                if self.tokenize
                else Rearrange("1 h w -> (h w) 1")
            )
        return torchvision.transforms.Compose(transform_list)

    def prepare_data(self):
        if not self.data_dir.is_dir():
            raise FileNotFoundError(
                f"""
            Directory {str(self.data_dir)} not found.
            To get the dataset, download lra_release.gz from
            https://github.com/google-research/long-range-arena,
            then unzip it with tar -xvf lra_release.gz.
            Then point data_dir to the pathfinderX directory, where X is either 32, 64, 128, or 256.
            """
            )

    def setup(self, stage=None):
        if stage == "test" and hasattr(self, "dataset_test"):
            return
        # https://github.com/pytorch/pytorch/issues/11201
        torch.multiprocessing.set_sharing_strategy("file_system")
        dataset = PathFinderDataset(
            self.data_dir, transform=self.default_transforms()
        )
        len_dataset = len(dataset)
        val_len = int(self.val_split * len_dataset)
        test_len = int(self.test_split * len_dataset)
        train_len = len_dataset - val_len - test_len
        self.train, self.val, self.test = torch.utils.data.random_split(
            dataset,
            [train_len, val_len, test_len],
            generator=torch.Generator().manual_seed(self.seed),
        )


class AAN(SequenceDataset):
    _name_ = "aan"
    d_output = 2 # Use accuracy instead of binary_accuracy
    l_output = 0


    @property
    def n_tokens(self): return len(self.vocab)

    init_defaults = {
        'l_max': 4000,
        # 'max_vocab': 100, # Full size 98
        'append_bos': False,
        'append_eos': True,
        'n_workers': 4, # For tokenizing only
    }

    @property
    def _cache_dir_name(self):
        return f'l_max-{self.l_max}-append_bos-{self.append_bos}-append_eos-{self.append_eos}'


    def init(self):
        if self.data_dir is None: self.data_dir = default_data_path / self._name_
        self.cache_dir = self.data_dir / self._cache_dir_name

    def prepare_data(self):
        if self.cache_dir is None:
            for split in ['train', 'eval', 'test']:
                split_path = self.data_dir / f'new_aan_pairs.{split}.tsv'
                if not split_path.is_file():
                    raise FileNotFoundError(f"""
                    File {str(split_path)} not found.
                    To get the dataset, download lra_release.gz from
                    https://github.com/google-research/long-range-arena,
                    then unzip it with tar -xvf lra_release.gz.
                    Then point data_dir to the tsv_data directory.
                    """)
        else:  # Process the dataset and save it
            self.process_dataset()

    def setup(self, stage=None):
        if stage == 'test' and hasattr(self, 'dataset_test'):
            return

        # https://github.com/pytorch/pytorch/issues/11201
        torch.multiprocessing.set_sharing_strategy('file_system')

        dataset, self.tokenizer, self.vocab = self.process_dataset()
        # self.vocab_size = len(self.vocab)
        print("AAN vocab size:", len(self.vocab))

        dataset.set_format(type='torch', columns=['input_ids1', 'input_ids2', 'label'])
        self.dataset_train, self.dataset_val, self.dataset_test = (
            dataset['train'], dataset['val'], dataset['test']
        )

        def collate_batch(resolution, batch):
            xs1, xs2, ys = zip(*[(data['input_ids1'], data['input_ids2'], data['label'])
                                 for data in batch])
            lengths1 = torch.tensor([len(x) for x in xs1])
            lengths2 = torch.tensor([len(x) for x in xs2])
            xs1 = nn.utils.rnn.pad_sequence(xs1, padding_value=self.vocab['<pad>'], batch_first=True)
            xs2 = nn.utils.rnn.pad_sequence(xs2, padding_value=self.vocab['<pad>'], batch_first=True)
            ys = torch.tensor(ys)
            # return xs1, xs2, ys, lengths1, lengths2

            # Concatenate two batches
            xs = torch.cat([xs1, xs2], dim=0)
            lengths = torch.cat([lengths1, lengths2], dim=0)
            return xs, ys, lengths

        self.collate_fn = collate_batch

        # Rename for this dataset (compared to transformers dataset)
        self.train, self.val, self.test = self.dataset_train, self.dataset_val, self.dataset_test

    def process_dataset(self):
        cache_dir = None if self.cache_dir is None else self.cache_dir / self._cache_dir_name
        if cache_dir is not None:
            if cache_dir.is_dir():
                return self._load_from_cache(cache_dir)

        dataset = load_dataset('csv',
                               data_files={'train': str(self.data_dir / 'new_aan_pairs.train.tsv'),
                                           'val': str(self.data_dir / 'new_aan_pairs.eval.tsv'),
                                           'test': str(self.data_dir / 'new_aan_pairs.test.tsv')},
                               delimiter='\t',
                               column_names=['label', 'input1_id', 'input2_id', 'text1', 'text2'],
                               keep_in_memory=False)# True)
        dataset = dataset.remove_columns(['input1_id', 'input2_id'])
        new_features = dataset['train'].features.copy()
        new_features['label'] = Value('int32')
        dataset = dataset.cast(new_features)

        tokenizer = list  # Just convert a string to a list of chars
        # Account for <bos> and <eos> tokens
        l_max = self.l_max - int(self.append_bos) - int(self.append_eos)
        tokenize = lambda example: {'tokens1': tokenizer(example['text1'])[:l_max],
                               'tokens2': tokenizer(example['text2'])[:l_max]}
        dataset = dataset.map(tokenize, remove_columns=['text1', 'text2'], keep_in_memory=True,
                              load_from_cache_file=False, num_proc=max(self.n_workers, 1))
        vocab = torchtext.vocab.build_vocab_from_iterator(
            dataset['train']['tokens1'] + dataset['train']['tokens2'],
            specials=(['<pad>', '<unk>']
                      + (['<bos>'] if self.append_bos else [])
                      + (['<eos>'] if self.append_eos else []))
        )
        vocab.set_default_index(vocab['<unk>'])

        encode = lambda text: vocab(
            (['<bos>'] if self.append_bos else []) + text + (['<eos>'] if self.append_eos else [])
        )
        numericalize = lambda example: {'input_ids1': encode(example['tokens1']),
                                   'input_ids2': encode(example['tokens2'])}
        dataset = dataset.map(numericalize, remove_columns=['tokens1', 'tokens2'],
                              keep_in_memory=True, load_from_cache_file=False,
                              num_proc=max(self.n_workers, 1))

        if cache_dir is not None:
            self._save_to_cache(dataset, tokenizer, vocab, cache_dir)
        return dataset, tokenizer, vocab

    def _save_to_cache(self, dataset, tokenizer, vocab, cache_dir):
        cache_dir = self.cache_dir / self._cache_dir_name
        logger = logging.getLogger(__name__)
        logger.info(f'Saving to cache at {str(cache_dir)}')
        dataset.save_to_disk(str(cache_dir))
        with open(cache_dir / 'tokenizer.pkl', 'wb') as f:
            pickle.dump(tokenizer, f)
        with open(cache_dir / 'vocab.pkl', 'wb') as f:
            pickle.dump(vocab, f)

    def _load_from_cache(self, cache_dir):
        assert cache_dir.is_dir()
        logger = logging.getLogger(__name__)
        logger.info(f'Load from cache at {str(cache_dir)}')
        dataset = DatasetDict.load_from_disk(str(cache_dir))
        with open(cache_dir / 'tokenizer.pkl', 'rb') as f:
            tokenizer = pickle.load(f)
        with open(cache_dir / 'vocab.pkl', 'rb') as f:
            vocab = pickle.load(f)
        return dataset, tokenizer, vocab

class Integrator(SequenceDataset):
    _name_ = "integrator"
    # task = MulticlassClassification()
    # default_task = tasks.lm # TODO get rid of these

    @property
    def d_input(self): return 1

    @property
    def d_output(self): return 1

    init_defaults = {
        'l_seq': 1024, # length of sequence
        'n_components': 10, # number of sins to mix
        'max_ampl': 10.0,
        'max_freq': 100.0,
        'n_samples': 100000,
        'val_split': 0.1,
    }

    def setup(self):
        from .integrator import integrator_data
        data, targets = integrator_data(
            self.n_samples,
            self.l_seq,
            self.n_components,
            self.max_ampl,
            self.max_freq,
        )
        self.train = torch.utils.data.TensorDataset(data.unsqueeze(-1), targets.unsqueeze(-1))

        self.split_train_val(self.val_split)
        self.test = None

        self.collate_fn = None

    def __str__(self):
        return f"{self._name_}"


