import os

import pytorch_lightning as pl
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms as T
import pickle
from tqdm import tqdm

class AttributeDataset(Dataset):
    """
    PyTorch dataset that loads images and attributes from a directory tree.

    The root directory should contain one subdirectory for each category.
    Each category subdirectory should contain one subdirectory for each attribute.
    Each attribute subdirectory should contain one or more image files.

    Args:
        root_dir (str): The root directory of the dataset.
        categories (list of str): The list of category names, in the order in which they should be indexed.
        transform (callable, optional): A function/transform to apply to the images.

    Example:
        Suppose we have the following directory structure:
        ```
        root_dir/
            category1/
                attribute1/
                    image1.jpg
                    image2.jpg
                    ...
                attribute2/
                    image1.jpg
                    image2.jpg
                    ...
                ...
            category2/
                attribute1/
                    image1.jpg
                    image2.jpg
                    ...
                attribute2/
                    image1.jpg
                    image2.jpg
                    ...
                ...
            ...
        ```

        To create a dataset object from this directory structure, we can use:
        ```
        dataset = AttributeDataset(root_dir='root_dir', 
                                    categories=['category1', 'category2'], 
                                    transform=transforms.ToTensor())
        ```
        This will create a dataset object that indexes the categories in the order ['category1', 'category2'] and applies the ToTensor transform to each image.
    """

    def __init__(self, root_dir, categories, prefix='0/', transform=None):
        """
        Initializes the dataset object and loads the images and attributes.

        Args:
            root_dir (str): The root directory of the dataset.
            categories (list of str): The list of category names, in the order in which they should be indexed.
            prefix (str, optional): A string prefix to be added to the beginning of the attribute subdirectory names.
            transform (callable, optional): A function/transform to apply to the images.
        """
        self.root_dir = root_dir
        self.categories = categories
        self.transform = transform
        self.category_to_index = {cat: i for i, cat in enumerate(categories)}
        self.y = []
        self.images = []
        self.attributes = []

        file_name = '{}_{}_label.pkl'.format(root_dir.split('/')[-2],root_dir.split('/')[-1])
        
        # check if the dataset file exists
        if os.path.exists(file_name):
            print('-'*20)
            print('Loading the labels from', file_name)
            self.load_label(file_name)
            
        else:
            for i, category in tqdm(enumerate(categories)):
                category_dir = os.path.join(root_dir, category)
                for attribute in os.listdir(category_dir):
                    if prefix is not None:
                        attribute_dir = os.path.join(
                            category_dir, attribute, prefix)
                    else:
                        attribute_dir = os.path.join(category_dir, attribute)
                    if False == os.path.isdir(attribute_dir):
                        continue
                    # for image_file in os.listdir(attribute_dir):
                    for subdir, dirs, files in os.walk(attribute_dir):
                        for image_file in files:
                            # check if file is an image file
                            if image_file.endswith('.jpg') or image_file.endswith('.png'):
                                # get the full path to the image file
                                image_path = os.path.join(subdir, image_file)
                                self.images.append(image_path)

                                # map the category to its index
                                self.y.append(self.category_to_index[category])

                                # extract the attribute from the path
                                attribute = subdir[subdir.find(category)+len(category)+1:]
                                self.attributes.append(attribute)

            # save the dataset to a file
            with open(file_name, 'wb') as f:
                dataset = {'images': self.images, 'y': self.y, 'attributes': self.attributes}
                pickle.dump(dataset, f)
                
        print('-'*20)
        print('Total image number:', len(self.images))
        print("Number of attribute:", len(set(self.attributes)))
        print('-'*20)

    def load_label(self, file_name):
        with open(file_name, 'rb') as f:
            dataset = pickle.load(f)
            self.images = dataset['images']
            self.y = dataset['y']
            self.attributes = dataset['attributes']
            print("Dataset loaded from file.")

    def __len__(self):
        """
        Returns the number of images in the dataset.

        Returns:
            int: The number of images in the dataset.
        """
        return len(self.images)

    def __getitem__(self, idx):
        """
        Returns the image, category, and attribute at the given index.

        Args:
            idx (int): The index of the image/category/attribute tuple to return.

        Returns:
            tuple: A tuple of (image, category, attribute), where image is a PIL Image object, category is an integer, and attribute is a string.
        """
        image_path = self.images[idx]
        category_idx = self.y[idx]
        attribute = self.attributes[idx]
        image = Image.open(image_path).convert('RGB')

        
        if self.transform:
            image = self.transform(image)
        return image, category_idx, attribute


class AttributeCIFAR10DataModule(pl.LightningDataModule):

    def __init__(self, args, prefix='0/'):
        super().__init__()
        self.prefix = prefix
        self.hparams = args
        self.root_dir = self.hparams.support_data_dir
        self.categories = self.get_cat()
        self.mean = (0.4914, 0.4822, 0.4465)
        self.std = (0.2471, 0.2435, 0.2616)
        self.transform = T.Compose(
            [
                T.Resize((32, 32)),
                T.ToTensor(),
                T.Normalize(self.mean, self.std),
            ]
        )

    def get_cat(self):
        return ["airplane",  "automobile",
                "bird",  "cat",  "deer", "dog", "frog",
                "horse", "ship", "truck"]

    def train_dataloader(self):
        dataset = AttributeDataset(
            os.path.join(self.root_dir, 'train'),
            self.categories,
            prefix=self.prefix,
            transform=self.transform)
        return DataLoader(dataset,
                          batch_size=self.hparams.batch_size,
                          shuffle=True,
                          num_workers=self.hparams.num_workers,
                          pin_memory=True)

    def val_dataloader(self):
        dataset = AttributeDataset(
            os.path.join(self.root_dir, 'test'),
            self.categories,
            prefix=self.prefix,
            transform=self.transform)
        return DataLoader(dataset,
                          batch_size=self.hparams.batch_size,
                          num_workers=self.hparams.num_workers,
                          pin_memory=True)

    def test_dataloader(self):
        return self.val_dataloader()


class AttributeCIFAR100DataModule(AttributeCIFAR10DataModule):
    def __init__(self, args, prefix='0/'):
        super().__init__(args, prefix)
        self.mean = (0.5070, 0.4865, 0.4409)
        self.std = (0.2673, 0.2564, 0.2761)

    def get_cat(self):
        return ("apple", "aquarium fish", "baby", "bear", "beaver",
                "bed", "bee", "beetle", "bicycle", "bottle", "bowl",
                "boy", "bridge", "bus", "butterfly", "camel", "can",
                "castle", "caterpillar", "cattle", "chair", "chimpanzee",
                "clock", "cloud", "cockroach", "couch", "crab",
                "crocodile", "cup", "dinosaur", "dolphin", "elephant",
                "flatfish", "forest", "fox", "girl", "hamster", "house",
                "kangaroo", "keyboard", "lamp", "lawn mower", "leopard",
                "lion", "lizard", "lobster", "man", "maple tree",
                "motorcycle", "mountain", "mouse", "mushroom",
                "oak tree", "orange", "orchid", "otter",
                "palm tree", "pear", "pickup truck",
                "pine tree", "plain", "plate", "poppy",
                "porcupine", "possum", "rabbit", "raccoon",
                "ray", "road", "rocket", "rose", "sea", "seal",
                "shark", "shrew", "skunk", "skyscraper", "snail",
                "snake", "spider", "squirrel", "streetcar", "sunflower",
                "sweet pepper", "table", "tank", "telephone", "television",
                "tiger", "tractor", "train", "trout", "tulip", "turtle",
                "wardrobe", "whale", "willow tree", "wolf", "woman", "worm")

class AttributeImageNetDataModule(AttributeCIFAR10DataModule):
    def __init__(self, args, prefix='0/'):
        super().__init__(args, prefix)
        self.mean=(0.485, 0.456, 0.406)
        self.std=(0.229, 0.224, 0.225)
        # print_transform = T.Lambda(lambda image: print(image.shape))
        self.transform=T.Compose([
                T.Resize(256),
                T.CenterCrop(224),
                T.ToTensor(),
                # print_transform,
                T.Normalize(self.mean, self.std),
            ])

    def get_cat(self, class_file=None):
        if class_file is None:
            class_file = "imagenet/imagenet-classes.txt"
        with open(class_file) as f:
            classes = f.read().splitlines()
        return classes