import os
import pickle
from collections import OrderedDict, defaultdict
from torch.utils.data import Dataset
from PIL import Image
import random

from .base_dataset import Datum, DatasetBase
from utils.train_util import listdir_nohidden, mkdir_if_missing
import math



class ImageNet(DatasetBase):

    dataset_dir = "imagenet"

    def __init__(self, cfg):
        root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
        self.dataset_dir = os.path.join(root, self.dataset_dir)
        self.image_dir = os.path.join(self.dataset_dir, "images")
        self.preprocessed = os.path.join(self.dataset_dir, "preprocessed.pkl")
        self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
        mkdir_if_missing(self.split_fewshot_dir)

        if os.path.exists(self.preprocessed):
            with open(self.preprocessed, "rb") as f:
                preprocessed = pickle.load(f)
                train = preprocessed["train"]
                test = preprocessed["test"]
        else:
            text_file = os.path.join(self.dataset_dir, "classnames.txt")
            classnames = self.read_classnames(text_file)
            train = self.read_data(classnames, "train")
            # Follow standard practice to perform evaluation on the val set
            # Also used as the val set (so evaluate the last-step model)
            test = self.read_data(classnames, "val")

            preprocessed = {"train": train, "test": test}
            with open(self.preprocessed, "wb") as f:
                pickle.dump(preprocessed, f, protocol=pickle.HIGHEST_PROTOCOL)

        num_shots = cfg.DATASET.NUM_SHOTS
        if num_shots >= 1:
            seed = cfg.SEED
            preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
            
            if os.path.exists(preprocessed):
                print(f"Loading preprocessed few-shot data from {preprocessed}")
                with open(preprocessed, "rb") as file:
                    data = pickle.load(file)
                    train = data["train"]
            else:
                train = self.generate_fewshot_dataset(train, num_shots=num_shots)
                data = {"train": train}
                print(f"Saving preprocessed few-shot data to {preprocessed}")
                with open(preprocessed, "wb") as file:
                    pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)

        subsample = cfg.DATASET.SUBSAMPLE_CLASSES
        train, test = ImageNet.subsample_classes(train, test, subsample=subsample)

        super().__init__(train_x=train, val=test, test=test)

    @staticmethod
    def read_classnames(text_file):
        """Return a dictionary containing
        key-value pairs of <folder name>: <class name>.
        """
        classnames = OrderedDict()
        with open(text_file, "r") as f:
            lines = f.readlines()
            for line in lines:
                line = line.strip().split(" ")
                folder = line[0]
                classname = " ".join(line[1:])
                classnames[folder] = classname
        return classnames
    
    @staticmethod
    def subsample_classes(*args, subsample="all"):
        """Divide classes into two groups. The first group
        represents base classes while the second group represents
        new classes.

        Args:
            args: a list of datasets, e.g. train, val and test.
            subsample (str): what classes to subsample.
        """
        assert subsample in ["all", "base", "new"]

        if subsample == "all":
            return args

        dataset = args[0]
        labels = set()
        for item in dataset:
            labels.add(item.label)
        labels = list(labels)
        labels.sort()
        n = len(labels)
        # Divide classes into two halves
        m = math.ceil(n / 2)

        print(f"SUBSAMPLE {subsample.upper()} CLASSES!")
        if subsample == "base":
            selected = labels[:m]  # take the first half
        else:
            selected = labels[m:]  # take the second half
        relabeler = {y: y_new for y_new, y in enumerate(selected)}

        output = []
        for dataset in args:
            dataset_new = []
            for item in dataset:
                if item.label not in selected:
                    continue
                item_new = Datum(
                    impath=item.impath,
                    label=relabeler[item.label],
                    classname=item.classname
                )
                dataset_new.append(item_new)
            output.append(dataset_new)

        return output

    def read_data(self, classnames, split_dir):
        split_dir = os.path.join(self.image_dir, split_dir)
        folders = sorted(f.name for f in os.scandir(split_dir) if f.is_dir())
        items = []

        for label, folder in enumerate(folders):
            imnames = listdir_nohidden(os.path.join(split_dir, folder))
            classname = classnames[folder]
            for imname in imnames:
                impath = os.path.join(split_dir, folder, imname)
                item = Datum(impath=impath, label=label, classname=classname)
                items.append(item)

        return items


class Datum:
    def __init__(self, impath, label, classname):
        self.impath = impath
        self.label = label
        self.classname = classname

class ImageNetDataset(Dataset):
    def __init__(self, root, split="train", num_shots=None, seed=1, transform=None, subsample="all"):
        self.dataset_dir = os.path.join(root, "imagenet")
        self.image_dir = os.path.join(self.dataset_dir, "images")
        self.transform = transform
        self.num_shots = num_shots
        self.seed = seed
        self.subsample = subsample

        # クラス名の読み込み
        text_file = os.path.join(self.dataset_dir, "classnames.txt")
        self._classnames_dict = self.read_classnames(text_file)

        # データの読み込み
        self.data = self.read_data(self._classnames_dict, split)

        # few-shot対応
        if self.num_shots is not None and self.num_shots > 0:
            self.data = self.generate_fewshot_dataset(self.data, self.num_shots, self.seed)

        # クラス数やラベル→クラス名辞書
        self._num_classes = self.get_num_classes(self.data)
        self._lab2cname, self._classnames = self.get_lab2cname(self.data)

    @staticmethod
    def read_classnames(text_file):
        classnames = OrderedDict()
        with open(text_file, "r") as f:
            for line in f:
                line = line.strip().split(" ")
                folder = line[0]
                classname = " ".join(line[1:])
                classnames[folder] = classname
        return classnames

    def read_data(self, classnames, split_dir):
        split_dir = os.path.join(self.image_dir, split_dir)
        folders = sorted(f for f in os.listdir(split_dir) if os.path.isdir(os.path.join(split_dir, f)) and not f.startswith('.'))
        items = []
        for label, folder in enumerate(folders):
            imnames = [im for im in os.listdir(os.path.join(split_dir, folder)) if not im.startswith('.')]
            classname = classnames[folder]
            for imname in imnames:
                impath = os.path.join(split_dir, folder, imname)
                items.append(Datum(impath, label, classname))
        return items

    def generate_fewshot_dataset(self, data, num_shots, seed):
        random.seed(seed)
        class_to_items = defaultdict(list)
        for item in data:
            class_to_items[item.label].append(item)
        fewshot_data = []
        for items in class_to_items.values():
            if len(items) <= num_shots:
                fewshot_data.extend(items)
            else:
                fewshot_data.extend(random.sample(items, num_shots))
        return fewshot_data

    @staticmethod
    def get_num_classes(data_source):
        label_set = set()
        for item in data_source:
            label_set.add(item.label)
        return max(label_set) + 1

    @staticmethod
    def get_lab2cname(data_source):
        container = set()
        for item in data_source:
            container.add((item.label, item.classname))
        mapping = {label: classname for label, classname in container}
        labels = list(mapping.keys())
        labels.sort()
        classnames = [mapping[label] for label in labels]
        return mapping, classnames

    @property
    def num_classes(self):
        return self._num_classes

    @property
    def lab2cname(self):
        return self._lab2cname

    @property
    def classnames(self):
        return self._classnames

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        try:
            image = Image.open(item.impath).convert("RGB")
            if self.transform:
                image = self.transform(image)
            return image, item.label, item.classname
        except Exception as e:
            print(f"Error loading image {item.impath}: {e}")
            # エラーが発生した場合は最初の画像を返す
            if idx > 0:
                return self.__getitem__(0)
            else:
                # 最初の画像でもエラーの場合はダミー画像を返す
                dummy_image = Image.new('RGB', (224, 224), color='gray')
                if self.transform:
                    dummy_image = self.transform(dummy_image)
                return dummy_image, item.label, item.classname
