import os
import re
import tiktoken
import torch
import traceback
from collections import defaultdict
from pathlib import Path

import httpx
from FlagEmbedding import BGEM3FlagModel, FlagReranker
from openai import OpenAI, OpenAIError
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
from zhipuai import ZhipuAI

from src.graph_utils import path_to_str
from src.utils import logger

model_path_root = Path(os.getenv("MODEL_DIR"))
tiktoken.model.MODEL_TO_ENCODING["default"] = "cl100k_base"
tiktoken.model.MODEL_TO_ENCODING["meta-llama/Meta-Llama-3.1-8B-Instruct"] = "cl100k_base"
tiktoken.model.MODEL_TO_ENCODING["deepseek-ai/DeepSeek-V3"] = "cl100k_base"

MODEL_TO_CONTEXT_WINDOW = defaultdict(lambda: 8192, {
    "gpt-3.5-turbo": 4096,
    "gpt-4o-mini": 16384,
    "gpt-4o": 16384,
    "gpt-4-turbo": 4096,
})

class OpenAIBase():
    def __init__(self, api_key, base_url, model_name):
        self.client = OpenAI(api_key=api_key, base_url=base_url, http_client = httpx.Client(verify=False) )
        self.model_name = model_name
        self.context_window = MODEL_TO_CONTEXT_WINDOW[self.model_name]

    def predict(self, message, stream=False):
        if isinstance(message, str):
            messages=[{"role": "user", "content": message}]
        else:
            messages = message

        input_tokens = OpenAIBase.count_messages_tokens(messages)
        if input_tokens > self.context_window:
            logger.warning(f"Input tokens ({input_tokens}) exceed the context window ({self.context_window}) for model {self.model_name}")

        try:
            if stream:
                return self._stream_response(messages)
            else:
                return self._get_response(messages)
        except OpenAIError as e:
            logger.error(f"Error getting response from `{self.client}` with model `{self.model_name}`: {e}")
            raise e
        except Exception as e:
            logger.error(f"Error getting response from `{self.client}` with model `{self.model_name}`: {e}")
            raise e

    def _stream_response(self, messages):
        response = self.client.chat.completions.create(
            model=self.model_name,
            messages=messages,
            stream=True,
            extra_body={"chat_template_kwargs": {"enable_thinking": False}},
        )
        for chunk in response:
            yield chunk.choices[0].delta

    def _get_response(self, messages):
        response = self.client.chat.completions.create(
            temperature=0,
            model=self.model_name,
            messages=messages,
            stream=False,
            extra_body={"chat_template_kwargs": {"enable_thinking": False}},
        )
        try:
            return response.choices[0].message
        except Exception as e:
            logger.error(f"Error getting response from {self.model_name}: {e}")
            logger.error(traceback.format_exc())
            logger.error(f"Response: {response}")
            raise e

    @staticmethod
    def count_tokens(text):
        return len(tiktoken.encoding_for_model("default").encode(text))

    @staticmethod
    def count_messages_tokens(messages):
        total_tokens = 0
        for message in messages:
            if isinstance(message, dict) and "content" in message:
                content = message["content"]
            else:
                content = message.content
            total_tokens += OpenAIBase.count_tokens(content)
        return total_tokens


class VLLM(OpenAIBase):
    def __init__(self, model_name=None, base_url=None):
        model_name = model_name or "llama"
        api_key = os.getenv("VLLM_API_KEY", "EMPTY")
        base_url = base_url or os.getenv("VLLM_API_BASE", "http://localhost:8080/v1")
        logger.debug(f"Connecting to VLLM at {base_url} with model {model_name}")
        super().__init__(api_key=api_key, base_url=base_url, model_name=model_name)


class ZhipuEmbedding:

    def __init__(self, model_name) -> None:
        self.model_name = model_name
        self.client = ZhipuAI(api_key=os.getenv("ZHIPUAI_API_KEY"))
        self.query_instruction_for_retrieval = "为这个句子生成表示以用于检索相关文章："

    def predict(self, message, batch_size=120, **kwargs):
        data = []

        for i in tqdm(range(0, len(message), batch_size), desc="Zhipu Embedding..."):
            group_msg = message[i:i+batch_size]
            response = self.client.embeddings.create(
                model=self.model_name,
                input=group_msg,
                **kwargs
            )

            data.extend([a.embedding for a in response.data])

        return data

    def encode(self, message, **kwargs):
        return self.predict(message, **kwargs)

    def encode_queries(self, queries):
        return self.predict(queries)

class LocalEmbeddingModel():
    def __init__(self, model_name):
        self.model_name = model_name
        if model_name in ["bge-m3"]:
            self.model = BGEM3FlagModel(Path(os.getenv("MODEL_DIR")) / "BAAI/bge-m3")
        elif model_name in ["stella"]:
            # Bug 修复：将路径对象转换为字符串
            model_path = str(Path(os.getenv("MODEL_DIR")) / "dunzhang/stella_en_400M_v5")
            self.model = SentenceTransformer(model_path, trust_remote_code=True).cuda()
        else:
            raise ValueError(f"Unknown key `{model_name}`")

    def encode(self, text, is_query=False, **kwargs):
        if self.model_name in ["bge-m3"]:
            return self.model.encode(text, **kwargs)['dense_vecs']
        elif self.model_name in ["stella"]:
            prompt_name = "s2p_query" if is_query else "s2s_query"
            return self.model.encode(text, prompt_name=prompt_name, **kwargs)
        else:
            raise ValueError(f"Unknown key `{self.model_name}`")

    def stop_self_pool(self):
        """停止模型的多进程池，确保资源被正确释放"""
        try:
            if self.model_name in ["bge-m3"] and hasattr(self.model, "_pool"):
                # 如果模型有_pool属性，尝试关闭它
                if hasattr(self.model._pool, "close"):
                    self.model._pool.close()
                if hasattr(self.model._pool, "join"):
                    self.model._pool.join()
                if hasattr(self.model._pool, "terminate"):
                    self.model._pool.terminate()
                logger.info(f"已停止 {self.model_name} 的进程池")
        except Exception as e:
            logger.warning(f"清理 {self.model_name} 资源时发生错误: {str(e)}")
            logger.debug(traceback.format_exc())

class RerankerModel:
    def __init__(self, model_name_or_path):
        self.model_name_or_path = model_name_or_path

        if model_name_or_path in ["BAAI/bge-reranker-v2-m3", "BAAI/bge-reranker-v2-gemma"]:
            self.reranker = FlagReranker(model_path_root / model_name_or_path)
        else:
            self.reranker = FlagReranker(model_name_or_path)

    def score(self, query, path):
        if isinstance(path, str):
            path_str = path
        else:
            path_str = path_to_str(path[0][0], [path])[0]

        return self.reranker.compute_score([query, path_str], normalize=True)[0]

    def score_batch(self, query, paths, batch_size=256):
        inputs = [[query, path] for path in paths]
        return self.reranker.compute_score(inputs, normalize=True, batch_size=batch_size)

class LLMBaseReranker:
    """
    使用 LLM 作为 reranker 的基类，利用 Prompt 来实现 reranker 的功能，给出一个 0-1 之间的分数，实际上是一个打分器
    """
    def __init__(self, model_name):
        self.model_name = model_name
        self.model = select_llm_model(model_name)
        self.prompt = (
            "You are a helpful assistant that scores a path based on the query. "
            "The path is given in the following format: "
            "```{path}```"
            "The query is given in the following format: "
            "```{query}```"
            "Please give a score between 0 and 1 for the path (such as 0.0, 0.01, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0). without any other text. The score must be a float number."
            "And the score have to can be used to sort the paths. The higher the score, the better the path."
            "score:"
        )

    def score(self, query, path):
        input_text = self.prompt.format(query=query, path=path)
        response = self.model.predict(input_text, stream=False)

        # 提取回复中的浮点数
        logger.debug(f"Response: {response}")
        score = re.search(r'\d+\.\d+', response.content)
        return float(score.group(0))

    def score_batch(self, query, paths, batch_size=256):
        return [self.score(query, path) for path in paths]

class QwenReranker:
    """
    基于 Qwen3-Reranker 模型的重排序器
    Requires transformers>=4.51.0
    """

    def __init__(self, model_name_or_path="Qwen/Qwen3-Reranker-0.6B", max_length=8192, use_flash_attention=True):
        """
        初始化 QwenReranker

        Args:
            model_name_or_path: 模型路径或名称
            max_length: 最大序列长度
            use_flash_attention: 是否使用 flash attention 2
        """
        try:
            from transformers import AutoTokenizer, AutoModelForCausalLM
        except ImportError:
            raise ImportError("请安装 transformers: pip install transformers")

        temp_model_name_or_path = model_path_root / model_name_or_path
        if temp_model_name_or_path.exists():
            self.model_name_or_path = temp_model_name_or_path
        else:
            self.model_name_or_path = model_name_or_path

        self.max_length = max_length

        # 初始化 tokenizer 和模型
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, padding_side='left')

        if use_flash_attention:
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_name_or_path,
                torch_dtype=torch.float16,
                attn_implementation="flash_attention_2"
            ).cuda().eval()
        else:
            self.model = AutoModelForCausalLM.from_pretrained(self.model_name_or_path).eval()

        # 获取 token IDs
        self.token_false_id = self.tokenizer.convert_tokens_to_ids("no")
        self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes")

        # 设置前缀和后缀
        self.prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
        self.suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
        self.prefix_tokens = self.tokenizer.encode(self.prefix, add_special_tokens=False)
        self.suffix_tokens = self.tokenizer.encode(self.suffix, add_special_tokens=False)

        logger.info(f"QwenReranker 初始化完成，模型: {self.model_name_or_path}")

    def format_instruction(self, instruction, query, doc):
        """格式化指令、查询和文档"""
        if instruction is None:
            instruction = 'Given a question, retrieve relevant knowledge passages that can help answer the question'

        output = "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format(
            instruction=instruction, query=query, doc=doc
        )
        return output

    def process_inputs(self, pairs):
        """处理输入对，进行tokenization和padding"""
        inputs = self.tokenizer(
            pairs,
            padding=False,
            truncation='longest_first',
            return_attention_mask=False,
            max_length=self.max_length - len(self.prefix_tokens) - len(self.suffix_tokens)
        )

        for i, ele in enumerate(inputs['input_ids']):
            inputs['input_ids'][i] = self.prefix_tokens + ele + self.suffix_tokens

        inputs = self.tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=self.max_length)

        for key in inputs:
            inputs[key] = inputs[key].to(self.model.device)

        return inputs

    @torch.no_grad()
    def compute_logits(self, inputs):
        """计算logits并返回分数"""
        batch_scores = self.model(**inputs).logits[:, -1, :]
        true_vector = batch_scores[:, self.token_true_id]
        false_vector = batch_scores[:, self.token_false_id]
        batch_scores = torch.stack([false_vector, true_vector], dim=1)
        batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
        scores = batch_scores[:, 1].exp().tolist()
        return scores

    def score(self, query, path, instruction=None):
        """
        为单个查询-文档对打分

        Args:
            query: 查询文本
            path: 文档文本或路径对象
            instruction: 可选的任务指令

        Returns:
            相关性分数 (0-1 之间)
        """
        if isinstance(path, str):
            doc = path
        else:
            doc = path_to_str(path[0][0], [path])[0]

        pair = self.format_instruction(instruction, query, doc)
        inputs = self.process_inputs([pair])
        scores = self.compute_logits(inputs)
        return scores[0]

    def score_batch(self, query, paths, instruction=None, batch_size=32):
        """
        批量打分

        Args:
            query: 查询文本
            paths: 文档文本列表或路径对象列表
            instruction: 可选的任务指令
            batch_size: 批处理大小

        Returns:
            分数列表
        """
        # 转换路径对象为字符串
        docs = []
        for path in paths:
            if isinstance(path, str):
                docs.append(path)
            else:
                docs.append(path_to_str(path[0][0], [path])[0])

        all_scores = []

        # 分批处理
        for i in range(0, len(docs), batch_size):
            batch_docs = docs[i:i + batch_size]
            pairs = [self.format_instruction(instruction, query, doc) for doc in batch_docs]
            inputs = self.process_inputs(pairs)
            batch_scores = self.compute_logits(inputs)
            all_scores.extend(batch_scores)

        return all_scores

def select_embedding_model(model_name):
    if model_name in ['bge-m3', 'stella']:
        return LocalEmbeddingModel(model_name)

    elif model_name in ['embedding-2', 'embedding-3']:
        return ZhipuEmbedding(model_name=model_name)

def select_llm_model(model):
    if model in [
        "meta-llama/Meta-Llama-3.1-8B-Instruct",
        'deepseek-ai/DeepSeek-V3',
        'deepseek-ai/DeepSeek-R1',
        'Qwen/Qwen3-235B-A22B',
        'Qwen/Qwen3-32B',
    ]:
        api_key = os.getenv("SILICONFLOW_API_KEY")
        if not api_key:
            raise ValueError("SILICONFLOW_API_KEY is not set")
        return OpenAIBase(api_key=api_key, base_url="https://api.siliconflow.cn/v1", model_name=model)

    elif model in ["meta-llama/llama-3.1-8b-instruct:free"]:
        api_key = os.getenv("OPENROUTER_API_KEY")
        if not api_key:
            raise ValueError("OPENROUTER_API_KEY is not set")
        return OpenAIBase(api_key=api_key, base_url="https://openrouter.ai/api/v1", model_name=model)

    elif model in ["gpt-4-turbo", "gpt-3.5-turbo", "gpt-4o-mini", "gpt-4o"]:
        return OpenAIBase(api_key=os.getenv("OPENAI_API_KEY"), base_url=os.getenv("OPENAI_API_BASE"), model_name=model)

    elif model in ["llama3.1:8b", "qwen2.5:32b", "qwen3:32b",]:
        return VLLM(model_name=model)

    elif model in ["qwen3:0.6b"]:
        return VLLM(model_name=model, base_url="http://localhost:8080/v1")

    elif model in ["qwen3:8b"]:
        return VLLM(model_name=model, base_url="http://172.19.13.4:8080/v1")

    elif model in ["glm-4-flash", "glm-4-plus"]:
        api_key = os.getenv("ZHIPUAI_API_KEY")
        if not api_key:
            raise ValueError("ZHIPUAI_API_KEY is not set")
        return OpenAIBase(api_key=api_key, base_url="https://open.bigmodel.cn/api/paas/v4/", model_name=model)

    else:
        raise ValueError(f"Unknown key `{model}`")


def select_rerank_model(model_name):
    if model_name in ["qwen3:0.6b", "qwen3:8b"]:
        return LLMBaseReranker(model_name)
    elif "Qwen3-Reranker" in model_name:
        return QwenReranker(model_name)

    return RerankerModel(model_name)


def log_token_usage(stage, input_count, output_count, ins_id=None):

    # from main import conf
    from src.config import TIMESTAMP
    usage_path = f"outputs/logs/usage/{TIMESTAMP}.usage.log"
    os.makedirs(os.path.dirname(usage_path), exist_ok=True)
    if isinstance(input_count, str):
        input_tokens = OpenAIBase.count_tokens(input_count)
    elif isinstance(input_count, dict) and "content" in input_count:
        input_tokens = OpenAIBase.count_tokens(input_count["content"])
    else:
        input_tokens = OpenAIBase.count_messages_tokens(input_count)

    output_tokens = OpenAIBase.count_tokens(output_count)
    # logger.info(f"Token Usage - Stage: {stage}, Input: {input_tokens}, Output: {output_tokens}")
    with open(usage_path, "a") as f:
        f.write(f"{ins_id} | {stage} | {input_tokens} | {output_tokens}\n")

if __name__ == "__main__":
    # 测试 QwenReranker
    print("测试 QwenReranker...")
    reranker = QwenReranker("Qwen/Qwen3-Reranker-0.6B")

    # 测试数据
    queries = ["What is the capital of China?", "Explain gravity"]
    documents = [
        "The capital of China is Beijing.",
        "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
    ]

    # 测试单个打分
    score1 = reranker.score(queries[0], documents[0])
    print(f"Query: {queries[0]}")
    print(f"Document: {documents[0]}")
    print(f"Score: {score1}")

    # 测试批量打分
    scores = reranker.score_batch(queries[0], documents)
    print(f"\nBatch scores for query '{queries[0]}':")
    for i, (doc, score) in enumerate(zip(documents, scores)):
        print(f"Document {i+1}: {score:.4f}")

    print("\n测试完成！")
