import os
import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
from typing import Union

from PIL import Image, UnidentifiedImageError
import numpy as np
import pandas as pd
from torchvision.datasets import VisionDataset
import torch

"""
This module defines a custom dataset class NICODataset for loading the NICO dataset.
Main features include:
- Loading image data and environment labels from a specified path.
- Processing all samples as test set.
- Providing transformation functionality for images and labels.
- Supporting image file filtering.
- Automatically skipping corrupted image files.
"""

def pil_loader(path: str) -> Image.Image:
    with open(path, "rb") as f:
        img = Image.open(f)
        return img.convert("RGB")

class NICODataset(VisionDataset):
    def __init__(
        self,
        root: str,
        split: str = "test",  # Parameter kept but not used, all samples are treated as test set
        loader: Callable[[str], Any] = pil_loader,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
    ) -> None:
        super().__init__(root, transform=transform, target_transform=target_transform)

        self.loader = loader
        self.samples = []
        self.img_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff']
        
        # Get class and context mappings
        self.classes = sorted([d for d in os.listdir(root) if os.path.isdir(os.path.join(root, d))])
        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
        self.contexts_by_class = {}  # Store all contexts for each class
        
        # Build dataset
        for class_name in self.classes:
            class_dir = os.path.join(root, class_name)
            contexts = sorted([d for d in os.listdir(class_dir) if os.path.isdir(os.path.join(class_dir, d))])
            self.contexts_by_class[class_name] = contexts
            
            for context_idx, context in enumerate(contexts):
                context_dir = os.path.join(class_dir, context)
                
                # Only get image files
                for file in os.listdir(context_dir):
                    ext = os.path.splitext(file)[1].lower()
                    if ext in self.img_extensions:
                        image_path = os.path.join(context_dir, file)
                        item = (
                            image_path,
                            np.array([self.class_to_idx[class_name], context_idx], dtype=np.int64)
                        )
                        self.samples.append(item)
        
        self.samples.sort(key=lambda x: x[0])
        print(f"Dataset loaded: {len(self.samples)} samples")

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index
        Returns:
            tuple: (sample, target) where target is [class_index, context_index]
        """
        path, target = self.samples[index]
        sample = self.loader(path)
        # No need to check if image is None again, as corrupted images have been filtered during initialization
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target
    
    def __len__(self) -> int:
        return len(self.samples)
    
    def save_label_info(self, output_file: str) -> None:
        """
        Save class and context information to a text file, sorted by label index.
        
        Args:
            output_file (str): Path to the output file
        """
        with open(output_file, 'w', encoding='utf-8') as f:
            f.write("# NICO Dataset Label Information\n\n")
            
            # Write class information
            f.write("## Class Labels\n")
            for cls_name, cls_idx in sorted(self.class_to_idx.items(), key=lambda x: x[1]):
                f.write(f"{cls_idx}: {cls_name}\n")
            f.write("\n")
            
            # Write context information for each class
            f.write("## Context Labels (Grouped by Class)\n")
            for cls_name in self.classes:
                f.write(f"\n### Class: {cls_name} (Class ID: {self.class_to_idx[cls_name]})\n")
                for ctx_idx, ctx_name in enumerate(self.contexts_by_class[cls_name]):
                    f.write(f"{ctx_idx}: {ctx_name}\n")
            
            # Add dataset statistics
            f.write(f"\n## Dataset Statistics\n")
            f.write(f"Total valid images: {len(self.samples)}\n")
        
        print(f"Label information saved to: {output_file}")

# Usage example:
# dataset = NICODataset(root="path/to/NICO")
# dataset.save_label_info("nico_labels.txt")