
import os
import torch 
from torch.utils.data import Dataset, DataLoader
import cv2 
import numpy as np 
from pathlib import Path

def collect_image_paths(input_folder, max_depth=10):
    """Recursively collects image paths from a given folder up to a maximum depth.

    Args:
        input_folder (str): The path to the root folder to search for images.
        max_depth (int, optional): The maximum depth of subdirectories to search. Defaults to 10.

    Raises:
        NotADirectoryError: If the specified input_folder does not exist or is not a directory.

    Returns:
        list[str]: A sorted list of relative paths to the found image files.
    """
    # Validate paths
    if not os.path.isdir(input_folder):
        raise NotADirectoryError(f"Input folder not found: {input_folder}")

    # Supported image extensions (case-insensitive)
    image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.gif', '.webp'}
    
    # Normalize base path for consistent relative paths
    base_path = os.path.abspath(input_folder)
    image_paths = []

    for root, dirs, files in os.walk(input_folder):
        # Calculate current depth relative to base path
        current_depth = os.path.relpath(root, base_path).count(os.sep)
        if current_depth > max_depth:
            del dirs[:]  # Skip further recursion in this branch
            continue

        for file in files:
            ext = os.path.splitext(file)[1].lower()
            if ext in image_extensions:
                full_path = os.path.join(root, file)
                rel_path = os.path.relpath(full_path, base_path)
                image_paths.append(rel_path)

    return sorted(image_paths)


class ImageFolderDataset(Dataset):
    """A PyTorch Dataset for loading images from a specified folder.

    This dataset recursively finds all images in the `root_dir`, loads them,
    and applies a specified processing method (crop, resize, or pad) to ensure
    their dimensions are divisible by a given multiple.
    """

    def __init__(self, root_dir, max_depth=10, proc_method='resize', proc_mult=64):
        """Initializes the ImageFolderDataset.

        Args:
            root_dir (str): The root directory containing the images.
            max_depth (int, optional): Maximum recursion depth for finding images. Defaults to 10.
            proc_method (str, optional): Method to make image dimensions divisible by `proc_mult`.
                                         One of 'resize', 'crop', or 'pad'. Defaults to 'resize'.
            proc_mult (int, optional): The multiple that image dimensions should be divisible by.
                                       Defaults to 64.
        """
        print("Initializing dataset, collecting image paths.")
        self.image_paths = collect_image_paths(root_dir, max_depth=max_depth)
        print(f"{len(self.image_paths)} paths collected.")
        self.root_dir = root_dir
        assert proc_method in ['resize', 'crop', 'pad']
        assert isinstance(proc_mult, int) and proc_mult > 0
        self.proc_method = proc_method
        self.proc_mult = proc_mult


    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        """Retrieves an image and its metadata at the given index.

        Reads an image, converts it to RGB, normalizes it to [0, 1],
        converts it to a PyTorch tensor, and applies the specified
        processing (crop, resize, or pad).

        Args:
            idx (int): The index of the item to retrieve.

        Returns:
            dict: A dictionary containing:
                  - 'image' (torch.Tensor): The processed image tensor.
                  - 'image_path' (str): The full path to the image file.
                  - 'image_name' (str): The name of the image file.
        """
        fn = os.path.join(self.root_dir, self.image_paths[idx])
        image = cv2.imread(fn)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = image / 255
        image = np.expand_dims(image, 0)
        image = torch.from_numpy(image.astype(np.float32)).permute(0, 3, 1, 2)

        if self.proc_method == 'crop':
            image = image[:, :, :image.size(2) // self.proc_mult * self.proc_mult, :image.size(3) // self.proc_mult * self.proc_mult]
        elif self.proc_method == 'resize':
            image = torch.nn.functional.interpolate(image, size=(image.size(2) // self.proc_mult * self.proc_mult, image.size(3) // self.proc_mult * self.proc_mult), mode='bilinear')
        elif self.proc_method == 'pad':
            # pad to the right and bottom so that both dimensions are divisible by 32
            pad_h = (image.size(2) // self.proc_mult + 1) * self.proc_mult - image.size(2)
            pad_w = (image.size(3) // self.proc_mult + 1) * self.proc_mult - image.size(3)
            image = torch.nn.functional.pad(image, (0, pad_w, 0, pad_h), mode='constant', value=0)    

        sample = {'image': image, 'image_path': fn, 'image_name': Path(fn).name}


        return sample
    

def image_folder_collate_fn(data):
    """Custom collate function for the ImageFolderDataset.

    This function takes a list of samples (dictionaries) from the dataset
    and organizes them into a single dictionary where each key corresponds
    to a list of values from all samples in the batch.

    Args:
        data (list[dict]): A list of samples, where each sample is a dictionary
                           returned by `ImageFolderDataset.__getitem__`.

    Returns:
        dict: A dictionary with keys 'images', 'image_paths', and 'image_names',
              each containing a list of the corresponding items from the batch.
    """
    return {
        "images": [sample['image'] for sample in data],
        "image_paths": [sample['image_path'] for sample in data],
        "image_names": [sample['image_name'] for sample in data]
    }