import csv
import json
import os
import warnings

import numpy as np
import pandas as pd
import torch.utils.data as data
from PIL import Image
from scipy import io
import random

def getTIDFileName(path, suffix):
    filename = []
    f_list = os.listdir(path)
    for i in f_list:
        if suffix.find(os.path.splitext(i)[1]) != -1:
            filename.append(i[1:3])
    return filename

def getFileName(path, suffix):
    filename = []
    f_list = os.listdir(path)
    for i in f_list:
        if os.path.splitext(i)[1] == suffix:
            filename.append(i)
    return filename

class KONIQDATASET(data.Dataset):
    def __init__(self, root, HQ_diff_content_root, index, patch_num, transform=None, HQ_diff_content_transform=None):
        super(KONIQDATASET, self).__init__()

        self.HQ_diff_content_root = HQ_diff_content_root
        self.data_path = root
        imgname = []
        mos_all = []
        csv_file = os.path.join(root, "koniq10k_scores_and_distributions.csv")
        with open(csv_file) as f:
            reader = csv.DictReader(f)
            for row in reader:
                imgname.append(row["image_name"])
                mos = np.array(float(row["MOS_zscore"])).astype(np.float32)
                mos_all.append(mos)

        sample = []
        for _, item in enumerate(index):
            for _ in range(patch_num):
                sample.append(
                    (os.path.join(root, "1024x768", imgname[item]), mos_all[item])
                )
        self.HQ_diff_content_path = []
        for HQ_diff_content_img_name in os.listdir(HQ_diff_content_root):
            if HQ_diff_content_img_name[-3:] == 'png' or HQ_diff_content_img_name[-3:] == 'jpg' or HQ_diff_content_img_name[-3:] == 'bmp':
                self.HQ_diff_content_path.append(os.path.join(HQ_diff_content_root, HQ_diff_content_img_name))

        self.samples = sample
        self.transform = transform
        self.HQ_diff_content_transform = HQ_diff_content_transform

    def _load_image(self, path):
        try:
            im = Image.open(path).convert("RGB")
        except:
            print("ERROR IMG LOADED: ", path)
            random_img = np.random.rand(224, 224, 3) * 255
            im = Image.fromarray(np.uint8(random_img))
        return im

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        # LQ_path, HQ_path, target = self.samples[index]
        LQ_path,target = self.samples[index]
        HQ_diff_content_path = self.HQ_diff_content_path[random.randint(0, len(self.HQ_diff_content_path) - 1)]
        LQ =self._load_image(LQ_path)
        HQ_diff_content = self._load_image(HQ_diff_content_path)
        # HQ =self._load_image(HQ_path)
        LQ = self.transform(LQ)
        HQ_diff_content = self.HQ_diff_content_transform(HQ_diff_content)
        # HQ = self.transform(HQ)
        return LQ, LQ , HQ_diff_content, target

    def __len__(self):
        length = len(self.samples)
        return length


class LIVECDATASET(data.Dataset):
    def __init__(self, root, HQ_diff_content_root, index, patch_num, transform=None, HQ_diff_content_transform=None):

        self.HQ_diff_content_root = HQ_diff_content_root

        imgpath = io.loadmat(os.path.join(root, "Data", "AllImages_release.mat"))
        imgpath = imgpath["AllImages_release"]
        imgpath = imgpath[7:1169]
        mos = io.loadmat(os.path.join(root, "Data", "AllMOS_release.mat"))
        labels = mos["AllMOS_release"].astype(np.float32)
        labels = labels[0][7:1169]

        sample = []
        for i, item in enumerate(index):
            for aug in range(patch_num):
                sample.append(
                    (os.path.join(root, "Images", imgpath[item][0][0]), labels[item])
                )
        self.HQ_diff_content_path = []
        for HQ_diff_content_img_name in os.listdir(HQ_diff_content_root):
            if HQ_diff_content_img_name[-3:] == 'png' or HQ_diff_content_img_name[-3:] == 'jpg' or HQ_diff_content_img_name[-3:] == 'bmp':
                self.HQ_diff_content_path.append(os.path.join(HQ_diff_content_root, HQ_diff_content_img_name))

        self.samples = sample
        self.transform = transform
        self.HQ_diff_content_transform = HQ_diff_content_transform

    def _load_image(self, path):
        try:
            im = Image.open(path).convert("RGB")
        except:
            print("ERROR IMG LOADED: ", path)
            random_img = np.random.rand(224, 224, 3) * 255
            im = Image.fromarray(np.uint8(random_img))
        return im

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        LQ_path,target = self.samples[index]
        HQ_diff_content_path = self.HQ_diff_content_path[random.randint(0, len(self.HQ_diff_content_path) - 1)]
        LQ =self._load_image(LQ_path)
        HQ_diff_content = self._load_image(HQ_diff_content_path)
        # HQ =self._load_image(HQ_path)
        LQ = self.transform(LQ)
        HQ_diff_content = self.HQ_diff_content_transform(HQ_diff_content)
        # HQ = self.transform(HQ)
        return LQ, LQ , HQ_diff_content, target

    def __len__(self):
        length = len(self.samples)
        return length


class LIVEDataset(data.Dataset):
    def __init__(self, root, HQ_diff_content_root, index, patch_num, transform=None, HQ_diff_content_transform=None):

        self.HQ_diff_content_root = HQ_diff_content_root

        refpath = os.path.join(root, "refimgs")
        refname = getFileName(refpath, ".bmp")

        jp2kroot = os.path.join(root, "jp2k")
        jp2kname = self.getDistortionTypeFileName(jp2kroot, 227)

        jpegroot = os.path.join(root, "jpeg")
        jpegname = self.getDistortionTypeFileName(jpegroot, 233)

        wnroot = os.path.join(root, "wn")
        wnname = self.getDistortionTypeFileName(wnroot, 174)

        gblurroot = os.path.join(root, "gblur")
        gblurname = self.getDistortionTypeFileName(gblurroot, 174)

        fastfadingroot = os.path.join(root, "fastfading")
        fastfadingname = self.getDistortionTypeFileName(fastfadingroot, 174)

        imgpath = jp2kname + jpegname + wnname + gblurname + fastfadingname

        dmos = io.loadmat(os.path.join(root, "dmos_realigned.mat"))
        labels = dmos["dmos_new"].astype(np.float32)

        orgs = dmos["orgs"]
        refnames_all = io.loadmat(os.path.join(root, "refnames_all.mat"))
        refnames_all = refnames_all["refnames_all"]

        refname.sort()
        sample = []

        for i in range(0, len(index)):
            train_sel = refname[index[i]] == refnames_all
            train_sel = train_sel * ~orgs.astype(np.bool_)
            train_sel = np.where(train_sel == True)
            train_sel = train_sel[1].tolist()
            for j, item in enumerate(train_sel):
                for aug in range(patch_num):
                    LQ_path = imgpath[item]
                    HQ_path = os.path.join(root, 'refimgs', refnames_all[0][item][0])
                    label = labels[0][item]
                    sample.append((LQ_path, HQ_path, label))
        self.HQ_diff_content_path = []
        for HQ_diff_content_img_name in os.listdir(HQ_diff_content_root):
            if HQ_diff_content_img_name[-3:] == 'png' or HQ_diff_content_img_name[-3:] == 'jpg' or HQ_diff_content_img_name[-3:] == 'bmp':
                self.HQ_diff_content_path.append(os.path.join(HQ_diff_content_root, HQ_diff_content_img_name))

        self.samples = sample
        self.transform = transform
        self.HQ_diff_content_transform = HQ_diff_content_transform

    def _load_image(self, path):
        try:
            im = Image.open(path).convert("RGB")
        except:
            print("ERROR IMG LOADED: ", path)
            random_img = np.random.rand(224, 224, 3) * 255
            im = Image.fromarray(np.uint8(random_img))
        return im

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        # path, target = self.samples[index]
        # sample = self._load_image(path)
        LQ_path, HQ_path, target = self.samples[index]
        HQ_diff_content_path = self.HQ_diff_content_path[random.randint(0, len(self.HQ_diff_content_path) - 1)]
        LQ =self._load_image(LQ_path)
        HQ =self._load_image(HQ_path)
        HQ_diff_content = self._load_image(HQ_diff_content_path)
        if self.transform is not None:
            LQ = self.transform(LQ)
            HQ = self.transform(HQ)
            HQ_diff_content = self.HQ_diff_content_transform(HQ_diff_content)

        return LQ, HQ,HQ_diff_content, target

    def __len__(self):
        length = len(self.samples)
        return length

    def getDistortionTypeFileName(self, path, num):
        filename = []
        index = 1
        for i in range(0, num):
            name = "%s%s%s" % ("img", str(index), ".bmp")
            filename.append(os.path.join(path, name))
            index = index + 1
        return filename

class TID2013Dataset(data.Dataset):
    def __init__(self, root, HQ_diff_content_root, index, patch_num, transform=None, HQ_diff_content_transform=None):

        self.HQ_diff_content_root = HQ_diff_content_root

        refpath = os.path.join(root, "reference_images")
        refname = getTIDFileName(refpath, ".bmp.BMP")
        txtpath = os.path.join(root, "mos_with_names.txt")
        fh = open(txtpath, "r")
        imgnames = []
        target = []
        refnames_all = []
        for line in fh:
            line = line.split("\n")
            words = line[0].split()
            imgnames.append((words[1]))
            target.append(words[0])
            ref_temp = words[1].split("_")
            refnames_all.append(ref_temp[0][1:])
        labels = np.array(target).astype(np.float32)
        refnames_all = np.array(refnames_all)

        refname.sort()
        sample = []
        for i, item in enumerate(index):
            train_sel = refname[index[i]] == refnames_all
            train_sel = np.where(train_sel == True)
            train_sel = train_sel[0].tolist()
            for j, item in enumerate(train_sel):
                for aug in range(patch_num):
                    LQ_path = os.path.join(root, 'distorted_images', imgnames[item])
                    HQ_name = 'I' + imgnames[item].split("_")[0][1:] + '.BMP'
                    HQ_path = os.path.join(refpath, HQ_name)
                    label = labels[item]
                    sample.append((LQ_path, HQ_path, label))
        self.HQ_diff_content_path = []
        for HQ_diff_content_img_name in os.listdir(HQ_diff_content_root):
            if HQ_diff_content_img_name[-3:] == 'png' or HQ_diff_content_img_name[
                                                         -3:] == 'jpg' or HQ_diff_content_img_name[-3:] == 'bmp':
                self.HQ_diff_content_path.append(os.path.join(HQ_diff_content_root, HQ_diff_content_img_name))

        self.samples = sample
        self.transform = transform
        self.HQ_diff_content_transform = HQ_diff_content_transform

    def _load_image(self, path):
        try:
            im = Image.open(path).convert("RGB")
        except:
            print("ERROR IMG LOADED: ", path)
            random_img = np.random.rand(224, 224, 3) * 255
            im = Image.fromarray(np.uint8(random_img))
        return im

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        LQ_path, HQ_path, target = self.samples[index]
        HQ_diff_content_path = self.HQ_diff_content_path[random.randint(0, len(self.HQ_diff_content_path) - 1)]
        LQ =self._load_image(LQ_path)
        HQ =self._load_image(HQ_path)
        HQ_diff_content = self._load_image(HQ_diff_content_path)
        if self.transform is not None:
            LQ = self.transform(LQ)
            HQ = self.transform(HQ)
            HQ_diff_content = self.HQ_diff_content_transform(HQ_diff_content)

        return LQ, HQ, HQ_diff_content, target

    def __len__(self):
        length = len(self.samples)
        return length


class CSIQDataset(data.Dataset):
    def __init__(self, root, HQ_diff_content_root, index, patch_num, transform=None, HQ_diff_content_transform=None):
        self.HQ_diff_content_root = HQ_diff_content_root

        refpath = os.path.join(root, "src_imgs")
        refname = getFileName(refpath, ".png")
        txtpath = os.path.join(root, "csiq_label.txt")
        fh = open(txtpath, "r")
        imgnames = []
        target = []
        refnames_all = []
        for line in fh:
            line = line.split("\n")
            words = line[0].split()#1600.AWGN.1 0.062
            imgnames.append((words[0])+ "."+"png")#1600.AWGN.1
            target.append(words[1])#0.062
            ref_temp = words[0].split(".")#1600 AWGN 1
            refnames_all.append(ref_temp[0] + "." + "png")#1600.png

        labels = np.array(target).astype(np.float32)
        refnames_all = np.array(refnames_all)

        sample = []

        for i, item in enumerate(index):
            train_sel = refname[index[i]] == refnames_all
            train_sel = np.where(train_sel == True)
            train_sel = train_sel[0].tolist()
            for j, item in enumerate(train_sel):
                for aug in range(patch_num):
                    LQ_path = os.path.join(root, 'dst_imgs_all', imgnames[item])
                    HQ_path = os.path.join(root, 'src_imgs', refnames_all[item])
                    label = labels[item]
                    sample.append((LQ_path, HQ_path, label))
        self.HQ_diff_content_path = []
        for HQ_diff_content_img_name in os.listdir(HQ_diff_content_root):
            if HQ_diff_content_img_name[-3:] == 'png' or HQ_diff_content_img_name[-3:] == 'jpg' or HQ_diff_content_img_name[-3:] == 'bmp':
                self.HQ_diff_content_path.append(os.path.join(HQ_diff_content_root, HQ_diff_content_img_name))

        self.samples = sample
        self.transform = transform
        self.HQ_diff_content_transform = HQ_diff_content_transform

    def _load_image(self, path):
        try:
            im = Image.open(path).convert("RGB")
        except:
            print("ERROR IMG LOADED: ", path)
            random_img = np.random.rand(224, 224, 3) * 255
            im = Image.fromarray(np.uint8(random_img))
        return im

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        LQ_path, HQ_path, target = self.samples[index]
        HQ_diff_content_path = self.HQ_diff_content_path[random.randint(0, len(self.HQ_diff_content_path) - 1)]
        LQ =self._load_image(LQ_path)
        HQ =self._load_image(HQ_path)
        HQ_diff_content = self._load_image(HQ_diff_content_path)
        if self.transform is not None:
            LQ = self.transform(LQ)
            HQ = self.transform(HQ)
            HQ_diff_content = self.HQ_diff_content_transform(HQ_diff_content)

        return LQ, HQ, HQ_diff_content, target

    def __len__(self):
        length = len(self.samples)
        return length


class KADIDDataset(data.Dataset):
    def __init__(self, root, HQ_diff_content_root, index, patch_num, transform=None, HQ_diff_content_transform=None):

        self.HQ_diff_content_root = HQ_diff_content_root

        refpath = os.path.join(root, "reference_images")
        # refname = getTIDFileName(refpath, ".png.PNG")
        refname = getFileName(refpath, ".png")

        imgnames = []
        target = []
        refnames_all = []

        csv_file = os.path.join(root, "dmos.csv")
        with open(csv_file) as f:
            reader = csv.DictReader(f)
            for row in reader:
                imgnames.append(row["dist_img"])
                refnames_all.append(row["ref_img"])

                mos = np.array(float(row["dmos"])).astype(np.float32)
                target.append(mos)

        labels = np.array(target).astype(np.float32)
        refnames_all = np.array(refnames_all)

        refname.sort()
        sample = []
        for i, item in enumerate(index):
            train_sel = refname[index[i]] == refnames_all
            train_sel = np.where(train_sel == True)
            train_sel = train_sel[0].tolist()
            for j, item in enumerate(train_sel):
                for _ in range(patch_num):
                    sample.append(
                        (
                            os.path.join(root, 'images', imgnames[item]),
                            os.path.join(root, 'reference_images', refnames_all[item]),
                            labels[item]
                        )
                    )
        self.HQ_diff_content_path = []
        for HQ_diff_content_img_name in os.listdir(HQ_diff_content_root):
            if HQ_diff_content_img_name[-3:] == 'png' or HQ_diff_content_img_name[-3:] == 'jpg' or HQ_diff_content_img_name[-3:] == 'bmp':
                self.HQ_diff_content_path.append(os.path.join(HQ_diff_content_root, HQ_diff_content_img_name))

        self.samples = sample
        self.transform = transform
        self.HQ_diff_content_transform = HQ_diff_content_transform

    def _load_image(self, path):
        try:
            im = Image.open(path).convert("RGB")
        except:
            print("ERROR IMG LOADED: ", path)
            random_img = np.random.rand(224, 224, 3) * 255
            im = Image.fromarray(np.uint8(random_img))
        return im

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        LQ_path, HQ_path, target = self.samples[index]
        HQ_diff_content_path = self.HQ_diff_content_path[random.randint(0, len(self.HQ_diff_content_path) - 1)]
        LQ =self._load_image(LQ_path)
        HQ =self._load_image(HQ_path)
        HQ_diff_content = self._load_image(HQ_diff_content_path)
        if self.transform is not None:
            LQ_student = self.transform(LQ)
            HQ = self.transform(HQ)
            HQ_diff_content = self.transform(HQ_diff_content)
            # HQ_diff_content = self.HQ_diff_content_transform(HQ_diff_content)
            LQ_teacher = self.HQ_diff_content_transform(LQ)

        return LQ_student, LQ_student, HQ_diff_content, target

    def __len__(self):
        length = len(self.samples)
        return length


class SPAQDATASET(data.Dataset):
    def __init__(self, root, HQ_diff_content_root, index, patch_num, transform=None, HQ_diff_content_transform=None):
        super(SPAQDATASET, self).__init__()

        self.HQ_diff_content_root = HQ_diff_content_root
        self.data_path = root
        anno_folder = os.path.join(self.data_path, "Annotations")
        xlsx_file = os.path.join(anno_folder, "MOS and Image attribute scores.xlsx")
        read = pd.read_excel(xlsx_file)
        imgname = read["Image name"].values.tolist()
        mos_all = read["MOS"].values.tolist()
        for i in range(len(mos_all)):
            mos_all[i] = np.array(mos_all[i]).astype(np.float32)
        sample = []
        for _, item in enumerate(index):
            for _ in range(patch_num):
                sample.append(
                    (
                        os.path.join(
                            self.data_path,
                            "TestImage",
                            imgname[item],
                        ),
                        mos_all[item],
                    )
                )
        self.HQ_diff_content_path = []
        for HQ_diff_content_img_name in os.listdir(HQ_diff_content_root):
            if HQ_diff_content_img_name[-3:] == 'png' or HQ_diff_content_img_name[-3:] == 'jpg' or HQ_diff_content_img_name[-3:] == 'bmp':
                self.HQ_diff_content_path.append(os.path.join(HQ_diff_content_root, HQ_diff_content_img_name))

        self.samples = sample
        self.transform = transform
        self.HQ_diff_content_transform = HQ_diff_content_transform

    def _load_image(self, path):
        try:
            im = Image.open(path).convert("RGB")
        except:
            print("ERROR IMG LOADED: ", path)
            random_img = np.random.rand(224, 224, 3) * 255
            im = Image.fromarray(np.uint8(random_img))
        return im

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        # LQ_path, HQ_path, target = self.samples[index]
        # LQ =self._load_image(LQ_path)
        # HQ =self._load_image(HQ_path)
        # LQ = self.transform(LQ)
        # HQ = self.transform(HQ)
        # return LQ, HQ, target

        LQ_path,target = self.samples[index]
        HQ_diff_content_path = self.HQ_diff_content_path[random.randint(0, len(self.HQ_diff_content_path) - 1)]
        LQ =self._load_image(LQ_path)
        HQ_diff_content = self._load_image(HQ_diff_content_path)
        # HQ =self._load_image(HQ_path)
        LQ_student = self.transform(LQ)
        # LQ_teacher = self.HQ_diff_content_transform(LQ)
        HQ_diff_content = self.transform(HQ_diff_content)
        # HQ = self.transform(HQ)
        return LQ_student, LQ_student , HQ_diff_content, target

    def __len__(self):
        length = len(self.samples)
        return length

class CID2013Folder(data.Dataset):

    def __init__(self, root, HQ_diff_content_root, index, patch_num, transform=None, HQ_diff_content_transform=None):#root:data/qgy/.../CID2013

        self.HQ_diff_content_root = HQ_diff_content_root
        imgnames = []
        target = []
        print('index', index)
        # ：0~380

        for i in range(len(index)):#：0~380
            n = int(index[i])
            if n == 0:
                txtpath = os.path.join(root, 'IS1.txt')
                fh = open(txtpath, 'r')
                for line in fh:
                    line = line.split('\n')
                    words = line[0].split()
                    imgnames.append((words[0]))
                    target.append(words[1])

            elif n == 1:
                txtpath = os.path.join(root, 'IS2.txt')
                fh = open(txtpath, 'r')
                for line in fh:
                    line = line.split('\n')
                    words = line[0].split()
                    imgnames.append((words[0]))
                    target.append(words[1])

            elif n == 2:
                txtpath = os.path.join(root, 'IS3.txt')
                fh = open(txtpath, 'r')
                for line in fh:
                    line = line.split('\n')
                    words = line[0].split()
                    imgnames.append((words[0]))
                    target.append(words[1])

            elif n == 3:
                txtpath = os.path.join(root, 'IS4.txt')
                fh = open(txtpath, 'r')
                for line in fh:
                    line = line.split('\n')
                    words = line[0].split()
                    imgnames.append((words[0]))
                    target.append(words[1])

            elif n == 4:
                txtpath = os.path.join(root, 'IS5.txt')
                fh = open(txtpath, 'r')
                for line in fh:
                    line = line.split('\n')
                    words = line[0].split()
                    imgnames.append((words[0]))
                    target.append(words[1])

            elif n == 5:
                txtpath = os.path.join(root, 'IS6.txt')
                fh = open(txtpath, 'r')
                for line in fh:
                    line = line.split('\n')
                    words = line[0].split()
                    imgnames.append((words[0])) #image 名字
                    target.append(words[1]) #image 分数

        labels = np.array(target).astype(np.float32)


        sample = []
        for i in range(len(imgnames)):
            for aug in range(patch_num):
                sample.append((os.path.join(root, 'images', imgnames[i]+'.jpg'), labels[i]))

        self.HQ_diff_content_path = []
        for HQ_diff_content_img_name in os.listdir(HQ_diff_content_root):
            if HQ_diff_content_img_name[-3:] == 'png' or HQ_diff_content_img_name[-3:] == 'jpg' or HQ_diff_content_img_name[-3:] == 'bmp':
                self.HQ_diff_content_path.append(os.path.join(HQ_diff_content_root, HQ_diff_content_img_name))

        self.samples = sample
        self.transform = transform
        self.HQ_diff_content_transform = HQ_diff_content_transform

    def _load_image(self, path):
        try:
            im = Image.open(path).convert("RGB")
        except:
            print("ERROR IMG LOADED: ", path)
            random_img = np.random.rand(224, 224, 3) * 255
            im = Image.fromarray(np.uint8(random_img))
        return im

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        LQ_path, target = self.samples[index]
        HQ_diff_content_path = self.HQ_diff_content_path[random.randint(0, len(self.HQ_diff_content_path) - 1)]
        LQ =self._load_image(LQ_path)
        HQ_diff_content = self._load_image(HQ_diff_content_path)
        LQ = self.transform(LQ)
        HQ_diff_content = self.transform(HQ_diff_content)
        return LQ, LQ, HQ_diff_content, target

    def __len__(self):
        length = len(self.samples)
        return length


class FBLIVEFolder(data.Dataset):
    def __init__(self, root, HQ_diff_content_root, index, patch_num, transform=None, HQ_diff_content_transform=None):

        self.HQ_diff_content_root = HQ_diff_content_root

        imgname = []
        mos_all = []
        csv_file = os.path.join(root, "labels_image.csv")
        with open(csv_file) as f:
            reader = csv.DictReader(f)
            for row in reader:
                imgname.append(row["name"])
                mos = np.array(float(row["mos"])).astype(np.float32)
                mos_all.append(mos)

        sample = []
        for i, item in enumerate(index):
            for aug in range(patch_num):
                sample.append(
                    (os.path.join(root, "database", imgname[item]), mos_all[item])
                )
        self.HQ_diff_content_path = []
        for HQ_diff_content_img_name in os.listdir(HQ_diff_content_root):
            if HQ_diff_content_img_name[-3:] == 'png' or HQ_diff_content_img_name[-3:] == 'jpg' or HQ_diff_content_img_name[-3:] == 'bmp':
                self.HQ_diff_content_path.append(os.path.join(HQ_diff_content_root, HQ_diff_content_img_name))

        self.samples = sample
        self.transform = transform
        self.HQ_diff_content_transform = HQ_diff_content_transform

    def _load_image(self, path):
        try:
            im = Image.open(path).convert("RGB")
        except:
            print("ERROR IMG LOADED: ", path)
            random_img = np.random.rand(224, 224, 3) * 255
            im = Image.fromarray(np.uint8(random_img))
        return im

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        LQ_path, target = self.samples[index]
        HQ_diff_content_path = self.HQ_diff_content_path[random.randint(0, len(self.HQ_diff_content_path) - 1)]
        LQ =self._load_image(LQ_path)
        # HQ =self._load_image(HQ_path)
        HQ_diff_content = self._load_image(HQ_diff_content_path)
        if self.transform is not None:
            LQ = self.transform(LQ)
            # HQ = self.transform(HQ)
            HQ_diff_content = self.transform(HQ_diff_content)

        return LQ, LQ, HQ_diff_content, target

    def __len__(self):
        length = len(self.samples)
        return length
