import argparse
from pathlib import Path

from components import RetrievalConfig, RetrievalPipeline


def main():
               
    parser = argparse.ArgumentParser(description='Standalone Retrieval Runner')

          
    parser.add_argument('--game', default='dyinglight2', help='Game name')
    parser.add_argument('--segment_id', type=int, required=True, help='Segment ID')
    parser.add_argument('--output_file', help='Output file path (optional, auto-generated)')

          
    parser.add_argument('--retrieval_method', default='vector', choices=['vector', 'bm25'],
                        help='Retrieval method: vector (vector retrieval) or bm25 (BM25 retrieval)')
    parser.add_argument('--api_key', default='{your api_key}')
    parser.add_argument('--base_url', default='{your base_url}')
    parser.add_argument('--siliconflow_api_key', default='{your api_key}')
    parser.add_argument('--siliconflow_base_url', default='{your base_url}')
    parser.add_argument('--embedding_model', default='text-embedding-3-small')
    parser.add_argument('--embedding_service', default='openai', choices=['openai', 'siliconflow'],
                        help='Embedding service type: openai or siliconflow')
    parser.add_argument('--top_k', type=int, default=5)

              
    parser.add_argument('--bm25_k1', type=float, default=1.2, help='BM25 parameter k1 (term frequency saturation)')
    parser.add_argument('--bm25_b', type=float, default=0.75, help='BM25 parameter b (document length normalization)')

          
    parser.add_argument('--force_rebuild', action='store_true', help='Force rebuild index')
    parser.add_argument('--include_timeless', action='store_true', default=True, help='Include timeless data')
    parser.add_argument('--verbose', action='store_true', default=True, help='Verbose output')

    args = parser.parse_args()

            
    config = RetrievalConfig(
        game_name=args.game,
        target_segment_id=args.segment_id,
        retrieval_method=args.retrieval_method,
        api_key=args.api_key,
        base_url=args.base_url,
        siliconflow_api_key=args.siliconflow_api_key,
        siliconflow_base_url=args.siliconflow_base_url,
        embedding_model=args.embedding_model,
        embedding_service=args.embedding_service,
        top_k=args.top_k,
        bm25_k1=args.bm25_k1,
        bm25_b=args.bm25_b,
        force_rebuild=args.force_rebuild,
        include_timeless=args.include_timeless,
        verbose=args.verbose
    )
          
    pipeline = RetrievalPipeline(config)
    if not pipeline.initialize():
        print("Failed to initialize retrieval pipeline")
        return

    qa_pairs = pipeline.load_qa_pairs(args.segment_id)
    if not qa_pairs:
        print("No QA pairs found")
        return

    if args.output_file:
        output_file = Path(args.output_file)
    else:
                 
        if args.retrieval_method == 'bm25':
            method_name = "bm25"
        else:
            model_name = args.embedding_model.replace('-', '_').replace('/', '_')
            method_name = f"{args.embedding_service}_{model_name}"

        filename = f"retrieval_{args.game}_segment_{args.segment_id}_{method_name}_k{args.top_k}.jsonl"

        output_dir = Path(config.output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)
        output_file = output_dir / filename


          
    results = pipeline.batch_retrieve_qa_pairs(qa_pairs, str(output_file))


            
    success_count = len([r for r in results if 'error' not in r])
    error_count = len(results) - success_count

    if error_count > 0:
        print(f"Error retrieving for {error_count} items")

if __name__ == '__main__':
    main()