#!/usr/bin/env python3
import os
import argparse
import json
from pathlib import Path
import dotenv
from openai import OpenAI    # pip install --upgrade openai>=1.0.0
import litellm  # pip install litellm

# Load environment variables from .env file
dotenv.load_dotenv()

def test_opeani_inference(endpoint: str, model_name: str) -> None:
    """
    Call the RITS endpoint and test tool calling.
    """
    base_url = (f"http://{endpoint}/v1")

    client = OpenAI(
        base_url=base_url,
        api_key="ollama",
        # default_headers={"RITS_API_KEY": api_key},
        timeout=60,
    )

    tools = [
        {
            "type": "function",
            "function": {
                "name": "get_weather",
                "description": "Get the current weather for a city",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "city": {"type": "string", "description": "City name"}
                    },
                    "required": ["city"]
                }
            }
        }
    ]

    with open("messages_tool.json", "r") as f:
        messages = json.load(f)
    with open("tools.json", "r") as f:
        tools = json.load(f)
        # import ipdb; ipdb.set_trace()
    # messages = [{"role": "user", "content": "What's the weather like in Paris? "}]

    resp = client.chat.completions.create(
        model=model_name,
        messages=messages, # 
        tools=tools,
        tool_choice="required",
        # temperature=0,
        # max_tokens=64,
    )

    print("=== Response ===")
    if resp.choices[0].message.tool_calls:
        print("Tool calls:")
        for call in resp.choices[0].message.tool_calls:
            print(f"  {call.function.name}: {call.function.arguments}")
    else:
        print(resp.choices[0].message.content)
    print(f"Usage: {resp.usage}")

def test_litellm_inference(endpoint: str, model_name: str) -> None:
    """Test the request using litellm."""
    with open("messages_tool.json", "r") as f:
        messages = json.load(f)
    with open("tools.json", "r") as f:
        tools = json.load(f)
    
    resp = litellm.completion(
        model=f"openai/{model_name}",
        messages=messages,
        tools=tools,
        tool_choice="required",
        base_url=f"http://{endpoint}/v1",
        api_key="ollama"
    )
    
    print("=== LiteLLM Response ===")
    if resp.choices[0].message.tool_calls:
        print("Tool calls:")
        for call in resp.choices[0].message.tool_calls:
            print(f"  {call.function.name}: {call.function.arguments}")
    else:
        print(resp.choices[0].message.content)
    print(f"Usage: {resp.usage}")

def main() -> None:
    parser = argparse.ArgumentParser(description="Quick RITS inference test")
    parser.add_argument(
        "--endpoint",
        default="34.66.144.59:8002",
        help="Endpoint slug as shown in the RITS UI (e.g. deepseek-v3-h200)",
    )
    parser.add_argument(
        "--model",
        default="gpt-oss:120b",
        help="Fully-qualified model name (e.g. deepseek-ai/DeepSeek-V3)",
    )
    parser.add_argument(
        "--litellm",
        action="store_true",
        help="Use litellm instead of OpenAI client",
    )
    args = parser.parse_args()
    
    if args.litellm:
        test_litellm_inference(args.endpoint, args.model)
    else:
        test_opeani_inference(args.endpoint, args.model)

if __name__ == "__main__":
    main()
