import pickle
from pathlib import Path
from typing import Optional
import csv
import os
import numpy as np
import torch
from PIL import Image
from pytorch_lightning import LightningDataModule
from torch.utils.data import Dataset, DataLoader, Sampler
from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS, has_file_allowed_extension

import pickle


class CUBDataset(Dataset):
    """
    Dataset for CUB fine-grained classification

    We use the train/test split provided in the original CUB dataset. You can modify
    `train_test_split.txt` in the `split_dir` to customize the split.

    Args:
        data_dir (str): the CUB image directory, should be `CUB_200_2011/images` for
                        the uncompressed data
        meta_dir (str): the CUB meta data directory, should be `CUB_200_2011` for
                        the uncompressed data
        split (str): train/test/all, to pick the split or all examples
    """

    def __init__(self, data_dir, meta_dir, split, transform=None,
                 target_transform=None, loader=default_loader):
        super(CUBDataset, self).__init__()
        data_dir = Path(data_dir).expanduser()
        meta_dir = Path(meta_dir).expanduser()
        assert split in ['train', 'test', 'all'], 'split not supported'

        # load classes names
        class_file = meta_dir / 'classes.txt'
        class_tmp = np.loadtxt(class_file, dtype=str)
        classes = class_tmp[:, 1].tolist()
        class_to_idx = {}
        for idx, cls in zip(class_tmp[:, 0], class_tmp[:, 1]):
            class_to_idx[cls] = int(idx) - 1

        # load labels and splits
        label_file = meta_dir / 'image_class_labels.txt'
        label_arr = np.loadtxt(label_file, dtype=np.int64)[:, 1] - 1
        split_file = meta_dir / 'train_test_split.txt'
        split_arr = np.loadtxt(split_file, dtype=bool)[:, 1]

        # load image file names
        fname_file = meta_dir / 'images.txt'
        fname_list = []
        with fname_file.open() as f:
            for i, line in enumerate(f):
                fname = data_dir / line.strip().split(' ')[-1]

                if split == 'train' and split_arr[i]:
                    fname_list.append(fname)

                if split == 'test' and not split_arr[i]:
                    fname_list.append(fname)

                if split == 'all':
                    fname_list.append(fname)

        # split label
        if split == 'train':
            label_arr = label_arr[split_arr]
        elif split == 'test':
            label_arr = label_arr[~split_arr]

        self.data_dir = data_dir
        self.meta_dir = meta_dir
        self.split = split

        self.classes = classes
        self.class_to_idx = class_to_idx

        self.fname_list = fname_list
        self.label_arr = label_arr

        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        path, target = self.fname_list[index], self.label_arr[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return sample, target

    def __len__(self):
        return len(self.label_arr)


class CUBFSLDataset(CUBDataset):
    """
    Dataset for CUB few-shot learning classification

    The dataset is split into three parts: 100 train classes, 50 validation classes,
    and 50 test classes.

    Args:
        data_dir (str): the CUB image directory
        split_dir (str): the directory containing the class split files
        split (str): train/val/trainval/test
    """
    def __init__(self, data_dir, split_dir='datasets/cub_split_new', split='train', transform=None, target_transform=None,
                 loader=default_loader):
        super(CUBDataset, self).__init__()
        data_dir = Path(data_dir).expanduser()
        split_dir = Path(split_dir).expanduser()
        assert split in ['train', 'val', 'trainval', 'test'], 'split not supported'

        # load classes names
        class_file = split_dir / f'{split}.txt'
        classes = []
        class_to_idx = {}
        with class_file.open() as f:
            for idx, line in enumerate(f):
                cls = line.strip()
                classes.append(cls)
                class_to_idx[cls] = idx

        # make dataset
        files = []
        labels = []
        for target in classes:
            d = data_dir / target
            if not d.is_dir():
                continue

            for file in d.glob('*'):
                if has_file_allowed_extension(file.name, IMG_EXTENSIONS):
                    files.append(file)
                    labels.append(class_to_idx[target])

        self.data_dir = data_dir
        self.split_dir = split_dir
        self.split = split

        self.classes = classes
        self.class_to_idx =class_to_idx
        self.fname_list = files
        self.label_arr = labels

        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader


class OmniglotDataset(Dataset):

    def __init__(self, data_dir, split, transform=None,
                 target_transform=None, loader=default_loader):
        super(OmniglotDataset, self).__init__()
        data_dir = Path(data_dir).expanduser()
        assert split in ['train', 'val', 'trainval', 'test'], 'split not supported'

        # load classes
        classes = []
        class_to_idx = {}
        with (data_dir / f'{split}.txt').open() as f:
            for c, line in enumerate(f):
                cls = line.strip()
                classes.append(cls)
                class_to_idx[cls] = c

        # make dataset
        files = []
        labels = []
        for target in classes:
            d = data_dir / target
            if not d.is_dir():
                continue

            for file in d.glob('*'):
                if has_file_allowed_extension(file.name, IMG_EXTENSIONS):
                    files.append(file)
                    labels.append(class_to_idx[target])

        self.data_dir = data_dir
        self.split = split

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.files = files
        self.labels = labels

        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        path, target = self.files[index], self.labels[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target


def load_data(file):
    with open(file, 'rb') as fo:
        data = pickle.load(fo, encoding='latin1')
    return data

class Aircraft(Dataset):
    """
    Stanford Car few-shot learning dataset

    The split is Proposed in [DN4](https://github.com/WenbinLee/DN4). The dataset is split
    into 130 train, 17 val and 49 test classes.

    Args:
        data_dir (Union[str, Path]): the directory to CAR dataset
        split (str): train/val/test
    """

    def __init__(self, data_dir='/home/pathto/Aircraft_fewshot/', split='train', transform=None, target_transform=None, loader=default_loader):
        super(Aircraft, self).__init__()
        data_dir = os.path.join(data_dir,split)
        self.path_list = []
        raw_labels = []
        for folder in os.listdir(data_dir):
            image_folder = os.path.join(data_dir,folder)
            for f in os.listdir(image_folder):
                if '.png' in f:
                    self.path_list.append(os.path.join(image_folder, f))
                    raw_labels.append(image_folder)
        self.classes = np.unique(raw_labels)
        class_to_idx = {}
        for i, cls in enumerate(self.classes):
            class_to_idx[cls] = i
            # print(i, 'real label is', cls)
        mapped_labels = []
        for l in raw_labels:
            mapped_labels.append(class_to_idx[l])
        self.labels = mapped_labels
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        path, target = self.path_list[index], self.labels[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return sample, target

    def __len__(self):
        return len(self.labels)



class CAR(Dataset):
    """
    Stanford Car few-shot learning dataset

    The split is Proposed in [DN4](https://github.com/WenbinLee/DN4). The dataset is split
    into 130 train, 17 val and 49 test classes.

    Args:
        data_dir (Union[str, Path]): the directory to CAR dataset
        split (str): train/val/test
    """

    def __init__(self, data_dir, split_dir, split='train', transform=None, target_transform=None, loader=default_loader):
        super(CAR, self).__init__()
        data_dir = Path(data_dir).expanduser()
        split_dir = Path(split_dir).expanduser()
        self.data_dir = data_dir
        self.split_dir = split_dir
        assert split in ['train', 'val', 'test'], 'split not supported'
        self.split = split

        if split == 'train':
            split_name = 'train.csv'
        elif split == 'val':
            split_name = 'test.csv'
        else:
            split_name = 'test.csv'
        split_file = self.split_dir / split_name
        self.fname_list = []
        raw_labels = []
        with split_file.open() as f:
            all_data = csv.reader(f, delimiter=',')
            for row in enumerate(all_data):
                if row[0] == 0:
                    continue
                img_name, img_class = row[1]
                self.fname_list.append(img_name)
                raw_labels.append(img_class)
        f.close()
        self.classes = np.unique(raw_labels)
        class_to_idx = {}
        for i, cls in enumerate(self.classes):
            class_to_idx[cls] = i
            # print(i, 'real label is', cls)
        mapped_labels = []
        for l in raw_labels:
            mapped_labels.append(class_to_idx[l])
        self.labels = mapped_labels
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        path, target = self.data_dir / self.fname_list[index], self.labels[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return sample, target

    def __len__(self):
        return len(self.labels)


class CARvis(Dataset):
    """
    Stanford Car few-shot learning dataset

    The split is Proposed in [DN4](https://github.com/WenbinLee/DN4). The dataset is split
    into 130 train, 17 val and 49 test classes.

    Args:
        data_dir (Union[str, Path]): the directory to CAR dataset
        split (str): train/val/test
    """

    def __init__(self, data_dir, split_dir, split='train', transform=None, target_transform=None, loader=default_loader):
        super(CARvis, self).__init__()
        data_dir = Path(data_dir).expanduser()
        split_dir = Path(split_dir).expanduser()
        self.data_dir = data_dir
        self.split_dir = split_dir
        assert split in ['train', 'val', 'test'], 'split not supported'
        self.split = split

        if split == 'train':
            split_name = 'train.csv'
        elif split == 'val':
            split_name = 'val.csv'
        else:
            split_name = 'test.csv'
        split_file = self.split_dir / split_name
        self.fname_list = []
        raw_labels = []
        with split_file.open() as f:
            all_data = csv.reader(f, delimiter=',')
            for row in enumerate(all_data):
                if row[0] == 0:
                    continue
                img_name, img_class = row[1]
                self.fname_list.append(img_name)
                raw_labels.append(img_class)
        f.close()
        self.classes = np.unique(raw_labels)
        class_to_idx = {}
        for i, cls in enumerate(self.classes):
            class_to_idx[cls] = i
            # print(i, 'real label is', cls)
        mapped_labels = []
        for l in raw_labels:
            mapped_labels.append(class_to_idx[l])
        self.labels = mapped_labels
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        path, target = self.data_dir / self.fname_list[index], self.labels[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)
        #print(path,index)
        return sample, target, index

    def __len__(self):
        return len(self.labels)

class CARvisualization(Dataset):
    """
    Stanford Car few-shot learning dataset

    The split is Proposed in [DN4](https://github.com/WenbinLee/DN4). The dataset is split
    into 130 train, 17 val and 49 test classes.

    Args:
        data_dir (Union[str, Path]): the directory to CAR dataset
        split (str): train/val/test
    """

    def __init__(self, data_dir, split_dir, transform=None, target_transform=None, loader=default_loader):
        super(CARvisualization, self).__init__()
        data_dir = Path(data_dir).expanduser()
        split_dir = Path(split_dir).expanduser()
        self.data_dir = data_dir
        self.split_dir = split_dir

        self.fname_list = []
        raw_labels = []

        img_dir = os.listdir(data_dir)
        for img_dir in os.listdir(data_dir):
        #with split_file.open() as f:
            # all_data = csv.reader(f, delimiter=',')
            # for row in enumerate(all_data):
            #     if row[0] == 0:
            #         continue
            #img_name, img_class = img_dir
            self.fname_list.append(img_dir)
            raw_labels.append(1)
        self.classes = np.unique(raw_labels)
        class_to_idx = {}
        for i, cls in enumerate(self.classes):
            class_to_idx[cls] = i
            print(i, 'real label is', cls)
        mapped_labels = []
        for l in raw_labels:
            mapped_labels.append(class_to_idx[l])
        self.labels = mapped_labels
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        path, target = self.data_dir / self.fname_list[index], self.labels[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)
        #print(path,index)
        return sample, index

    def __len__(self):
        return len(self.labels)

class DOG(Dataset):
    """
    Stanford Dogs few-shot learning dataset

    The split is Proposed in [DN4](https://github.com/WenbinLee/DN4). The dataset is split
    into 70 train, 20 val and 30 test classes.

    Args:
        data_dir (Union[str, Path]): the directory to CAR dataset
        split (str): train/val/test
    """

    def __init__(self, data_dir, split_dir, split='train', transform=None, target_transform=None, loader=default_loader):
        super(DOG, self).__init__()
        data_dir = Path(data_dir).expanduser()
        split_dir = Path(split_dir).expanduser()
        self.data_dir = data_dir
        self.split_dir = split_dir
        assert split in ['train', 'val', 'test'], 'split not supported'
        self.split = split

        if split == 'train':
            split_name = 'train.csv'
        elif split == 'val':
            split_name = 'test.csv'
        else:
            split_name = 'test.csv'
        split_file = self.split_dir / split_name
        self.fname_list = []
        self.raw_labels = []
        with split_file.open() as f:
            all_data = csv.reader(f, delimiter=',')
            for row in enumerate(all_data):
                if row[0] == 0:
                    continue
                img_name, img_class = row[1]
                self.fname_list.append(img_name)
                self.raw_labels.append(img_class)
        f.close()
        self.classes = np.unique(self.raw_labels)
        class_to_idx = {}
        for i, cls in enumerate(self.classes):
            class_to_idx[cls] = i
            print(cls,i)
        mapped_labels = []
        for l in self.raw_labels:
            mapped_labels.append(class_to_idx[l])
        self.labels = mapped_labels
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        file = self.raw_labels[index]+'/'+self.fname_list[index]
        path, target = self.data_dir /file, self.labels[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return sample, target

    def __len__(self):
        return len(self.labels)


class DOGvis(Dataset):
    """
    Stanford Dogs few-shot learning dataset

    The split is Proposed in [DN4](https://github.com/WenbinLee/DN4). The dataset is split
    into 70 train, 20 val and 30 test classes.

    Args:
        data_dir (Union[str, Path]): the directory to CAR dataset
        split (str): train/val/test
    """

    def __init__(self, data_dir, split_dir, split='train', transform=None, target_transform=None, loader=default_loader):
        super(DOGvis, self).__init__()
        data_dir = Path(data_dir).expanduser()
        split_dir = Path(split_dir).expanduser()
        self.data_dir = data_dir
        self.split_dir = split_dir
        self.split = split

        split_name = 'test_new.csv'
        split_file = self.split_dir / split_name
        self.fname_list = []
        self.raw_labels = []
        with split_file.open() as f:
            all_data = csv.reader(f, delimiter=',')
            for row in enumerate(all_data):
                if row[0] == 0:
                    continue
                img_name, img_class = row[1]
                self.fname_list.append(img_name)
                self.raw_labels.append(img_class)
        f.close()
        self.classes = np.unique(self.raw_labels)
        class_to_idx = {}
        for i, cls in enumerate(self.classes):
            class_to_idx[cls] = i
            print(i, 'real label is', cls)


        mapped_labels = []
        for l in self.raw_labels:
            mapped_labels.append(class_to_idx[l])
        self.labels = mapped_labels
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        file = self.raw_labels[index]+'/'+self.fname_list[index]
        path, target = self.data_dir /file, self.labels[index]
        print('the data is in the path',path, 'and the label is', self.raw_labels[index])
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return sample, target

    def __len__(self):
        return len(self.labels)


class EpisodicSampler(Sampler):

    def __init__(self, labels, classes_per_it, num_sup_per_cls, num_query_per_cls, num_iter):
        super(EpisodicSampler, self).__init__(None)
        self.labels = labels
        self.classes_per_it = classes_per_it
        self.num_sup_per_cls = num_sup_per_cls
        self.num_query_per_cls = num_query_per_cls
        self.sample_per_class = num_sup_per_cls + num_query_per_cls
        self.num_iter = num_iter

        self.classes, self.counts = torch.unique(torch.tensor(self.labels, dtype=torch.long),
                                                 sorted=True, return_counts=True)
        self.classes = self.classes.to(torch.long)

        self.indexes = torch.full((len(self.classes), max(self.counts)), np.nan)
        self.numel_per_class = torch.zeros_like(self.classes)
        for idx, label in enumerate(self.labels):
            self.indexes[label, self.numel_per_class[label]] = idx
            self.numel_per_class[label] += 1

    def __iter__(self):
        spc = self.sample_per_class
        cpi = self.classes_per_it

        for it in range(self.num_iter):
            sup_batch_size = self.num_sup_per_cls * cpi
            query_batch_size = self.num_query_per_cls * cpi
            sup_batch = torch.LongTensor(sup_batch_size)
            query_batch = torch.LongTensor(query_batch_size)
            c_idxs = torch.randperm(len(self.classes))[:cpi]
            for i, c in enumerate(self.classes[c_idxs]):
                s_sup = slice(i * self.num_sup_per_cls, (i + 1) * self.num_sup_per_cls)
                s_query = slice(i * self.num_query_per_cls, (i + 1) * self.num_query_per_cls)
                sample_idxs = torch.randperm(self.numel_per_class[c])[:spc]
                sup_idxs = sample_idxs[:self.num_sup_per_cls]
                query_idxs = sample_idxs[self.num_sup_per_cls:]
                sup_batch[s_sup] = self.indexes[c][sup_idxs]
                query_batch[s_query] = self.indexes[c][query_idxs]
            sup_batch = sup_batch[torch.randperm(len(sup_batch))]
            query_batch = query_batch[torch.randperm(len(query_batch))]
            batch = torch.cat([sup_batch, query_batch])
            yield batch

    def __len__(self):
        return self.num_iter


class CUBDataModule(LightningDataModule):

    def __init__(self, data_dir, meta_dir, batch_size, train_transform,
                 test_transform):
        super(CUBDataModule, self).__init__()
        self.data_dir = Path(data_dir).expanduser()
        self.meta_dir = Path(meta_dir).expanduser()
        self.train_transform = train_transform
        self.test_transform = test_transform
        self.batch_size = batch_size

    def prepare_data(self):
        # check the meta files
        assert (self.meta_dir / 'classes.txt').exists()
        assert (self.meta_dir / 'image_class_labels.txt').exists()
        assert (self.meta_dir / 'train_test_split.txt').exists()
        assert (self.meta_dir / 'images.txt').exists()

    def setup(self, stage: Optional[str] = None):
        if stage == 'fit' or stage is None:
            self.cub_train = CUBDataset(self.data_dir, self.meta_dir, 'train',
                                        self.train_transform)
            self.cub_val = CUBDataset(self.data_dir, self.meta_dir, 'val',
                                      self.test_transform)

        if stage == 'test':
            self.cub_test = CUBDataset(self.data_dir, self.meta_dir, 'test',
                                       self.test_transform)

    def train_dataloader(self):
        return DataLoader(self.cub_train, self.batch_size, shuffle=True, num_workers=4,
                          pin_memory=True)

    def test_dataloader(self):
        return DataLoader(self.cub_test, self.batch_size, shuffle=False, num_workers=4,
                          pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.cub_val, self.batch_size, shuffle=False, num_workers=4,
                          pin_memory=True)


class FSLDataModuleBase(LightningDataModule):

    def __init__(self, iterations_tr, iterations_val, iterations_test, classes_per_it_tr,
                 num_support_tr, num_query_tr, classes_per_it_val, num_support_val, num_query_val,
                 train_mode, train_transform, test_transform):
        super(FSLDataModuleBase, self).__init__()
        assert train_mode in ['batch', 'episode'], 'train_mode not supported'
        self.train_mode = train_mode
        self.train_transform = train_transform
        self.test_transform = test_transform
        self.iterations_tr = iterations_tr
        self.iterations_val = iterations_val
        self.iterations_test = iterations_test
        self.classes_per_it_tr = classes_per_it_tr
        self.classes_per_it_val = classes_per_it_val
        self.num_support_tr = num_support_tr
        self.num_query_tr = num_query_tr
        self.num_support_val = num_support_val
        self.num_query_val = num_query_val

class CUBFSLDataModule(FSLDataModuleBase):

    def __init__(self, data_dir, split_dir, iterations_tr, iterations_val, iterations_test,
                 classes_per_it_tr, num_support_tr, num_query_tr,
                 classes_per_it_val, num_support_val, num_query_val, train_mode,
                 train_transform, test_transform, batch_size=64):
        super(CUBFSLDataModule, self).__init__(
            iterations_tr, iterations_val, iterations_test,
            classes_per_it_tr, num_support_tr, num_query_tr,
            classes_per_it_val, num_support_val, num_query_val, train_mode,
            train_transform, test_transform)
        self.data_dir = Path(data_dir).expanduser()
        self.split_dir = split_dir
        self.batch_size = batch_size

    def setup(self, stage: Optional[str] = None):
        if stage == 'fit' or stage is None:
            self.cub_train = CUBFSLDataset(self.data_dir, self.split_dir, 'train',
                                           self.train_transform)
            self.sampler_train = EpisodicSampler(
                self.cub_train.label_arr, self.classes_per_it_tr, self.num_support_tr,
                self.num_query_tr, self.iterations_tr)

            self.cub_val = CUBFSLDataset(self.data_dir, self.split_dir, 'val',
                                         self.test_transform)
            self.sampler_val = EpisodicSampler(
                self.cub_val.label_arr, self.classes_per_it_val, self.num_support_val,
                self.num_query_val, self.iterations_val)

        if stage == 'test':
            self.cub_test = CUBFSLDataset(self.data_dir, self.split_dir, 'test',
                                          self.test_transform)
            self.sampler_test = EpisodicSampler(
                self.cub_test.label_arr, self.classes_per_it_val, self.num_support_val,
                self.num_query_val, self.iterations_test)

    def train_dataloader(self):
        if self.train_mode == 'batch':
            return DataLoader(self.cub_train, self.batch_size, shuffle=True, num_workers=4,
                              pin_memory=True)
        else:
            return DataLoader(self.cub_train, batch_sampler=self.sampler_train,
                              num_workers=4, pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.cub_val, batch_sampler=self.sampler_val,
                          num_workers=4, pin_memory=True)

    def test_dataloader(self):
        return DataLoader(self.cub_test, batch_sampler=self.sampler_test,
                          num_workers=4, pin_memory=True)


class CARFSLDataModule(FSLDataModuleBase):

    def __init__(self, data_dir, split_dir, iterations_tr, iterations_val, iterations_test,
                 classes_per_it_tr, num_support_tr, num_query_tr,
                 classes_per_it_val, num_support_val, num_query_val, train_mode,
                 train_transform, test_transform, batch_size=64):
        super(CARFSLDataModule, self).__init__(
            iterations_tr, iterations_val, iterations_test,
            classes_per_it_tr, num_support_tr, num_query_tr,
            classes_per_it_val, num_support_val, num_query_val, train_mode,
            train_transform, test_transform
        )
        self.data_dir = Path(data_dir).expanduser()
        self.batch_size = batch_size
        self.split_dir = split_dir

    def setup(self, stage: Optional[str] = None):
        if stage == 'fit' or stage is None:
            self.CAR_train = CAR(self.data_dir,self.split_dir, 'train', self.train_transform)
            self.sampler_train = EpisodicSampler(
                self.CAR_train.labels, self.classes_per_it_tr, self.num_support_tr,
                self.num_query_tr, self.iterations_tr)
            self.CAR_val = CAR(self.data_dir,self.split_dir, 'val', self.test_transform)
            self.sampler_val = EpisodicSampler(
                self.CAR_val.labels, self.classes_per_it_val, self.num_support_val,
                self.num_query_val, self.iterations_val)

        if stage == 'test':
            self.CAR_test = CAR(self.data_dir,self.split_dir, 'test', self.test_transform)
            self.sampler_test = EpisodicSampler(
                self.CAR_test.labels, self.classes_per_it_val, self.num_support_val,
                self.num_query_val, self.iterations_test)

    def train_dataloader(self):
        if self.train_mode == 'batch':
            return DataLoader(self.CAR_train, self.batch_size, shuffle=True, num_workers=4,
                              pin_memory=True)
        else:
            return DataLoader(self.CAR_train, batch_sampler=self.sampler_train,
                              num_workers=4, pin_memory=True)

    def val_dataloader(self):
        val_loader = DataLoader(self.CAR_val, batch_sampler=self.sampler_val,
                          num_workers=4, pin_memory=True)
        return val_loader

    def test_dataloader(self):
        return DataLoader(self.CAR_test, batch_sampler=self.sampler_test,
                          num_workers=4, pin_memory=True)

class CARFSLDataModule_vis(FSLDataModuleBase):

    def __init__(self, data_dir, split_dir, iterations_tr, iterations_val, iterations_test,
                 classes_per_it_tr, num_support_tr, num_query_tr,
                 classes_per_it_val, num_support_val, num_query_val, train_mode,
                 train_transform, test_transform, batch_size=64):
        super(CARFSLDataModule_vis, self).__init__(
            iterations_tr, iterations_val, iterations_test,
            classes_per_it_tr, num_support_tr, num_query_tr,
            classes_per_it_val, num_support_val, num_query_val, train_mode,
            train_transform, test_transform
        )
        self.data_dir = Path(data_dir).expanduser()
        self.batch_size = batch_size
        self.split_dir = split_dir

    def setup(self, stage: Optional[str] = None):
        if stage == 'fit' or stage is None:
            self.CAR_train = CARvis(self.data_dir,self.split_dir, 'train', self.train_transform)
            self.sampler_train = EpisodicSampler(
                self.CAR_train.labels, self.classes_per_it_tr, self.num_support_tr,
                self.num_query_tr, self.iterations_tr)
            self.CAR_val = CARvis(self.data_dir,self.split_dir, 'val', self.test_transform)
            self.sampler_val = EpisodicSampler(
                self.CAR_val.labels, self.classes_per_it_val, self.num_query_val,
                self.num_query_val, self.iterations_val)

        if stage == 'test':
            self.CAR_test = CARvis(self.data_dir,self.split_dir, 'test', self.test_transform)
            self.sampler_test = EpisodicSampler(
                self.CAR_test.labels, self.classes_per_it_val, self.num_support_val,
                self.num_query_val, self.iterations_test)

    def train_dataloader(self):
        if self.train_mode == 'batch':
            return DataLoader(self.CAR_train, self.batch_size, shuffle=True, num_workers=4,
                              pin_memory=True)
        else:
            return DataLoader(self.CAR_train, batch_sampler=self.sampler_train,
                              num_workers=4, pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.CAR_val, batch_sampler=self.sampler_val,
                          num_workers=4, pin_memory=True)

    def test_dataloader(self):
        return DataLoader(self.CAR_test, batch_sampler=self.sampler_test,
                          num_workers=4, pin_memory=True)

class DOGFSLDataModule(FSLDataModuleBase):

    def __init__(self, data_dir, split_dir, iterations_tr, iterations_val, iterations_test,
                 classes_per_it_tr, num_support_tr, num_query_tr,
                 classes_per_it_val, num_support_val, num_query_val, train_mode,
                 train_transform, test_transform, batch_size=64):
        super(DOGFSLDataModule, self).__init__(
            iterations_tr, iterations_val, iterations_test,
            classes_per_it_tr, num_support_tr, num_query_tr,
            classes_per_it_val, num_support_val, num_query_val, train_mode,
            train_transform, test_transform
        )
        self.data_dir = Path(data_dir).expanduser()
        self.batch_size = batch_size
        self.split_dir = split_dir

    def setup(self, stage: Optional[str] = None):
        if stage == 'fit' or stage is None:
            self.DOG_train = DOG(self.data_dir,self.split_dir, 'train', self.train_transform)
            self.sampler_train = EpisodicSampler(
                self.DOG_train.labels, self.classes_per_it_tr, self.num_support_tr,
                self.num_query_tr, self.iterations_tr)
            self.DOG_val = DOG(self.data_dir,self.split_dir, 'val', self.test_transform)
            self.sampler_val = EpisodicSampler(
                self.DOG_val.labels, self.classes_per_it_val, self.num_support_val,
                self.num_query_val, self.iterations_val)

        if stage == 'test':
            self.DOG_test = DOG(self.data_dir,self.split_dir, 'test', self.test_transform)
            self.sampler_test = EpisodicSampler(
                self.DOG_test.labels, self.classes_per_it_val, self.num_support_val,
                self.num_query_val, self.iterations_test)

    def train_dataloader(self):
        if self.train_mode == 'batch':
            return DataLoader(self.DOG_train, self.batch_size, shuffle=True, num_workers=4,
                              pin_memory=True)
        else:
            return DataLoader(self.DOG_train, batch_sampler=self.sampler_train,
                              num_workers=4, pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.DOG_val, batch_sampler=self.sampler_val,
                          num_workers=4, pin_memory=True)

    def test_dataloader(self):
        return DataLoader(self.DOG_test, batch_sampler=self.sampler_test,
                          num_workers=4, pin_memory=True)

class AircraftFSLDataModule(FSLDataModuleBase):

    def __init__(self, data_dir, split_dir, iterations_tr, iterations_val, iterations_test,
                 classes_per_it_tr, num_support_tr, num_query_tr,
                 classes_per_it_val, num_support_val, num_query_val, train_mode,
                 train_transform, test_transform, batch_size=64):
        super(AircraftFSLDataModule, self).__init__(
            iterations_tr, iterations_val, iterations_test,
            classes_per_it_tr, num_support_tr, num_query_tr,
            classes_per_it_val, num_support_val, num_query_val, train_mode,
            train_transform, test_transform
        )
        self.data_dir = Path(data_dir).expanduser()
        self.batch_size = batch_size
        self.split_dir = split_dir

    def setup(self, stage: Optional[str] = None):
        if stage == 'fit' or stage is None:
            self.CAR_train = Aircraft(self.data_dir, 'train', self.train_transform)
            self.sampler_train = EpisodicSampler(
                self.CAR_train.labels, self.classes_per_it_tr, self.num_support_tr,
                self.num_query_tr, self.iterations_tr)
            self.CAR_val = Aircraft(self.data_dir, 'val', self.test_transform)
            self.sampler_val = EpisodicSampler(
                self.CAR_val.labels, self.classes_per_it_val, self.num_support_val,
                self.num_query_val, self.iterations_val)

        if stage == 'test':
            self.CAR_test = Aircraft(self.data_dir, 'test', self.test_transform)
            self.sampler_test = EpisodicSampler(
                self.CAR_test.labels, self.classes_per_it_val, self.num_support_val,
                self.num_query_val, self.iterations_test)

    def train_dataloader(self):
        if self.train_mode == 'batch':
            return DataLoader(self.CAR_train, self.batch_size, shuffle=True, num_workers=4,
                              pin_memory=True)
        else:
            return DataLoader(self.CAR_train, batch_sampler=self.sampler_train,
                              num_workers=4, pin_memory=True)

    def val_dataloader(self):
        val_loader = DataLoader(self.CAR_val, batch_sampler=self.sampler_val,
                          num_workers=4, pin_memory=True)
        return val_loader

    def test_dataloader(self):
        return DataLoader(self.CAR_test, batch_sampler=self.sampler_test,
                          num_workers=4, pin_memory=True)

class DOGFSLDataModulevis(FSLDataModuleBase):

    def __init__(self, data_dir, split_dir, iterations_tr, iterations_val, iterations_test,
                 classes_per_it_tr, num_support_tr, num_query_tr,
                 classes_per_it_val, num_support_val, num_query_val, train_mode,
                 train_transform, test_transform, batch_size=64):
        super(DOGFSLDataModulevis, self).__init__(
            iterations_tr, iterations_val, iterations_test,
            classes_per_it_tr, num_support_tr, num_query_tr,
            classes_per_it_val, num_support_val, num_query_val, train_mode,
            train_transform, test_transform
        )
        self.data_dir = Path(data_dir).expanduser()
        self.batch_size = batch_size
        self.split_dir = split_dir

    def setup(self, stage: Optional[str] = None):
        if stage == 'fit' or stage is None:
            self.DOG_train = DOGvis(self.data_dir,self.split_dir, 'train', self.train_transform)
            self.sampler_train = EpisodicSampler(
                self.DOG_train.labels, self.classes_per_it_tr, self.num_support_tr,
                self.num_query_tr, self.iterations_tr)
            self.DOG_val = DOGvis(self.data_dir,self.split_dir, 'val', self.test_transform)
            self.sampler_val = EpisodicSampler(
                self.DOG_val.labels, self.classes_per_it_val, self.num_support_val,
                self.num_query_val, self.iterations_val)

        if stage == 'test':
            self.DOG_test = DOGvis(self.data_dir,self.split_dir, 'test', self.test_transform)
            self.sampler_test = EpisodicSampler(
                self.DOG_test.labels, self.classes_per_it_val, self.num_support_val,
                self.num_query_val, self.iterations_test)

    def train_dataloader(self):
        if self.train_mode == 'batch':
            return DataLoader(self.DOG_train, self.batch_size, shuffle=True, num_workers=4,
                              pin_memory=True)
        else:
            return DataLoader(self.DOG_train, batch_sampler=self.sampler_train,
                              num_workers=4, pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.DOG_val, batch_sampler=self.sampler_val,
                          num_workers=4, pin_memory=True)

    def test_dataloader(self):
        return DataLoader(self.DOG_test, batch_sampler=self.sampler_test,
                          num_workers=4, pin_memory=True)



