import torch
import sys
import time
import random
import argparse
import numpy as np
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import re
import torch.distributed as dist
import secrets

class VLLM_models:
    def __init__(self, model_path, llm=None, device=0, gpu_memory_utilization=0.7, tensor_parallel_size=1):

        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        config = AutoConfig.from_pretrained(model_path)
        max_model_len = getattr(config, 'max_position_embeddings', None)
        max_model_len = min(max_model_len, 18192) if max_model_len is not None else None
        torch.cuda.init()  # ✅ 关键
        self.llm = LLM(
            model=model_path,
            gpu_memory_utilization= gpu_memory_utilization,
            max_model_len = max_model_len,
            tensor_parallel_size=tensor_parallel_size,
            enable_prefix_caching=True,
            device = device, 
            dtype="bfloat16",
        )

        
    def build_prompts(self, querys):
        prompts = []
        for q in querys:
            chat = [
                {"role": "user", "content": q},
            ]
            prompt = self.tokenizer.apply_chat_template(
                chat, tokenize=False, add_generation_prompt=True
            )
            prompts.append(prompt)
        return prompts

    def generate(self, prompts, apply_chat_template=None, max_model_len=1000, temperature=0, top_p=1.0, num_generation=1):
        self.sampling_params = SamplingParams(
            n=num_generation,
            temperature=temperature,
            top_p=top_p,
            max_tokens=max_model_len,
            skip_special_tokens=False,
            # 使用更安全的随机数生成方式
            seed=secrets.randbelow(2**32),
        )
        if apply_chat_template==None:
            prompts_input = self.build_prompts(prompts)
        else:
            prompts_input = list(map(apply_chat_template, prompts))
            
        outputs = self.llm.generate(prompts_input, sampling_params=self.sampling_params, use_tqdm=True)
        responses = []
        thinks = []
        for k in range(len(outputs)):
            response_single = [o.text for o in outputs[k].outputs]
            think_single = ['None' for o in outputs[k].outputs]
            # 添加异常处理，防止正则表达式匹配失败
            try:
                if '<tool_call>' in response_single[0]:
                    response_single = [re.search(r'<tool_call>\s*(.*)', text, re.DOTALL).group(1).strip() for text in response_single]
                    think_single = [re.search(r'(.*?)</think>', text, re.DOTALL).group(1).strip() for text in response_single]
            except (AttributeError, re.error):
                # 如果正则表达式匹配失败，则使用默认值
                response_single = [text for text in response_single]
                think_single = ['None' for _ in response_single]
            responses.append(response_single)
            thinks.append(think_single)

        return thinks, responses
    def close(self):
        """
        彻底关闭 vLLM 服务并释放所有显存。
        调用后该实例将不再可用，需要重新 new 一个 VLLM_models 才能继续推理。
        """
        if hasattr(self, "llm"):
            if hasattr(self.llm, "shutdown"):
                self.llm.shutdown()
            del self.llm

        if hasattr(self, "tokenizer"):
            del self.tokenizer

        # 关键：销毁 NCCL 进程组
        if dist.is_initialized():
            dist.destroy_process_group()

        # 3. 强制 PyTorch 释放显存
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()