import os
import torch.nn as nn
import torch
import sys

from MTIL_datasets.caltech101 import Caltech101
from MTIL_datasets.cifar100 import CIFAR100
from MTIL_datasets.dtd import DescribableTextures as DTD
from MTIL_datasets.eurosat import EuroSAT
from MTIL_datasets.fgvc_aircraft import FGVCAircraft as Aircraft
from MTIL_datasets.food101 import Food101 as Food
from MTIL_datasets.mnist import MNIST
from MTIL_datasets.oxford_flowers import OxfordFlowers as Flowers
from MTIL_datasets.oxford_pets import OxfordPets as OxfordPet
from MTIL_datasets.stanford_cars import StanfordCars
from MTIL_datasets.sun397 import SUN397
from MTIL_datasets.ucf101 import UCF101
from MTIL_datasets.country211 import Country211
from MTIL_datasets.sst2 import SST2
from MTIL_datasets.hatefulmemes import HatefulMemes
from MTIL_datasets.gtsrb import GTSRB
from MTIL_datasets.resisc import RESISC45
from MTIL_datasets.fer2013 import FER2013
from MTIL_datasets.cifar10 import CIFAR10
from MTIL_datasets.stl10 import STL10
from MTIL_datasets.voc2007 import VOC2007
from MTIL_datasets.imagenet_r import ImageNetR
from MTIL_datasets.kitti_distance import KittiDistance
from MTIL_datasets.pcam import PCam
from MTIL_datasets.clevr_count import CLEVRCount
from MTIL_datasets.utils import DatasetWrapper


def get_dataset(cfg, split, transforms=None):
    if split == 'val' and (not cfg.use_validation):
        return None, None, None

    is_train = (split == 'train')
    templates = None
    dataset_names = None

    if cfg.dataset == "MTIL":
        # Build the full MTIL dataset family (indices 0..24) in a fixed order.
        # Names here must match class __name__ so that zero-shot filtering by name works.
        all_sets = [
            Aircraft,          # 0
            Caltech101,        # 1
            CIFAR100,          # 2
            DTD,               # 3
            EuroSAT,           # 4
            Flowers,           # 5 (OxfordFlowers)
            Food,              # 6 (Food101)
            MNIST,             # 7
            OxfordPet,         # 8 (OxfordPets)
            StanfordCars,      # 9
            SUN397,            # 10
            Country211,        # 11
            SST2,              # 12
            HatefulMemes,      # 13
            GTSRB,             # 14
            RESISC45,          # 15
            FER2013,           # 16
            UCF101,            # 17
            CIFAR10,           # 18
            STL10,             # 19
            VOC2007,           # 20
            ImageNetR,         # 21
            KittiDistance,     # 22
            PCam,              # 23
            CLEVRCount,        # 24
        ]
        # Optional alternate order kept for compatibility: permute the first 11 as in legacy order_2,
        # then append the extended datasets unchanged.
        if getattr(cfg, 'MTIL_order_2', False):
            legacy_first11 = [StanfordCars, Food, MNIST, OxfordPet, Flowers, SUN397, Aircraft, Caltech101, DTD, EuroSAT, CIFAR100]
            all_sets = legacy_first11 + all_sets[11:]
        single_mode = int(getattr(cfg, 'train_one_dataset', -1))
        selected_indices = [single_mode] if single_mode >= 0 else list(range(len(all_sets)))
        dataset = []
        classes_names = []
        templates = []
        dataset_names = []
        for idx in selected_indices:
            base_ctor = all_sets[idx] if idx < len(all_sets) else None
            if base_ctor is None:
                if single_mode >= 0:
                    raise ValueError(f"MTIL dataset index {idx} is not available (module missing).")
                # Skip missing entries in 'all' mode
                continue
            base = base_ctor(cfg.dataset_root, seed=getattr(cfg, 'seed', 1))
            classes_names.append(base.classnames)
            # each dataset exposes a list of template callables/strings via `templates`
            templates.append(getattr(base, 'templates', None))
            dataset_names.append(base_ctor.__name__)
            if split == 'train':
                dataset.append(DatasetWrapper(base.train_x, transform=transforms, is_train=is_train))
            elif split == 'val':
                dataset.append(DatasetWrapper(base.val, transform=transforms, is_train=is_train))
            elif split == 'test':
                dataset.append(DatasetWrapper(base.test, transform=transforms, is_train=is_train))
    else:
        ValueError(f"'{cfg.dataset}' is a invalid dataset.")
    return dataset, classes_names, templates, dataset_names



def parse_sample(sample, is_train, task_id, cfg):
    return sample[0], sample[1], torch.IntTensor([task_id]).repeat(sample[0].size(0))