import os

from torchvision import transforms
import torch
import numpy as np
import cv2
import os
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from glob import glob
from utils import *
from iwssim import *
from tqdm import tqdm
from time import time
from dct import DCT
from torchvision.io import decode_jpeg, encode_jpeg

from pytorch_grad_cam import GradCAMPlusPlus # type: ignore
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget # type: ignore
from lbp import get_lbp_features
from kornia.color import rgb_to_lab

import os, random

import pandas as pd
from PIL import Image
from torch.utils.data import Dataset
from dlm.dlm import dlm
from torchmetrics.image import VisualInformationFidelity


def apply_jpeg(x: torch.Tensor, quality: int) -> torch.Tensor:
    return decode_jpeg(encode_jpeg(x, quality))

class IQADatasetPyTorch(Dataset):
    def __init__(self, name, device = "cuda",
                 args = None):
        
        # csv_file, name, dataset_root, attributes,
        dataset_root = '/home/24a_guh/data' 
        csv_file = f"{dataset_root}/{name}/{name}1.txt"

        self.name, self.device, self.dataset_root = name, device, dataset_root
        self.return_path, self.flip = args['return_path'], args['flip']
        self.dlm, self.vif, self.saliency, self.lbp = args['dlm'], args['vif'], args['saliency'], args['lbp']
        self.saliency_model = args['saliency_model']
        self.ret_image = args['return_path']
        self.part = args['part']
        self.simple = args['simple']
        self.vif_class = VisualInformationFidelity().to(self.device)
        self.iwssim = IWSSSIMpyr(INCLUDE_COLOR=True, LBP=args['lbp']).to(self.device)
        self.df = pd.read_csv(csv_file, dtype=str)
        if self.part == 'train':
            self.df = self.df.loc[self.df['part'] == 'train']
        if self.part == 'test':
            self.df = self.df.loc[self.df['part'] == 'test']

        self.df = self.df.reset_index(drop=True)
        if args['dct']:
            self.DCT = DCT(patch_size=64)
        else:
            self.DCT = None

        self.length = len(self.df)

    def to_torch(self, x):
        x = torch.from_numpy(x)
        x = x.permute(2, 0, 1).unsqueeze(0)
        x = x.type(torch.FloatTensor).to(self.device)
        return x

    def get_image(self, path):
        image = cv2.imread(path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image_torch = image / 255
        return self.to_torch(image_torch)

    def __str__(self):
        return f"IQADataset ({self.name}), attributes: {self.attributes}"

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        st = time()
        dist_path, gt_path, mos = self.df["dis_img_path"][idx], self.df["ref_img_path"][idx], float(self.df["score"][idx])
        # print(dist_path, gt_path)

        # if self.name == 'LIVE':
        #     mos = mos * 5 / 100
        # print(mos)
        # mos = torch.tensor(mos).to(self.device)

        dist_im = self.get_image(os.path.join(self.dataset_root, dist_path))
        gt_im   = self.get_image(os.path.join(self.dataset_root, gt_path))
        # print("read image", time() - st)
        st = time()

        if self.ret_image:
            return dist_im

        if self.flip and random.random() > 0.5:
            gt_im = T.functional.vflip(gt_im)
            dist_im = T.functional.vflip(dist_im)
        if self.flip and random.random() > 0.5:
            gt_im = T.functional.hflip(gt_im)
            dist_im = T.functional.hflip(dist_im)

        if self.simple:
            resize = T.Resize((224,224))
            dist_im = resize(dist_im)
            gt_im = resize(gt_im)
            color_x = rgb_to_lab(dist_im)[:,:,2:-2,2:-2]
            color_y = rgb_to_lab(gt_im)[:,:,2:-2,2:-2]
            gray = T.Grayscale()
            gt_im   = get_lbp_features(gray(gt_im))
            dist_im = get_lbp_features(gray(dist_im))
            sample = torch.cat((dist_im, gt_im, dist_im-gt_im, color_x, color_y, color_x-color_y), dim=1)
            return (sample.squeeze(0), mos)

        sample = self.transform_image(gt_im, dist_im)
        # print("transform", time() - st)
        st = time()
        if not self.return_path:
            return (sample, mos)
        else:
            return (dist_path, gt_path, sample, mos)
        
        
    def transform_image(self, gt_img, dist_img):
        maps = self.iwssim(gt_img, dist_img)

        if torch.isnan(maps[0]).any():
            print("there are nans:", len(np.argwhere(np.isnan(maps[0].cpu().numpy()))))
        maps[0] = torch.nan_to_num(maps[0])


        gt_img_numpy = gt_img.permute(0, 2, 3, 1).cpu().numpy()[0]
        dist_img_numpy = dist_img.permute(0, 2, 3, 1).cpu().numpy()[0]

        if self.DCT:
            gt_dct = self.DCT(gt_img)
            dist_dct = self.DCT(dist_img)
            return maps, gt_dct - dist_dct

        if self.vif and self.dlm:
            vif_maps = vifp_mscale(gt_img_numpy, dist_img_numpy)
            vif_maps = [torch.from_numpy(i).type(torch.FloatTensor).permute(2,0,1).to(self.device) for i in vif_maps]
            dlm_feat = dlm(gt_img_numpy, dist_img_numpy) 
            dlm_feat = torch.from_numpy(dlm_feat).type(torch.FloatTensor).to(self.device)
            return (maps, dlm_feat, vif_maps)
        
        if self.dlm:
            dlm_feat = dlm(gt_img_numpy, dist_img_numpy)
            dlm_feat = torch.from_numpy(dlm_feat).type(torch.FloatTensor).to(self.device)
            return (maps, dlm_feat)
        
        if self.vif:
            vif_maps = vifp_mscale(gt_img_numpy, dist_img_numpy)
            vif_maps = [torch.from_numpy(i).type(torch.FloatTensor).permute(2,0,1).to(self.device) for i in vif_maps]
            return (maps, vif_maps)
        
        if self.saliency:
            target_layers = [self.saliency_model.layer4[-1]]
            cam = GradCAMPlusPlus(model=self.saliency_model, target_layers=target_layers)
            targets = [ClassifierOutputTarget(281)]
            saliency = cam(input_tensor=gt_img, targets=targets)
            return (maps, saliency)
        
        return maps