from pathlib import Path
import json
import torch
from typing import List, Dict, Union, Optional
from collections import OrderedDict


class DatasetMapper:
    DATASET_CONFIGS = {'imagenet': {'type': 'imagenet', 'subsets': {'nette': './misc/class_nette.txt', 'woof': './misc/class_woof.txt', 'woofre': './misc/class_woof_re.txt', 'imagenet-A': './misc/imagenet-a.txt', 'imagenet-B': './misc/imagenet-b.txt', 'imagenet-C': './misc/imagenet-c.txt', 'imagenet-D': './misc/imagenet-d.txt', 'imagenet-E': './misc/imagenet-e.txt', 'imagenet100': './misc/class100.txt', 'imagenet1k': None, 'class0': './misc/imagenet_class0.txt', 'class1': './misc/imagenet_class1.txt', 'class2': './misc/imagenet_class2.txt', 'class3': './misc/imagenet_class3.txt', 'class4': './misc/imagenet_class4.txt', 'class5': './misc/imagenet_class5.txt', 'class6': './misc/imagenet_class6.txt', 'class7': './misc/imagenet_class7.txt', 'class8': './misc/imagenet_class8.txt', 'class9': './misc/imagenet_class9.txt', 'class9_re': './misc/imagenet_class9_re.txt', 'class_idc': './misc/class_idc.txt'}}, 'cifar10': {'type': 'cifar10', 'class_file': './misc/cifar10_class_names.txt'}}
    def __init__(self, dataset_name: str, subset_name: Optional[str]=None):
        self.dataset_name = dataset_name.lower()
        self.subset_name = subset_name.lower() if subset_name else None
        if self.dataset_name not in self.DATASET_CONFIGS:
            raise ValueError(f'Unsupported dataset: {dataset_name}')
        self._initialize_mappings()
    def _initialize_mappings(self):
        if self.dataset_name == 'imagenet':
            self._init_imagenet_mappings()
        elif self.dataset_name == 'cifar10':
            self._init_cifar10_mappings()
    def _init_imagenet_mappings(self):
        from .imagenet_class_names import IMAGENET2012_CLASSES
        self.full_classes = IMAGENET2012_CLASSES
        if not self.subset_name or self.subset_name == 'imagenet1k':
            self.selected_indices = list(range(1000))
            self.class_names = list(self.full_classes.values())
            self.label_mapping = {i: i for i in range(1000)}
        else:
            config = self.DATASET_CONFIGS['imagenet']['subsets'].get(self.subset_name)
            if not config:
                raise ValueError(f'Unsupported ImageNet subset: {self.subset_name}')
            with open(config, 'r') as f:
                class_ids = [line.strip() for line in f.readlines()]
            self.class_names = [self.full_classes[class_id] for class_id in class_ids]
            id_to_idx = {class_id: idx for idx, class_id in enumerate(sorted(self.full_classes.keys()))}
            self.selected_indices = [id_to_idx[class_id] for class_id in class_ids]
            self.label_mapping = {old_idx: new_idx for new_idx, old_idx in enumerate(self.selected_indices)}
    def _init_cifar10_mappings(self):
        config = self.DATASET_CONFIGS['cifar10']['class_file']
        with open(config, 'r') as f:
            self.class_names = [line.strip() for line in f.readlines()]
        self.label_mapping = {i: i for i in range(10)}
        self.selected_indices = list(range(10))
    def convert_labels(self, labels: Union[torch.Tensor, List[int]]) -> Union[torch.Tensor, List[int]]:
        try:
            if isinstance(labels, torch.Tensor):
                return torch.tensor([self.label_mapping[l.item()] for l in labels], device=labels.device, dtype=labels.dtype)
            return [self.label_mapping[l] for l in labels]
        except KeyError as e:
            raise KeyError(f"Invalid label found. Make sure your labels match the {self.dataset_name} {('subset ' + self.subset_name if self.subset_name else '')} format. Error on label: {e}")
    def get_all_class_names(self):
        if not hasattr(self, 'class_names'):
            raise ValueError('class_names not initialized')
        return self.class_names
    def get_class_name(self, idx: int, simplified: bool=True) -> str:
        if idx >= len(self.class_names):
            raise ValueError(f'Index {idx} out of range for dataset with {len(self.class_names)} classes')
        name = self.class_names[idx]
        if self.dataset_name == 'imagenet' and simplified:
            return name.split(',')[0]
        return name
    def get_num_classes(self) -> int:
        return len(self.class_names)
    def get_prompt(self, idx: int) -> str:
        class_name = self.get_class_name(idx)
        return f'a photo of a {class_name}'