import os
import ast
import math
import random
from glob import glob
from collections import defaultdict

from PIL import Image

import librosa

from scipy.io import loadmat

import numpy as np
import pandas as pd

import torch
from torch.utils.data import Dataset

import torchaudio
import torchaudio.datasets as AD

import torchvision.datasets as VD
import torchvision.transforms as T

import ignite.distributed as idist


EVAL_TOKENIZED  = ['secondary-structure']
EVAL_SPEARMAN   = ['stability', 'fluorescence']
EVAL_F1         = ['wafer-map']


def get_split_df(df, train=True, train_ratio=0.9):
    if train:
        df = df.head(math.ceil(train_ratio * len(df)))
    else:
        df = df.tail(math.floor((1 - train_ratio) * len(df)))
    return df


class CUB(Dataset):
    def __init__(self, root, train=True, transform=None):
        super().__init__()
        self.train = train
        self.root = root
        self.transform = transform
        self.paths, self.labels = self.load_images()

    def load_images(self):
        # load id to image path information
        image_info_path = os.path.join(self.root, 'CUB_200_2011', 'images.txt')
        with open(image_info_path, 'r') as f:
            image_info = [line.split('\n')[0].split(' ', 1) for line in f.readlines()]
        image_info = dict(image_info)

        # load image to label information
        label_info_path = os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt')
        with open(label_info_path, 'r') as f:
            label_info = [line.split('\n')[0].split(' ', 1) for line in f.readlines()]
        label_info = dict(label_info)

        # load train test split
        train_test_info_path = os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt')
        with open(train_test_info_path, 'r') as f:
            train_test_info = [line.split('\n')[0].split(' ', 1) for line in f.readlines()]
        train_test_info = dict(train_test_info)

        all_paths, all_labels = [], []
        for index, image_path in image_info.items():
            label = label_info[index]
            split = int(train_test_info[index])
            if self.train:
                if split == 1:
                    all_paths.append(image_path)
                    all_labels.append(label)
            else:
                if split == 0:
                    all_paths.append(image_path)
                    all_labels.append(label)
        return all_paths, all_labels

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = os.path.join(self.root, 'CUB_200_2011', 'images', self.paths[index])
        label = int(self.labels[index]) - 1
        image = Image.open(path).convert(mode='RGB')
        image = self.transform(image)

        return image, label


class VGGFlower(Dataset):
    def __init__(self, root, train=True, transform=None):
        super().__init__()
        self.train = train
        self.transform = transform
        self.root = os.path.join(root, 'flowers102')
        self.paths, self.labels = self.load_images()

    def load_images(self):
        rs = np.random.RandomState(42)
        imagelabels_path = os.path.join(self.root, 'imagelabels.mat')
        with open(imagelabels_path, 'rb') as f:
            labels = loadmat(f)['labels'][0]

        all_filepaths = defaultdict(list)
        for i, label in enumerate(labels):
            all_filepaths[label].append(os.path.join(self.root, 'jpg', 'image_{:05d}.jpg'.format(i + 1)))
        # train test split
        split_filepaths, split_labels = [], []
        for label, paths in all_filepaths.items():
            num = len(paths)
            paths = np.array(paths)
            indexer = np.arange(num)
            rs.shuffle(indexer)
            paths = paths[indexer].tolist()
            paths = paths[:int(0.8 * num)] if self.train else  paths[int(0.8 * num):]
            labels = [label] * len(paths)
            split_filepaths.extend(paths)
            split_labels.extend(labels)

        return split_filepaths, split_labels

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path, label = self.paths[index], int(self.labels[index]) - 1
        image = Image.open(path).convert(mode='RGB')
        image = self.transform(image)
        return image, label


class DTD(Dataset):
    def __init__(self, root, train=True, transform=None):
        self.train = train
        self.transform = transform
        self.root = os.path.join(root, 'imagenet')
        self.paths, self.labels = self.load_images()

    def __len__(self):
        return len(self.paths)

    def load_images(self):
        if self.train:
            train_info_path = os.path.join(self.root, 'dtd', 'labels', 'train1.txt')
            with open(train_info_path, 'r') as f:
                train_info = [line.split('\n')[0] for line in f.readlines()]

            val_info_path = os.path.join(self.root, 'dtd', 'labels', 'val1.txt')
            with open(val_info_path, 'r') as f:
                val_info = [line.split('\n')[0] for line in f.readlines()]

            split_info = train_info + val_info
        else:
            test_info_path = os.path.join(self.root, 'dtd', 'labels', 'test1.txt')
            with open(test_info_path, 'r') as f:
                split_info = [line.split('\n')[0] for line in f.readlines()]

        # pull out categoires from paths
        categories = []
        for row in split_info:
            image_path = row
            category = image_path.split('/')[0]
            categories.append(category)
        categories = sorted(list(set(categories)))

        all_paths, all_labels = [], []
        for row in split_info:
            image_path = row
            category = image_path.split('/')[0]
            label = categories.index(category)
            all_paths.append(os.path.join(self.root, 'dtd', 'images', image_path))
            all_labels.append(label)

        return all_paths, all_labels

    def __getitem__(self, index):
        path, label = self.paths[index], self.labels[index]
        image = Image.open(path).convert(mode='RGB')
        image = self.transform(image)
        return image, label


class TrafficSign(Dataset):
    def __init__(self, root, train=True):
        self.train = train
        self.root = os.path.join(root, 'imagenet', 'traffic_sign')
        self.transform = T.Compose([T.Resize((32, 32)),
                                    T.CenterCrop((32, 32)),
                                    T.ToTensor()])
        self.paths, self.labels = self.load_images()

    def load_images(self):
        rs = np.random.RandomState(42)
        all_filepaths, all_labels = [], []
        for class_i in range(43):
            class_dir_i = os.path.join(self.root, 'GTSRB', 'Final_Training', 'Images', '{:05d}'.format(class_i))
            image_paths = glob(os.path.join(class_dir_i, '*.ppm'))
            image_paths = np.array(image_paths)
            num = len(image_paths)
            indexer = np.arange(num)
            rs.shuffle(indexer)
            image_paths = image_paths[indexer].tolist()
            if self.train:
                image_paths = image_paths[:int(0.8 * num)]
            else:
                image_paths = image_paths[int(0.8 * num):]
            labels = [class_i] * len(image_paths)
            all_filepaths.extend(image_paths)
            all_labels.extend(labels)

        return all_filepaths, all_labels

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]
        label = self.labels[index]
        image = Image.open(path).convert(mode='RGB')
        image = self.transform(image)
        return image, label


class Aircraft(Dataset):
    def __init__(self, root, train=True):
        super().__init__()
        self.train = train
        self.root = os.path.join(root, 'imagenet', 'fgvc-aircraft-2013b')
        self.transform = T.Compose([T.Resize((32,32)),
                                    T.CenterCrop((32, 32)),
                                    T.ToTensor()])
        paths, bboxes, labels = self.load_images()
        self.paths = paths
        self.bboxes = bboxes
        self.labels = labels

    def load_images(self):
        split = 'trainval' if self.train else 'test'
        variant_path = os.path.join(self.root, 'data', f'images_variant_{split}.txt')
        with open(variant_path, 'r') as f:
            names_to_variants = [line.split('\n')[0].split(' ', 1) for line in f.readlines()]
        names_to_variants = dict(names_to_variants)
        variants_to_names = defaultdict(list)
        for name, variant in names_to_variants.items():
            variants_to_names[variant].append(name)

        names_to_bboxes = self.get_bounding_boxes()

        variants = sorted(list(set(variants_to_names.keys())))
        split_files, split_labels, split_bboxes = [], [], []

        for variant_id, variant in enumerate(variants):
            class_files = [
                os.path.join(self.root, 'data', 'images', f'{filename}.jpg')
                for filename in sorted(variants_to_names[variant])
            ]
            bboxes = [names_to_bboxes[name] for name in sorted(variants_to_names[variant])]
            labels = list([variant_id] * len(class_files))

            split_files += class_files
            split_labels += labels
            split_bboxes += bboxes

        return split_files, split_bboxes, split_labels

    def get_bounding_boxes(self):
        bboxes_path = os.path.join(self.root, 'data', 'images_box.txt')
        with open(bboxes_path, 'r') as f:
            names_to_bboxes = [line.split('\n')[0].split(' ') for line in f.readlines()]
            names_to_bboxes = dict(
                (name, list(map(int, (xmin, ymin, xmax, ymax)))) for name, xmin, ymin, xmax, ymax in names_to_bboxes
            )

        return names_to_bboxes

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]
        bbox = tuple(self.bboxes[index])
        label = self.labels[index]

        image = Image.open(path).convert(mode='RGB')
        image = image.crop(bbox)
        image = self.transform(image)

        return image.float(), label


class waferMap(Dataset):
    FAILURE_MAP = {'unlabeled': 0,
                   'none': 0,
                   'random': 1,
                   'donut': 2,
                   'scratch': 3,
                   'center': 4,
                   'loc': 5,
                   'edge-loc': 6,
                   'edge-ring': 7,
                   'near-full': 8}
    def __init__(self, root, pre_train=True, train=True):
        self.root = os.path.join(root, 'wafer_map')
        self.transforms = T.Compose([T.Resize((32, 32)),
                                     T.ToTensor()])

        pkl_file = 'unlabeled.pkl' if pre_train else 'labeled.pkl'
        self.data = pd.read_pickle(os.path.join(self.root, pkl_file))
        self.data = self.data.sample(frac=1, random_state=42).reset_index(drop=True)
        self.data = get_split_df(self.data, train, train_ratio=0.9)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        row = self.data.iloc[index]
        img, label = row['pixels'], self.FAILURE_MAP[row['failureType']]
        img = img.astype('uint8')
        img = Image.fromarray(img).convert('RGB')
        img = self.transforms(img)
        return img, label


class HIGGS(Dataset):
    def __init__(self, root, pre_train=True, train=True):
        super().__init__()
        filename = 'higgsPretrain' if pre_train else 'higgsTransfer'
        self.csv_dir = os.path.join(root, 'particle_physics', f'{filename}.csv')

        higgs_df = pd.read_csv(self.csv_dir)
        higgs_df = higgs_df.sample(frac=1, random_state=42).reset_index(drop=True)
        higgs_df = higgs_df.iloc[:, 1:]
        self.higgs_df = get_split_df(higgs_df, train, train_ratio=0.9)

    def __len__(self):
        return len(self.higgs_df)

    def __getitem__(self, index):
        row = self.higgs_df.iloc[index]
        label = int(row[0])

        features = torch.tensor(row[1:])
        features = features.view(1, 28)
        return features.float(), np.array([label])


class EuroSAT(Dataset):
    MEAN = [1354.3003, 1117.7579, 1042.2800, 947.6443,  1199.6334, 2001.9829, 2372.5579,
            2299.6663, 731.0175,  12.0956,   1822.4083, 1119.5759, 2598.4456]
    STD  = [244.0469,  323.4128,  385.0928,  584.1638,  566.0543,  858.5753,  1083.6704,
            1103.0342, 402.9594,  4.7207,    1002.4071, 759.6080,  1228.4104]
    def __init__(self, root, pre_train=True, train=True):
        super().__init__()
        filename = 'pretrain' if pre_train else 'transfer'
        self.pkl_dir = os.path.join(root, 'eurosat_all', f'{filename}.pkl')

        eurosat_df = pd.read_pickle(self.pkl_dir)
        self.eurosat = get_split_df(eurosat_df, train, train_ratio=0.9)

        self.transform = T.Compose([T.ToTensor(),
                                    T.Normalize(mean=self.MEAN, std=self.STD)])

    def __len__(self):
        return len(self.eurosat)

    def __getitem__(self, index):
        row = self.eurosat.iloc[index]
        label = row['label']
        img = row['sentinel2']

        return self.transform(img), label


class LibriSpeech(Dataset):
    MAX_LENGTH = 150526

    ALL_TRAIN_NUM_CLASSES = 2338
    DEV_CLEAN_NUM_CLASSES = 40

    mean = [-22.924]
    std  = [12.587]

    def __init__(self, root, download=True, pre_train=True, train=True):
        super().__init__()
        self.root = os.path.join(root, 'speech', 'librispeech')

        if not os.path.exists(self.root):
            os.makedirs(self.root)

        if pre_train:
            self.dataset1 = AD.LIBRISPEECH(self.root, url='train-clean-100', download=download, folder_in_archive='LibriSpeech')
            self.dataset2 = AD.LIBRISPEECH(self.root, url='train-clean-360', download=download, folder_in_archive='LibriSpeech')
            self.dataset3 = AD.LIBRISPEECH(self.root, url='train-other-500', download=download, folder_in_archive='LibriSpeech')
        else:
            self.dataset = AD.LIBRISPEECH(self.root, url='dev-clean', download=download, folder_in_archive='LibriSpeech')

        self.pre_train = pre_train
        self.all_speaker_ids = self.get_speaker_ids()
        unique_speaker_ids = sorted(list(set(self.all_speaker_ids)))
        num_classes = self.ALL_TRAIN_NUM_CLASSES if pre_train else self.DEV_CLEAN_NUM_CLASSES
        assert num_classes == len(unique_speaker_ids)
        self.speaker_id_map = dict(zip(unique_speaker_ids, range(num_classes)))

        if not self.pre_train:
            self.indices = self.train_test_split(self.all_speaker_ids, train=train)

    def get_speaker_ids(self):
        if self.pre_train:
            speaker_ids_1 = self._get_speaker_ids(self.dataset1)
            speaker_ids_2 = self._get_speaker_ids(self.dataset2)
            speaker_ids_3 = self._get_speaker_ids(self.dataset3)
            return np.concatenate([speaker_ids_1, speaker_ids_2, speaker_ids_3])
        else:
            return self._get_speaker_ids(self.dataset)

    def _get_speaker_ids(self, dataset):
        speaker_ids = []
        for i in range(len(dataset)):
            fileid = dataset._walker[i]
            speaker_id = self.load_librispeech_speaker_id(
                fileid,
                dataset._path,
                dataset._ext_audio,
                dataset._ext_txt,
            )
            speaker_ids.append(speaker_id)
        return np.array(speaker_ids)

    def train_test_split(self, speaker_ids, train=True):
        rs = np.random.RandomState(42)  # fix seed so reproducible splitting

        unique_speaker_ids = sorted(set(speaker_ids))
        unique_speaker_ids = np.array(unique_speaker_ids)

        # train test split to ensure the 80/20 splits
        train_indices, test_indices = [], []
        for speaker_id in unique_speaker_ids:
            speaker_indices = np.where(speaker_ids == speaker_id)[0]
            size = len(speaker_indices)
            rs.shuffle(speaker_indices)
            train_size = int(0.8 * size)
            train_indices.extend(speaker_indices[:train_size].tolist())
            test_indices.extend(speaker_indices[train_size:].tolist())

        return train_indices if train else test_indices

    def load_librispeech_speaker_id(self, fileid, path, ext_audio, ext_txt):
        speaker_id, _, _ = fileid.split('-')
        return int(speaker_id)

    def __getitem__(self, index):
        if self.pre_train:
            if index >= (len(self.dataset1) + len(self.dataset2)):
                wavform, sample_rate, _, speaker_id, _, _ = self.dataset3.__getitem__(index - len(self.dataset1) - len(self.dataset2))
            elif index >= len(self.dataset1):
                wavform, sample_rate, _, speaker_id, _, _ = self.dataset2.__getitem__(index - len(self.dataset1))
            else:
                wavform, sample_rate, _, speaker_id, _, _ = self.dataset1.__getitem__(index)
        else:
            wavform, sample_rate, _, speaker_id, _, _ = self.dataset.__getitem__(self.indices[index])

        speaker_id = self.speaker_id_map[speaker_id]
        wavform = np.asarray(wavform[0])

        if len(wavform) > self.MAX_LENGTH:
            flip = (bool(random.getrandbits(1)) if self.pre_train else True)
            padded = (wavform[:self.MAX_LENGTH] if flip else wavform[-self.MAX_LENGTH:])
        else:
            padded = np.zeros(self.MAX_LENGTH)
            padded[:len(wavform)] = wavform

        spectrum = librosa.feature.melspectrogram(
            y=padded,
            sr=sample_rate,
            hop_length=672,
            n_mels=224,
        )

        # log mel-spectrogram
        spectrum = librosa.power_to_db(spectrum**2)
        spectrum = torch.from_numpy(spectrum).float()
        spectrum = spectrum.unsqueeze(0)

        normalize = T.Normalize(self.mean, self.std)
        spectrum = normalize(spectrum)

        return spectrum, speaker_id

    def __len__(self):
        if self.pre_train:
            return len(self.dataset1) + len(self.dataset2) + len(self.dataset3)
        else:
            return len(self.indices)


class AudioMNIST(Dataset):
    AUDIOMNIST_TRAIN_SPK = [28, 56, 7,  19, 35, 1,  6,  16, 23, 34, 46, 53,
                            36, 57, 9,  24, 37, 2,  8,  17, 29, 39, 48, 54,
                            43, 58, 14, 25, 38, 3,  10, 20, 30, 40, 49, 55]
    AUDIOMNIST_VAL_SPK   = [12, 47, 59, 15, 27, 41, 4,  11, 21, 31, 44, 50]
    AUDIOMNIST_TEST_SPK  = [26, 52, 60, 18, 32, 42, 5,  13, 22, 33, 45, 51]
    MAX_LENGTH = 150526

    def __init__(self, root, train=True):
        super().__init__()
        self.root = os.path.join(root, 'speech', 'AudioMNIST')
        self.train = train

        speakers = self.AUDIOMNIST_TRAIN_SPK + self.AUDIOMNIST_VAL_SPK if train else self.AUDIOMNIST_TEST_SPK
        self.wav_paths = []
        for spk in speakers:
            spk_paths = glob(os.path.join(self.root, 'data', '{:02}'.format(spk), '*.wav'))
            self.wav_paths.extend(spk_paths)

        self.transform = T.Normalize([-90.293], [11.799])

    def __len__(self):
        return len(self.wav_paths)

    def __getitem__(self, index):
        wav_path = self.wav_paths[index]
        label, _, _ = wav_path.rstrip('.wav').split('/')[-1].split('_')

        wavform, sample_rate = torchaudio.load(wav_path)
        wavform = wavform[0].numpy()

        if len(wavform) > self.MAX_LENGTH:
            # randomly pick which side to chop off (fix if validation)
            flip = (bool(random.getrandbits(1)) if self.train else True)
            padded = (wavform[:self.MAX_LENGTH] if flip else wavform[-self.MAX_LENGTH:])
        else:
            padded = np.zeros(self.MAX_LENGTH)
            padded[:len(wavform)] = wavform  # pad w/ silence

        spectrum = librosa.feature.melspectrogram(
            y=padded,
            sr=sample_rate,
            hop_length=672,
            n_mels=224,
        )
        spectrum = librosa.power_to_db(spectrum**2)
        spectrum = torch.from_numpy(spectrum).float()
        spectrum = spectrum.unsqueeze(0)
        spectrum = self.transform(spectrum)

        return spectrum, int(label)


class FluentSpeechCommand(Dataset):
    FLUENTSPEECH_ACTIONS = ['change language', 'activate', 'deactivate', 'increase', 'decrease', 'bring']
    FLUENTSPEECH_OBJECTS = ['none',  'music', 'lights', 'volume',  'heat',   'lamp',    'newspaper',
                            'juice', 'socks', 'shoes',  'Chinese', 'Korean', 'English', 'German']
    FLUENTSPEECH_LOCATIONS = ['none', 'kitchen', 'bedroom', 'washroom']
    def __init__(self, root, label_type, train=True):
        super().__init__()
        self.root = os.path.join(root, 'speech')
        self.label_type = label_type
        self.train = train
        assert self.label_type in ['action', 'object', 'location']

        if train:
            train_path = os.path.join(self.root, 'fluent_speech_commands_dataset', 'data', 'train_data.csv')
            val_path = os.path.join(self.root, 'fluent_speech_commands_dataset', 'data', 'valid_data.csv')
            train_data = pd.read_csv(train_path)
            train_paths = list(train_data['path'])
            train_labels = list(train_data[self.label_type])
            val_data = pd.read_csv(val_path)
            val_paths = list(val_data['path'])
            val_labels = list(val_data[self.label_type])
            wav_paths = train_paths + val_paths
            labels = train_labels + val_labels
        else:
            test_path = os.path.join(self.root, 'fluent_speech_commands_dataset', 'data', 'test_data.csv')
            test_data = pd.read_csv(test_path)
            wav_paths = list(test_data['path'])
            labels = list(test_data[self.label_type])

        self.transform = T.Normalize([-31.809], [13.127])

        self.wav_paths = wav_paths
        self.labels = labels

    def __len__(self):
        return len(self.wav_paths)

    def __getitem__(self, index):
        wav_name = self.wav_paths[index]
        wav_path = os.path.join(self.root, 'fluent_speech_commands_dataset', wav_name)
        label = self.labels[index]
        if self.label_type == 'action':
            label = self.FLUENTSPEECH_ACTIONS.index(label)
        elif self.label_type == 'object':
            label = self.FLUENTSPEECH_OBJECTS.index(label)
        elif self.label_type == 'location':
            label = self.FLUENTSPEECH_LOCATIONS.index(label)

        wavform, sample_rate = torchaudio.load(wav_path)
        wavform = wavform[0].numpy()

        if len(wavform) > 150526:
            flip = (bool(random.getrandbits(1)) if self.train else True)
            padded = (wavform[:150526] if flip else wavform[-150526:])
        else:
            padded = np.zeros(150526)
            padded[:len(wavform)] = wavform  # pad w/ silence

        spectrum = librosa.feature.melspectrogram(
            y=padded,
            sr=sample_rate,
            hop_length=672,
            n_mels=224,
        )

        spectrum = librosa.power_to_db(spectrum**2)
        spectrum = torch.from_numpy(spectrum).float()
        spectrum = spectrum.unsqueeze(0)
        spectrum = self.transform(spectrum)

        return spectrum, int(label)


class GoogleSpeechCommand(Dataset):
    LABELS = ['eight', 'right', 'happy', 'three',  'yes',  'up',    'no',     'stop',   'on',       'four',    'nine',  'zero',
              'down',  'go',    'six',   'two',    'left', 'five',  'off',    'seven',  'one',      'cat',     'bird',  'marvin',
              'wow',   'tree',  'dog',   'sheila', 'bed',  'house', 'follow', 'visual', 'backward', 'forward', 'learn', '_background_noise_']
    def __init__(self, root, train=True):
        super().__init__()
        self.train = train
        self.root = os.path.join(root, 'speech', 'google_speech')

        if train:
            train_paths = []
            for path, _, files in os.walk(self.root):
                for name in files:
                    if name.endswith('wav'):
                        train_paths.append(os.path.join(path.split('/')[-1], name))
            val_paths = open(os.path.join(self.root, 'validation_list.txt'), 'r').readlines()
            test_paths = open(os.path.join(self.root, 'testing_list.txt'), 'r').readlines()
            train_paths = (set(train_paths) - set(val_paths) - set(test_paths))
            wav_paths = list(train_paths) + val_paths
        else:
            wav_paths = open(os.path.join(self.root, 'testing_list.txt'), 'r').readlines()

        self.transform = T.Normalize([-46.847], [19.151])

        self.wav_paths = [path.strip() for path in wav_paths]

    def __len__(self):
        return len(self.wav_paths)

    def __getitem__(self, index):
        wav_name = self.wav_paths[index]
        label_name = wav_name.split('/')[0].lower()
        label = self.LABELS.index(label_name)
        wav_path = os.path.join(self.root, wav_name)

        wavform, sample_rate = torchaudio.load(wav_path)
        wavform = wavform[0].numpy()

        if len(wavform) > 150526:
            flip = (bool(random.getrandbits(1)) if self.train else True)
            padded = (wavform[:150526] if flip else wavform[-150526:])
        else:
            padded = np.zeros(150526)
            padded[:len(wavform)] = wavform

        spectrum = librosa.feature.melspectrogram(
            y=padded,
            sr=sample_rate,
            hop_length=672,
            n_mels=224,
        )

        spectrum = librosa.power_to_db(spectrum**2)
        spectrum = torch.from_numpy(spectrum).float()
        spectrum = spectrum.unsqueeze(0)
        spectrum = self.transform(spectrum)

        return spectrum, int(label)


class VoxCeleb1(Dataset):
    MAX_LENGTH = 150526

    def __init__(self, root, train=True):
        super().__init__()
        self.root = os.path.join(root, 'speech', 'voxceleb1')
        self.wav_paths, speaker_strs = self.get_split(train)
        unique_speakers = sorted(set(speaker_strs))
        speaker_id_map = dict(zip(unique_speakers, range(len(unique_speakers))))
        self.speaker_ids = [speaker_id_map[sp] for sp in speaker_strs]
        self.train = train

        self.transform = T.Normalize([-37.075], [19.776])

    def __len__(self):
        return len(self.wav_paths)

    def get_split(self, train=True):
        split_file = os.path.join(self.root, 'iden_split.txt')
        with open(split_file, 'r') as fp:
            splits = fp.readlines()

        paths = defaultdict(lambda: [])
        for split in splits:
            spl, path = split.strip().split(' ')
            paths[spl].append(path)

        train_paths = paths['1'] + paths['2']
        test_paths = paths['3']
        train_speaker_ids = [p.split('/')[0] for p in train_paths]
        test_speaker_ids = [p.split('/')[0] for p in test_paths]
        if train:
            return train_paths, train_speaker_ids
        else:
            return test_paths, test_speaker_ids

    def __getitem__(self, index):
        wav_path = os.path.join(self.root, 'wav', self.wav_paths[index])
        speaker_id = self.speaker_ids[index]
        wavform, sample_rate = torchaudio.load(wav_path)
        wavform = wavform[0].numpy()
        if len(wavform) > self.MAX_LENGTH:
            flip = bool(random.getrandbits(1)) if self.train else True
            padded = (wavform[:self.MAX_LENGTH] if flip else wavform[-self.MAX_LENGTH:])
        else:
            padded = np.zeros(self.MAX_LENGTH)
            padded[:len(wavform)] = wavform  # pad w/ silence

        spectrum = librosa.feature.melspectrogram(
            y=padded,
            sr=sample_rate,
            hop_length=672,
            n_mels=224,
        )

        # log mel-spectrogram
        spectrum = librosa.power_to_db(spectrum**2)
        spectrum = torch.from_numpy(spectrum).float()
        spectrum = spectrum.unsqueeze(0)
        spectrum = self.transform(spectrum)

        return spectrum, speaker_id


class PAMAP2(Dataset):
    TRAIN_EXAMPLES_PER_EPOCH = 50000  # examples are generated stochastically
    VAL_EXAMPLES_PER_EPOCH = 10000
    MEASUREMENTS_PER_EXAMPLE = 320
    ACTIVITY_LABELS = [1, 2, 3, 4, 5, 6, 7, 12, 13, 16, 17, 24]

    def __init__(self, root, train=True):
        super().__init__()
        self.root = os.path.join(root, 'sensor', 'pamap2')
        self.mode = 'train' if train else 'val'

        self.data = self.load_data()
        self.samples = self.get_candidates(self.data)

    def load_data(self):
        subject_data = []
        nums = [1, 2, 3, 4, 7, 8, 9] if self.mode == 'train' else [5, 6]
        for subject_filename in [f'subject10{num}.dat' for num in nums]:
            columns = ['timestamp', 'activity_id', 'heart_rate']
            for part in ['hand', 'chest', 'ankle']:
                for i in range(17):
                    columns.append(part + str(i))
            subj_path = os.path.join(self.root, 'Protocol', subject_filename)
            subj_path_cache = subj_path + '.p'
            if os.path.isfile(subj_path_cache):
                df = pd.read_pickle(subj_path_cache)
            else:
                df = pd.read_csv(subj_path, names=columns, sep=' ')
                df = df.interpolate()
                df.to_pickle(subj_path_cache)
            subject_data.append(df)
            print(f'load done: {subject_filename}', end='\r')

        return subject_data

    def get_candidates(self, data):
        samples = []
        for df in data:
            for activity_id in range(len(self.ACTIVITY_LABELS)):
                activity_data = df[df['activity_id'] == self.ACTIVITY_LABELS[activity_id]].to_numpy()
                if len(activity_data) > self.MEASUREMENTS_PER_EXAMPLE:
                    samples.append((activity_data, activity_id))

        return samples

    def __len__(self):
        return self.TRAIN_EXAMPLES_PER_EPOCH if self.mode == 'train' else self.VAL_EXAMPLES_PER_EPOCH

    def __getitem__(self, index):
        sample_id = np.random.randint(len(self.samples))
        activity_data, activity_id = self.samples[sample_id]
        start_idx = np.random.randint(len(activity_data) - self.MEASUREMENTS_PER_EXAMPLE)
        x = activity_data[start_idx:start_idx + self.MEASUREMENTS_PER_EXAMPLE, 2:].T
        x = torch.tensor(x, dtype=torch.float32)
        return x, activity_id


class Genomics(Dataset):
    # Token representation of genomic bases.
    BASES = {'A': 0, 'C': 1, 'G': 2, 'T': 3}

    def __init__(self, root, pre_train=True, train=True, ood=False):
        self.ood = ood
        filename = 'Pretrain' if pre_train else 'TransferOOD' if ood else 'TransferID'
        self.csv_root = os.path.join(root, 'genomics', f'genomics{filename}.csv')

        genomics_df = pd.read_csv(self.csv_root)
        genomics_df = genomics_df.sample(frac=1, random_state=42).reset_index(drop=True)
        genomics_df = genomics_df.iloc[:, 2:4]
        self.genomics_df = get_split_df(genomics_df, train, train_ratio=0.9)

    def __len__(self):
        return len(self.genomics_df)

    def __getitem__(self, index):
        row = self.genomics_df.iloc[index]
        label = row[0] - 10 if self.ood else row[0]
        seq   = row[1][2:252]
        tokens = []
        for base in seq:
            tokens.append(self.BASES[base])
        tokens = torch.tensor(tokens, dtype=torch.long)
        return tokens, label


class Pfam(Dataset):
    def __init__(self, root, pre_train=True, train=True):
        super().__init__()
        filename = 'pfam_pretrain_train' if pre_train else 'pfam_transfer'
        self.csv_dir = os.path.join(root, 'pfam', f'{filename}.csv')
        self.pfam_df = pd.read_csv(self.csv_dir)
        if not pre_train:
            self.pfam_df = get_split_df(self.pfam_df, train, train_ratio=0.9)
        self.pfam_df = self.pfam_df.iloc[:, 1:]

        #vocab map
        self.amino_acid_map = {amino_acid:i for i, amino_acid in enumerate("XARNDCQEGHILKMFPSTWYVUOBZJ")}

        #labeling map
        if pre_train:
            pfam_pretrain = self.pfam_df
        else:
            pfam_pretrain = pd.read_csv(os.path.join(root, 'pfam', 'pfam_pretrain_train.csv'))

        clans = list(set(pfam_pretrain['clan']))
        clans.sort()
        self.label_map = {clan:i for i, clan in enumerate(clans)}

    def __len__(self):
        return len(self.pfam_df)

    def __getitem__(self, index):
        row = self.pfam_df.iloc[index]
        seq = row['primary']
        seq = (seq[:128]).ljust(128, 'X')
        tokens = []
        for amino_acid in seq:
            tokens.append(self.amino_acid_map[amino_acid])
        token_tensor = torch.tensor(tokens, dtype=torch.long)
        clan = row['clan']
        label = self.label_map[clan]
        return token_tensor, label


class ProteinTransfer(Dataset):
    def __init__(self, root, dataset_name, train=True):
        filename = '_train' if train else '_valid'
        self.csv_dir = os.path.join(root, 'pfam', 'transfer', dataset_name+f'{filename}.csv')
        self.data = pd.read_csv(self.csv_dir)

        #vocab map
        self.amino_acid_map = {amino_acid:i for i, amino_acid in enumerate("XARNDCQEGHILKMFPSTWYVUOBZJ")}

    def __len__(self):
        return len(self.data)

    def __getrow__(self, index):
        row = self.data.iloc[index]
        seq = row['primary']
        seq = (seq[:128]).ljust(128, 'X')
        tokens = []
        for amino_acid in seq:
            tokens.append(self.amino_acid_map[amino_acid])
        token_tensor = torch.tensor(tokens, dtype=torch.long)
        return token_tensor, row


class SCOP(ProteinTransfer):
    def __init__(self, root, train=True):
        super().__init__(root, 'remote_homology', train)

    def __getitem__(self, index):
        token_tensor, row = super().__getrow__(index)
        return token_tensor, row['fold_label']


class SecondaryStructure(ProteinTransfer):
    def __init__(self, root, train=True):
        super().__init__(root, 'secondary_structure', train)

    def __getitem__(self, index):
        token_tensor, row = super().__getrow__(index)
        secondary_struct_seq = (ast.literal_eval(row['ss3']))[:128]
        for _ in range(128 - len(secondary_struct_seq)):
            secondary_struct_seq.append(3)
        secondary_struct_seq = torch.tensor(secondary_struct_seq)
        return token_tensor, secondary_struct_seq


class Stability(ProteinTransfer):
    def __init__(self, root, train=True):
        super().__init__(root, 'stability', train)

    def __getitem__(self, index):
        token_tensor, row = super().__getrow__(index)
        score = np.float32(row['stability_score'][1:-1])
        return token_tensor, score


class Fluorescence(ProteinTransfer):
    def __init__(self, root, train=True):
        super().__init__(root, 'fluorescence', train)

    def __getitem__(self, index):
        token_tensor, row = super().__getrow__(index)
        # Convert log_fluorescence score from string to np float
        score = np.float32(row['log_fluorescence'][1:-1])
        return token_tensor, score


def get_dataset_config(dataset):
    ## image - imagenet
    if dataset in ['imagenet32', 'cifar10', 'cub', 'vggflower', 'dtd', 'traffic-sign', 'aircraft']:
        if dataset == 'imagenet32':
            num_classes = 1000
        elif dataset == 'cifar10':
            num_classes = 10
        elif dataset == 'cub':
            num_classes = 200
        elif dataset == 'vggflower':
            num_classes = 102
        elif dataset == 'dtd':
            num_classes = 47
        elif dataset == 'traffic-sign':
            num_classes = 43
        elif dataset == 'aircraft':
            num_classes = 102
        input_shape = (3, 32, 32)
        patch_size = (4, 4)
        batch_size = 64
    ## image - wafer-map
    elif dataset == 'wafer-map':
        num_classes = 9
        input_shape = (3, 32, 32)
        patch_size = (4, 4)
        batch_size = 128
    ## multi-spectral image
    elif dataset == 'eurosat':
        num_classes = 10
        input_shape = (13, 64, 64)
        patch_size  = (8, 8)
        batch_size  = 64
    ## time-series
    elif dataset == 'pamap2':
        num_classes = 12
        input_shape = (52, 320)
        patch_size  = (5,)
        batch_size  = 256
    ## speech
    elif dataset in ['libri-speech', 'audio-mnist', 'fluent-speech-loc', 'fluent-speech-obj', 'fluent-speech-act', 'google-speech', 'voxceleb1']:
        if dataset == 'libri-speech':
            num_classes = 40
        elif dataset == 'audio-mnist':
            num_classes = 10
        elif dataset == 'fluent-speech-loc':
            num_classes = 4
        elif dataset == 'fluent-speech-obj':
            num_classes = 14
        elif dataset == 'fluent-speech-act':
            num_classes = 6
        elif dataset == 'google-speech':
            num_classes = 36
        elif dataset == 'voxceleb1':
            num_classes = 1251
        input_shape = (1, 224, 224)
        patch_size  = (16, 16)
        batch_size  = 64
    ## tabular
    elif dataset == 'higgs':
        num_classes = 2
        input_shape = (1, 28)
        patch_size  = (1,)
        batch_size  = 256
    ## token - genomics
    elif dataset in ['genomics', 'genomics-id', 'genomics-ood']:
        if dataset  == 'genomics':
            num_classes = 10
            batch_size  = 32
        elif dataset == 'genomics-id':
            num_classes = 10
            batch_size  = 64
        else:
            num_classes = 60
            batch_size  = 32
        input_shape = ((4,), 250)
        patch_size  = (1,)
    ## token - proteins
    elif dataset in ['pfam', 'scop', 'secondary-structure', 'stability', 'fluorescence']: #more
        if dataset == 'pfam':
            num_classes = 623
        elif dataset == 'scop':
            num_classes = 1195
        elif dataset == 'secondary-structure':
            num_classes = 4
        else:
            num_classes = 0
        input_shape = ((26,), 128) #vocab size=26
        patch_size  = (1,)
        batch_size  = 128
    return dict(num_classes=num_classes,
                input_shape=input_shape,
                patch_size=patch_size,
                batch_size=batch_size)


def get_pretrain_dataset(dataset, datadir):
    if dataset == 'imagenet32':
        transform = T.Compose([T.Resize((32, 32)),
                               T.CenterCrop((32, 32)),
                               T.ToTensor()])
        train = VD.ImageNet(datadir, 'train', transform=transform)
    elif dataset == 'wafer-map':
        train = waferMap(datadir, pre_train=True, train=True)
    elif dataset == 'eurosat':
        train = EuroSAT(datadir, pre_train=True,  train=True)
    elif dataset == 'pamap2':
        train =  PAMAP2(datadir, train=True)
    elif dataset == 'libri-speech':
        train = LibriSpeech(datadir, pre_train=True, download=True)
    elif dataset == 'higgs':
        train = HIGGS(datadir, pre_train=True,  train=True)
    elif dataset == 'pfam':
        train = Pfam(datadir, pre_train=True, train=True)
    elif dataset == 'genomics':
        train = Genomics(datadir, pre_train=True, train=True)
    return train


def get_transfer_dataset(dataset, datadir):
    img_transform = T.Compose([T.Resize((32, 32)),
                                 T.CenterCrop((32, 32)),
                                 T.ToTensor()])
    ## image - imagenet
    if dataset == 'imagenet32':
        val  = VD.ImageNet(datadir, 'train', transform=img_transform)
        test = VD.ImageNet(datadir, 'val',   transform=img_transform)
    elif dataset == 'cifar10':
        val  = VD.CIFAR10(datadir, train=True,  transform=img_transform)
        test = VD.CIFAR10(datadir, train=False, transform=img_transform)
    elif dataset == 'cub':
        val  = CUB(datadir, train=True,  transform=img_transform)
        test = CUB(datadir, train=False, transform=img_transform)
    elif dataset == 'vggflower':
        val  = VGGFlower(datadir, train=True,  transform=img_transform)
        test = VGGFlower(datadir, train=False, transform=img_transform)
    elif dataset == 'dtd':
        val  = DTD(datadir, train=True,  transform=img_transform)
        test = DTD(datadir, train=False, transform=img_transform)
    elif dataset == 'traffic-sign':
        val  = TrafficSign(datadir, train=True)
        test = TrafficSign(datadir, train=False)
    elif dataset == 'aircraft':
        val  = Aircraft(datadir, train=True)
        test = Aircraft(datadir, train=False)
    ## image - wafer-map
    elif dataset == 'wafer-map':
        val  = waferMap(datadir, pre_train=False, train=True)
        test = waferMap(datadir, pre_train=False, train=False)
    ## multi-spectral image
    elif dataset == 'eurosat':
        val  = EuroSAT(datadir, pre_train=False, train=True)
        test = EuroSAT(datadir, pre_train=False, train=False)
    ## time-series
    elif dataset == 'pamap2':
        val  = PAMAP2(datadir, train=True)
        test = PAMAP2(datadir, train=False)
    ## speech
    elif dataset == 'libri-speech':
        val  = LibriSpeech(datadir, pre_train=False, train=True,  download=True)
        test = LibriSpeech(datadir, pre_train=False, train=False, download=True)
    elif dataset == 'audio-mnist':
        val  = AudioMNIST(datadir, train=True)
        test = AudioMNIST(datadir, train=False)
    elif dataset == 'fluent-speech-loc':
        val  = FluentSpeechCommand(datadir, 'location', train=True)
        test = FluentSpeechCommand(datadir, 'location', train=False)
    elif dataset == 'fluent-speech-obj':
        val  = FluentSpeechCommand(datadir, 'object', train=True)
        test = FluentSpeechCommand(datadir, 'object', train=False)
    elif dataset == 'fluent-speech-act':
        val  = FluentSpeechCommand(datadir, 'action', train=True)
        test = FluentSpeechCommand(datadir, 'action', train=False)
    elif dataset == 'google-speech':
        val  = GoogleSpeechCommand(datadir, train=True)
        test = GoogleSpeechCommand(datadir, train=False)
    elif dataset == 'voxceleb1':
        val  = VoxCeleb1(datadir, train=True)
        test = VoxCeleb1(datadir, train=False)
    ## tabular
    elif dataset == 'higgs':
        val  = HIGGS(datadir, pre_train=False, train=True)
        test = HIGGS(datadir, pre_train=False, train=False)
    ## token - genomics
    elif dataset == 'genomics-id':
        val  = Genomics(datadir, pre_train=False, train=True,  ood=False)
        test = Genomics(datadir, pre_train=False, train=False, ood=False)
    elif dataset == 'genomics-ood':
        val  = Genomics(datadir, pre_train=False, train=True,  ood=True)
        test = Genomics(datadir, pre_train=False, train=False, ood=True)
    ## token - proteins
    elif dataset == 'pfam':
        val  = Pfam(datadir, pre_train=False, train=True)
        test = Pfam(datadir, pre_train=False, train=False)
    elif dataset == 'scop':
        val  = SCOP(datadir, train=True)
        test = SCOP(datadir, train=False)
    elif dataset == 'secondary-structure':
        val  = SecondaryStructure(datadir, train=True)
        test = SecondaryStructure(datadir, train=False)
    elif dataset == 'stability':
        val  = Stability(datadir, train=True)
        test = Stability(datadir, train=False)
    elif dataset == 'fluorescence':
        val  = Fluorescence(datadir, train=True)
        test = Fluorescence(datadir, train=False)
    else:
        raise NotImplementedError
    return val, test


def get_dataset(dataset, datadir, mode='pretrain'):
    data_dict = get_dataset_config(dataset)
    if mode == 'pretrain':
        data_dict['train'] = get_pretrain_dataset(dataset, datadir)
    else:
        val, test = get_transfer_dataset(dataset, datadir)
        data_dict['val']  = val
        data_dict['test'] = test

    return data_dict


def get_loader(args, dataset, mode='pretrain'):
    loader = {}
    if mode == 'pretrain':
        loader['train'] = idist.auto_dataloader(dataset['train'],
                                                batch_size=dataset['batch_size'],
                                                num_workers=args.num_workers,
                                                shuffle=True, drop_last=True,
                                                pin_memory=True)
    else:
        for split in ['val', 'test']:
            loader[split] = idist.auto_dataloader(dataset[split],
                                                batch_size=dataset['batch_size'],
                                                num_workers=args.num_workers,
                                                pin_memory=True)

    return loader

