import torch
from FlagEmbedding import FlagAutoReranker
from modelscope import AutoModelForSequenceClassification, AutoTokenizer, AutoModelForCausalLM


class BaseRerankCls:
    @staticmethod
    def get_top_context(context_list, score_list, top_k=1):
        score_list_with_index = [{"index": i, "score": score_list[i]} for i in range(len(score_list))]
        score_list_sorted = sorted(score_list_with_index, key=lambda x: x["score"], reverse=True)
        context_list_sorted = [context_list[i["index"]] for i in score_list_sorted]
        return context_list_sorted[:top_k]

    def _generate(self, *args, **kwargs):
        ...

    def generate(self, eval_dataset, top_k=1):
        context_list = []
        for meta in eval_dataset:
            context_list.append(self._generate(meta["query"], meta["pos"] + meta["neg"], top_k))

        return context_list


class BgeRerankCls(BaseRerankCls):

    def __init__(self, llm_path):
        self.reranker = FlagAutoReranker.from_finetuned(
            llm_path,
            model_class='encoder-only-base',
            query_max_length=256,
            passage_max_length=512,
            use_fp16=True,
            devices=['cuda:0']
        )

    def _generate(self, query, context_list, top_k=1):
        input_list = [[query, _context] for _context in context_list]
        scores = self.reranker.compute_score(input_list, normalize=True)
        return self.get_top_context(context_list, scores, top_k)


class GteRerankCls(BaseRerankCls):
    def __init__(self, llm_path):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 关键：检测设备
        self.tokenizer = AutoTokenizer.from_pretrained(llm_path)
        self.model = AutoModelForSequenceClassification.from_pretrained(
            llm_path,
            trust_remote_code=True
        ).to(self.device)  # 将模型移动到 GPU 或 CPU
        self.model.eval()

    def _generate(self, query, context_list, top_k=1):
        with torch.no_grad():
            input_list = [[query, _context] for _context in context_list]
            inputs = self.tokenizer(
                input_list,
                padding=True,
                truncation=True,
                return_tensors='pt',
                max_length=8192
            ).to(self.device)  # 关键：将输入数据移动到 GPU

            scores = self.model(**inputs, return_dict=True).logits.view(-1, ).float()
            return self.get_top_context(context_list, scores, top_k)


class QwenRerankCls(BaseRerankCls):
    def __init__(self, llm_path, instruction=None):
        self.tokenizer = AutoTokenizer.from_pretrained(llm_path, padding_side='left')
        self.model = AutoModelForCausalLM.from_pretrained(llm_path).eval()

        # 设置ID映射
        self.token_false_id = self.tokenizer.convert_tokens_to_ids("no")
        self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes")
        self.max_length = 8192  # 最大序列长度

        # 配置前缀/后缀
        self.prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
        self.suffix = "<|im_end|>\n<|im_start|>assistant\n"
        self.prefix_tokens = self.tokenizer.encode(self.prefix, add_special_tokens=False)
        self.suffix_tokens = self.tokenizer.encode(self.suffix, add_special_tokens=False)

        # 默认指令文本
        self.instruction = instruction or 'Given a web search query, retrieve relevant passages that answer the query'

        # GPU加速
        if torch.cuda.is_available():
            self.model = self.model.cuda()

    def _format_instruction(self, query, doc):
        """格式化指令文本"""
        return f"<Instruct>: {self.instruction}\n<Query>: {query}\n<Document>: {doc}"

    def _process_inputs(self, texts):
        """预处理输入文本"""
        inputs = self.tokenizer(
            texts,
            padding=False,
            truncation='longest_first',
            return_attention_mask=False,
            max_length=self.max_length - len(self.prefix_tokens) - len(self.suffix_tokens)
        )

        # 添加前后缀标记
        for i in range(len(inputs['input_ids'])):
            inputs['input_ids'][i] = self.prefix_tokens + inputs['input_ids'][i] + self.suffix_tokens

        # 填充并转换为tensor
        inputs = self.tokenizer.pad(inputs, padding=True, return_tensors="pt")

        # 移至GPU
        if torch.cuda.is_available():
            for k in inputs:
                inputs[k] = inputs[k].cuda()

        return inputs

    def _compute_logits(self, inputs):
        """计算相关性分数"""
        with torch.no_grad():
            batch_scores = self.model(**inputs).logits[:, -1, :]
            true_vector = batch_scores[:, self.token_true_id]
            false_vector = batch_scores[:, self.token_false_id]
            batch_scores = torch.stack([false_vector, true_vector], dim=1)
            batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
            return batch_scores[:, 1].exp().tolist()

    def _generate(self, query, context_list, top_k=1, batch_size=10):  # 增加batch_size参数
        scores = []
        for i in range(0, len(context_list), batch_size):
            batch_pairs = context_list[i: i + batch_size]
            formatted_pairs = [self._format_instruction(query, doc) for doc in batch_pairs]
            inputs = self._process_inputs(formatted_pairs)
            batch_scores = self._compute_logits(inputs)
            scores.extend(batch_scores)
        return self.get_top_context(context_list, scores, top_k)
