import torch
import numpy as np
import cv2
import os
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from glob import glob
from utils import *
from iwssim import *
from tqdm import tqdm
from torchvision import transforms as T
import random
from dlm.dlm import dlm
from torchmetrics.image import VisualInformationFidelity


def to_torch(x, device="cpu"):
    x = torch.from_numpy(x)
    x = x.permute(2, 0, 1).unsqueeze(0)
    x = x.type(torch.FloatTensor).to(device)
    return x

def get_image(path):
    image = cv2.imread(path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image_torch = image / 255
    return image, to_torch(image_torch)

# class VideoDataset(Dataset):
#     def __init__(self, path, device="cuda"):
#         self.path = path
#         self.device = device
#         self.df = pd.read_csv(os.path.join(path, 'dmos.csv'))[['dist_img', 'ref_img', 'dmos']]

#     def __len__(self):
#         return len(self.df)

#     def __getitem__(self, idx):

#         dist_path, gt_path, mos = self.df.iloc[idx]
#         mos = torch.tensor(mos).float()

#         dist_im = get_image(os.path.join(self.path, "images", dist_path)).to(self.device)
#         gt_im = get_image(os.path.join(self.path, "reference", gt_path)).to(self.device)

#         iwssim = IWSSSIM(self.device).to(self.device)
        
#         maps = iwssim(gt_im, dist_im)[0]

#         return (maps, mos)
    

# class VideoDatasetPyr(Dataset):
#     def __init__(self, path, device="cuda", INCLUDE_COLOR=False):
#         self.path = path
#         self.device = device
#         self.INCLUDE_COLOR = INCLUDE_COLOR
#         self.df = pd.read_csv(os.path.join(path, 'dmos.csv'))[['dist_img', 'ref_img', 'dmos']]

#     def __len__(self):
#         return len(self.df)

#     def __getitem__(self, idx):

#         dist_path, gt_path, mos = self.df.iloc[idx]
#         mos = torch.tensor(mos).float()
#         print(dist_path, gt_path)

#         dist_im, dist_im_torch = get_image(os.path.join(self.path, "images", dist_path))
#         gt_im, gt_im_torch = get_image(os.path.join(self.path, "reference", gt_path))

#         dist_im_torch = dist_im_torch.to(self.device)
#         gt_im_torch = gt_im_torch.to(self.device)

#         iwssim = IWSSSIMpyr(INCLUDE_COLOR=self.INCLUDE_COLOR)

#         iwssim = iwssim.to(self.device)

#         if random.random() > 0.5:
#             gt_im_torch = T.functional.vflip(gt_im_torch)
#             dist_im_torch = T.functional.vflip(dist_im_torch)
#         if random.random() > 0.5:
#             gt_im_torch = T.functional.hflip(gt_im_torch)
#             dist_im_torch = T.functional.hflip(dist_im_torch)
        
#         maps = iwssim(gt_im_torch, dist_im_torch)

#         dlm_feat = dlm(gt_im, dist_im)
#         dlm_feat = torch.from_numpy(dlm_feat).type(torch.FloatTensor)

#         return (maps, dlm_feat, mos)