"""
Poster Data Loader
Load prompt-logo pairs and answer images
"""

import os
import sys
from pathlib import Path
from typing import List, Optional

# Add current directory to path
_current_dir = Path(__file__).parent
if str(_current_dir) not in sys.path:
    sys.path.insert(0, str(_current_dir))

from poster_config import PosterSample


def find_answer_image(prompt_index: int, answer_base_dir: str) -> Optional[Path]:
    """
    Find corresponding answer image based on prompt index
    
    Args:
        prompt_index: Prompt index (e.g., 1, 2, 3...)
        answer_base_dir: Base directory for answer images
        
    Returns:
        Answer image path, or None if not found
    """
    answer_base_path = Path(answer_base_dir)
    
    if not answer_base_path.exists():
        return None
    
    # Try different file naming formats
    possible_names = [
        f"{prompt_index:03d}.png",
        f"{prompt_index:03d}.jpg",
        f"{prompt_index}.png",
        f"{prompt_index}.jpg",
        f"answer_{prompt_index:03d}.png",
        f"answer_{prompt_index}.png",
    ]
    
    for name in possible_names:
        image_path = answer_base_path / name
        if image_path.exists():
            return image_path
    
    # If not found, try searching in subdirectories
    for subdir in answer_base_path.iterdir():
        if subdir.is_dir():
            for name in possible_names:
                image_path = subdir / name
                if image_path.exists():
                    return image_path
    
    return None


def find_prompt_logo_pairs(logos_dir: str, prompt_dir: str, num_pairs: int = 10, start_index: int = 1, use_logos: bool = True) -> List[tuple]:
    """
    Find prompt-logo pairs in logos directory and prompt directory
    
    Supports two modes:
    1. Logo-based setup: logos_dir contains logo files and prompt files (banner format)
    2. No-logo setup: separate prompt_dir and logos_dir (poster format)
    
    Args:
        logos_dir: Logos directory path (may contain prompt files if use_logos=True)
        prompt_dir: Prompt files directory path (used if use_logos=False or as fallback)
        num_pairs: Number of pairs to find
        start_index: Start index (1-based)
        use_logos: Whether to use logo-based setup (look for prompts in logos_dir)
        
    Returns:
        List of tuples: (prompt_file, logo_file, brand_name, index)
    """
    pairs = []
    logos_path = Path(logos_dir)
    prompt_path = Path(prompt_dir)
    
    if use_logos:
        # Banner-style setup: prompts and logos are in the same directory
        if not logos_path.exists():
            print(f"❌ Logos directory not found: {logos_dir}")
            return pairs
        
        # Find prompt files in logos directory (format: 001_ethicai_prompt.txt)
        prompt_files = sorted(logos_path.glob("*_prompt.txt"))
        
        count = 0
        for prompt_file in prompt_files:
            # Extract index and brand name from filename
            # Format: 001_ethicai_prompt.txt
            stem = prompt_file.stem  # 001_ethicai_prompt
            parts = stem.split('_')
            
            if len(parts) < 2:
                continue
            
            index = parts[0]
            try:
                idx_num = int(index)
                if idx_num < start_index:
                    continue
                index = f"{idx_num:03d}"
            except ValueError:
                continue
            
            brand_name = '_'.join(parts[1:-1]) if len(parts) > 2 else "unknown"
            
            # Find corresponding logo file
            logo_file = None
            # Try cropped version first
            logo_cropped = logos_path / f"{index}_{brand_name}_cropped.png"
            if logo_cropped.exists():
                logo_file = logo_cropped
            else:
                # Try non-cropped version
                logo_normal = logos_path / f"{index}_{brand_name}.png"
                if logo_normal.exists():
                    logo_file = logo_normal
            
            pairs.append((prompt_file, logo_file, brand_name, index))
            count += 1
            if count >= num_pairs:
                break
    else:
        # No-logo setup: separate prompt directory
        if not prompt_path.exists():
            print(f"❌ Prompt directory not found: {prompt_dir}")
            return pairs
        
        prompt_files = sorted(prompt_path.glob("*.txt"))
        
        count = 0
        for prompt_file in prompt_files:
            # Extract index from filename (e.g., "001.txt")
            stem = prompt_file.stem
            try:
                idx_num = int(stem)
                if idx_num < start_index:
                    continue
                index = f"{idx_num:03d}"
            except ValueError:
                continue
            
            # Try to find logo (optional)
            logo_file = None
            if logos_path.exists():
                for logo_candidate in logos_path.glob(f"{index}_*.png"):
                    if logo_candidate.stem == index:
                        continue
                    if logo_candidate.name.endswith("_cropped.png"):
                        logo_file = logo_candidate
                        break
                
                if not logo_file:
                    for logo_candidate in logos_path.glob(f"{index}_*.png"):
                        if logo_candidate.stem == index:
                            continue
                        logo_file = logo_candidate
                        break
            
            # Extract brand name
            if logo_file and logo_file.exists():
                logo_stem = logo_file.stem.replace("_cropped", "")
                parts = logo_stem.split('_')
                if len(parts) > 1:
                    brand_name = '_'.join(parts[1:])
                else:
                    brand_name = "unknown"
            else:
                brand_name = f"sample_{index}"
                logo_file = None
            
            pairs.append((prompt_file, logo_file, brand_name, index))
            count += 1
            if count >= num_pairs:
                break
    
    return pairs


def read_prompt(prompt_file: Path) -> Optional[str]:
    """
    Read prompt text
    
    Args:
        prompt_file: Prompt file path
        
    Returns:
        Prompt text content, or None if read fails
    """
    try:
        with open(prompt_file, 'r', encoding='utf-8') as f:
            content = f.read().strip()
            # Remove quotes if present
            if content.startswith('"') and content.endswith('"'):
                content = content[1:-1]
            return content
    except Exception as e:
        print(f"❌ Error reading prompt file {prompt_file}: {e}")
        return None


def load_poster_samples(
    logos_dir: str,
    prompt_dir: str,
    answer_base_dir: str,
    num_samples: int = 10,
    start_index: int = 1,
    use_logos: bool = True
) -> List[PosterSample]:
    """
    Load poster samples
    
    Args:
        logos_dir: Logos directory path (contains logos and prompts if use_logos=True)
        prompt_dir: Prompt files directory path (used if use_logos=False)
        answer_base_dir: Answer images base directory
        num_samples: Number of samples to load
        start_index: Start index (1-based)
        use_logos: Whether to use logo-based setup
        
    Returns:
        List of PosterSample objects
    """
    samples = []
    pairs = find_prompt_logo_pairs(logos_dir, prompt_dir, num_samples, start_index, use_logos=use_logos)
    
    for prompt_file, logo_file, brand_name, index in pairs:
        # Find corresponding answer image
        try:
            idx_int = int(index)
        except ValueError:
            print(f"⚠️ Invalid index: {index}, skipping")
            continue
        
        answer_image_path = find_answer_image(idx_int, answer_base_dir)
        
        if not answer_image_path:
            print(f"⚠️ Answer image not found for {index}_{brand_name}, skipping")
            continue
        
        # Create sample (logo_file can be None if not found)
        sample = PosterSample(
            index=index,
            brand_name=brand_name,
            prompt_file=prompt_file,
            logo_file=logo_file if logo_file and logo_file.exists() else None,  # Use None if no logo
            answer_image_path=answer_image_path
        )
        samples.append(sample)
        logo_status = f" (logo: {logo_file.name})" if logo_file else " (no logo)"
        print(f"✅ Loaded sample {index}: {brand_name}{logo_status}")
    
    return samples


def get_total_samples(logos_dir: str, prompt_dir: str, use_logos: bool = True) -> int:
    """
    Get total number of available samples
    
    Args:
        logos_dir: Logos directory path
        prompt_dir: Prompt files directory path
        use_logos: Whether to use logo-based setup
        
    Returns:
        Total number of samples
    """
    if use_logos:
        # Banner-style: prompts in logos_dir
        logos_path = Path(logos_dir)
        if not logos_path.exists():
            return 0
        prompt_files = list(logos_path.glob("*_prompt.txt"))
    else:
        # No-logo: separate prompt_dir
        prompt_path = Path(prompt_dir)
        if not prompt_path.exists():
            return 0
        prompt_files = list(prompt_path.glob("*.txt"))
    
    return len(prompt_files)

