"""
Dataset module for image fusion training
Handles loading and preprocessing of IR and VIS image pairs
"""
import os
import numpy as np
import torch
import torch.utils.data as data
from PIL import Image
import cv2
from typing import Tuple, List, Optional


class ImageLoader:
    """Utility class for loading and preprocessing images"""

    @staticmethod
    def load_grayscale_image(image_path: str) -> torch.Tensor:
        """
        Load grayscale image using OpenCV

        Args:
            image_path: Path to the image file

        Returns:
            Tensor with shape [1, H, W]
        """
        if not os.path.exists(image_path):
            raise FileNotFoundError(f"Image not found: {image_path}")

        img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
        if img is None:
            raise ValueError(f"Failed to load image: {image_path}")

        img = img.astype('float32')
        img_tensor = torch.from_numpy(img).unsqueeze(0)
        return img_tensor

    @staticmethod
    def load_color_image(image_path: str) -> torch.Tensor:
        """
        Load color image using PIL

        Args:
            image_path: Path to the image file

        Returns:
            Tensor with shape [C, H, W] where C is number of channels
        """
        if not os.path.exists(image_path):
            raise FileNotFoundError(f"Image not found: {image_path}")

        try:
            img = Image.open(image_path)
            img_array = np.asarray(img)

            # Ensure 3D array and transpose to [C, H, W]
            img_array = np.atleast_3d(img_array).transpose(2, 0, 1).astype(np.float32)
            img_tensor = torch.from_numpy(img_array)

            return img_tensor
        except Exception as e:
            raise ValueError(f"Failed to load image {image_path}: {str(e)}")

    @staticmethod
    def load_ir_image(image_path: str) -> torch.Tensor:
        """
        Load IR image and convert to grayscale

        Args:
            image_path: Path to the IR image file

        Returns:
            Tensor with shape [1, H, W]
        """
        if not os.path.exists(image_path):
            raise FileNotFoundError(f"Image not found: {image_path}")

        try:
            img = Image.open(image_path).convert('L')
            img_array = np.asarray(img)
            img_array = np.atleast_3d(img_array).transpose(2, 0, 1).astype(np.float32)
            img_tensor = torch.from_numpy(img_array)

            return img_tensor
        except Exception as e:
            raise ValueError(f"Failed to load IR image {image_path}: {str(e)}")

    @staticmethod
    def rgb_to_y_channel(img_tensor: torch.Tensor) -> torch.Tensor:
        """
        Convert RGB image to Y channel (luminance)

        Args:
            img_tensor: RGB image tensor with shape [3, H, W]

        Returns:
            Y channel tensor with shape [1, H, W]
        """
        if img_tensor.shape[0] != 3:
            raise ValueError("Input tensor must have 3 channels for RGB conversion")

        r, g, b = torch.split(img_tensor, 1, dim=0)
        y = 0.299 * r + 0.587 * g + 0.114 * b
        return y


class DatasetBuilder:
    """Utility class for building dataset file lists"""

    @staticmethod
    def create_image_pairs(data_dir: str, num_samples: int = 8000,
                           ir_prefix: str = 'IR', vis_prefix: str = 'VIS',
                           file_extension: str = '.png') -> List[Tuple[str, str]]:
        """
        Create list of IR and VIS image pairs

        Args:
            data_dir: Directory containing the images
            num_samples: Number of image pairs to include
            ir_prefix: Prefix for IR image files
            vis_prefix: Prefix for VIS image files
            file_extension: File extension for images

        Returns:
            List of tuples containing (IR_path, VIS_path)
        """
        if not os.path.exists(data_dir):
            raise FileNotFoundError(f"Data directory not found: {data_dir}")

        dataset_pairs = []

        for index in range(num_samples):
            ir_filename = f"{ir_prefix}{index + 1}{file_extension}"
            vis_filename = f"{vis_prefix}{index + 1}{file_extension}"

            ir_path = os.path.join(data_dir, ir_filename)
            vis_path = os.path.join(data_dir, vis_filename)

            # Check if both files exist
            if os.path.exists(ir_path) and os.path.exists(vis_path):
                dataset_pairs.append((ir_path, vis_path))
            else:
                print(f"Warning: Missing image pair for index {index + 1}")

        if not dataset_pairs:
            raise ValueError(f"No valid image pairs found in {data_dir}")

        return dataset_pairs


class FusionDataset(data.Dataset):
    """
    Dataset class for loading IR and VIS image pairs for fusion training
    """

    def __init__(self,
                 root_dir: str = None,
                 data_dir: str = r'D:\Image_Data\RoadScene-master\128/',
                 num_samples: int = 8000,
                 normalize: bool = True,
                 transform: Optional[callable] = None,
                 train: bool = True):
        """
        Initialize the fusion dataset

        Args:
            root_dir: Root directory (for compatibility, can be None)
            data_dir: Directory containing IR and VIS images
            num_samples: Number of image pairs to load
            normalize: Whether to normalize images to [0, 1]
            transform: Optional transform to apply to images
            train: Whether this is for training (for compatibility)
        """
        self.data_dir = data_dir
        self.num_samples = num_samples
        self.normalize = normalize
        self.transform = transform
        self.train = train

        # Validate data directory
        if not os.path.exists(self.data_dir):
            raise FileNotFoundError(f"Data directory not found: {self.data_dir}")

        # Create dataset pairs
        self.image_pairs = DatasetBuilder.create_image_pairs(
            self.data_dir,
            self.num_samples
        )

        print(f"Loaded {len(self.image_pairs)} image pairs from {self.data_dir}")

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Get an image pair by index

        Args:
            index: Index of the image pair

        Returns:
            Tuple of (IR_image, VIS_image) tensors
        """
        if index >= len(self.image_pairs):
            raise IndexError(f"Index {index} out of range for dataset of size {len(self.image_pairs)}")

        ir_path, vis_path = self.image_pairs[index]

        try:
            # Load images
            ir_image = ImageLoader.load_grayscale_image(ir_path)
            vis_image = ImageLoader.load_grayscale_image(vis_path)

            # Normalize if required
            if self.normalize:
                ir_image = ir_image / 255.0
                vis_image = vis_image / 255.0

            # Apply transforms if provided
            if self.transform:
                ir_image = self.transform(ir_image)
                vis_image = self.transform(vis_image)

            return ir_image, vis_image

        except Exception as e:
            print(f"Error loading image pair {index}: {str(e)}")
            # Return a default tensor or re-raise the exception
            raise e

    def __len__(self) -> int:
        """Return the size of the dataset"""
        return len(self.image_pairs)

    def get_image_info(self, index: int) -> dict:
        """
        Get information about an image pair

        Args:
            index: Index of the image pair

        Returns:
            Dictionary with image pair information
        """
        if index >= len(self.image_pairs):
            raise IndexError(f"Index {index} out of range")

        ir_path, vis_path = self.image_pairs[index]

        return {
            'index': index,
            'ir_path': ir_path,
            'vis_path': vis_path,
            'ir_exists': os.path.exists(ir_path),
            'vis_exists': os.path.exists(vis_path)
        }

    def validate_dataset(self) -> dict:
        """
        Validate the dataset and return statistics

        Returns:
            Dictionary with validation statistics
        """
        valid_pairs = 0
        missing_ir = 0
        missing_vis = 0

        for ir_path, vis_path in self.image_pairs:
            ir_exists = os.path.exists(ir_path)
            vis_exists = os.path.exists(vis_path)

            if ir_exists and vis_exists:
                valid_pairs += 1
            elif not ir_exists:
                missing_ir += 1
            elif not vis_exists:
                missing_vis += 1

        return {
            'total_pairs': len(self.image_pairs),
            'valid_pairs': valid_pairs,
            'missing_ir': missing_ir,
            'missing_vis': missing_vis,
            'validity_ratio': valid_pairs / len(self.image_pairs) if self.image_pairs else 0
        }


# For backward compatibility
fusiondata = FusionDataset