import logging

from collections import Counter,defaultdict

from transformers import StoppingCriteriaList, MaxLengthCriteria, AutoTokenizer, AutoModelForSequenceClassification
from nltk.corpus import stopwords
punctuation = '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~'
stopword_set = set(stopwords.words('english'))
import time

import torch 
from itertools import combinations
from .helper import clean_str, StopOnTokens
import copy
from tqdm import tqdm
from torch import LongTensor, FloatTensor
import numpy as np
from numpy import dot
from numpy.linalg import norm
import os
from openai import OpenAI
from sentence_transformers import SentenceTransformer
from sentence_transformers import util

import spacy
import os
import json
import random
from transformers import pipeline
from itertools import chain, combinations
import math

from src.decoding_methods import secure_decoding

from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
import time
from .pairwise_em import PairwiseConflictEM

from sklearn.metrics.pairwise import cosine_similarity
import networkx as nx
from networkx.algorithms import approximation

import re

# 实现了一组用于 RAG 场景的防御/聚合策略（基类 + 多种具体方法），目的是在检索到的 top-k 文档可能被攻击/污染时仍尽量产出可靠答案。
# 所有防御都以相同输入格式 data_item（至少包含 question, topk_content, answer 等字段）为输入，输出最终的文本回答（字符串）。

logger = logging.getLogger('RRAG-main')

INJECTION = True # injection attack. if False, we consider passage modification attacks discussed in the appendix

def save_all_responses(save_path,response_list,data_item):
    all_data = []# it is a bit ugly... unnecessary read and write ; TODO: change it to jsonl instead
    if os.path.exists(save_path) and os.path.getsize(save_path) > 0:
        with open(save_path,'r') as f:
            all_data = json.load(f)
    all_data.append({"query":data_item['question'],
                     "answer":data_item['answer'],
                     "response":response_list})
    with open(save_path,'w') as f:
        json.dump(all_data,f,indent=4)

# 基类
class RRAG:

    def __init__(self,llm):
        self.llm = llm

    def query_undefended(self,data_item):
        query_prompt = self.llm.wrap_prompt(data_item,as_multi_choice='choices' in data_item)
        #response = None 
        response =  self.llm.query(query_prompt)
        logger.debug(f'Query_prompt:\n{query_prompt}')
        logger.debug(f'Response:\n{response}')
        logger.debug(f'Answer:\n{data_item["answer"]}')
        return response

    def query(self, data_item):
        raise NotImplementedError

    def _eval_response(self,response,data_item):
        answer = data_item['answer']
        response = clean_str(response)
        for ans in answer:
            if clean_str(ans) in response:
                return True 
        return False

class UnionFind:
    """简单的并查集，支持：
    - 按 size 合并
    - 记录每个簇中最小的原始文档下标（用于平票时打破平局）
    """
    def __init__(self, n, index_to_original):
        # n: 有效文档数量（过滤掉 "I don't know" 后）
        # index_to_original: 有效文档索引 -> 原始文档索引
        self.parent = list(range(n))
        self.size = [1] * n
        # 对每个根节点，记录该簇内最小的原始文档下标（用于打破平局）
        self.min_rank = [index_to_original[i] for i in range(n)]

    def find(self, x):
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]

    def union(self, x, y):
        rx, ry = self.find(x), self.find(y)
        if rx == ry:
            return
        # 按 size 合并，保证树尽量平衡
        if self.size[rx] < self.size[ry]:
            rx, ry = ry, rx
        # 把 ry 合并入 rx
        self.parent[ry] = rx
        self.size[rx] += self.size[ry]
        self.min_rank[rx] = min(self.min_rank[rx], self.min_rank[ry])


class ClusterBasedRRAG(RRAG):
    """
    基于 NLI + 阈值 + 并查集（连通分量）的 RAG 聚类多数决实现。

    步骤：
    1. 对每个文档单独问 LLM，得到单文档答案；
    2. 过滤掉答案为 "I don't know" 的文档；
    3. 对剩余文档答案，两两跑 NLI，计算 entailment 概率；
    4. 若 entailment_prob >= sim_threshold，则认为两个答案属于同一簇，执行 union；
    5. 所有 union 完成后，用并查集的连通分量作为簇；
    6. 选簇大小最大的那个簇；若多个簇大小相同，则选原始文档下标最小的那个簇；
    7. 用该簇中的文档重新拼 prompt 问一次 LLM，得到最终答案。
    """

    def __init__(self, llm, err=0.0, sim_threshold=0.5):
        """
        :param llm: 你的 LLM 封装对象，要求有 batch_query / wrap_prompt / _query 接口
        :param err: 预留的“噪声”参数（暂未使用，可以用来做随机扰动）
        :param sim_threshold: NLI 中将两个答案视为“相似”的阈值（基于 entailment 概率）
        """
        self.llm = llm
        self.err = err
        self.sim_threshold = sim_threshold

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # device = "cpu"  # for gpt-4o

        model_name = "DeBERTa-v3-large-mnli-fever-anli-ling-wanli"
        self.nli_tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.nli_model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)

        # 通过 config 自动解析 label 的索引，避免手写 magic number
        id2label = {i: l.lower() for i, l in self.nli_model.config.id2label.items()}
        entail_idx = None
        contradict_idx = None
        for i, name in id2label.items():
            if "entail" in name:
                entail_idx = i
            if "contradict" in name:
                contradict_idx = i
        if entail_idx is None:
            raise ValueError(f"Cannot find entailment label in id2label: {self.nli_model.config.id2label}")
        self.entail_idx = entail_idx
        self.contradict_idx = contradict_idx  # 如果以后想用“非矛盾”来定义相似度，可以用到

        logger.info(
            f"[ClusterBasedRRAG] NLI labels: {self.nli_model.config.id2label}, "
            f"entail_idx={self.entail_idx}, contradict_idx={self.contradict_idx}"
        )

    def query(self, data_item):
        """
        执行一次 RAG 查询（单轮），输入 data_item 应包含：
        - data_item['question']: 问题文本
        - data_item['topk_content']: 检索出的 top-k 文档内容列表

        返回：最终答案字符串
        """
        docs = data_item['topk_content']
        k = len(docs)

        # 1. 对每个文档单独问 LLM，得到单文档答案
        start_time = time.perf_counter()
        seperate_responses = self.llm.batch_query(
            self.llm.wrap_prompt(
                data_item,
                as_multi_choice='choices' in data_item,
                seperate=True
            )
        )
        end_time = time.perf_counter()
        logger.info(f"[ClusterBasedRRAG] Single-doc responses: {seperate_responses}")
        logger.info(f"[ClusterBasedRRAG] Time for batch single-doc query: {end_time - start_time:.4f}s")
        logging.getLogger().handlers[0].flush()  # 强制刷新（你原来的习惯）

        # 2. 过滤掉 "I don't know" 的答案
        valid_indices = [
            i for i, ans in enumerate(seperate_responses)
            if "I don't know" not in ans
        ]
        m = len(valid_indices)

        # 如果全部都是 "I don't know"，那就直接用所有文档一次性问 LLM
        if m == 0:
            logger.info("[ClusterBasedRRAG] All single-doc answers are 'I don't know'; "
                        "fallback to using all documents.")
            return self._final_query_with_docs(data_item, list(range(k)))

        # 3. 对有效答案两两跑 NLI，构造并查集 + 连通分量
        #    UnionFind 的索引是 0..m-1，对应 valid_indices[i] 这个原始文档下标
        uf = UnionFind(m, index_to_original=valid_indices)

        premises, hypotheses, pair_indices = [], [], []
        # 只对有效答案两两组合
        for a in range(m):
            for b in range(a + 1, m):
                i = valid_indices[a]
                j = valid_indices[b]
                q = data_item['question']
                premise = f"The answer to the question: {q}\nis {seperate_responses[i]}."
                hypothesis = f"The answer to the question: {q}\nis {seperate_responses[j]}."
                premises.append(premise)
                hypotheses.append(hypothesis)
                # pair_indices 记录的是在并查集中的索引 (a, b)
                pair_indices.append((a, b))

        if premises:
            inputs = self.nli_tokenizer(
                premises,
                hypotheses,
                return_tensors='pt',
                truncation=True,
                padding=True
            )
            inputs = {key: value.to(self.nli_model.device) for key, value in inputs.items()}

            start_time = time.perf_counter()
            with torch.no_grad():
                outputs = self.nli_model(**inputs)
            probs = torch.softmax(outputs.logits, dim=1)
            end_time = time.perf_counter()
            logger.info(f"[ClusterBasedRRAG] Time for NLI: {end_time - start_time:.4f}s")

            # 对每一对答案，若 entailment 概率 >= 阈值，则认为它们属于同一簇，执行 union
            for idx, (a, b) in enumerate(pair_indices):
                entail_prob = probs[idx][self.entail_idx].item()
                # 你可以在这里加入 err 的随机扰动逻辑，这里先不给它搞复杂，直接用阈值
                if entail_prob >= self.sim_threshold:
                    uf.union(a, b)

        # 4. 从并查集恢复各个簇（只考虑有效文档）
        clusters = defaultdict(list)  # root -> [valid_idx(0..m-1)]
        for valid_idx in range(m):
            root = uf.find(valid_idx)
            clusters[root].append(valid_idx)

        # 5. 按“簇大小最大；若平局则原始文档 index 最小”选出最佳簇
        best_root = None
        best_size = 0
        best_min_rank = None

        for root, members in clusters.items():
            size = len(members)
            # 对这个簇内所有成员，找出原始文档下标的最小值，用于打破平局
            orig_indices = [valid_indices[vi] for vi in members]
            min_rank = min(orig_indices)
            if (size > best_size) or (size == best_size and (
                best_min_rank is None or min_rank < best_min_rank
            )):
                best_root = root
                best_size = size
                best_min_rank = min_rank

        if best_root is None:
            # 理论上不应该发生，防御性代码
            logger.warning("[ClusterBasedRRAG] No cluster found; "
                           "fallback to using all valid documents.")
            chosen_doc_indices = valid_indices
        else:
            chosen_doc_indices = [valid_indices[vi] for vi in clusters[best_root]]

        chosen_doc_indices.sort()  # 保持由好到差（原始 rank 由小到大）
        logger.info(f"[ClusterBasedRRAG] Selected document indices: {chosen_doc_indices}")
        logging.getLogger().handlers[0].flush()

        # 6. 用选中的文档重新构造 data_item，并问一次最终答案
        final_answer = self._final_query_with_docs(data_item, chosen_doc_indices)
        logger.info(f"[ClusterBasedRRAG] Final answer: {final_answer}")
        return final_answer

    def _final_query_with_docs(self, data_item, doc_indices):
        """辅助函数：给定要使用的文档下标列表，拼出最终 prompt 并调用一次 LLM。"""
        docs = data_item['topk_content']
        new_data_item = data_item.copy()
        new_data_item['topk_content'] = [docs[i] for i in doc_indices]

        ultimate_prompt = self.llm.wrap_prompt(
            new_data_item,
            as_multi_choice='choices' in data_item,
            seperate=False
        )
        #print(ultimate_prompt)  # 保留你原来的 debug 行为

        start_time = time.perf_counter()
        final_answer = self.llm._query(ultimate_prompt)
        end_time = time.perf_counter()
        print("time for the ultimate query: ", end_time - start_time)
        print("final_answer:", final_answer)

        return final_answer

class MinCutRRAG(RRAG):
    def __init__(self, llm, nli_model_path="DeBERTa-v3-large-mnli-fever-anli-ling-wanli"):
        super().__init__(llm)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.nli_tokenizer = AutoTokenizer.from_pretrained(nli_model_path)
        self.nli_model = AutoModelForSequenceClassification.from_pretrained(nli_model_path).to(self.device)
        self.nli_model.eval()
        self.nli_batch_size = 32  # 默认 batch_size
        self.embed_model = SentenceTransformer('all-MiniLM-L6-v2')

    def query(self, data_item):
        docs = data_item['topk_content']
        k = len(docs)

        # Step 1: Generate single-document answers and compute conflict matrix M using NLI on answers
        responses = []
        valid_docs = []  # 记录有效的文档索引
        for i in range(k):
            single_data_item = data_item.copy()
            single_data_item['topk_content'] = [docs[i]]
            single_prompt = self.llm.wrap_prompt(single_data_item, as_multi_choice='choices' in data_item, seperate=False)
            resp = self.llm._query(single_prompt)
            
            if "I don't know" in resp:  # 如果答案包含 "I don't know"，跳过该文档
                logger.info(f"[MinCut] Skipping document {i} as it contains 'I don't know'.")
                continue
            
            responses.append(resp)
            valid_docs.append(i)  # 记录有效文档的索引

        # 如果没有有效文档，直接对 LLM 进行原始询问并返回最终答案
        if not responses:
            logger.info("[MinCut] No valid documents found, querying LLM directly for the final answer.")
            final_prompt = self.llm.wrap_prompt(data_item, as_multi_choice='choices' in data_item, seperate=False)
            final_answer = self.llm._query(final_prompt)
            return final_answer  # 返回 LLM 查询结果

        logger.info(f"[MinCut] Responses: {responses}")
        logging.getLogger().handlers[0].flush()  # 强制刷新

        M, C = self._build_sim_and_conflict_matrices(data_item['question'], responses)
        logger.info(f"[MinCut] M: {M.tolist()}")

        # Step 2: S_raw 和 F_raw 计算原始值
        #S_cons = M.sum(axis=1) / len(responses)            
        #F_cons = C.sum(axis=1) / len(responses)

        k_valid = len(responses)
        #S_rank = np.zeros(k_valid)
        #F_rank = np.zeros(k_valid)
        """
        for i in range(k_valid):
            orig_idx = valid_docs[i]
            total_docs = len(data_item['topk_content'])
            # rank_score = np.log(total_docs - orig_idx) / np.log(total_docs) # 原始排名得分，越靠前越高
            rank_score = (total_docs - orig_idx) / total_docs  # 原始排名得分，越靠前越高
            S_rank[i] = rank_score
            F_rank[i] = 1.0 - rank_score
        """
       # beta = 0.6  # 权重参数，可调节
        #S_raw = beta * S_cons + (1 - beta) * S_rank
        #F_raw = beta * F_cons + (1 - beta) * F_rank
        
        # Step 3: 缩放 S 和 F
        def scale_flat(arr, target_mean=0.5, jitter=0.2):
            arr = np.array(arr, dtype=float)
            if arr.ptp() < 1e-8:
                return np.ones_like(arr) * target_mean
            
            arr = (arr - arr.mean()) / (arr.std() + 1e-8)
            arr = arr * jitter + target_mean
            
            return arr.clip(target_mean - jitter, target_mean + jitter)
        
        """
        def scale_flat(arr, target_mean, jitter):
            arr = np.array(arr, dtype=float)
            
            if arr.ptp() < 1e-8:  # If the range of arr is very small, just return the target_mean
                return np.ones_like(arr) * target_mean
            
            # 归一化到 [0, 1] 范围
            normalized = (arr - arr.min()) / (arr.max() - arr.min())

            # 线性映射到 [target_mean - jitter, target_mean + jitter] 范围
            scaled = normalized * (target_mean + jitter - (target_mean - jitter)) + (target_mean - jitter)

            # 确保值在目标范围内
            return np.clip(scaled, target_mean - jitter, target_mean + jitter)
        """
        #S_rank = scale_flat(S_rank, 0.8, 0.2)
        #F_rank = scale_flat(F_rank, 0.8, 0.2)
        #S_cons = scale_flat(S_cons, 0.7, 0.3)
        #F_cons = scale_flat(F_cons, 0.7, 0.3)
        #S = S_rank * S_cons#scale_flat(S_cons, 0.8, 0.2)#
        #F = F_rank * F_cons#scale_flat(F_cons, 0.8, 0.2)#
        total_docs = len(data_item['topk_content'])
        S, F = self.compute_scores_balanced(M, C, valid_docs=valid_docs, total_docs=total_docs)
        
        logger.info(f"[MinCut] S: {S.tolist()}")
        logger.info(f"[MinCut] F: {F.tolist()}")
        
        # Step 4: Build the graph for min-cut
        G = nx.DiGraph()  # Directed graph
        
        # Add source and sink nodes
        source = "source"
        sink = "sink"
        G.add_node(source)
        G.add_node(sink)
        
        # Add edges from source to documents (S_i values) only for valid docs
        for i in range(len(responses)):  # 使用 valid_docs 中的索引
            G.add_edge(source, i, capacity=S[i])  # S 和 F 中的索引依赖于 valid_docs
        
        # Add edges from documents to sink (F_i values) only for valid docs
        for i in range(len(responses)):  # 使用 valid_docs 中的索引
            G.add_edge(i, sink, capacity=F[i])  # S 和 F 中的索引依赖于 valid_docs
        
        # Add edges between documents (M_ij values) only for valid docs
        for i in range(len(responses)):
            for j in range(i + 1, len(responses)):
                G.add_edge(i, j, capacity=M[i, j])
                G.add_edge(j, i, capacity=M[i, j])
        
        # Step 5: Compute the minimum cut using max-flow min-cut theorem
        min_cut_value, partition = nx.minimum_cut(G, source, sink)
        
        # Step 6: Get the documents in the selected partition (the reliable documents)
        reachable, non_reachable = partition
        
        # Sort the reachable documents (those selected) based on their indices
        selected_docs = sorted([node for node in reachable if isinstance(node, int)])
        logger.info(f"[MinCut] Selected: {selected_docs}")
        
        # Post-processing: Calculate average cosine of selected responses, if <0.7, exclude isolated ones
        if len(selected_docs) > 1:
            # Recompute embeddings and cos_sim for consistency (though could reuse from _build)
            selected_responses = [responses[i] for i in selected_docs]
            selected_embeddings = self.embed_model.encode(selected_responses)
            selected_cos_sim = util.cos_sim(selected_embeddings, selected_embeddings).cpu().numpy()

            # Calculate average off-diagonal cosine
            n_selected = len(selected_docs)
            off_diag_mask = ~np.eye(n_selected, dtype=bool)
            avg_cosine = np.mean(selected_cos_sim[off_diag_mask])
            logger.info(f"[MinCut] Average cosine of selected: {avg_cosine}")

            if avg_cosine < 1.0:
                # Exclude isolated: Compute per-doc average cosine to others (exclude self)
                doc_avg_cos = []
                for idx in range(n_selected):
                    others = [j for j in range(n_selected) if j != idx]
                    avg = np.mean(selected_cos_sim[idx, others])
                    doc_avg_cos.append(avg)

                # Exclude docs with avg_cos < 0.3 (arbitrary threshold for 'isolated')
                isolation_threshold = 0.3
                new_selected = [selected_docs[idx] for idx in range(n_selected) if doc_avg_cos[idx] >= isolation_threshold]
                if new_selected != selected_docs:
                    selected_docs = sorted(new_selected)
                    logger.info(f"[MinCut] After excluding isolated: {selected_docs}")
        

        # Step 7: Final answer generation
        selected_data_item = data_item.copy()
        selected_data_item['topk_content'] = [docs[valid_docs[i]] for i in selected_docs]  # 映射回原始文档
        
        final_prompt = self.llm.wrap_prompt(selected_data_item, as_multi_choice='choices' in data_item, seperate=False)
        final_answer = self.llm._query(final_prompt)
        return final_answer
    
    def _build_sim_and_conflict_matrices(self, question, responses):
        k = len(responses)
        # 初始化矩阵
        M = np.zeros((k, k), dtype=np.float32)
        C = np.zeros((k, k), dtype=np.float32)

        # 1. 预计算 Embeddings (用于过滤完全无关的噪声)
        embeddings = self.embed_model.encode(responses)
        cos_sim = util.cos_sim(embeddings, embeddings).cpu().numpy()

        # 2. 准备 NLI 推理对
        # 为了防御攻击，我们需要通过 NLI 来判断逻辑关系
        pairs = []
        indices = []
        
        for i in range(k):
            for j in range(k):
                if i == j: 
                    M[i, j] = 1.0 # 自己和自己完全一致
                    continue
                
                # 过滤掉 "I don't know" 这种无效回答
                # 它们既不支持别人，也不反驳别人，M=0, C=0
                if "I don't know" in responses[i] or "I don't know" in responses[j]:
                    continue

                # 性能优化：如果语义相似度极低 (比如 < 0.4)，直接视为不相关
                # 既不一致，也不矛盾 (Neutral)
                #if cos_sim[i, j] < 0.4:
                #    continue

                # 构建 Prompt: 必须包含 question 以提供上下文
                premise = f"Question: {question} Answer: {responses[i]}"
                hypothesis = f"Question: {question} Answer: {responses[j]}"
                
                pairs.append((premise, hypothesis))
                indices.append((i, j))

        # 3. 批量 NLI 推理
        if len(pairs) > 0:
            # 假设 NLI 输出顺序是 [Entailment, Neutral, Contradiction]
            # 注意：不同模型的输出顺序不同，请根据你的模型确认！
            # DeBERTa-v3-mnli 通常是: 0:Entailment, 1:Neutral, 2:Contradiction (或者反过来，务必检查 config)
            # 这里假设是 [Entailment, Neutral, Contradiction]
            
            batch_size = 16
            nli_probs_map = {}
            
            for start in range(0, len(pairs), batch_size):
                batch_pairs = pairs[start:start+batch_size]
                inputs = self.nli_tokenizer(
                    [p[0] for p in batch_pairs], 
                    [p[1] for p in batch_pairs], 
                    return_tensors='pt', truncation=True, padding=True
                ).to(self.device)

                with torch.no_grad():
                    outputs = self.nli_model(**inputs)
                    # 使用 Softmax 归一化概率
                    probs = torch.softmax(outputs.logits, dim=1).cpu().numpy()
                
                # 存入临时字典
                for idx, prob in enumerate(probs):
                    real_idx = start + idx
                    i, j = indices[real_idx]
                    nli_probs_map[(i, j)] = prob

            # 4. 填充矩阵 (双向逻辑处理)
            for i in range(k):
                for j in range(i + 1, k):
                    if (i, j) not in nli_probs_map or (j, i) not in nli_probs_map:
                        continue

                    # 获取双向概率
                    # P_ij 代表: i 是前提, j 是假设
                    # probs 格式: [Entailment, Neutral, Contradiction] (需要确认你的模型索引)
                    p_ij = nli_probs_map[(i, j)] 
                    p_ji = nli_probs_map[(j, i)]

                    # --- 计算 C (矛盾矩阵) ---
                    # 只要有一方觉得矛盾，就是矛盾。这是对抗防御的关键。
                    # 哪怕 Cosine很高，只要由 "not" 引起的 Contradiction 很高，这里就会捕获。
                    contra_score = np.sqrt(p_ij[2] * p_ji[2]) # 取最大矛盾概率
                    
                    # 只有当矛盾概率显著大于中立概率时，才算有效矛盾
                    if contra_score > 0.5:
                        C[i, j] = C[j, i] = contra_score
                        M[i, j] = M[j, i] = 0.0 # 矛盾则不可能相似
                    else:
                        # --- 计算 M (一致性矩阵) ---
                        # 必须双方互证。几何平均值 (Geometric Mean) 可以惩罚单向蕴含。
                        # 如果 A->B 是 0.9，但 B->A 是 0.1，sqrt(0.09) = 0.3 (很低)，这是对的。
                        entail_score = np.sqrt(p_ij[0] * p_ji[0])
                        
                        # 辅助加成：如果 Cosine 很高，可以给 Entailment 一个小小的 boost，防止 NLI 过于保守
                        # 但在对抗攻击下，不要过度依赖 cos_sim
                       #combined_sim = entail_score 
                        
                        M[i, j] = M[j, i] = entail_score
                        C[i, j] = C[j, i] = contra_score

        return M, C
    def compute_scores_balanced(self, M, C, valid_docs, total_docs):
        k = len(valid_docs)
        if k == 1: return np.array([0.9]), np.array([0.1])

        # 1. 计算中心度
        adj = M + np.eye(k) * 0.01
        v = np.ones(k) / k
        for _ in range(10):
            v = np.dot(adj, v); v /= (np.linalg.norm(v) + 1e-8)
        
        # 这里的 centrality 决定了谁是“核心”
        centrality = (v - v.min()) / (v.max() - v.min() + 1e-8)

        # 2. 计算基础 S 和 F
        S_raw = np.zeros(k)
        F_raw = np.zeros(k)

        for i in range(k):
            # S: 中心度 * 排名衰减
            rank_weight = np.exp(-valid_docs[i] / total_docs)
            rank_penalty = 1 - rank_weight  # 低排 → penalty ≈1，高排 → penalty ≈0
            #orig_rank = valid_docs[i]
            #phase = (orig_rank / total_docs) * (np.pi / 2)
            #rank_weight = np.cos(phase)
            S_raw[i] = (centrality[i] + 1e-8) * rank_weight
            
            weighted_conflict = 0
            for j in range(k):
                if i == j: continue
                # 对方的 centrality 代表了对方在共识中的“话语权”
                # 如果我跟一个“话语权”很高（处于共识中心）的人冲突，我的 F 应该很高
                weighted_conflict += C[i, j] * centrality[j]
            
            # 归一化冲突得分
            F_raw[i] = weighted_conflict / (np.sum(centrality) - centrality[i] + 1e-8)

        # --- 关键：将 S 和 F 映射到同一量级 ---
        # 我们希望 S 的均值和 F 的均值接近，这样最小割才有“纠结”的空间
        def final_scale(arr):
            if arr.ptp() < 1e-12: return np.ones_like(arr) * 0.5
            return (arr - arr.min()) / (arr.ptp() + 1e-12)

        S = final_scale(S_raw)
        F = final_scale(F_raw)

        S = np.clip(S, 0.01, 0.99)
        F = np.clip(F, 0.01, 0.99)

        # --- 关键：加强 M 的权重 ---
        # 让 M 的量级与 S/F 相当，否则节点之间拉不住
        #M_scaled = M * 0.5 

        return S, F

class DynamicMinCutRRAG(MinCutRRAG):
    def __init__(self, llm, nli_model_path="DeBERTa-v3-large-mnli-fever-anli-ling-wanli"):
        super().__init__(llm, nli_model_path)

    def _bayesian_update(self, prior, likelihood):
        """
        根据贝叶斯定理更新概率:
        P(H|E) = (P(E|H) * P(H)) / P(E)
        其中 P(E) = P(E|H)*P(H) + P(E|~H)*P(~H)
        假设 P(E|~H) ≈ 1 - P(E|H) (对称性假设，简化计算)
        """
        # 防止数值计算错误 (Prior 不能完全为 0 或 1)
        prior = np.clip(prior, 0.01, 0.99)
        
        numerator = likelihood * prior
        # 全概率公式作为分母（归一化因子）
        denominator = numerator + (1 - likelihood) * (1 - prior)
        
        # 避免除以零
        return numerator / (denominator + 1e-9)
    
    def dynamic_query(self, data_item, previous_answer=None, previous_priors=None):
        
        #处理动态场景的查询。
        #:param data_item: 包含 'question' 和 'topk_content' (新文档) 的字典
        #:param previous_answer: 上一轮运行得出的最终答案 (String)，如果是第一次运行则为 None
        
        docs = data_item['topk_content']
        question = data_item['question']
        k = len(docs)

        # 默认先验概率 (如果第一轮没有，则设为 0.5 中立)
        if previous_priors is None:
            previous_priors = {'S': 1.0, 'F': 0.0}

        # --- Step 1: 生成单文档答案，并过滤无效文档 ---
        responses = []
        doc_sources = [] # 存储来源内容，用于最后拼接 prompt (新文档存原文，旧答案存答案本身)
        valid_docs = []  # 记录有效的文档索引
        
        # 1.1 处理新文档
        for i in range(k):
            single_data_item = data_item.copy()
            single_data_item['topk_content'] = [docs[i]]
            # 调用 LLM 生成单个文档的答案
            single_prompt = self.llm.wrap_prompt(single_data_item, as_multi_choice='choices' in data_item, seperate=False)
            resp = self.llm._query(single_prompt)
            
            if "I don't know" in resp:
                logger.info(f"[DynamicMinCut] Skipping new document {i} as it contains 'I don't know'.")
                continue
            
            responses.append(resp)
            doc_sources.append(docs[i]) # 对应的新闻档内容
            valid_docs.append(i)  # 记录有效文档的索引

        # 1.2 处理旧答案 (如果存在)
        old_answer_idx = -1
        if previous_answer and "I don't know" not in previous_answer:
            logger.info(f"[DynamicMinCut] Adding previous answer to the pool.")
            responses.append(previous_answer)
            # 旧答案本身就是“根据资料”得到的内容，直接作为 source
            doc_sources.append(f"Previous reliable conclusion: {previous_answer}") 
            old_answer_idx = len(responses) - 1
            valid_docs.append(k)  # 旧答案视为第 k 个文档

        logger.info(f"[DynamicMinCut] Valid responses: {responses}")
        logging.getLogger().handlers[0].flush()  # 强制刷新

        # 如果没有任何有效响应（既没有有效新文档，也没有旧答案）
        if not responses:
            logger.info("[DynamicMinCut] No valid info found. Querying LLM directly.")
            final_prompt = self.llm.wrap_prompt(data_item, as_multi_choice='choices' in data_item, seperate=False)
            final_answer = self.llm._query(final_prompt)
            # 没有新证据时，priors 沿用上一轮；若上一轮也没有，就给一个中立/默认
            if previous_priors is None:
                previous_priors = {'S': 0.5, 'F': 0.5}   # 或者你想要的默认
            return final_answer, previous_priors

        # --- Step 2: 使用 NLI 计算两两相似度 (M矩阵) ---
        # 注意：这里包含了 (新文档 vs 新文档) 以及 (新文档 vs 旧答案) 的计算
        M, C = self._build_sim_and_conflict_matrices(question, responses)
        logger.info(f"[DynamicMinCut] M: {M.tolist()}")
        
        # --- Step 3: 计算 S 和 F，并构建最小割图 ---
        
        # 3.1 计算原始 S 和 F
        # S: 平均支持度 (该回答被其他回答蕴含的概率)
        # F_raw: 平均被反驳度
        S_cons = M.sum(axis=1) / len(responses)            
        F_cons = C.sum(axis=1) / len(responses)

        k_valid = len(responses)
        total_docs = len(data_item['topk_content']) + (1 if previous_answer else 0)
        S, F = self.compute_scores_balanced(M, C, valid_docs=valid_docs, total_docs=total_docs)
        logger.info(f"[DynamicMinCut] new S: {S.tolist()}")
        logger.info(f"[DynamicMinCut] new F: {F.tolist()}")
        # === 核心修改：贝叶斯更新 (仅针对旧答案) ===
        if old_answer_idx != -1:
            #logger.info(f"[DynamicBayes] Updating Previous Answer at index {old_answer_idx}")
            
            # 1. 获取先验 (Prior) P(H)
            prior_S = previous_priors.get('S', 1.0)
            prior_F = previous_priors.get('F', 0.0)

            # 2. 计算似然 (Likelihood) P(E_new | H)
            # 逻辑：新文档们 (indices other than old_idx) 对旧答案的支持度平均值
            # M[old_answer_idx, i] 表示 "旧答案" 与 "文档i" 的相似度/蕴含度
            other_indices = [i for i in range(len(responses)) if i != old_answer_idx]
            
            if other_indices:
                # 似然 S：新文档平均有多支持旧答案？
                likelihood_S = np.mean(M[old_answer_idx, other_indices])
                # 似然 F：新文档平均有多反驳旧答案？
                likelihood_F = np.mean(C[old_answer_idx, other_indices])
            else:
                # 只有旧答案自己，没有新证据，似然默认正确
                likelihood_S = 1.0
                likelihood_F = 0.0

            logger.info(f"[DynamicBayes] Prior S: {prior_S:.2f}, Likelihood S: {likelihood_S:.2f}")

            # 3. 计算后验 (Posterior) P(H | E_new)
            posterior_S = self._bayesian_update(prior_S, likelihood_S)
            posterior_F = self._bayesian_update(prior_F, likelihood_F)

            # 4. 更新图的权重 (Overwrite)
            S[old_answer_idx] = posterior_S
            F[old_answer_idx] = posterior_F
            
            logger.info(f"[DynamicBayes] Posterior S: {posterior_S:.2f}, Posterior F: {posterior_F:.2f}")

        # 3.3 建图
        G = nx.DiGraph()
        source = "source"
        sink = "sink"
        G.add_node(source)
        G.add_node(sink)

        num_nodes = len(responses)
        
        # 添加边
        for i in range(num_nodes):
            # Source -> Doc (Capacity = Support)
            G.add_edge(source, i, capacity=S[i])
            # Doc -> Sink (Capacity = Falsehood)
            G.add_edge(i, sink, capacity=F[i])
            
            # Doc i <-> Doc j (Capacity = Consistency M_ij)
            for j in range(i + 1, num_nodes):
                G.add_edge(i, j, capacity=M[i, j])
                G.add_edge(j, i, capacity=M[i, j])

        # --- Step 4: 计算最小割，选出正确答案集合 ---
        min_cut_value, partition = nx.minimum_cut(G, source, sink)
        reachable, non_reachable = partition
        
        # 筛选出被 Source 连通的节点 (即保留下来的节点)
        selected_indices = sorted([node for node in reachable if isinstance(node, int)])
        logger.info(f"[DynamicMinCut] Selected indices: {selected_indices}")
        
        if old_answer_idx != -1:
            is_old_kept = old_answer_idx in selected_indices
            logger.info(f"[DynamicMinCut] Previous answer kept? {is_old_kept}")

        # --- Step 5: 最终生成 ---
        # 收集被选中节点的原始内容
        selected_contents = [responses[i] for i in selected_indices]
        
        # 如果割完之后啥都没了 (极端情况)，回退到全选或者直接问
        if not selected_contents:
            logger.warning("[DynamicMinCut] Min-cut removed all nodes! Fallback to using all available info.")
            selected_contents = doc_sources

        # 构造最终的 Prompt
        # 格式："[query]问题经查阅资料得到答案……，请综合以上答案给出最终答案。"
        
        # 将选中的内容拼接成字符串列表
        context_str = ""
        for idx, content in enumerate(selected_contents):
            context_str += f"Answer {idx+1}: {content}\n"

        final_prompt = (
            f"Question: {question}\n\n"
            f"The following are all the reference answers obtained (synthesize **strictly based on this content only**. "
            f"Do NOT add, correct, question, or challenge any information in it, even if you believe it may be outdated):\n"
            f"{context_str}\n\n"
            f"Strictly follow the reference answers provided above and synthesize the most consistent and main conclusion as the final answer.\n"
            f"Output **only the final answer itself**. Do NOT write any explanations, reminders, supplements, or comments about dates."
        )
        
        logger.info(f"[DynamicMinCut] Final Prompt: {final_prompt}")
        
        final_answer = self.llm._query(final_prompt)

        # === 新增：基于新 final_answer 与本轮单文档答案计算下一轮 priors ===
        new_priors = {'S': 1.0, 'F': 0.0}  # 默认相信

        # 提取本轮单文档答案（排除上一轮 final，如果存在）
        current_responses = responses

        if current_responses:
            # 构建用于 priors 计算的列表：本轮单文档 + 新 final
            prior_list = current_responses + [final_answer]

            # 计算专属 NLI 矩阵
            M_prior, C_prior = self._build_sim_and_conflict_matrices(question, prior_list)

            final_idx = len(prior_list) - 1

            # 平均相似度（S: 蕴含/支持度）和冲突度（F）
            avg_S = np.mean(M_prior[final_idx, :len(current_responses)])
            avg_F = np.mean(C_prior[final_idx, :len(current_responses)])

            # 直接作为下一轮 priors（不进行贝叶斯更新，直接覆盖）
            new_priors['S'] = avg_S

            logger.info(f"[DynamicPriors] New priors based on final_answer: S={avg_S:.3f}, F={avg_F:.3f}")

        return final_answer, new_priors
        
"""
class DynamicMinCutRRAG(MinCutRRAG):
    def __init__(
        self, 
        llm, 
        nli_model_path="DeBERTa-v3-large-mnli-fever-anli-ling-wanli",
        # 方案一的三个超参数（可根据需要在初始化时改）
        old_answer_S_prior: float = 0.7,
        old_answer_F_prior: float = 0.3,
        old_answer_lambda: float = 0.5,
    ):
        super().__init__(llm, nli_model_path)
        # 保存先验和权重参数
        self.old_answer_S_prior = old_answer_S_prior
        self.old_answer_F_prior = old_answer_F_prior
        self.old_answer_lambda = old_answer_lambda

    def dynamic_query(self, data_item, previous_answer=None):
        
        处理动态场景的查询。
        :param data_item: 包含 'question' 和 'topk_content' (新文档) 的字典
        :param previous_answer: 上一轮运行得出的最终答案 (String)，如果是第一次运行则为 None
        
        docs = data_item['topk_content']
        question = data_item['question']
        k = len(docs)

        # --- Step 1: 生成单文档答案，并过滤无效文档 ---
        responses = []
        doc_sources = []  # 存储来源内容，用于最后拼接 prompt (新文档存原文，旧答案存答案本身)
        
        # 1.1 处理新文档
        for i in range(k):
            single_data_item = data_item.copy()
            single_data_item['topk_content'] = [docs[i]]
            # 调用 LLM 生成单个文档的答案
            single_prompt = self.llm.wrap_prompt(
                single_data_item,
                as_multi_choice=('choices' in data_item),
                seperate=False
            )
            resp = self.llm._query(single_prompt)
            
            if "I don't know" in resp:
                logger.info(f"[DynamicMinCut] Skipping new document {i} as it contains 'I don't know'.")
                continue
            
            responses.append(resp)
            doc_sources.append(docs[i])  # 对应的新闻档内容

        # 1.2 处理旧答案 (如果存在)
        old_answer_idx = -1
        if previous_answer and "I don't know" not in previous_answer:
            logger.info(f"[DynamicMinCut] Adding previous answer to the pool.")
            responses.append(previous_answer)
            # 旧答案本身就是“根据资料”得到的内容，直接作为 source
            doc_sources.append(f"Previous reliable conclusion: {previous_answer}") 
            old_answer_idx = len(responses) - 1  # 记录旧答案在 responses 列表中的索引

        logger.info(f"[DynamicMinCut] Valid responses: {responses}")
        logging.getLogger().handlers[0].flush()  # 强制刷新

        # 如果没有任何有效响应（既没有有效新文档，也没有旧答案）
        if not responses:
            logger.info("[DynamicMinCut] No valid info found. Querying LLM directly.")
            final_prompt = self.llm.wrap_prompt(
                data_item,
                as_multi_choice=('choices' in data_item),
                seperate=False
            )
            return self.llm._query(final_prompt)

        # --- Step 2: 使用 NLI 计算两两相似度 (M矩阵) ---
        # 注意：这里包含了 (新文档 vs 新文档) 以及 (新文档 vs 旧答案) 的计算
        logger.info(f"[DynamicMinCut] Computing NLI matrix for {len(responses)} items.")
        M = self._build_sim_matrix(question, responses)
        
        # --- Step 3: 计算 S 和 F，并构建最小割图 ---
        
        # 3.1 计算原始 S 和 F
        # S_raw: 平均支持度 (该回答被其他回答蕴含的概率)
        # F_raw: 平均被反驳度
        num_responses = len(responses)
        S_raw = M.sum(axis=1) / num_responses
        F_raw = ((1 - M).sum(axis=1)) / num_responses

        # 3.2 缩放 S 和 F (复用父类的逻辑，或者简单归一化)
        def scale_flat(arr, target_mean=0.5, jitter=0.2):
            arr = np.array(arr, dtype=float)
            if arr.ptp() < 1e-8:
                return np.ones_like(arr) * target_mean
            arr = (arr - arr.mean()) / (arr.std() + 1e-8)
            arr = arr * jitter + target_mean
            return arr.clip(target_mean - jitter, target_mean + jitter)

        S = scale_flat(S_raw)
        F = scale_flat(F_raw)

        # === 方案一：对旧答案应用“软先验偏置”（凸组合） ===
        # 数学形式：
        #   S_o_new = (1 - λ) * S_o + λ * S_o_prior
        #   F_o_new = (1 - λ) * F_o + λ * F_o_prior
        if old_answer_idx != -1:
            lambda_ = self.old_answer_lambda
            S_prior = self.old_answer_S_prior
            F_prior = self.old_answer_F_prior

            S_old_orig = float(S[old_answer_idx])
            F_old_orig = float(F[old_answer_idx])

            S[old_answer_idx] = (1.0 - lambda_) * S_old_orig + lambda_ * S_prior
            F[old_answer_idx] = (1.0 - lambda_) * F_old_orig + lambda_ * F_prior

            logger.info(
                "[DynamicMinCut] Applying soft prior bias to previous answer at index %d: "
                "S_old_orig=%.4f, F_old_orig=%.4f, S_new=%.4f, F_new=%.4f, "
                "S_prior=%.2f, F_prior=%.2f, lambda=%.2f",
                old_answer_idx,
                S_old_orig,
                F_old_orig,
                S[old_answer_idx],
                F[old_answer_idx],
                S_prior,
                F_prior,
                lambda_,
            )

        # 3.3 建图
        G = nx.DiGraph()
        source = "source"
        sink = "sink"
        G.add_node(source)
        G.add_node(sink)

        num_nodes = len(responses)
        
        # 添加边
        for i in range(num_nodes):
            # Source -> Doc (Capacity = Support)
            G.add_edge(source, i, capacity=float(S[i]))
            # Doc -> Sink (Capacity = Falsehood)
            G.add_edge(i, sink, capacity=float(F[i]))
            
            # Doc i <-> Doc j (Capacity = Consistency M_ij)
            for j in range(i + 1, num_nodes):
                cap_ij = float(M[i, j])
                G.add_edge(i, j, capacity=cap_ij)
                G.add_edge(j, i, capacity=cap_ij)

        # --- Step 4: 计算最小割，选出正确答案集合 ---
        min_cut_value, partition = nx.minimum_cut(G, source, sink)
        reachable, non_reachable = partition
        
        # 筛选出被 Source 连通的节点 (即保留下来的节点)
        selected_indices = sorted([node for node in reachable if isinstance(node, int)])
        logger.info(f"[DynamicMinCut] Selected indices: {selected_indices}")
        
        if old_answer_idx != -1:
            is_old_kept = old_answer_idx in selected_indices
            logger.info(f"[DynamicMinCut] Previous answer kept? {is_old_kept}")

        # --- Step 5: 最终生成 ---
        # 收集被选中节点的原始内容
        selected_contents = [responses[i] for i in selected_indices]
        
        # 如果割完之后啥都没了 (极端情况)，回退到全选或者直接问
        if not selected_contents:
            logger.warning("[DynamicMinCut] Min-cut removed all nodes! Fallback to using all available info.")
            selected_contents = doc_sources

        # 构造最终的 Prompt
        # 格式："[query]问题经查阅资料得到答案……，请综合以上答案给出最终答案。"
        context_str = ""
        for idx, content in enumerate(selected_contents):
            context_str += f"Answer {idx+1}: {content}\n"

        final_prompt = (
            f"Question: {question}\n"
            f"The answers obtained after consulting the materials are:\n"
            f"{context_str}\n"
            f"Please synthesize the above answers to give the final answer."
        )
        
        logger.info(f"[DynamicMinCut] Final Prompt: {final_prompt}")
        
        final_answer = self.llm._query(final_prompt)
        return final_answer
    """

# 简单的并查集辅助类，支持自定义初始权重
class WeightedUnionFind:
    def __init__(self, n, initial_weights=None):
        self.parent = list(range(n))
        # 如果没有指定权重，默认每个节点权重为1
        self.size = initial_weights if initial_weights is not None else [1] * n
        self.count = n

    def find(self, p):
        if self.parent[p] != p:
            self.parent[p] = self.find(self.parent[p])
        return self.parent[p]

    def union(self, p, q):
        rootP = self.find(p)
        rootQ = self.find(q)
        if rootP != rootQ:
            # 将小树接到大树下（或者随意，这里简单处理）
            if self.size[rootP] < self.size[rootQ]:
                self.parent[rootP] = rootQ
                self.size[rootQ] += self.size[rootP]
            else:
                self.parent[rootQ] = rootP
                self.size[rootP] += self.size[rootQ]
            self.count -= 1
            return True
        return False

class DynamicClusterBasedRRAG(ClusterBasedRRAG):
    def __init__(self, llm, err=0.0, sim_threshold=0.5):
        super().__init__(llm, err, sim_threshold)

    def dynamic_query(self, data_item, previous_state=None):
        """
        动态聚类查询方法。
        
        :param data_item: 包含 'question' 和 'topk_content' (新文档)
        :param previous_state: 上一轮的状态列表。
               格式示例: [{'summary': '...', 'size': 5}, {'summary': '...', 'size': 2}]
               如果是第一轮，则为 None 或 []。
        :return: (final_answer, current_state_for_next_round)
        """
        docs = data_item['topk_content']
        question = data_item['question']
        k = len(docs)
        
        previous_clusters = previous_state if previous_state else []
        
        # --- Step 1: 对新文档单独问 LLM，过滤 I don't know ---
        start_time = time.perf_counter()
        seperate_responses = self.llm.batch_query(
            self.llm.wrap_prompt(
                data_item,
                as_multi_choice=('choices' in data_item and bool(data_item.get('choices'))),
                seperate=True
            )
        )
        logger.info(f"[DynamicCluster] Single-doc responses: {seperate_responses}")
        logging.getLogger().handlers[0].flush()  # 强制刷新
        
        # 筛选有效文档
        # new_valid_docs 存储结构: (original_index, answer_text)
        new_valid_docs = []
        for i, ans in enumerate(seperate_responses):
            if "I don't know" not in ans:
                new_valid_docs.append({'orig_idx': i, 'answer': ans})
        
        m = len(new_valid_docs)
        n = len(previous_clusters)
        total_nodes = m + n
        
        logger.info(f"[DynamicCluster] New valid docs: {m}, Old clusters: {n}")

        # 如果既没有新文档，也没有旧状态，直接退化为一次普通 query，
        # 用 LLM 的回答作为本轮最终答案，下一轮状态为空列表
        if total_nodes == 0:
            logger.warning("[DynamicCluster] No valid info. Fallback to raw query.")
            is_multi = ('choices' in data_item and bool(data_item.get('choices')))
            prompt = self.llm.wrap_prompt(
                data_item,
                seperate=False,
                as_multi_choice=is_multi
            )
            final_answer = self.llm._query(prompt)
            return final_answer, []

        # --- Step 2: NLI 计算与构建并查集 ---
        
        # 节点映射：
        # 0 ~ m-1 : 新文档 (对应 new_valid_docs)
        # m ~ m+n-1 : 旧簇 (对应 previous_clusters)
        
        # 初始化并查集权重
        # 新文档权重为 1，旧簇权重为它之前的 size
        initial_weights = [1] * m + [c['size'] for c in previous_clusters]
        uf = WeightedUnionFind(total_nodes, initial_weights=initial_weights)
        
        premises = []
        hypotheses = []
        pair_indices = [] # (u, v)
        
        # 2.1 New vs New
        for i in range(m):
            for j in range(i + 1, m):
                u, v = i, j
                premise = f"The answer to the question: {question}\nis {new_valid_docs[i]['answer']}."
                hypothesis = f"The answer to the question: {question}\nis {new_valid_docs[j]['answer']}."
                premises.append(premise)
                hypotheses.append(hypothesis)
                pair_indices.append((u, v))
        
        # 2.2 New vs Old (Check if new docs fit into old clusters)
        for i in range(m):
            for j in range(n):
                u = i          # New doc index
                v = m + j      # Old cluster index offset by m
                old_summary = previous_clusters[j]['summary']
                
                premise = f"The answer to the question: {question}\nis {new_valid_docs[i]['answer']}."
                hypothesis = f"The answer to the question: {question}\nis {old_summary}."
                
                premises.append(premise)
                hypotheses.append(hypothesis)
                pair_indices.append((u, v))
                
        # 执行 NLI
        if premises:
            inputs = self.nli_tokenizer(premises, hypotheses, return_tensors='pt', truncation=True, padding=True)
            inputs = {k: v.to(self.nli_model.device) for k, v in inputs.items()}
            
            with torch.no_grad():
                outputs = self.nli_model(**inputs)
                probs = torch.softmax(outputs.logits, dim=1)
            
            # 根据阈值 Union
            for idx, (u, v) in enumerate(pair_indices):
                entail_prob = probs[idx][self.entail_idx].item()
                if entail_prob >= self.sim_threshold:
                    uf.union(u, v)

        # --- Step 3: 整理簇并选择最优簇 ---
        
        # 收集每个 Root 下的所有成员
        # clusters[root] = {'new_indices': [], 'old_indices': []}
        clusters = defaultdict(lambda: {'new_indices': [], 'old_indices': []})
        
        for i in range(total_nodes):
            root = uf.find(i)
            if i < m:
                clusters[root]['new_indices'].append(i)
            else:
                clusters[root]['old_indices'].append(i - m) # 存回 relative index of old clusters

        # 计算每个簇的评分指标
        candidate_clusters = []
        
        for root, members in clusters.items():
            # 1. Size (并查集维护的权重)
            size = uf.size[root]
            
            # 2. Has New Doc?
            has_new = len(members['new_indices']) > 0
            
            # 3. Min New Rank (本次加入的原始文档下标最小者)
            if has_new:
                # new_valid_docs[i]['orig_idx'] 拿回原始 topk 中的 index
                ranks = [new_valid_docs[ni]['orig_idx'] for ni in members['new_indices']]
                min_rank = min(ranks)
            else:
                min_rank = float('inf') # 无穷大，排在最后
            
            candidate_clusters.append({
                'root': root,
                'size': size,
                'has_new': has_new,
                'min_rank': min_rank,
                'members': members
            })

        # 排序逻辑:
        # 1. Size (DESC)
        # 2. Has New Doc (True > False) -> 利用 False=0, True=1, 倒序
        # 3. Min Rank (ASC)
        
        candidate_clusters.sort(key=lambda x: (
            x['size'],           # 优先大尺寸
            x['has_new'],        # 优先有新文档
            -x['min_rank']       # 优先 rank 小 (负号实现升序效果在多重降序sort中)
        ), reverse=True)

        logger.info(f"[DynamicCluster] Candidates sorted: {[(c['size'], c['has_new'], c['min_rank']) for c in candidate_clusters]}")

        # 选出 Winner
        best_cluster = candidate_clusters[0]
        
        # --- Step 4: 综合（Synthesize）所有簇的答案 ---
        next_round_state = []
        final_answer_text = None
        
        for cluster in candidate_clusters:
            # 收集内容用于 Prompt
            contents = []
            
            for ni in cluster['members']['new_indices']:
                contents.append(new_valid_docs[ni]['answer'])
            
            for oi in cluster['members']['old_indices']:
                contents.append(previous_clusters[oi]['summary'])
            
            # 构造英文 Prompt
            context_str = ""
            for idx, content in enumerate(contents):
                context_str += f"Answer {idx+1}: {content}\n"
            
            prompt = (
                f"Question: {question}\n"
                f"The answers obtained after consulting the materials are:\n"
                f"{context_str}\n"
                f"Please synthesize the above answers to give the final answer."
            )
            
            synthesized_ans = self.llm._query(prompt)
            
            # 记录状态
            cluster_state = {
                'summary': synthesized_ans,
                'size': cluster['size']
            }
            next_round_state.append(cluster_state)
            
            # 如果是 Winner Cluster
            if cluster['root'] == best_cluster['root']:
                final_answer_text = synthesized_ans
                if not cluster['has_new']:
                    logger.info("[DynamicCluster] Winner cluster has no new docs. Effectively retaining old belief.")

        return final_answer_text, next_round_state


# 对每个检索文档单独让 LLM 给出独立回答（separate responses）；
# 用序贯 NLI（DeBERTa 多分类）判断两两回答是否“矛盾”；
# 构建有向图并迭代去掉出度大于剩余顶点/2 的节点；
# 从剩余节点中选入度为 0 的文档作为可信集，按原始顺序拼回去做最终回答。
# 复杂度O(k^2)
class GraphBasedRRAG(RRAG):

    def __init__(self,llm):
        self.llm = llm
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # device = "cpu" # for gpt-4o
        self.nli_tokenizer = AutoTokenizer.from_pretrained("/scratch/gpfs/zs7353/DeBERTa-v3-large-mnli-fever-anli-ling-wanli")
        self.nli_model = AutoModelForSequenceClassification.from_pretrained("/scratch/gpfs/zs7353/DeBERTa-v3-large-mnli-fever-anli-ling-wanli").to(device)

    def query(self, data_item):
        docs = data_item['topk_content']
        seperate_responses = self.llm.batch_query(self.llm.wrap_prompt(data_item,as_multi_choice='choices' in data_item,seperate=True))
        k = len(docs)
        # Build pairwise prompts to check for contradictory information
        prompts = []
        out_edges = {i: set() for i in range(k)}
        in_edges = {i: set() for i in range(k)}

        premises = []
        hypotheses = []
        pair_indices = []

        for i in range(k):
            for j in range(i + 1, k):
                premise = f"The answer to the question: {data_item['question']}\nis {seperate_responses[i]}."
                hypothesis = f"The answer to the question: {data_item['question']}\nis {seperate_responses[j]}."
                premises.append(premise)
                hypotheses.append(hypothesis)
                pair_indices.append((i, j))

        if premises:
            inputs = self.nli_tokenizer(premises, hypotheses, return_tensors='pt', truncation=True, padding=True)
            inputs = {key: value.to(self.nli_model.device) for key, value in inputs.items()}

            # Run the model on the batch
            with torch.no_grad():
                outputs = self.nli_model(**inputs)
            probs = torch.softmax(outputs.logits, dim=1)

            # Process each batch item and update edges based on contradiction probability
            for idx, (i, j) in enumerate(pair_indices):
                contradiction_probability = probs[idx][2].item()
                if contradiction_probability >= 0.5 and "I don't know" not in seperate_responses[i] and "I don't know" not in seperate_responses[j]:
                    out_edges[i].add(j)
                    in_edges[j].add(i)
        
        # Iteratively remove vertices with out-degree greater than (number of remaining vertices)/2
        # remaining = set(range(k))
        remaining = set()
        # just don't take irrelevant docs? They are just noisy and useless
        for i in range(k):
            remaining.add(i)
        
        removal_occurred = True
        while removal_occurred:
            removal_occurred = False
            current_remaining = list(remaining)
            n_remaining = len(remaining)
            to_remove = []
            for v in current_remaining:
                current_out_degree = len(out_edges[v].intersection(remaining))
                if current_out_degree > math.floor(n_remaining / 2):
                    to_remove.append(v)
            if to_remove:
                removal_occurred = True
                for v in to_remove:
                    remaining.discard(v)
        
        # From the remaining documents, select those with in-degree 0
        selected = []
        for v in remaining:
            current_in_degree = len(in_edges[v].intersection(remaining))
            if current_in_degree == 0:
                selected.append(v)
        
        logger.info(selected)
        # Fallback: if no document has in-degree 0, use all remaining documents
        if not selected:
            selected = list(remaining)
        
        # Sort selected documents by their original rank order
        selected.sort()
        
        # Update the data_item to include only the selected documents
        new_data_item = data_item.copy()
        new_data_item['topk_content'] = [docs[i] for i in selected]
        
        # Create the final prompt using the LLM's wrap_prompt method
        ultimate_prompt = self.llm.wrap_prompt(new_data_item, as_multi_choice='choices' in data_item, seperate=False)
        
        # Return the final answer by querying the LLM
        final_answer = self.llm._query(ultimate_prompt)
        print("final_answer: ", final_answer)
        return final_answer

# 同样先对每个文档单独生成回答并用 NLI 判断是否矛盾，构建无向图（矛盾边）。
# 然后在候选顶点集合 z（非 "I don't know" 的回答集合）上穷举求最大独立集（MIS），若有多个取字典序最小者；
# 把 MIS 对应文档作为可信集合，拼接最终 prompt 并查询 LLM。
# 复杂度O(2^k)
class MISBasedRRAG(RRAG):

    def __init__(self, llm, err):
        self.llm = llm
        self.err = err
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # device = "cpu"  # for gpt-4o
        self.nli_tokenizer = AutoTokenizer.from_pretrained("DeBERTa-v3-large-mnli-fever-anli-ling-wanli")
        self.nli_model = AutoModelForSequenceClassification.from_pretrained("DeBERTa-v3-large-mnli-fever-anli-ling-wanli").to(device)

    def query(self, data_item):
        # Retrieve the documents and get separate responses.
        docs = data_item['topk_content']
        start_time = time.perf_counter()
        seperate_responses = self.llm.batch_query(self.llm.wrap_prompt(data_item, as_multi_choice='choices' in data_item, seperate=True))
        end_time = time.perf_counter()
        logger.info(f"[MIS] Responses: {seperate_responses}")
        logging.getLogger().handlers[0].flush()  # 强制刷新
        k = len(docs)
        
        # Build an undirected graph: graph[i] holds all vertices j that contradict with document i.
        graph = {i: set() for i in range(k)}
        premises, hypotheses, pair_indices = [], [], [] 

        for i in range(k):
            for j in range(i + 1, k):
                premise = f"The answer to the question: {data_item['question']}\nis {seperate_responses[i]}."
                hypothesis = f"The answer to the question: {data_item['question']}\nis {seperate_responses[j]}."
                premises.append(premise)
                hypotheses.append(hypothesis)
                pair_indices.append((i, j))
        
        if premises:
            inputs = self.nli_tokenizer(premises, hypotheses, return_tensors='pt', truncation=True, padding=True)
            inputs = {key: value.to(self.nli_model.device) for key, value in inputs.items()}
            start_time = time.perf_counter()
            with torch.no_grad():
                outputs = self.nli_model(**inputs)
            probs = torch.softmax(outputs.logits, dim=1)
            end_time = time.perf_counter()
            print("time for NLI: ", end_time - start_time)
            
            # For each pair, add an undirected edge if the answers contradict.
            for idx, (i, j) in enumerate(pair_indices):
                contradiction_probability = probs[idx][2].item()
                x = random.random()
                if ("I don't know" not in seperate_responses[i] and "I don't know" not in seperate_responses[j]):
                    if contradiction_probability >= 0.5:
                        if x >= self.err:
                            graph[i].add(j)
                            graph[j].add(i)
                    else:
                        if x <= self.err:
                            graph[i].add(j)
                            graph[j].add(i)                            
                # if (contradiction_probability >= 0.5 and "I don't know" not in seperate_responses[i] and "I don't know" not in seperate_responses[j]):
                #     graph[i].add(j)
                #     graph[j].add(i)
        
        z = {i for i in range(k) if "I don't know" not in seperate_responses[i]}
        # z = {i for i in range(k)}
        
        # Compute the maximum independent set over the vertices.
        # Among all maximum independent sets, choose the one with the lexicographically smallest order.
        start_time = time.perf_counter()
        best_set = self._max_independent_set(graph, z)
        end_time = time.perf_counter()
        print("time for finding MIS: ", end_time - start_time)
        
        # Fallback: if best_set is empty, use all z documents.
        if not best_set:
            best_set = list(z)
            if not best_set:
                best_set = [i for i in range(k)]
        else:
            best_set = list(best_set)

        best_set.sort()  # sort in ascending order (better ranked docs have lower indices)
        logger.info(f"Selected document indices: {best_set}")
        
        # Update data_item with only the selected documents.
        new_data_item = data_item.copy()
        new_data_item['topk_content'] = [docs[i] for i in best_set]
        
        # Create the final prompt and query for the ultimate answer.
        ultimate_prompt = self.llm.wrap_prompt(new_data_item, as_multi_choice='choices' in data_item, seperate=False)
        print(ultimate_prompt)
        start_time = time.perf_counter()
        final_answer = self.llm._query(ultimate_prompt)
        end_time = time.perf_counter()
        print("time for the ultimate query: ", end_time - start_time)
        print("final_answer:", final_answer)
        return final_answer

    def _max_independent_set(self, graph, vertices):
        best_size = 0
        best_sets = []
        vertices_list = list(vertices)
        
        # Generate all subsets of vertices_list
        for subset in chain.from_iterable(combinations(vertices_list, r) for r in range(len(vertices_list) + 1)):
            subset = set(subset)
            if self._is_independent(subset, graph):
                subset_size = len(subset)
                if subset_size > best_size:
                    best_size = subset_size
                    best_sets = [tuple(sorted(subset))]
                elif subset_size == best_size:
                    best_sets.append(tuple(sorted(subset)))
                    
        # Return the lexicographically smallest independent set (as a tuple).
        if best_sets:
            return min(best_sets)
        else:
            return set()

    def _is_independent(self, subset, graph):
        for v in subset:
            for u in subset:
                if u != v and u in graph[v]:
                    return False
        return True

# 对每个文档分别预测（separate responses），把每个位置的预测按指数权重 gamma^i 加权计数，取最多票项作为最终预测（适用于多选题/离散选项）。
class WeightedMajorityVoting(RRAG):
    def __init__(self, llm, gamma=1):
        self.llm = llm
        self.gamma = gamma

    def query(self, data_item):
        # assume the prompt ask the LLM to output A., B., C., D., or E. No information found
        seperate_responses = self.llm.batch_query(self.llm.wrap_prompt(data_item,as_multi_choice='choices' in data_item,seperate=True))
        seperate_preds = []
        for response in seperate_responses:
            if "gpt" in self.llm.model_name: 
                if response.find('Answer') != -1:
                    response = response[(response.find('Answer')+7):].strip()
                else:
                    response = response.strip()
                if response[0] in 'ABCD':
                    seperate_preds.append(response[0]+'.')
                else:
                    seperate_preds.append('E.')
            else:
                response = response.strip()
                if len(response)>=2 and response[1]=='.' and response[0] in'ABCD':
                    seperate_preds.append(response[:2])
                else:
                    seperate_preds.append('E.')

        logger.debug(f'Seperate responses: {seperate_preds}')

        cntr = defaultdict(float)

        total_weight = 0
        total_weight_orig = 0
        for i, pred in enumerate(seperate_preds):
            if pred == 'E.':
                continue
            weight = self.gamma ** i  # First position weight=1, second=gamma, third=gamma^2, etc.
            total_weight += weight
            total_weight_orig += 1

        for i, pred in enumerate(seperate_preds):
            if pred == 'E.':
                continue 
            weight = self.gamma ** i      
            cntr[pred] += weight * total_weight_orig / total_weight
        
        cntr = Counter(cntr)
        cntr = cntr.most_common(2)

        if len(cntr)==0:
            pred = 'E.' # No information found.
        else:
            pred = cntr[0][0] 
        return pred

# 对每个文档单独生成回答，使用 spaCy 提取短语/关键字并以加权计数过滤出高频关键词（基于 absolute / relative 阈值）；
# 把关键词合并成 hints 放入 prompt 再做一次聚合查询以获取最终答案。
class WeightedKeywordAgg(RRAG):

    def __init__(self,llm,relative_threshold=0.3, absolute_threshold=3, abstention_threshold=1, gamma=1, longgen=False):
        self.llm = llm
        self.abstention_threshold = 1
        self.keyword_extractor = spacy.load("en_core_web_sm") 
        self.ignore_set = {'VERB','INTJ','ADP','AUX','CCONJ','DET','PART','PRON','SCONJ','PUNCT','SPACE'}
        self.absolute = absolute_threshold
        self.relative = relative_threshold
        self.gamma = gamma
        self.longgen = longgen # if it is long-form generation or short-form (we use slightly different prompt template)
        logger.info(f'abs: {absolute_threshold}, relative: {relative_threshold}')

    def query(self, data_item, abstention_threshold=None): 
        # override original threshold parameters if given
        abstention_threshold = abstention_threshold if abstention_threshold is not None else self.abstention_threshold
        if self.longgen:
            data_item['genhint'] = True # add a flag so that wrap_prompt() can retrieve the correct prompt template
        # make seperate predictions
        seperate_responses_raw = self.llm.batch_query(self.llm.wrap_prompt(data_item,as_multi_choice='choices' in data_item,seperate=True))
        abstained_idx = []
        seperate_responses = []
        logger.debug(f'Seperate responses:\n')
        total_weight = 0
        total_weight_orig = 0
        for i,x in enumerate(seperate_responses_raw):
            logger.debug(f'{i}: {x}\n')
            if "I don't" in x:
                abstained_idx.append(i)
            else:
                seperate_responses.append((x,  self.gamma ** i))
                total_weight +=  self.gamma ** i
                total_weight_orig += 1

        logger.debug(f'Number of retained responses: {len(seperate_responses)}')

        if len(seperate_responses) < abstention_threshold:
            logger.warning('Abstain from making response...')
            return "I don't know."
        
        def construct_phrase(token_list):
            ret = ''
            for token in token_list:
                ret+=token.lemma_+token.whitespace_
        # extract keyword/keyphrase
        all_extracted_phrase = []
        token_counter = defaultdict(int)
        for response, weight in seperate_responses:
            doc = self.keyword_extractor(response)
            phrase_list = [response.strip()] 
            tmp = []
            for token in doc:
                if token.pos_ in self.ignore_set:
                    if len(tmp)>0:
                        phrase = ''.join([x.lemma_+x.whitespace_ for x in tmp]).strip()
                        phrase_list.append(phrase)
                        phrase_list+=[x.lemma_ for x in tmp]
                        tmp = []
                else:
                    tmp.append(token)

            phrase = ''.join([x.lemma_+x.whitespace_ for x in tmp]).strip()
            phrase_list.append(phrase)
            phrase_list+=[x.lemma_ for x in tmp]
            phrase_list = set(phrase_list) # only consider unique keywords
            all_extracted_phrase.append(phrase_list)
            for phrase in phrase_list:
                token_counter[phrase]+=weight * total_weight_orig / total_weight

        # filtering 
        print(phrase_list)
        count_threshold = min(self.absolute,self.relative*len(seperate_responses))
        logger.debug(sorted(token_counter.items(), key=lambda x: (len(x[0]),x[0]), reverse=True))
        logger.debug(f'count_threshold,{count_threshold}')
        for token,count in list(token_counter.items()):
            if (count < count_threshold) or (token in punctuation) or (token in stopword_set)  or (self.longgen and ' ' not in token): # if it is long generation, we remove single words to reduce the size the keyword set...
                del token_counter[token]

        # generate keyword hints
        sorted_tokens = sorted(token_counter.items(), key=lambda x: (len(x[0]),x[0]), reverse=True)
        hints = ', '.join([f'{token}' for token,count in sorted_tokens])
        logger.debug(sorted_tokens)
        query_prompt = self.llm.wrap_prompt(data_item,as_multi_choice='choices' in data_item,hints=hints)
        logger.debug(f'Hint prompt:\n{query_prompt}')
        response = self.llm.query(query_prompt)
        logger.debug(f'Keyword aggregated response:\n{response}')

        return response

# 基于 model 的 secure_decoding（自定义解码器）对各个 prompt 做并行/批量分析；
# 先估计“我不知道”概率来过滤掉可能无关 prompt，然后对保留输入运行 secure_decoding（带 stopping criteria、eta、gamma 等超参），生成聚合输出。
class WeightedDecodingAgg(RRAG):
    def __init__(self,llm, eta, gamma=1, abstention_prob=None):
        self.llm = llm
        self.llm.model.secure_decoding = secure_decoding.__get__(self.llm.model, type(self.llm.model))
        self.temperature = 1.0 #args.temperature
        abstention_prob_list = {'/scratch/gpfs/zs7353/Llama-3.2-3B-Instruct': 0.99, 
                                '/scratch/gpfs/zs7353/Mistral-7B-Instruct-v0.2': 0.99, 
                                '/scratch/gpfs/zs7353/DeepSeek-R1-Distill-Qwen-7B': 0.99}
        if abstention_prob is None:
            self.abstention_prob = abstention_prob_list.get(llm.model_name, 0.99)
            logger.debug(f"Using default abstention probability: {self.abstention_prob}")

        self.gamma = gamma
        self.eta = eta
       
    def preprocess_input(self,data_item):
        prompt_list = self.llm.wrap_prompt(data_item,as_multi_choice='choices' in data_item,seperate=True)
        data_item_zero_shot = {"question": data_item["question"], "topk_content":[], "long_gen": True}
        prompt_zero_shot = self.llm.wrap_prompt(data_item_zero_shot,as_multi_choice='choices' in data_item,seperate=False)
        prompt_list.append(prompt_zero_shot)

        prompt_list_draft = [prompt + " I don't know" for prompt in prompt_list]

        # batched version 
        input_dict_draft = self.llm.tokenizer(prompt_list_draft, return_tensors="pt", padding=True).to("cuda")
        input_ids_draft = input_dict_draft.input_ids.to("cuda")
        attention_mask_draft = input_dict_draft.attention_mask.to("cuda")

        # compute the perplexity of the prompt "I don't know"
        with torch.no_grad():
            output_token_draft = self.llm.model(input_ids_draft, attention_mask=attention_mask_draft)
            logits_draft = output_token_draft.logits

        probs = torch.softmax(logits_draft, dim=-1)
        total_probability = torch.ones(input_ids_draft.shape[0]).to("cuda")

        input_dict = self.llm.tokenizer(prompt_list, return_tensors="pt", padding= True)
        start_index = input_dict.input_ids.size(1)

        for i in range(start_index, input_ids_draft.size(1) - 1):  # Exclude the last token since there's no next token to predict
            # Get the probability of the actual next token
            next_token_id = input_ids_draft[0, i + 1]  # The next token in the sequence
            next_token_prob = probs[:, i, next_token_id]
            total_probability *= next_token_prob

        #print(f"total_probability: {total_probability}")
        input_ids = input_dict.input_ids.to("cuda")
        attention_mask = input_dict.attention_mask.to("cuda")

        # filter the prompt with the probability of "I don't know" is greater than 0.9
        total_probability[-1] = 0.0 # last one is the zero-shot prompt
        ab_record = total_probability < self.abstention_prob
        input_ids = input_ids[ab_record]
        attention_mask = attention_mask[ab_record]
        return input_ids,attention_mask,ab_record

    def query(self, data_item):

        input_ids,attention_mask,ab_record = self.preprocess_input(data_item)

        if input_ids.shape[0] == 1: # only the no-retrieval prediction
            return "I don't know.", False
        
        # Initialize past_key_values for caching
        past_key_values = None
        generated_outputs = []

        stop_list = ["\n#", "\n##","\n###","\n####","\n#####"] + ["\n\n"] ################ seems to work fine
        stop_token_ids = [self.llm.tokenizer(x, return_tensors='pt', add_special_tokens=False)['input_ids'] for x in stop_list]
        stop_token_ids = [LongTensor(x).to("cuda") for x in stop_token_ids]
        stopping_criteria = StoppingCriteriaList([
            MaxLengthCriteria(max_length=len(input_ids[0]) + self.llm.max_output_tokens),
            StopOnTokens(stop_token_ids=stop_token_ids)
        ])
        
        generated_outputs = self.llm.model.secure_decoding(input_ids,
                                                           attention_mask=attention_mask,
                                                           stopping_criteria=stopping_criteria,
                                                           use_cache=False,
                                                           pad_token_id=self.llm.tokenizer.pad_token_id,
                                                           eos_token_id=self.llm.tokenizer.eos_token_id,
                                                           return_dict_in_generate=True,
                                                           temperature=self.temperature,
                                                           tokenizer=self.llm.tokenizer,
                                                           eta=self.eta,
                                                           gamma=self.gamma)

        generated_output_text = self.llm.tokenizer.decode(generated_outputs, skip_special_tokens=True)
        return generated_output_text
    
# 给 top-k 赋几何权重（gamma^i），多次从 top-k 按权采样子集，分别用 LLM 生成候选回答；
# 对这些候选回答计算 embedding（OpenAI 或 sentence-transformers），找最接近平均 embedding（centroid）的回答作为最终答案。
class RandomSamplingReQueryAgg(RRAG):
    def __init__(
        self, 
        llm,
        sample_size=5, 
        num_samples=3,
        gamma=1
    ):
        super().__init__(llm)
        self.sample_size = sample_size
        self.num_samples = num_samples
        self.gamma = gamma

        self.use_openai = True
        self.openai_model = "text-embedding-ada-002"
        self.hf_model_name = "/scratch/gpfs/bi0600/all-mpnet-base-v2"

        if not self.use_openai:
            self.hf_model = SentenceTransformer( self.hf_model_name)
        else:
            self.client = OpenAI()

    def get_openai_embeddings(self, text_list):
        response = self.client.embeddings.create(
            model=self.openai_model,
            input=text_list
        )
        embeddings = [item.embedding for item in response.data]
        return embeddings

    def get_hf_embeddings(self, text_list):
        embeddings = self.hf_model.encode(text_list)
        return embeddings

    def query(self, data_item):
        question = data_item["question"]
        all_chunks = data_item["topk_content"]
        n = len(all_chunks)

        # 1) Assign geometric weights to chunks: gamma^i
        weights = np.array([self.gamma ** i for i in range(n)])
        weights /= weights.sum()  # normalize

        # 2) First-stage sampling: sample multiple subsets & query LLM
        sampled_responses = []
        for i in range(self.num_samples):
            sampled_chunks = list(
                np.random.choice(
                    all_chunks,
                    size=min(self.sample_size, n),
                    replace=False,
                    p=weights
                )
            )
            prompt = self.build_prompt(question, sampled_chunks)
            response = self.llm.query(prompt)
            sampled_responses.append(response)

        logger.debug(f"First-stage sampled responses:\n{sampled_responses}")

        # 3) Second-stage: pick the response closest to the mean embedding
        if self.use_openai:
            response_embeddings = self.get_openai_embeddings(sampled_responses)
        else:
            response_embeddings = self.get_hf_embeddings(sampled_responses)

        # Compute average (centroid) embedding
        response_embeddings = np.array(response_embeddings)
        avg_embedding = np.mean(response_embeddings, axis=0)

        # Find whichever response is closest to this centroid
        best_idx = None
        best_sim = -float("inf")
        for i, emb in enumerate(response_embeddings):
            cos_sim = dot(emb, avg_embedding) / (norm(emb) * norm(avg_embedding))
            if cos_sim > best_sim:
                best_idx = i
                best_sim = cos_sim
        
        final_response = sampled_responses[best_idx]
        logger.debug(f"Second-stage final response:\n{final_response}")
        return final_response

    def build_prompt(self, question, chunks):
        context_text = "\n\n".join(chunks)
        return f"Answer the following question based on the context below. It is very important that the answer should be based solely on evidence found in the context information. The answer should be as short as possible and can only use words found in the context information. \n\nContext:\n{context_text}\n\nQuestion: {question}\nAnswer:"


# 结合采样和关键词聚合：
# 多次按权采样、对采样得到的非弃权回答提取关键词并计权过滤，最后用关键词提示做最终查询。
class SamplingWithKeyWordAggregation(RRAG):
    def __init__(
        self, 
        llm,
        sample_size=5, 
        num_samples=3,
        gamma=1,
        relative_threshold=0.3,
        absolute_threshold=3,
        abstention_threshold=1,
    ):
        super().__init__(llm)
        self.sample_size = sample_size
        self.num_samples = num_samples
        self.gamma = gamma

        self.keyword_extractor = spacy.load("en_core_web_sm") 
        self.ignore_set = {'VERB','INTJ','ADP','AUX','CCONJ','DET','PART','PRON','SCONJ','PUNCT','SPACE'}

        self.abstention_threshold = abstention_threshold
        self.absolute = absolute_threshold
        self.relative = relative_threshold
        self.gamma = gamma
        logger.debug(f'Sampling+keyword. abs: {absolute_threshold}, relative: {relative_threshold}')


    def query(self, data_item):
        question = data_item["question"]
        all_chunks = data_item["topk_content"]
        n = len(all_chunks)

        # 1) Assign geometric weights to chunks: gamma^i
        weights = np.array([self.gamma ** i for i in range(n)])
        weights /= weights.sum()  # normalize

        # 2) First-stage sampling: sample multiple subsets & query LLM
        sampled_responses = []
        total_weight = 0
        total_weight_orig = 0

        for i in range(self.num_samples):
            indices = np.random.choice(
                n,
                size=min(self.sample_size, n),
                replace=False,
                p=weights
            )
            sampled_chunks = [all_chunks[j] for j in indices]
            chunk_weight = weights[indices].sum()
            prompt = self.build_prompt(question, sampled_chunks)
            response = self.llm.query(prompt)

            if "I don't" not in response:
                sampled_responses.append((response, chunk_weight))
                total_weight += chunk_weight
                total_weight_orig += 1

        logger.debug(f"Sampled responses:\n{sampled_responses}")

        if len(sampled_responses) < self.abstention_threshold:
            logger.warning("Abstain from making response...")
            return "I don't know."

        # 3) Keyword aggregation
        token_counter = defaultdict(int)
        all_extracted_phrase = []

        for response, weight in sampled_responses:
            doc = self.keyword_extractor(response)
            phrase_list = [response.strip()]
            tmp = []

            for token in doc:
                if token.pos_ in self.ignore_set:
                    if len(tmp) > 0:
                        phrase = ''.join([x.lemma_ + x.whitespace_ for x in tmp]).strip()
                        phrase_list.append(phrase)
                        phrase_list += [x.lemma_ for x in tmp]
                        tmp = []
                else:
                    tmp.append(token)

            phrase = ''.join([x.lemma_ + x.whitespace_ for x in tmp]).strip()
            phrase_list.append(phrase)
            phrase_list += [x.lemma_ for x in tmp]
            phrase_list = set(phrase_list)

            all_extracted_phrase.append(phrase_list)
            for phrase in phrase_list:
                token_counter[phrase] += weight * total_weight_orig / total_weight

        # Filtering
        print(phrase_list)
        count_threshold = min(self.absolute, self.relative * len(sampled_responses))
        for token, count in list(token_counter.items()):
            if (count < count_threshold) or (token in punctuation) or (token in stopword_set):
                del token_counter[token]

        # Generate keyword-based final query
        sorted_tokens = sorted(token_counter.items(), key=lambda x: (len(x[0]), x[0]), reverse=True)
        hints = ', '.join([token for token, _ in sorted_tokens])
        logger.debug("Sorted tokens for hints:")
        logger.debug(sorted_tokens)
        hint_prompt = self.llm.wrap_prompt(data_item, as_multi_choice='choices' in data_item, hints=hints)
        logger.debug(f'Hint prompt:\n{hint_prompt}')
        final_response = self.llm.query(hint_prompt)

        logger.debug(f"Final response:\n{final_response}")
        return final_response

    def build_prompt(self, question, chunks):
        context_text = "\n\n".join(chunks)
        return f"""
        Given the context information below and not prior knowledge, answer the query with only keywords.
        If there is no relevant information, just say "I don't know".\n\n
        Context:\n
        {context_text}\n\n
        Query: {question}\n
        Answer:
        """