from openai import OpenAI, DefaultHttpxClient
import json
from typing import List, Dict, Optional
import time
import os

# 禁用代理设置
os.environ.pop('http_proxy', None)
os.environ.pop('https_proxy', None)
os.environ.pop('HTTP_PROXY', None)
os.environ.pop('HTTPS_PROXY', None)


# -------------------------------------------------------------
# 通过配置文件按需加载不同模型的 IPv6 地址和端口
# -------------------------------------------------------------
CONFIG_PATH = "vllm_config.json"
DEFAULT_MODEL = "train_ppt"
# DEFAULT_MODEL = "stu-qwen2.5-7B-Instruct"

def get_client_for_model(model_name: str, config_path: str = CONFIG_PATH) -> OpenAI:
    """
    根据模型名称加载对应的IPv6地址和端口，并返回 OpenAI 客户端实例。
    
    Args:
        # model_name: 模型名称
        config_path: 配置文件路径，默认为项目根目录下的 vllm_config.json
        
    Returns:
        OpenAI 客户端对象
    """
    try:
        with open(config_path, "r", encoding="utf-8") as f:
            config = json.load(f)
        model_cfg = config.get(model_name)
        if not model_cfg:
            raise ValueError(f"模型 {model_name} 未在 {config_path} 中找到配置")
        ipv6 = model_cfg["ipv6"]
        port = model_cfg["port"]
        return OpenAI(
            base_url=f"http://[{ipv6}]:{port}/v1",
            api_key="EMPTY",
            http_client=DefaultHttpxClient(trust_env=False, timeout=30),
        )
    except FileNotFoundError:
        raise FileNotFoundError(f"配置文件 {config_path} 不存在，请创建后重试")

# 为向后兼容，创建一个默认客户端实例
try:
    # print(f"🔧 正在初始化默认客户端，模型: {DEFAULT_MODEL}")
    client = get_client_for_model(DEFAULT_MODEL)
    # print("✅ 默认客户端初始化成功")
except Exception as e:
    # print(f"⚠️ 无法初始化默认客户端: {e}")
    client = None


class VLLMChatSession:
    """支持连续对话的VLLM API客户端类"""
    
    def __init__(self, model_name: str = "stu-qwen2.5-7B-Instruct"):
        self.model_name = model_name
        # 为当前模型创建对应的客户端
        self.client = get_client_for_model(model_name)
        self.conversation_history: List[Dict[str, str]] = []
        self.system_prompt = None
    
    def set_system_prompt(self, system_prompt: str):
        """设置系统提示词"""
        self.system_prompt = system_prompt
        # 如果已有对话历史，在开头插入系统消息
        if self.conversation_history and self.conversation_history[0].get("role") != "system":
            self.conversation_history.insert(0, {"role": "system", "content": system_prompt})
        elif not self.conversation_history:
            self.conversation_history.append({"role": "system", "content": system_prompt})
        else:
            # 更新现有的系统消息
            self.conversation_history[0] = {"role": "system", "content": system_prompt}
    
    def add_message(self, role: str, content: str):
        """添加消息到对话历史"""
        self.conversation_history.append({"role": role, "content": content})
    
    def send_message(self, message: str, append_to_history: bool = True) -> Optional[str]:
        """
        发送消息并获取回复
        
        Args:
            message: 用户消息
            append_to_history: 是否将消息和回复添加到历史记录
        
        Returns:
            模型回复内容
        """
        # 构建当前会话的消息列表
        current_messages = self.conversation_history.copy()
        current_messages.append({"role": "user", "content": message})
        
        try:
            response = self.client.chat.completions.create(
                model=self.model_name,
                messages=current_messages,
                temperature=0.7,
                max_tokens=16384
            )
            
            reply = response.choices[0].message.content
            
            # 如果需要，添加到对话历史
            if append_to_history:
                self.add_message("user", message)
                self.add_message("assistant", reply)
            
            return reply
            
        except Exception as e:
            print(f"❌ 调用API时发生错误: {e}")
            return None
    
    def get_conversation_history(self) -> List[Dict[str, str]]:
        """获取完整对话历史"""
        return self.conversation_history.copy()
    
    def clear_history(self):
        """清空对话历史（保留系统提示词）"""
        if self.system_prompt:
            self.conversation_history = [{"role": "system", "content": self.system_prompt}]
        else:
            self.conversation_history = []
    
    def save_conversation(self, filename: str):
        """保存对话到文件"""
        try:
            with open(filename, 'w', encoding='utf-8') as f:
                json.dump({
                    "model_name": self.model_name,
                    "conversation": self.conversation_history
                }, f, ensure_ascii=False, indent=2)
            print(f"✅ 对话已保存到: {filename}")
        except Exception as e:
            print(f"❌ 保存对话失败: {e}")
    
    def load_conversation(self, filename: str):
        """从文件加载对话"""
        try:
            with open(filename, 'r', encoding='utf-8') as f:
                data = json.load(f)
                self.model_name = data.get("model_name", self.model_name)
                # 根据新的模型名称重新创建客户端
                self.client = get_client_for_model(self.model_name)
                self.conversation_history = data.get("conversation", [])
                
                # 提取系统提示词
                if self.conversation_history and self.conversation_history[0].get("role") == "system":
                    self.system_prompt = self.conversation_history[0]["content"]
                    
            print(f"✅ 对话已从文件加载: {filename}")
        except Exception as e:
            print(f"❌ 加载对话失败: {e}")


def call_vllm_api(message: str, model_name: str = "stu-qwen2.5-7B-Instruct"):
    """
    简单的单次API调用函数（向后兼容）
    
    Args:
        message: 要发送的消息内容
        model_name: 模型名称
    
    Returns:
        API响应结果
    """
    try:
        client = get_client_for_model(model_name)
        response = client.chat.completions.create(
            model=model_name,
            messages=[
                {"role": "system", "content": open("simple_system_prompt-sft.txt", "r").read()},
                {"role": "user", "content": message}]
        )
        return response
    except Exception as e:
        print(f"❌ 调用API时发生错误: {e}")
        return None


def demo_simple_chat():
    """演示简单的单次对话"""
    print("🔹 简单对话演示")
    print("=" * 50)
    
    test_message = "The following is a file titled 'Supply Chain Performance'. Based on the file, answer: What is the title of the deck?\ndeck file:\ndeck_scene_14.pptx. \nPlease use the provided code tools and terminal tools to read the file and complete the task. When encountering failed attempts, please try another method."
    print(f"💬 发送消息: {test_message}")
    
    response = call_vllm_api(test_message)
    if response:
        print("✅ API响应成功")
        print(f"🤖 回答内容: {response.choices[0].message.content}")
    else:
        print("❌ API调用失败")
    print()


def demo_continuous_chat():
    """演示连续对话功能"""
    print("🔹 连续对话演示")
    print("=" * 50)
    
    # 创建对话会话
    session = VLLMChatSession()
    
    # 设置一个简单的系统提示词
    session.set_system_prompt("你是一个友好的助手，能够记住之前的对话内容。")
    
    # 进行多轮对话
    messages = [
        "你好！我叫小明。",
        "我喜欢编程，特别是Python。",
        "你记得我的名字吗？",
        "你知道我喜欢什么编程语言吗？"
    ]
    
    for i, message in enumerate(messages, 1):
        print(f"💬 第{i}轮 - 用户: {message}")
        reply = session.send_message(message)
        if reply:
            print(f"🤖 第{i}轮 - 助手: {reply}")
        else:
            print("❌ 调用失败")
        print("-" * 30)
    
    # 显示完整对话历史
    print("📚 完整对话历史:")
    for msg in session.get_conversation_history():
        role_name = {"system": "系统", "user": "用户", "assistant": "助手"}
        print(f"{role_name.get(msg['role'], msg['role'])}: {msg['content'][:100]}...")
    print()


def demo_data_analysis_chat():
    """演示数据分析对话功能"""
    print("🔹 数据分析对话演示")
    print("=" * 50)
    
    # 创建数据分析会话
    session = create_data_analysis_session()
    
    # 模拟数据分析任务
    analysis_query = """我有一个CSV文件包含销售数据，文件路径是 'test_csv/gaia.xlsx'。
请帮我分析哪个供应商的收入租金比最低，并告诉我该供应商的类型。"""
    
    print(f"💬 数据分析任务: {analysis_query}")
    
    reply = session.send_message(analysis_query)
    if reply:
        print(f"🤖 分析回复: {reply}")
        
        # 保存这次对话
        session.save_conversation("data_analysis_conversation.json")
    else:
        print("❌ 分析失败")
    print()


def demo_conversation_management():
    """演示对话管理功能"""
    print("🔹 对话管理演示")
    print("=" * 50)
    
    session = VLLMChatSession()
    session.set_system_prompt("你是一个能够记住对话内容的助手。")
    
    # 添加一些对话
    session.send_message("请记住：我的生日是5月15日")
    session.send_message("我住在北京")
    
    # 保存对话
    session.save_conversation("test_conversation.json")
    print("✅ 对话已保存")
    
    # 创建新会话并加载对话
    new_session = VLLMChatSession()
    new_session.load_conversation("test_conversation.json")
    
    # 继续对话
    reply = new_session.send_message("你记得我的生日和住址吗？")
    if reply:
        print(f"🤖 加载对话后的回复: {reply}")
    print()


def demo_complex_analysis():
    """演示复杂的多轮数据分析"""
    print("🔹 复杂数据分析演示")
    print("=" * 50)
    
    session = create_data_analysis_session()
    
    # 第一轮：了解数据结构
    print("💬 第一轮：了解数据结构")
    reply1 = session.send_message("请预览 test_csv/gaia.xlsx 文件的前5行数据")
    if reply1:
        print(f"🤖 回复: {reply1[:200]}...")
    
    # 模拟用户提供工具返回结果
    mock_tool_return = """
    <tool_returns>[{"result": {"success": true, "structure": {"total_rows": 24, "total_columns": 5}, "preview": [{"Name": "店铺A", "Revenue": 10000, "Rent": 2000}]}}]</tool_returns>
    """
    
    # 第二轮：继续分析
    print("\n💬 第二轮：分析收入租金比")
    reply2 = session.send_message(f"根据数据计算所有店铺的收入租金比，找出比例最低的店铺。{mock_tool_return}")
    if reply2:
        print(f"🤖 回复: {reply2[:200]}...")
    
    print("✅ 复杂分析演示完成")
    print()


if __name__ == "__main__":
    print("🚀 VLLM API 连续对话功能演示")
    print("=" * 60)
    
    # 运行各种演示
    demo_simple_chat()
    # demo_continuous_chat() 
    # demo_data_analysis_chat()
    # demo_conversation_management()
    # demo_complex_analysis()
    
    
