import sqlite3
import pandas as pd
import numpy as np
import torch

from transformers import AutoTokenizer, AutoModelForSequenceClassification
from vllm import LLM, SamplingParams

# 以下为 lightblue 模型的辅助函数
def make_reranker_input(context, query):
    return f"<<<Query>>>\n{query}\n\n<<<Context>>>\n{context}"

def make_reranker_inference_conversation(context, query):
    system_message = "Given a query and a piece of text, output a score of 1-7 based on how related the query is to the text. 1 means least related and 7 is most related."
    return [
        {"role": "system", "content": system_message},
        {"role": "user", "content": make_reranker_input(context, query)},
    ]

def get_prob(logprob_dict, tok_id):
    return np.exp(logprob_dict[tok_id].logprob) if tok_id in logprob_dict.keys() else 0

# 参数及设备设置
BATCH_SIZE = 32    # 推理时的批量大小
CHUNKSIZE = 10000  # 分块读取数据库
device = "cuda" if torch.cuda.is_available() else "cpu"

# 连接数据库 new.db，确保表 document_pairs 包含字段：
# pid, query, doc, processed_score, ali, bge, jina, lb, rerank_score, score
db_path = 'new.db'
conn = sqlite3.connect(db_path)
cursor = conn.cursor()

alter_stmts = [
    "ALTER TABLE document_pairs ADD COLUMN ali FLOAT;",
    "ALTER TABLE document_pairs ADD COLUMN bge FLOAT;",
    "ALTER TABLE document_pairs ADD COLUMN jina FLOAT;",
    "ALTER TABLE document_pairs ADD COLUMN lb FLOAT;",
    "ALTER TABLE document_pairs ADD COLUMN rerank_score FLOAT;",
    "ALTER TABLE document_pairs ADD COLUMN score FLOAT;"
]
for stmt in alter_stmts:
    try:
        cursor.execute(stmt)
    except sqlite3.OperationalError as e:
        print(f"Warning: {e}")
conn.commit()

# 第一阶段：逐个模型更新，仅更新对应列为空的记录

# === 更新 Alibaba 模型 (ali) ===
def update_ali():
    print("开始更新 ali 模型得分……")
    model_name = "Alibaba-NLP/gte-multilingual-reranker-base"
    print("加载 gte-multilingual-reranker-base 模型 …")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name, trust_remote_code=True, torch_dtype=torch.float16
    )
    model.to(device)
    model.eval()

    def get_scores(queries, docs, batch_size=BATCH_SIZE):
        scores = []
        with torch.no_grad():
            for i in range(0, len(queries), batch_size):
                pair_batch = [[q, d] for q, d in zip(queries[i:i+batch_size], docs[i:i+batch_size])]
                inputs = tokenizer(pair_batch, padding=True, truncation=True, return_tensors='pt', max_length=8192)
                inputs = {k: v.to(device) for k, v in inputs.items()}
                outputs = model(**inputs, return_dict=True)
                batch_scores = outputs.logits.view(-1, ).float().cpu().numpy()
                scores.append(batch_scores)
        return np.concatenate(scores)
    
    sql_select = "SELECT pid, query, doc FROM document_pairs WHERE ali IS NULL"
    total_updated = 0
    for chunk in pd.read_sql_query(sql_select, conn, chunksize=CHUNKSIZE):
        queries = chunk['query'].tolist()
        docs = chunk['doc'].tolist()
        try:
            scores = get_scores(queries, docs)
            chunk['ali'] = scores
            update_data = chunk[['ali', 'pid']].values.tolist()
            cursor.executemany(
                "UPDATE document_pairs SET ali = ? WHERE pid = ?",
                update_data
            )
            conn.commit()
            total_updated += len(chunk)
            print(f"ali: 已更新 {len(chunk)} 条记录 ({total_updated})")
        except Exception as e:
            print(f"ali 更新时发生错误: {e}")
    print(f"ali 模型得分更新完毕，共更新记录 {total_updated}")
    del model, tokenizer
    if device == "cuda":
        torch.cuda.empty_cache()

update_ali()

# === 更新 BAAI 模型 (bge) ===
def update_bge():
    print("开始更新 bge 模型得分……")
    model_name = "BAAI/bge-reranker-v2-m3"
    print("加载 bge-reranker-v2-m3 模型 …")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name, torch_dtype=torch.float16
    )
    model.to(device)
    model.eval()

    def get_scores(queries, docs, batch_size=BATCH_SIZE):
        scores = []
        with torch.no_grad():
            for i in range(0, len(queries), batch_size):
                pair_batch = [[q, d] for q, d in zip(queries[i:i+batch_size], docs[i:i+batch_size])]
                inputs = tokenizer(pair_batch, padding=True, truncation=True, return_tensors='pt', max_length=8192)
                inputs = {k: v.to(device) for k, v in inputs.items()}
                outputs = model(**inputs, return_dict=True)
                batch_scores = outputs.logits.view(-1, ).float().cpu().numpy()
                scores.append(batch_scores)
        return np.concatenate(scores)
    
    sql_select = "SELECT pid, query, doc FROM document_pairs WHERE bge IS NULL"
    total_updated = 0
    for chunk in pd.read_sql_query(sql_select, conn, chunksize=CHUNKSIZE):
        queries = chunk['query'].tolist()
        docs = chunk['doc'].tolist()
        try:
            scores = get_scores(queries, docs)
            chunk['bge'] = scores
            update_data = chunk[['bge', 'pid']].values.tolist()
            cursor.executemany(
                "UPDATE document_pairs SET bge = ? WHERE pid = ?",
                update_data
            )
            conn.commit()
            total_updated += len(chunk)
            print(f"bge: 已更新 {len(chunk)} 条记录 ({total_updated})")
        except Exception as e:
            print(f"bge 更新时发生错误: {e}")
    print(f"bge 模型得分更新完毕，共更新记录 {total_updated}")
    del model, tokenizer
    if device == "cuda":
        torch.cuda.empty_cache()

update_bge()


# === 更新 jina 模型 (jina) ===
def update_jina():
    print("开始更新 jina 模型得分……")
    model_name = "jinaai/jina-reranker-v2-base-multilingual"
    print("加载 jina-reranker-v2-base-multilingual 模型 …")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name, trust_remote_code=True
    )
    model.to(device)
    model.eval()

    def get_scores(queries, docs):
        # 构造句对列表，调用 compute_score 接口
        pairs = [[q, d] for q, d in zip(queries, docs)]
        scores = model.compute_score(pairs, max_length=1024)
        return np.array(scores, dtype=np.float32)
    
    sql_select = "SELECT pid, query, doc FROM document_pairs WHERE jina IS NULL"
    total_updated = 0
    for chunk in pd.read_sql_query(sql_select, conn, chunksize=CHUNKSIZE):
        queries = chunk['query'].tolist()
        docs = chunk['doc'].tolist()
        try:
            scores = get_scores(queries, docs)
            chunk['jina'] = scores
            update_data = chunk[['jina', 'pid']].values.tolist()
            cursor.executemany(
                "UPDATE document_pairs SET jina = ? WHERE pid = ?",
                update_data
            )
            conn.commit()
            total_updated += len(chunk)
            print(f"jina: 已更新 {len(chunk)} 条记录 ({total_updated})")
        except Exception as e:
            print(f"jina 更新时发生错误: {e}")
    print(f"jina 模型得分更新完毕，共更新记录 {total_updated}")
    del model, tokenizer
    if device == "cuda":
        torch.cuda.empty_cache()

update_jina()


# === 更新 lightblue 模型 (lb) ===
def update_lb():
    print("开始更新 lb 模型得分……")
    # lightblue 模型采用 vllm，在每次更新时再初始化
    print("加载 lb-reranker-v1.0 模型 …")
    llm = LLM("lightblue/lb-reranker-v1.0", gpu_memory_utilization=0.5)
    sampling_params = SamplingParams(temperature=0.0, logprobs=14, max_tokens=1)
    llm_tok = llm.llm_engine.tokenizer.tokenizer
    idx_tokens = [llm_tok.encode(str(i))[0] for i in range(1, 8)]
    
    def get_scores(queries, docs):
        # 构造对话列表，每个对话包含系统提示和用户内容
        chats = [make_reranker_inference_conversation(doc[:16000], query[:16000]) for query, doc in zip(queries, docs)]
        responses = llm.chat(chats, sampling_params)
        probs = np.array([[get_prob(r.outputs[0].logprobs[0], token) for token in idx_tokens] for r in responses])
        M, N = probs.shape
        idxs = np.tile(np.arange(1, N+1), M).reshape(M, N)
        expected_vals = (probs * idxs).sum(axis=1)
        return expected_vals
    
    sql_select = "SELECT pid, query, doc FROM document_pairs WHERE lb IS NULL"
    total_updated = 0
    for chunk in pd.read_sql_query(sql_select, conn, chunksize=CHUNKSIZE):
        queries = chunk['query'].tolist()
        docs = chunk['doc'].tolist()
        try:
            scores = get_scores(queries, docs)
            chunk['lb'] = scores
            update_data = chunk[['lb', 'pid']].values.tolist()
            cursor.executemany(
                "UPDATE document_pairs SET lb = ? WHERE pid = ?",
                update_data
            )
            conn.commit()
            total_updated += len(chunk)
            print(f"lb: 已更新 {len(chunk)} 条记录 ({total_updated})")
        except Exception as e:
            print(f"lb 更新时发生错误: {e}")
    print(f"lb 模型得分更新完毕，共更新记录 {total_updated}")
    del llm, llm_tok, idx_tokens
    if device == "cuda":
        torch.cuda.empty_cache()

update_lb()

# 第二阶段：统计全表中各模型得分的最小/最大值，用于后续归一化
query_stats = """
SELECT 
    MIN(ali) as min_ali, MAX(ali) as max_ali,
    MIN(bge) as min_bge, MAX(bge) as max_bge,
    MIN(jina) as min_jina, MAX(jina) as max_jina,
    MIN(lb) as min_lb, MAX(lb) as max_lb
FROM document_pairs;
"""
cursor.execute(query_stats)
res = cursor.fetchone()
min_ali, max_ali, min_bge, max_bge, min_jina, max_jina, min_lb, max_lb = res
print("全局评估数据统计：")
print(f"ali: min = {min_ali}, max = {max_ali}")
print(f"bge: min = {min_bge}, max = {max_bge}")
print(f"jina: min = {min_jina}, max = {max_jina}")
print(f"lb: min = {min_lb}, max = {max_lb}")

# 第三阶段：分块处理数据，对各模型得分归一化，计算 rerank_score，
# 以及 processed_score 经分段归一化后与 rerank_score混合计算最终 score。
def process_chunk(df):
    # 归一化公式：$$ ali\_norm = \frac{ali - min_{ali}}{max_{ali} - min_{ali}} $$
    df['ali_norm'] = np.where(max_ali - min_ali == 0, 0.0, (df['ali'] - min_ali) / (max_ali - min_ali))
    df['bge_norm'] = np.where(max_bge - min_bge == 0, 0.0, (df['bge'] - min_bge) / (max_bge - min_bge))
    df['jina_norm'] = np.where(max_jina - min_jina == 0, 0.0, (df['jina'] - min_jina) / (max_jina - min_jina))
    df['lb_norm']   = np.where(max_lb - min_lb == 0, 0.0, (df['lb'] - min_lb) / (max_lb - min_lb))
    
    df['rerank_score'] = (df['ali_norm'] + df['bge_norm'] + df['jina_norm'] + df['lb_norm']) / 4.0
    p = df['processed_score']
    r = df['rerank_score']
    
    cond0 = (p >= 0.0)  & (p < 0.1)
    cond1 = (p >= 0.1)  & (p < 0.25)
    cond2 = (p >= 0.25) & (p < 0.5)
    cond3 = (p >= 0.5)  & (p < 0.75)
    cond4 = (p >= 0.75) & (p < 0.9)
    cond5 = (p >= 0.9)  & (p <= 1.0)
    
    score0 = ((0.5 * (((p - 0.0)  / (0.1 - 0.0)) + r)) * (0.1 - 0.0)) + 0.0
    score1 = ((0.5 * (((p - 0.1)  / (0.25 - 0.1)) + r)) * (0.25 - 0.1)) + 0.1
    score2 = ((0.5 * (((p - 0.25) / (0.5 - 0.25)) + r)) * (0.5 - 0.25)) + 0.25
    score3 = ((0.5 * (((p - 0.5)  / (0.75 - 0.5)) + r)) * (0.75 - 0.5)) + 0.5
    score4 = ((0.5 * (((p - 0.75) / (0.9 - 0.75)) + r)) * (0.9 - 0.75)) + 0.75
    score5 = ((0.5 * (((p - 0.9)  / (1.0 - 0.9)) + r)) * (1.0 - 0.9)) + 0.9
    
    new_score = p.copy()
    new_score = np.where(cond0, score0, new_score)
    new_score = np.where(cond1, score1, new_score)
    new_score = np.where(cond2, score2, new_score)
    new_score = np.where(cond3, score3, new_score)
    new_score = np.where(cond4, score4, new_score)
    new_score = np.where(cond5, score5, new_score)
    
    df['score'] = new_score
    return df[['pid', 'rerank_score', 'score']]

print("第三阶段：开始计算归一化后的 rerank_score 与最终 score ……")
sql_select2 = "SELECT pid, ali, bge, jina, lb, processed_score FROM document_pairs"
for chunk in pd.read_sql_query(sql_select2, conn, chunksize=CHUNKSIZE):
    processed_chunk = process_chunk(chunk)
    update_data = processed_chunk[['rerank_score', 'score', 'pid']].values.tolist()
    cursor.executemany(
        "UPDATE document_pairs SET rerank_score = ?, score = ? WHERE pid = ?",
        update_data
    )
    conn.commit()
    print(f"第三阶段：已处理更新 {len(processed_chunk)} 条记录")
    
conn.close()
print("所有数据处理完毕！")
