import os
import json
from PIL import Image

import pickle
import imageio
import scipy.io as sio
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision import transforms

from datasets import register


'''
Adapted from https://github.com/junjun-jiang/US3RN/blob/main/dataset.py
'''
def is_image_file(filename):
    return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg", ".mat"])

def load_img(filepath, load_img_tag = 'msi'):
    # load hypersepctral MSI or RGB  image
    # x: shape (H, W, C) => (512, 512, 31)
    x = sio.loadmat(filepath)
    x = x[load_img_tag].astype(np.float64)
    # x = torch.tensor(x).float()
    return x

# def load_img1(filepath):
#     # load RGB image
#     # x: shape (H, W, c) => (512, 512, 3)
#     x = sio.loadmat(filepath)
#     x = x['RGB'].astype(np.float64)
#     # x = torch.tensor(x).float()
#     return x

@register('cave-image-folder')
class CAVEImageFolder(Dataset):

    def __init__(self, root_path, load_img_tag = 'msi', split_file=None, split_key=None, first_k=None, 
                 repeat=1, cache='none'):
        '''
        Args:
            root_path: image folder directory
            split_file: json file path, a dict whose split_key has the list of image file name
            split_key: str
            first_k: int, the first K files from the folder
            repeat: int, increase the number of images by repeat K times
            cache:  'none': only save filename while read image during training, self.files are a list of image filepath
                    'in_memory': self.files are a list of image tensor
                    'bin': read image and save as pkl, self.files are a list of saved bin files
        '''
        self.repeat = repeat
        self.cache = cache
        self.load_img_tag = load_img_tag

        if split_file is None:
            filenames = [x for x in sorted(os.listdir(root_path)) if is_image_file(x)]
        else:
            with open(split_file, 'r') as f:
                filenames = json.load(f)[split_key]
        if first_k is not None:
            filenames = filenames[:first_k]

        self.files = []
        for filename in filenames:
            file = os.path.join(root_path, filename)

            if cache == 'none':
                self.files.append(file)

            elif cache == 'bin' or cache == 'in_memory':
                img = load_img(file, load_img_tag = self.load_img_tag).transpose(2,0,1)

                if cache == 'in_memory':
                    img = torch.tensor(img).float()
                    self.files.append(img)
                elif cache == 'bin':
                    bin_root = os.path.join(os.path.dirname(root_path),
                        '_bin_' + os.path.basename(root_path))
                    if not os.path.exists(bin_root):
                        os.mkdir(bin_root)
                        print('mkdir', bin_root)
                    bin_file = os.path.join(
                        bin_root, filename.split('.')[0] + '.pkl')
                    if not os.path.exists(bin_file):
                        with open(bin_file, 'wb') as f:
                            pickle.dump(img, f)
                        print('dump', bin_file)
                    self.files.append(bin_file)

                
                

    def __len__(self):
        return len(self.files) * self.repeat

    def __getitem__(self, idx):
        '''
        CAVE dataset are normalize to [0,1]
        '''
        x = self.files[idx % len(self.files)]

        if self.cache == 'none':
            x = load_img(x, load_img_tag = self.load_img_tag).transpose(2,0,1)
            x = torch.from_numpy(x).float()
            return x

        elif self.cache == 'bin':
            with open(x, 'rb') as f:
                x = pickle.load(f)
            x = torch.from_numpy(x).float()
            return x

        elif self.cache == 'in_memory':
            return x



@register('cave-paired-image-folders')
class CAVEPairedImageFolders(Dataset):

    def __init__(self, root_path_1, root_path_2, **kwargs):
        self.dataset_1 = CAVEImageFolder(root_path_1, load_img_tag = 'RGB', **kwargs)
        self.dataset_2 = CAVEImageFolder(root_path_2, load_img_tag = 'msi', **kwargs)

    def __len__(self):
        return len(self.dataset_1)

    def __getitem__(self, idx):
        return self.dataset_1[idx], self.dataset_2[idx]
