from vllm import LLM, SamplingParams, AsyncLLMEngine
from transformers import AutoTokenizer
import openai
from openai import OpenAI
import time
import torch
from concurrent.futures import ThreadPoolExecutor
import os

######################### open-source model path ##########################
mapping_modelname_to_path = {
    # Llama系列
    "Llama-3.1-8B": "/home/songfeifan/workspace/models/llama31-8b",
    "Llama-3.1-70B": "/home/songfeifan/workspace/models/llama31-70b",
    "Llama-3.1-8B-Instruct": "/home/songfeifan/workspace/models/llama-3.1-8b-instruct",
    "Llama-3.1-70B-Instruct": "/home/songfeifan/workspace/models/llama31-70b-instruct",
    # Qwen系列 NOTE 已测试能跑通
    "Qwen2.5-0.5B": "/home/weishaohang/workspace/models/models/Qwen/Qwen2.5-0.5B",
    "Qwen2.5-1.5B": "/home/weishaohang/workspace/models/models/Qwen/Qwen2.5-1.5B",
    "Qwen2.5-3B": "/home/weishaohang/workspace/models/models/Qwen/Qwen2.5-3B",
    "Qwen2.5-7B": "/home/weishaohang/workspace/models/models/Qwen/Qwen2.5-7B",
    "Qwen2.5-14B": "/home/weishaohang/workspace/models/models/Qwen/Qwen2.5-14B",
    "Qwen2.5-32B": "/home/weishaohang/workspace/models/models/Qwen/Qwen2.5-32B",
    "Qwen2.5-72B": "/home/weishaohang/workspace/models/models/Qwen/Qwen2.5-72B",
    "Qwen2.5-0.5B-Instruct": "/home/weishaohang/workspace/models/models/Qwen/Qwen2.5-0.5B-Instruct",
    "Qwen2.5-1.5B-Instruct": "/home/weishaohang/workspace/models/models/Qwen/Qwen2.5-1.5B-Instruct",
    "Qwen2.5-3B-Instruct": "/home/weishaohang/workspace/models/models/Qwen/Qwen2.5-3B-Instruct",
    "Qwen2.5-7B-Instruct": "/home/weishaohang/workspace/models/models/Qwen/Qwen2.5-7B-Instruct",
    "Qwen2.5-14B-Instruct": "/home/weishaohang/workspace/models/models/Qwen/Qwen2.5-14B-Instruct",
    "Qwen2.5-32B-Instruct": "/home/weishaohang/workspace/models/models/Qwen/Qwen2.5-32B-Instruct",
    # "Qwen2.5-72B-Instruct": "/home/weishaohang/workspace/models/models/Qwen/Qwen2.5-72B-Instruct",  
    "Qwen2.5-7B-Instruct-1M": "/home/weishaohang/workspace/models/models/Qwen/Qwen2.5-7B-Instruct-1M",  
    "Qwen2.5-14B-Instruct-1M": "/home/weishaohang/workspace/models/models/Qwen/Qwen2.5-14B-Instruct-1M",  
    # Mistral系列
    "Mistral-Small-3.1-24B-Instruct-2503": "/home/weishaohang/workspace/models/models/mistralai/Mistral-Small-3.1-24B-Instruct-2503",
    # gemma-3-27b-it
    "gemma-3-27b-it": "/home/weishaohang/workspace/models/models/google/gemma-3-27b-it",
    # Deepseek-R1-Distill-Qwen
    "Deepseek-R1-Distill-Qwen-7B": "/home/weishaohang/workspace/models/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
    "Deepseek-R1-Distill-Qwen-14B": "/home/weishaohang/workspace/models/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-14B",
    "Deepseek-R1-Distill-Qwen-32B": "/home/weishaohang/workspace/models/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
    # Deepseek-R1-Distill-Llama
    "Deepseek-R1-Distill-Llama-8B": "/home/weishaohang/workspace/models/models/deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
    "Deepseek-R1-Distill-Llama-70B": "/home/weishaohang/workspace/models/models/deepseek-ai/DeepSeek-R1-Distill-Llama-70B",
    # QwQ
    "QwQ-32B": "/home/weishaohang/workspace/models/models/Qwen/QwQ-32B",
    # s1
    "s1-32B": "/home/weishaohang/workspace/models/models/simplescaling/s1-32B",
}

models_need_instruct_template = [
    "Llama-3.1-8B-Instruct",
    "Llama-3.1-70B-Instruct",
    "Qwen2.5-3B-Instruct",
    "Qwen2.5-7B-Instruct",
    "Qwen2.5-14B-Instruct",
    "Qwen2.5-32B-Instruct",
    "Qwen2.5-72B-Instruct",
    "Qwen2.5-7B-Instruct-1M",
    "Qwen2.5-14B-Instruct-1M",
    "Mistral-Small-3.1-24B-Instruct-2503",
    "gemma-3-27b-it",
    "Deepseek-R1-Distill-Llama-8B",
    "Deepseek-R1-Distill-Llama-70B",
    "Deepseek-R1-Distill-Qwen-7B",
    "Deepseek-R1-Distill-Qwen-14B",
    "Deepseek-R1-Distill-Qwen-32B",
    "QwQ-32B",
    "s1-32B",
]

# 模型最大输入上下文长度
mapping_modelname_to_max_input_context_len = {
    # open-source models
    "Mistral-Small-3.1-24B-Instruct-2503": 131072,
    "gemma-3-27b-it": 131072,
    "Llama-3.1-8B": 131072,
    "Llama-3.1-70B": 131072,
    "Llama-3.2-1B-Instruct": 131072,
    "Llama-3.2-3B-Instruct": 131072,
    "Llama-3.1-8B-Instruct": 131072,
    "Llama-3.1-70B-Instruct": 131072,
    "Llama-3.3-70B-Instruct": 131072,
    "Qwen2.5-0.5B": 32768,  # max_position_embeddings=32768
    "Qwen2.5-1.5B": 131072,  # max_position_embeddings=131072
    "Qwen2.5-3B": 32768,  # max_position_embeddings=32768
    "Qwen2.5-7B": 131072,  # max_position_embeddings=131072
    "Qwen2.5-14B": 131072,  # max_position_embeddings=131072
    "Qwen2.5-32B": 131072,  # max_position_embeddings=131072
    "Qwen2.5-72B": 131072,  # max_position_embeddings=131072
    "Qwen2.5-0.5B-Instruct": 32768, # max_position_embeddings=32768
    "Qwen2.5-1.5B-Instruct": 32768, # max_position_embeddings=32768
    "Qwen2.5-3B-Instruct": 32768,   # max_position_embeddings=32768
    "Qwen2.5-7B-Instruct": 32768,   # max_position_embeddings=32768
    "Qwen2.5-14B-Instruct": 32768, # max_position_embeddings=32768
    "Qwen2.5-32B-Instruct": 32768, # max_position_embeddings=32768
    "Qwen2.5-72B-Instruct": 32768,  # max_position_embeddings=32768
    "QwQ-32B": 40960,   # max_position_embeddings=40960
    "Qwen2.5-7B-Instruct-1M": 1010000,  # max_position_embeddings=1010000
    "Qwen2.5-14B-Instruct-1M": 1010000,  # max_position_embeddings=1010000
    
    "Deepseek-R1-Distill-Qwen-1.5B": 131072,
    "Deepseek-R1-Distill-Qwen-7B": 131072,
    "Deepseek-R1-Distill-Qwen-14B": 131072, 
    "Deepseek-R1-Distill-Llama-8B": 131072, 
    "Deepseek-R1-Distill-Qwen-32B": 131072,
    "Deepseek-R1-Distill-Llama-70B": 131072,
    "s1-32B": 32768,
    
    # api models
    "Deepseek-V3": 64000,
    "gpt-4o": 131072,
    "o3-mini": 200000,
    "Deepseek-R1": 64000,
}


# 可以通过VLLM进行推理的开源模型
VLLM_MODEL_NAMES = list(mapping_modelname_to_path.keys())
if "Flan-t5-xxl" in VLLM_MODEL_NAMES:
    VLLM_MODEL_NAMES.remove("Flan-t5-xxl")
# 不能通过VLLM进行推理的开源模型
TRANSFORMERS_MODEL_NAMES = ["Flan-t5-xxl"]
# 可以通过SGLang进行推理的开源模型 - 这里先使用与VLLM相同的模型列表
SGLANG_MODEL_NAMES = list(mapping_modelname_to_path.keys())
if "Flan-t5-xxl" in SGLANG_MODEL_NAMES:
    SGLANG_MODEL_NAMES.remove("Flan-t5-xxl")
######################### closed-source model API #########################
API_MODEL_NAMES = ["qwen-max", "Deepseek-V3", "gpt-4o", "o3-mini", "Deepseek-R1", "Qwen2.5-72B-Instruct"]
API_MODEL_MAPPING = {   # key: model_name, value: (api_key, api_url, temperature, max_tokens, top_p, top_k, frequency_penalty, presence_penalty)

    "Deepseek-R1": ("sk-d39f38c6e0324e9da74e0ba521a3ffeb", "https://api.deepseek.com", 0.0, 10000, 0.0, 0, 0.0, 0.0),
    "Deepseek-V3": ("sk-d39f38c6e0324e9da74e0ba521a3ffeb", "https://api.deepseek.com", 0.0, 10000, 0.0, 0, 0.0, 0.0),
    "gpt-4o": ("sk-zk29d48b675d6e73666177431904f2e760567ff9afcb3060", 
                "https://api.zhizengzeng.com/v1/", 0.0, 10000, 0.0, 0, 0.0, 0.0),
    "o3-mini": ("sk-zk29d48b675d6e73666177431904f2e760567ff9afcb3060", 
                "https://api.zhizengzeng.com/v1/", 0.0, 10000, 0.0, 0, 0.0, 0.0),
    "Qwen2.5-72B-Instruct": ("sk-rtjgclzcdpketmgsdwrummuvfnigbuursvkqqucroiykqwpk", 
                              "https://api.siliconflow.cn/v1/", 0.0, 10000, 0.0, 0, 0.0, 0.0),
}
########################################################################### 

def get_visible_gpu_count():
    """
    获取当前可见的GPU数量
    根据CUDA_VISIBLE_DEVICES环境变量动态计算
    """
    cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES')
    if cuda_visible_devices is None or cuda_visible_devices.strip() == '':
        # 如果未设置CUDA_VISIBLE_DEVICES，默认为1
        return 1
    # 将CUDA_VISIBLE_DEVICES按逗号分割，计算GPU数量
    return len([x for x in cuda_visible_devices.split(',') if x.strip()])

class VLLM_MODEL():
    def __init__(self, model_name):
        self.model_name = model_name
        assert model_name in mapping_modelname_to_path, f"model_name {model_name} not in mapping_modelname_to_path"
        self.model_path = mapping_modelname_to_path[model_name]
        
        # 根据模型大小自动调整gpu_memory_utilization和tensor_parallel_size
        memory_util = 0.9  # 默认值
        
        # 为大型模型降低显存利用率，提高稳定性
        if "7B" in model_name:
            memory_util = 0.9
            if "1M" in model_name:
                memory_util = 0.9
        elif "14B" in model_name:
            memory_util = 0.95  # 降低到65%以避免14B模型的非法内存访问
        elif "32B" in model_name or "70B" in model_name or "72B" in model_name:
            memory_util = 0.95   # 对于更大模型进一步降低
            
        # 获取GPU数量
        tp_size = get_visible_gpu_count()
        # pp_size = get_visible_gpu_count()
        
        # 对于某些大型模型，可能需要固定tensor_parallel_size以提高稳定性
        if "14B" in model_name and tp_size > 1:
        # if "14B" in model_name and pp_size > 1:
            # 如果发现当前配置不稳定，可以尝试将tp_size固定为较小值
            # tp_size = min(tp_size, 2)  # 可选：限制最大并行度
            pass
            
        print(f"初始化模型 {model_name} 使用设置: gpu_memory_utilization={memory_util}, tensor_parallel_size={tp_size}")
        # print(f"初始化模型 {model_name} 使用设置: gpu_memory_utilization={memory_util}, pipeline_parallel_size={pp_size}")
        
        
        self.llm = LLM(
            model=self.model_path,
            gpu_memory_utilization=memory_util,
            enable_prefix_caching=True, # 这个开了之后老出错，不知道为什么；如果能不出错，开了是有可能加速
            device="cuda", # 不设置也行
            dtype="bfloat16", # 不设置也行
            tensor_parallel_size=tp_size,
            enforce_eager=True, # 不设置也行
            enable_chunked_prefill=True, # 最好设置，否则可能会爆显存
            swap_space=1,    # 为每一个GPU分配1GB的CPU交换空间，用于减轻GPU内存压力
            # max_model_len=500000    # TODO:只对Qwen2.5-14B-Instruct-1M有效
        )
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)

    def generate(self, texts):  # 批量输入, texts: list[str]
        
        prompts = []
        if "instruct" in self.model_path.lower():   # NOTE Instruct model
            for text in texts:
                prompt = self.tokenizer.apply_chat_template(
                    [{"role": "user", "content": text}],
                    add_generation_prompt=True,
                    tokenize=False,
                )
                prompts.append(prompt)
        else:   # NOTE Base model
            prompts = texts

        sampling_params = SamplingParams(
            n=1, # 每个prompt生成1个output
            temperature=0.0, # 设置为0.3进行实验
            max_tokens=10000, # 最多生成10000个token
            stop=[self.tokenizer.eos_token], # 生成到eos_token就停止，也可以加入更多终止字符串
            repetition_penalty=1.2, # 重复惩罚
            min_tokens=1,   # 至少生成一个token
        )

        # 对大批量prompt进行分批处理以降低OOM风险
        batch_size = 256  # 每批处理的prompt数量，对于大模型可以根据实际情况调整
        
        if "7B" in self.model_name:
            batch_size = 256
        
        # 对于特别大的模型，使用更小的批处理大小
        if "14B" in self.model_name or "32B" in self.model_name or "70B" in self.model_name or "72B" in self.model_name:
            batch_size = 128  # 降低批处理大小
            
        results = []

        for i in range(0, len(prompts), batch_size):
            batch_prompts = prompts[i:i+batch_size]
            print(f"处理批次 {i//batch_size + 1}/{(len(prompts) + batch_size - 1)//batch_size}, 批次大小: {len(batch_prompts)}")
            
            try:
                # 清理CUDA缓存以减少碎片化
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                    
                batch_outputs = self.llm.generate(
                    batch_prompts,
                    sampling_params=sampling_params,
                )
                
                for output in batch_outputs:
                    for sub_output in output.outputs:  # greedy search，所以其实只有一个output；n>1的时候才有多个output
                        results.append(sub_output.text)
            
            except Exception as e:
                print(f"警告: 处理批次 {i//batch_size + 1} 时发生错误: {str(e)}")
                print(f"尝试单条处理该批次中的每个prompt...")
                
                # 如果批处理失败，尝试逐条处理并跳过有问题的prompt
                for j, prompt in enumerate(batch_prompts):
                    try:
                        # 每次处理前清理缓存
                        if torch.cuda.is_available():
                            torch.cuda.empty_cache()
                            
                        single_output = self.llm.generate(
                            [prompt],
                            sampling_params=sampling_params,
                        )
                        for output in single_output:
                            for sub_output in output.outputs:
                                results.append(sub_output.text)
                    except Exception as e2:
                        print(f"警告: 跳过问题性prompt (批次 {i//batch_size + 1}, 索引 {j}): {str(e2)}")
                        # 对于失败的prompt，添加一个空结果或错误标记
                        results.append("[ERROR: 处理该prompt时出错]")

        return results


class API_MODELS():
    # TODO: 对于只能调用API的模型，需要实现API调用的逻辑，这里统一使用OpenAI Client的形式
    def __init__(self, model_name):
        self.model_name = model_name
        assert self.model_name in API_MODEL_NAMES, f"model_name {self.model_name} not in API_MODEL_NAMES"
        self.api_key = API_MODEL_MAPPING[self.model_name][0]
        self.api_url = API_MODEL_MAPPING[self.model_name][1]
        self.temperature = API_MODEL_MAPPING[self.model_name][2]
        self.max_tokens = API_MODEL_MAPPING[self.model_name][3]
        if self.model_name == "Deepseek-R1":
            self.is_r1 = True
        else:
            self.is_r1 = False
    
    def generate(self, prompts, asyn=True, max_workers=100):
        # 并发调用API
        client = OpenAI(api_key=self.api_key, base_url=self.api_url)
        def _call_deepseek_api(prompt):
            # Calculate the delay based on your rate limit
            rate_limit_per_minute = 30000
            delay = 60.0 / rate_limit_per_minute
            # Sleep for the delay
            time.sleep(delay)

            # 重试机制
            retries = 10000000  # 最大重试次数
            for attempt in range(retries):
                try:
                    if "deepseek" in self.model_name.lower():
                        if "r1" in self.model_name.lower():
                            # model = "ep-20250223144841-bllbh"
                            model = "deepseek-reasoner"
                        elif "v3" in self.model_name.lower():
                            # model = "ep-20250218232341-bht2q"
                            model = "deepseek-chat"
                        else:
                            raise ValueError(f"model_name {self.model_name} not supported")
                    elif "qwen" in self.model_name.lower():
                        model = "Qwen/Qwen2.5-72B-Instruct-128K"
                    else:
                        model = self.model_name
                    response = client.chat.completions.create(
                        model=model,  # 使用的模型 | deepseek-ai/DeepSeek-R1
                        messages=[
                            {"role": "system", "content": "You are a helpful assistant"},
                            {"role": "user", "content": prompt},
                        ],
                        stream=False,
                        temperature=self.temperature,
                    )
                    if "r1" in self.model_name.lower():
                        reasoning_content = response.choices[0].message.reasoning_content.strip()
                        answer_content = response.choices[0].message.content.strip()
                        all_content = f"<think>{reasoning_content}</think><answer>{answer_content}</answer>"
                        return all_content
                    return response.choices[0].message.content.strip()  # 返回响应，去除首尾空白字符和换行符
                except Exception as e:
                    print(f"Attempt {attempt + 1} failed: {e}")
                    if attempt == retries - 1:
                        raise e  # 如果是最后一次失败，则抛出异常
                    time.sleep(1)  # 重试前等待 1 秒
        
        if asyn:
            messages = prompts if isinstance(prompts, list) else [prompts]
            
            with ThreadPoolExecutor(max_workers=max_workers) as executor:
                # 提交任务到线程池
                ret = executor.map(_call_deepseek_api, messages)
                results = [res for res in ret]
                return results
        else:
            return _call_deepseek_api(prompts if isinstance(prompts, str) else prompts[0])

class MODEL():
    def __init__(self, model_name):
        self.model_name = model_name
        
        # 根据模型名称判断使用哪种模型类
        if model_name in API_MODEL_NAMES:
            self.model = API_MODELS(model_name)
        # 添加对SGLang推理的支持
        # elif model_name in SGLANG_MODEL_NAMES and os.environ.get('USE_SGLANG', 'false').lower() == 'true':
        #     print(f"使用SGLang推理模型: {model_name}")
        #     self.model = SGLANG_MODEL(model_name)
        elif model_name in mapping_modelname_to_path:
            print(f"使用VLLM推理模型: {model_name}")
            self.model = VLLM_MODEL(model_name)
        elif model_name in TRANSFORMERS_MODEL_NAMES:
            # 暂时不支持Transformers模型，可以在这里添加相应的实现
            raise NotImplementedError(f"模型 {model_name} 是Transformers模型，目前尚未实现")
        else:
            raise ValueError(f"未知的模型名称: {model_name}")
        
        # 获取模型最大输入上下文长度
        assert model_name in mapping_modelname_to_max_input_context_len, f"model_name {model_name} not in mapping_modelname_to_max_input_context_len"
        self.max_input_context_len = mapping_modelname_to_max_input_context_len[model_name]
    
    def generate(self, texts, asyn=True, max_workers=100):
        # 调用底层模型的generate方法
        if isinstance(self.model, API_MODELS):
            return self.model.generate(texts, asyn=asyn, max_workers=max_workers)
        else:
            return self.model.generate(texts)

# 测试
if __name__ == "__main__":
    # 设置环境变量以使用SGLang，默认为false，通过设置为true来启用SGLang
    os.environ['USE_SGLANG'] = 'true'
    
    # 测试小型模型
    print("="*50)
    print("测试使用SGLang进行推理...")
    print("="*50)
    
    model = MODEL("Qwen2.5-7B")
    
    texts = ["What is the capital of France?", "What is the capital of Germany?"]
    results = model.generate(texts)
    print("生成结果:")
    for i, result in enumerate(results):
        print(f"问题 {i+1}: {texts[i]}")
        print(f"回答 {i+1}: {result}")
        print("-"*30)
