#!/usr/bin/env python3
"""
简单的测试脚本 - 处理单条数据
"""
import logging
import sys
import os

# 添加路径
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

from refinement_engine import RefinementEngine

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout),
        logging.FileHandler('refinement_test.log', encoding='utf-8')
    ]
)

logger = logging.getLogger(__name__)


def main():
    """主函数"""
    # 配置参数
    model_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "exp/llada_v_lora_rank64_1227")
    model_base = "GSAI-ML/LLaDA-V"
    model_name = "llava_llada_lora"
    vision_tower_path = "google/siglip2-so400m-patch14-384"
    device = "cuda:0"
    IMAGE_INPUT_MODE = "both"  # 图像输入模式: "original"（仅原图）、"crop"（仅局部图）、"both"（原图+局部图）
    MASK_MODE = "span"  # Mask 模式: "span"（使用TextMiner解析span）、"single"（只mask高波动token）、"expand"（扩展高波动token左右各4个token）
    TOKEN_SELECTION_MODE = "jitter"  # Token 选择模式: "jitter"（选择jitter最高的token）、"random"（随机选择一个token）、"confidence"（选择confidence最低的token）、"jitter_confidence"（选择jitter最高的token，如果jitter低于阈值，则选择confidence最低的token）
    LOCAL_REFINEMENT_MODE = "crop"  # 局部修复模式: "crop"（裁剪模式，通过裁剪获取局部证据）、"bbox"（边界框模式，在原图上绘制红色边界框）
    # 测试数据（请根据实际情况修改）
    image_path = os.getenv("TEST_IMAGE_PATH", "/path/to/test_image.jpg")
    base_instruction = "Please describe the image in detail. Use less absolute directional descriptions. Do not repeat information."
    
    logger.info("=" * 60)
    logger.info("Starting Refinement Test")
    logger.info(f"Image: {image_path}")
    logger.info(f"Instruction: {base_instruction}")
    logger.info("=" * 60)
    
    try:
        # 初始化引擎
        engine = RefinementEngine(
            model_path=model_path,
            model_base=model_base,
            model_name=model_name,
            vision_tower_path=vision_tower_path,
            device=device,
            max_steps=6,
            jitter_threshold=0.35,
            mask_expansion=2,
            temp_dir="./cropped_image",
            logger=logger,
            image_input_mode=IMAGE_INPUT_MODE,
            mask_mode=MASK_MODE,
            token_selection_mode=TOKEN_SELECTION_MODE,
            local_refinement_mode=LOCAL_REFINEMENT_MODE
        )
        
        # 执行细化
        final_response, metadata = engine.refine(
            image_path=image_path,
            base_instruction=base_instruction
        )
        
        # 输出结果
        logger.info("\n" + "=" * 60)
        logger.info("FINAL RESULT")
        logger.info("=" * 60)
        logger.info(f"Response:\n{final_response}")
        logger.info(f"\nMetadata: {metadata}")
        logger.info("=" * 60)
        
        print("\n" + "=" * 60)
        print("FINAL RESULT")
        print("=" * 60)
        print(f"Response:\n{final_response}")
        print(f"\nMetadata: {metadata}")
        print("=" * 60)
        
    except Exception as e:
        logger.error(f"Error during refinement: {e}", exc_info=True)
        raise


if __name__ == "__main__":
    main()

