import os
import pickle
import torch
import numpy as np
from torch.utils.data import Dataset

from tqdm import trange
from PIL import Image
from python_tools import file_io

class LightFieldDataset(Dataset):
    def __init__(self, root, split, transform=None, target_transform=None):
        super().__init__()
        self.root = root

        split = 'training' if split == 'train' else 'additional'
        self.split = split
        self.transform = transform
        self.target_transform = target_transform

        self.image_dir = os.path.join(root, split)
        self.depth_dir = os.path.join(root, split)

        self.depths = []
        self.images = []

        for scene_dir in os.listdir(self.depth_dir):
            if not os.path.isdir(os.path.join(self.depth_dir, scene_dir)):
                continue
            depth = file_io.read_depth(os.path.join(self.depth_dir, scene_dir), highres=False)
            self.depths.append(depth)
            images = sorted(os.listdir(os.path.join(self.image_dir, scene_dir)))
            image = [os.path.join(self.image_dir, scene_dir, img) for img in images if img.endswith('.png')][0]
            self.images.append(image)

        print(len(self.images), len(self.depths))

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        depth = self.depths[idx]

        img = Image.open(img_path)
        img = self.transform(img)

        depth = torch.from_numpy(depth.copy()).float().unsqueeze(0)

        return img, depth

    def __len__(self):
        return len(self.images)
