# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------

import os
import PIL
import numpy as np
import pandas as pd
from PIL import Image, ImageEnhance

from torchvision import datasets, transforms
from torch.utils.data import Dataset

from timm.data import create_transform
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD


def build_dataset(is_train, args, data_path=None):
    is_gray = args.is_gray
    if data_path is not None:
        is_gray = False
    transform = build_transform(is_train, args, is_gray)

    if args.data_path == '../data/imagenet':
        root = os.path.join(args.data_path, 'train' if is_train else 'val')
    elif args.ds_type:
        root = os.path.join(args.data_path, 'train' if is_train else 'test')
    
    if data_path is not None:
        root = data_path
    dataset = datasets.ImageFolder(root, transform=transform)

    print(dataset)

    return dataset


def build_transform(is_train, args, is_gray=False):
    mean = IMAGENET_DEFAULT_MEAN
    std = IMAGENET_DEFAULT_STD
    # train transform
    if is_train:
        # this should always dispatch to transforms_imagenet_train
        
        transform = create_transform(
            input_size=args.input_size,
            is_training=True,
            color_jitter=args.color_jitter,
            auto_augment=args.aa,
            interpolation='bicubic',
            re_prob=args.reprob,
            re_mode=args.remode,
            re_count=args.recount,
            mean=mean,
            std=std,
        )
        if is_gray:
            transform_pre = transforms.Compose([transforms.Grayscale(3)])
            return transforms.Compose([transform, transform_pre])
        else:
            return transform

    # eval transform
    t = []
    if args.input_size <= 224:
        crop_pct = 224 / 256
    else:
        crop_pct = 1.0
    size = int(args.input_size / crop_pct)
    t.append(
        transforms.Resize(size, interpolation=PIL.Image.BICUBIC),  # to maintain same ratio w.r.t. 224 images
    )
    t.append(transforms.CenterCrop(args.input_size))

    t.append(transforms.ToTensor())
    t.append(transforms.Normalize(mean, std))
    if is_gray:
        t.append(transforms.Grayscale(3))
    return transforms.Compose(t)


class CustomDataset(Dataset):
    def __init__(self, root, transform=None, labels_include=None, color_type='RGB'):
        print(root)
        self.root = root
        self.transform = transform
        self.image_paths = []
        self.labels = []

        self.color_type = color_type

        label_mapping = {}  # 标签名称到数字编码的映射字典
        label_count = 0

        for label in os.listdir(self.root):
            label_path = os.path.join(self.root, label)
            if os.path.isdir(label_path):
                label_mapping[label] = label_count
                label_count += 1
                if labels_include is not None and label_mapping[label] not in labels_include:
                    continue
                print(label, sep=',')
                for image_name in os.listdir(label_path):
                    if image_name.endswith(".png") or image_name.endswith(".jpg") or image_name.endswith(".JPEG") or image_name.endswith(".jpeg"):
                        image_path = os.path.join(label_path, image_name)
                        self.image_paths.append(image_path)
                        self.labels.append(label_mapping[label])
                print(len(self.labels))

    def __getitem__(self, index):
        image_path = self.image_paths[index]
        label = self.labels[index]
        #label=torch.LongTensor(label)
        image = Image.open(image_path)
        if self.color_type in ['RGB', 'L']:
            image = image.convert(self.color_type)
        else:
            image = image.convert('RGB')
            image = ImageEnhance.Color(image).enhance(10)
            # image = np.array(image)
            # image[:, :, 0] = 0
            # image = Image.fromarray(image, mode='RGB')
        if self.transform is not None:
            image = self.transform(image)
        return image, label

    def __len__(self):
        return len(self.image_paths)
    

class CustomDataset_longtail(Dataset):
    def __init__(self, root, transform=None, sample_dict_path='imagenet10_sample_dict_all.npy', reserve_idx=0, color_type='RGB'):
        print(root)
        self.root = root
        self.transform = transform
        self.image_paths = []
        self.labels = []

        self.color_type = color_type

        self.sample_dict = np.load(sample_dict_path, allow_pickle=True).item()[reserve_idx]

        label_mapping = {}  # 标签名称到数字编码的映射字典
        label_count = 0

        for label in os.listdir(self.root):
            label_path = os.path.join(self.root, label)
            if os.path.isdir(label_path):
                label_mapping[label] = label_count
                label_count += 1
                print(label, sep=',')
                for image_name in os.listdir(label_path):
                    if image_name in self.sample_dict[label_count - 1].values:
                    # if image_name.endswith(".png") or image_name.endswith(".jpg") or image_name.endswith(".JPEG") or image_name.endswith(".jpeg"):
                        image_path = os.path.join(label_path, image_name)
                        self.image_paths.append(image_path)
                        self.labels.append(label_mapping[label])
                print(len(self.labels))

    def __getitem__(self, index):
        image_path = self.image_paths[index]
        label = self.labels[index]
        #label=torch.LongTensor(label)
        image = Image.open(image_path)
        image = image.convert(self.color_type)

        if self.transform is not None:
            image = self.transform(image)
        return image, label

    def __len__(self):
        return len(self.image_paths)


class CustomDataset_selected(Dataset):
    def __init__(self, root, transform=None, labels_include=None, color_type='RGB'):
        self.root = root
        self.transform = transform
        self.image_paths = []
        self.labels = []

        self.color_type = color_type
        count = np.zeros(11)

        label_mapping = {}  # 标签名称到数字编码的映射字典
        label_count = 0

        for label in os.listdir(self.root):
            label_path = os.path.join(self.root, label)
            if os.path.isdir(label_path):
                label_mapping[label] = label_count
                label_count += 1
                print(label, sep=',')
                if labels_include is not None and label_mapping[label] not in labels_include:
                    continue
                for image_name in os.listdir(label_path):
                    if image_name.endswith(".png") or image_name.endswith(".jpg") or image_name.endswith(".JPEG") or image_name.endswith(".jpeg"):
                        if count[label_count] > 999:
                            break
                        count[label_count] += 1
                        image_path = os.path.join(label_path, image_name)
                        self.image_paths.append(image_path)
                        self.labels.append(label_mapping[label])

    def __getitem__(self, index):
        image_path = self.image_paths[index]
        label = self.labels[index]
        #label=torch.LongTensor(label)
        image = Image.open(image_path)
        if self.color_type in ['RGB', 'L']:
            image = image.convert(self.color_type)
        else:
            image = image.convert('RGB')
            image = ImageEnhance.Color(image).enhance(10)
            # image = np.array(image)
            # image[:, :, 0] = 0
            # image = Image.fromarray(image, mode='RGB')
        if self.transform is not None:
            image = self.transform(image)
        return image, label

    def __len__(self):
        return len(self.image_paths)
    

class CustomDataset_noclass(Dataset):
    def __init__(self, root, transform=None, labels_include=None, color_type='RGB'):
        self.root = root
        self.transform = transform
        self.image_paths = []
        self.labels = []

        self.color_type = color_type
        count = np.zeros(11)

        label_mapping = {}  # 标签名称到数字编码的映射字典
        label_count = 0

        for label in os.listdir(self.root):
            label_path = os.path.join(self.root, label)
            if os.path.isdir(label_path):
                label_mapping[label] = label_count
                label_count += 1
                print(label, sep=',')
                if labels_include is not None and label_mapping[label] not in labels_include:
                    continue
                for image_name in os.listdir(label_path):
                    if image_name.endswith(".png") or image_name.endswith(".jpg") or image_name.endswith(".JPEG") or image_name.endswith(".jpeg"):
                        if count[label_count] > 999:
                            break
                        count[label_count] += 1
                        image_path = os.path.join(label_path, image_name)
                        self.image_paths.append(image_path)
                        self.labels.append(label_mapping[label])

    def __getitem__(self, index):
        image_path = self.image_paths[index]
        label = self.labels[index]
        #label=torch.LongTensor(label)
        image = Image.open(image_path)
        if self.color_type in ['RGB', 'L']:
            image = image.convert(self.color_type)
        else:
            image = image.convert('RGB')
            image = ImageEnhance.Color(image).enhance(10)
            # image = np.array(image)
            # image[:, :, 0] = 0
            # image = Image.fromarray(image, mode='RGB')
        if self.transform is not None:
            image = self.transform(image)
        return image, label

    def __len__(self):
        return len(self.image_paths)
    

class CustomDataset_ImgList(Dataset):
    def __init__(self, root, file_path='', transform=None):
        self.root = root
        self.transform = transform
        self.image_paths = []
        self.labels = []
        count = 0
        with open(file_path, 'r') as f:
            lines = f.readlines()
            for line in lines:
                count += 1
                self.image_paths.append(line.split(' ')[0])
                self.labels.append(int(line.split(' ')[1]))

        print('Count: ', count)

    def __getitem__(self, index):
        image_path = os.path.join(self.root, self.image_paths[index])
        label = self.labels[index]
        #label=torch.LongTensor(label)
        image = Image.open(image_path)
        image = image.convert('RGB')
        if self.transform is not None:
            image = self.transform(image)
        return image, label

    def __len__(self):
        return len(self.image_paths)
    

if __name__ == '__main__':
    dataset = CustomDataset_ImgList('/data4/jiangy/OpenOOD-main/data/images_classic', '/data4/jiangy/OpenOOD-main/data/benchmark_imglist/cifar10/test_cifar100.txt')
    print(len(dataset))