"""
LLM call logging utility.

Logs each LLM call's inputs (text and images) to a local log folder.
"""
import os
import time
import json
from typing import Optional, List
import numpy as np
from PIL import Image


_call_counter = 0


def get_log_folder() -> Optional[str]:
    """Return the log folder path."""
    try:
        from absl import flags
        try:
            if hasattr(flags.FLAGS, 'log_folder_exp'):
                log_folder = flags.FLAGS.log_folder_exp
                if log_folder and log_folder.strip():
                    return log_folder.strip()
        except (AttributeError, RuntimeError, flags.IllegalFlagValueError):
            pass
    except (ImportError, AttributeError):
        pass
    
    env_folder = os.environ.get('LLM_LOG_FOLDER', None)
    if env_folder and env_folder.strip():
        return env_folder.strip()
    
    return None


def log_llm_call(
    user_prompt: Optional[str] = None,
    system_prompt: Optional[str] = None,
    images: List[np.ndarray] = [],
    model_name: str = "unknown",
    call_type: str = "predict_mm",
    log_folder: Optional[str] = None,
    additional_info: Optional[dict] = None,
) -> Optional[str]:
    """Log LLM call inputs to local log folder. Returns saved path or None on failure."""
    global _call_counter
    
    if log_folder is None:
        log_folder = get_log_folder()
    
    if not log_folder or not log_folder.strip():
        return None
    
    llm_calls_folder = os.path.join(log_folder, 'llm_calls')
    os.makedirs(llm_calls_folder, exist_ok=True)
    
    _call_counter += 1
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    call_id = f"call_{_call_counter:06d}_{timestamp}"
    
    call_folder = os.path.join(llm_calls_folder, call_id)
    os.makedirs(call_folder, exist_ok=True)
    
    if system_prompt and system_prompt.strip():
        system_prompt_file = os.path.join(call_folder, 'system_prompt.txt')
        with open(system_prompt_file, 'w', encoding='utf-8') as f:
            f.write(system_prompt)
    
    if user_prompt:
        user_prompt_file = os.path.join(call_folder, 'user_prompt.txt')
        with open(user_prompt_file, 'w', encoding='utf-8') as f:
            f.write(user_prompt)
    
    image_info = []
    for idx, image in enumerate(images):
        if image is not None:
            try:
                image_path = os.path.join(call_folder, f'image_{idx}.jpg')
                img = Image.fromarray(image)
                img.save(image_path, format='JPEG', quality=95)
                image_info.append({
                    'index': idx,
                    'path': f'image_{idx}.jpg',
                    'shape': list(image.shape) if hasattr(image, 'shape') else None
                })
            except Exception as e:
                print(f"⚠️  Warning: Failed to save image {idx}: {e}")
                image_info.append({
                    'index': idx,
                    'error': str(e)
                })
    
    metadata = {
        'call_id': call_id,
        'model_name': model_name,
        'call_type': call_type,
        'timestamp': timestamp,
        'has_system_prompt': system_prompt is not None and system_prompt.strip() != '',
        'has_user_prompt': user_prompt is not None,
        'num_images': len(images),
        'image_info': image_info
    }
    
    if additional_info:
        metadata.update(additional_info)
    
    metadata_file = os.path.join(call_folder, 'metadata.json')
    with open(metadata_file, 'w', encoding='utf-8') as f:
        json.dump(metadata, f, indent=2, ensure_ascii=False)
    
    return call_folder

