import os
from torchvision import transforms
import pandas as pd
import numpy as np
from torch.utils.data import Dataset
from torchvision.datasets.folder import default_loader
from tqdm import tqdm
import cv2 as cv
from PIL import Image
from PIL import ImageFile
import torch

Image.MAX_IMAGE_PIXELS = None  # Disable DecompressionBombError
ImageFile.LOAD_TRUNCATED_IMAGES = True  # Disable OSError: image file is truncated
IMAGE_NET_MEAN = [0.485, 0.456, 0.406]
IMAGE_NET_STD = [0.229, 0.224, 0.225]
normalize = transforms.Normalize(
            mean=IMAGE_NET_MEAN,
            std=IMAGE_NET_STD)


class AVADataset(Dataset):
    def __init__(self, path_to_csv, images_path,if_train):
        self.df = pd.read_csv(path_to_csv)
        self.images_path =  images_path
        self.if_train = if_train
        if if_train:
            self.transform = transforms.Compose([
                transforms.Resize((256, 256)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop((224, 224)),
                transforms.ToTensor(),
                normalize])
        else:
            self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            normalize])

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, item):
        row = self.df.iloc[item]
        y = np.array([row['score'] / 10])
        image_id = row['image']
        image_path = os.path.join(self.images_path, f'{image_id}')
        image = default_loader(image_path)
        x = self.transform(image)
        return x, y.astype('float32')

class BBDataset1(Dataset):
    def __init__(self, file_dir):
        self.test_transformer = transforms.Compose(
            [
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
            ]
        )

        self.pic_paths = []
        image_lists = os.listdir(file_dir)
        image_lists.sort(key=lambda x: int(os.path.basename(x)[:-4]))
        # image_lists.sort(key=lambda x: int(x.split('_')[0]))
        for i in range(len(image_lists)):
            pic_path = os.path.join(file_dir, image_lists[i])
            self.pic_paths.append(pic_path)

    def __len__(self):
        return len(self.pic_paths)

    def __getitem__(self, index):
        pic_path = self.pic_paths[index]
        img = Image.open(pic_path).convert('RGB')
        img = self.test_transformer(img)
        return img, pic_path

class BBDataset2(Dataset):
    def __init__(self, file_dir):
        self.test_transformer = transforms.Compose(
            [
                transforms.Resize((256, 256)),
                transforms.ToTensor(),
            ]
        )

        self.pic_paths = []
        image_lists = os.listdir(file_dir)
        image_lists.sort(key=lambda x: int(os.path.basename(x)[:-4]))
        # image_lists.sort(key=lambda x: int(x.split('_')[0]))
        for i in range(len(image_lists)):
            pic_path = os.path.join(file_dir, image_lists[i])
            self.pic_paths.append(pic_path)

    def __len__(self):
        return len(self.pic_paths)

    def __getitem__(self, index):
        pic_path = self.pic_paths[index]
        img = Image.open(pic_path).convert('RGB')
        img = self.test_transformer(img)
        return img, pic_path


import pandas as pd
from torch.utils.data import Dataset
from PIL import Image


class CustomImageDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.image_labels = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transforms.Compose(
            [
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
            ]
        )

    def __len__(self):
        return len(self.image_labels)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.image_labels.iloc[idx, 0])
        image = Image.open(img_name)
        label = self.image_labels.iloc[idx, 1]

        image = self.transform(image)

        return img_name, image, label/10.0
