#!/usr/bin/env python3
"""
使用 SGLang 测试 Qwen2.5-3B-Instruct 模型
支持 Tensor Parallel (TP) 在 GPU 1 和 2 上
"""

import os
import time
import requests
import json
import sys
from typing import List, Dict, Optional

def test_sglang_server(host: str = "localhost", port: int = 30000, timeout: int = 60):
    """测试 SGLang 服务器"""
    base_url = f"http://{host}:{port}"
    
    print(f"测试 SGLang 服务器: {base_url}")
    print(f"超时时间: {timeout}秒\n")
    
    # 1. 健康检查
    print("=" * 60)
    print("1. 健康检查")
    print("=" * 60)
    try:
        response = requests.get(f"{base_url}/health", timeout=5)
        print(f"状态码: {response.status_code}")
        if response.status_code == 200:
            print("✅ 服务器健康\n")
        else:
            print(f"⚠️ 响应: {response.text}\n")
    except requests.exceptions.ConnectionError:
        print(f"❌ 无法连接到服务器 {base_url}")
        print("请确保 SGLang 服务器已启动\n")
        return False
    except Exception as e:
        print(f"❌ 健康检查失败: {e}\n")
        return False
    
    # 2. 获取模型信息
    print("=" * 60)
    print("2. 模型信息")
    print("=" * 60)
    try:
        response = requests.get(f"{base_url}/v1/models", timeout=timeout)
        if response.status_code == 200:
            models = response.json()
            print(f"可用模型: {json.dumps(models, indent=2, ensure_ascii=False)}\n")
        else:
            print(f"⚠️ 获取模型信息失败: {response.status_code}")
            print(f"响应: {response.text}\n")
    except Exception as e:
        print(f"⚠️ 获取模型信息失败: {e}\n")
    
    # 3. 测试文本生成 (Completions API)
    print("=" * 60)
    print("3. 测试文本生成 (Completions API)")
    print("=" * 60)
    
    test_prompts = [
        "你好，请介绍一下你自己。",
        "What is 2+2? Answer briefly.",
        "如何学习编程？请给出3个建议。"
    ]
    
    for i, prompt in enumerate(test_prompts, 1):
        try:
            print(f"\n提示 {i}: {prompt}")
            start_time = time.time()
            
            response = requests.post(
                f"{base_url}/v1/completions",
                json={
                    "model": "Qwen/Qwen2.5-3B-Instruct",
                    "prompt": prompt,
                    "max_tokens": 128,
                    "temperature": 0.7,
                    "top_p": 0.95
                },
                timeout=timeout
            )
            
            elapsed = time.time() - start_time
            
            if response.status_code == 200:
                result = response.json()
                if "choices" in result and len(result["choices"]) > 0:
                    generated_text = result["choices"][0].get("text", "")
                    print(f"✅ 生成文本: {generated_text[:200]}...")
                    print(f"   耗时: {elapsed:.2f}秒")
                else:
                    print(f"⚠️ 响应格式异常: {result}")
            else:
                print(f"❌ 生成失败: {response.status_code}")
                print(f"   响应: {response.text}")
                
        except Exception as e:
            print(f"❌ 生成失败: {e}")
    
    # 4. 测试聊天完成 (Chat Completions API)
    print("\n" + "=" * 60)
    print("4. 测试聊天完成 (Chat Completions API)")
    print("=" * 60)
    
    test_chats = [
        {
            "messages": [
                {"role": "system", "content": "你是一个有帮助的助手。"},
                {"role": "user", "content": "什么是机器学习？"}
            ]
        },
        {
            "messages": [
                {"role": "user", "content": "解释一下量子计算的基本原理。"}
            ]
        }
    ]
    
    for i, chat in enumerate(test_chats, 1):
        try:
            print(f"\n对话 {i}:")
            for msg in chat["messages"]:
                print(f"  {msg['role']}: {msg['content']}")
            
            start_time = time.time()
            
            response = requests.post(
                f"{base_url}/v1/chat/completions",
                json={
                    "model": "Qwen/Qwen2.5-3B-Instruct",
                    "messages": chat["messages"],
                    "max_tokens": 256,
                    "temperature": 0.7
                },
                timeout=timeout
            )
            
            elapsed = time.time() - start_time
            
            if response.status_code == 200:
                result = response.json()
                if "choices" in result and len(result["choices"]) > 0:
                    message = result["choices"][0]["message"]["content"]
                    print(f"✅ 回复: {message[:200]}...")
                    print(f"   耗时: {elapsed:.2f}秒")
                else:
                    print(f"⚠️ 响应格式异常: {result}")
            else:
                print(f"❌ 聊天失败: {response.status_code}")
                print(f"   响应: {response.text}")
                
        except Exception as e:
            print(f"❌ 聊天失败: {e}")
    
    # 5. 性能测试
    print("\n" + "=" * 60)
    print("5. 性能测试 (批量请求)")
    print("=" * 60)
    
    try:
        num_requests = 5
        prompts = ["Hello, how are you?"] * num_requests
        print(f"发送 {num_requests} 个并发请求...")
        
        start_time = time.time()
        
        response = requests.post(
            f"{base_url}/v1/completions",
            json={
                "model": "Qwen/Qwen2.5-3B-Instruct",
                "prompt": prompts,
                "max_tokens": 64,
                "temperature": 0.5
            },
            timeout=timeout * 2
        )
        
        elapsed = time.time() - start_time
        
        if response.status_code == 200:
            result = response.json()
            if "choices" in result:
                print(f"✅ 批量请求成功")
                print(f"   请求数: {num_requests}")
                print(f"   总耗时: {elapsed:.2f}秒")
                print(f"   平均耗时: {elapsed / num_requests:.2f}秒/请求")
                print(f"   吞吐量: {num_requests / elapsed:.2f} 请求/秒")
            else:
                print(f"⚠️ 响应格式异常: {result}")
        else:
            print(f"❌ 性能测试失败: {response.status_code}")
            print(f"   响应: {response.text}")
        
    except Exception as e:
        print(f"❌ 性能测试失败: {e}")
    
    # 6. 测试 logprobs（如果支持）
    print("\n" + "=" * 60)
    print("6. 测试 Logprobs")
    print("=" * 60)
    
    try:
        print("\n测试 prompt logprobs...")
        start_time = time.time()
        
        response = requests.post(
            f"{base_url}/v1/completions",
            json={
                "model": "Qwen/Qwen2.5-3B-Instruct",
                "prompt": "Hello",
                "max_tokens": 5,
                "logprobs": 5,
                "temperature": 0.0,
            },
            timeout=timeout
        )
        
        elapsed = time.time() - start_time
        
        if response.status_code == 200:
            result = response.json()
            if "choices" in result and len(result["choices"]) > 0:
                choice = result["choices"][0]
                if "logprobs" in choice:
                    print(f"✅ Logprobs 可用")
                    print(f"   Logprobs keys: {list(choice['logprobs'].keys())}")
                    print(f"   耗时: {elapsed:.2f}秒")
                else:
                    print(f"⚠️ Logprobs 不在响应中")
            else:
                print(f"⚠️ 响应格式异常")
        else:
            print(f"⚠️ Logprobs 测试失败: {response.status_code}")
        
    except Exception as e:
        print(f"⚠️ Logprobs 测试失败: {e}")
    
    print("\n" + "=" * 60)
    print("✅ 测试完成！")
    print("=" * 60)
    return True


if __name__ == "__main__":
    port = 30000
    host = "localhost"
    
    if len(sys.argv) > 1:
        port = int(sys.argv[1])
    if len(sys.argv) > 2:
        host = sys.argv[2]
    
    test_sglang_server(host=host, port=port)

