"""
LLM call logging utility.

Logs each LLM call's inputs (text and images) to a local log folder.
"""
import os
import time
import json
from typing import Optional, List
import numpy as np
from PIL import Image


# Global counter for unique call IDs
_call_counter = 0


def get_log_folder() -> Optional[str]:
    """Return the log folder path."""
    try:
        from absl import flags
        try:
            if hasattr(flags.FLAGS, 'log_folder_exp'):
                log_folder = flags.FLAGS.log_folder_exp
                if log_folder and log_folder.strip():
                    return log_folder.strip()
        except (AttributeError, RuntimeError, flags.IllegalFlagValueError):
            pass
    except (ImportError, AttributeError):
        pass
    
    # Fallback to environment variable
    env_folder = os.environ.get('LLM_LOG_FOLDER', None)
    if env_folder and env_folder.strip():
        return env_folder.strip()
    
    return None


def log_llm_call(
    user_prompt: Optional[str] = None,
    system_prompt: Optional[str] = None,
    images: List[np.ndarray] = [],
    model_name: str = "unknown",
    call_type: str = "predict_mm",
    log_folder: Optional[str] = None,
    additional_info: Optional[dict] = None,
) -> Optional[str]:
    """
    Log LLM call inputs to the local log folder.
    
    Args:
        user_prompt: User prompt (required)
        system_prompt: System prompt (optional)
        images: List of images (numpy arrays)
        model_name: Model name
        call_type: Call type (e.g. 'predict_mm', 'predict')
        log_folder: Log folder path (if None, tries FLAGS)
        additional_info: Extra dict (e.g. temperature, output_format)
    
    Returns:
        Path of saved log file, or None on failure
    """
    global _call_counter
    
    if log_folder is None:
        log_folder = get_log_folder()
    
    if not log_folder or not log_folder.strip():
        return None
    
    llm_calls_folder = os.path.join(log_folder, 'llm_calls')
    os.makedirs(llm_calls_folder, exist_ok=True)
    
    _call_counter += 1
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    call_id = f"call_{_call_counter:06d}_{timestamp}"
    
    # Create folder for this call
    call_folder = os.path.join(llm_calls_folder, call_id)
    os.makedirs(call_folder, exist_ok=True)
    
    # Save system prompt
    if system_prompt and system_prompt.strip():
        system_prompt_file = os.path.join(call_folder, 'system_prompt.txt')
        with open(system_prompt_file, 'w', encoding='utf-8') as f:
            f.write(system_prompt)
    
    # Save user prompt
    if user_prompt:
        user_prompt_file = os.path.join(call_folder, 'user_prompt.txt')
        with open(user_prompt_file, 'w', encoding='utf-8') as f:
            f.write(user_prompt)
    
    # Save images
    image_info = []
    for idx, image in enumerate(images):
        if image is not None:
            try:
                # Save image
                image_path = os.path.join(call_folder, f'image_{idx}.jpg')
                img = Image.fromarray(image)
                # Convert RGBA or other modes with alpha channel to RGB for JPEG compatibility
                if img.mode in ('RGBA', 'LA', 'P'):
                    # Create a white background for transparency
                    rgb_image = Image.new('RGB', img.size, (255, 255, 255))
                    if img.mode == 'P':
                        img = img.convert('RGBA')
                    rgb_image.paste(img, mask=img.split()[-1] if img.mode in ('RGBA', 'LA') else None)
                    img = rgb_image
                elif img.mode != 'RGB':
                    img = img.convert('RGB')
                img.save(image_path, format='JPEG', quality=95)
                image_info.append({
                    'index': idx,
                    'path': f'image_{idx}.jpg',
                    'shape': list(image.shape) if hasattr(image, 'shape') else None
                })
            except Exception as e:
                print(f"⚠️  Warning: Failed to save image {idx}: {e}")
                image_info.append({
                    'index': idx,
                    'error': str(e)
                })
    
    # Save metadata
    metadata = {
        'call_id': call_id,
        'model_name': model_name,
        'call_type': call_type,
        'timestamp': timestamp,
        'has_system_prompt': system_prompt is not None and system_prompt.strip() != '',
        'has_user_prompt': user_prompt is not None,
        'num_images': len(images),
        'image_info': image_info
    }
    
    if additional_info:
        metadata.update(additional_info)
    
    metadata_file = os.path.join(call_folder, 'metadata.json')
    with open(metadata_file, 'w', encoding='utf-8') as f:
        json.dump(metadata, f, indent=2, ensure_ascii=False)
    
    return call_folder

