"""
Tamper Detection Data Handler

This module provides a data handler for the tamper detection task.
"""

import base64
import os
from pathlib import Path
from typing import List, Tuple

from PIL import Image
import io

from src.core.config import config_manager
from src.utils.decorator_utils import with_logger


class TamperDetectionDataHandler:
    """
    Data handler for the tamper detection task.
    
    This class handles loading and processing data for the tamper detection task.
    """
    
    @with_logger
    def __init__(self, font_semantics: str):
        """
        initialise the tamper detection data handler.
        
        Args:
            font_semantics: Whether to use "font" or "semantics" data
        """
        logger.info(f"Initialising TamperDetectionDataHandler with font_semantics={font_semantics}")
        self.font_semantics = font_semantics
        self.original_b64_lst: List[str] = []
        self.original_type_lst: List[str] = []
        self.tampered_b64_lst: List[str] = []
        self.tampered_type_lst: List[str] = []

        self.init_dataset()
        logger.info("TamperDetectionDataHandler initialisation complete")

    @with_logger
    def init_dataset(self) -> None:
        """
        initialise the dataset by loading images.
        """
        logger.info("Initialising dataset")
        
        @with_logger
        def _load_data(dir_path: Path) -> Tuple[List[str], List[str]]:
            """
            Load data from a directory.
            
            Args:
                dir_path: The directory path
                
            Returns:
                A tuple of (base64_encoded_images, image_types)
            """
            b64e_lst = []
            type_lst = []
            
            # Get all image files in the directory
            logger.info(f"Loading images from directory: {dir_path}")
            img_lst = sorted(
                [
                    str(file)
                    for file in dir_path.iterdir()
                    if (file.is_file()) and (not str(file.name).startswith("."))
                ]
            )
            logger.info(f"Found {len(img_lst)} images in directory")
            
            # Load and encode each image
            for img_path in img_lst:
                logger.info(f"Processing image: {img_path}")
                with open(img_path, "rb") as image:
                    base64_image = base64.b64encode(image.read()).decode()
                
                b64e_lst.append(base64_image)
                
                # Get the image type using PIL
                try:
                    with open(img_path, "rb") as f:
                        img = Image.open(io.BytesIO(f.read()))
                        # Get the format (e.g., 'JPEG', 'PNG') and convert to lowercase
                        img_type = img.format.lower()
                        logger.info(f"Detected image type: {img_type}")
                except Exception as e:
                    # Fallback to extension if PIL detection fails
                    logger.warning(f"Failed to detect image type using PIL for {img_path}: {str(e)}")
                    img_type = os.path.splitext(img_path)[1].lstrip('.').lower()
                    if not img_type:
                        img_type = 'jpeg'  # Default to jpeg if no extension
                    logger.info(f"Using fallback image type: {img_type}")
                
                type_lst.append(img_type)
            
            return b64e_lst, type_lst
        
        # Get the dataset paths based on font_semantics
        logger.info(f"Setting up dataset paths for font_semantics: {self.font_semantics}")
        if self.font_semantics == "font":
            # Get paths from configuration
            data_path = config_manager.get("paths.data", "data")
            # Check if this is an external data path or the default path
            if data_path != "data":
                # External data path - assume it already has the correct structure
                dir_path_original = Path(data_path) / "original"
                dir_path_tampered = Path(data_path) / "tampered"
            else:
                # Default data path - use the standard structure
                dir_path_original = Path(data_path) / "tampering_detection_font" / "original"
                dir_path_tampered = Path(data_path) / "tampering_detection_font" / "tampered"
            logger.info(f"Using font dataset paths: original={dir_path_original}, tampered={dir_path_tampered}")
        elif self.font_semantics == "semantics":
            # Get paths from configuration
            data_path = config_manager.get("paths.data", "data")
            logger.info(f"PATHS.data: {data_path}")
            # Check if this is an external data path or the default path
            if data_path != "data":
                # External data path - assume it already has the correct structure
                dir_path_original = Path(data_path) / "original"
                dir_path_tampered = Path(data_path) / "tampered"
            else:
                # Default data path - use the standard structure
                dir_path_original = Path(data_path) / "tampering_detection_semantics" / "original"
                dir_path_tampered = Path(data_path) / "tampering_detection_semantics" / "tampered"
            logger.info(f"Using semantics dataset paths: original={dir_path_original}, tampered={dir_path_tampered}")
        else:
            logger.error(f"Unknown font_semantics value: {self.font_semantics}")
            raise ValueError(f"Unknown font_semantics: {self.font_semantics}")
        
        # Load the data
        logger.info("Loading original images")
        self.original_b64_lst, self.original_type_lst = _load_data(dir_path_original)
        logger.info("Loading tampered images")
        self.tampered_b64_lst, self.tampered_type_lst = _load_data(dir_path_tampered)
        
        logger.info(f"Dataset initialisation complete. Loaded {len(self.original_b64_lst)} image pairs")
    
    @with_logger
    def get_data(self, data_id: int) -> Tuple[str, str, str, str]:
        """
        Get data for a specific ID.
        
        Args:
            data_id: The data ID
            
        Returns:
            A tuple of (original_image_base64, original_image_type,
                        tampered_image_base64, tampered_image_type)
        """
        logger.debug(f"Retrieving data for ID: {data_id}")
        if data_id < 0 or data_id >= len(self.original_b64_lst):
            logger.error(f"Invalid data_id: {data_id}. Valid range is 0-{len(self.original_b64_lst)-1}")
            raise IndexError(f"Invalid data_id: {data_id}")

        return (
            self.original_b64_lst[data_id],
            self.original_type_lst[data_id],
            self.tampered_b64_lst[data_id],
            self.tampered_type_lst[data_id],
        )
    
    @with_logger
    def get_size(self) -> int:
        """
        Get the size of the dataset.
        
        Returns:
            The number of image pairs in the dataset
        """
        size = len(self.original_b64_lst)
        logger.info(f"Dataset size: {size} image pairs")
        return size