import sqlite3
import math
import numpy as np
import h5py
from sentence_transformers import SentenceTransformer
import random
import os
import traceback
import shutil
from datetime import datetime

# 模型配置
models_info = {
    "gte-base": {
        "model_name": "Alibaba-NLP/gte-multilingual-base",
        "dim": 768
    },
    "bge-m3": {
        "model_name": "BAAI/bge-m3",
        "dim": 1024
    },
    "MiniLM-L6-v2": {
        "model_name": "sentence-transformers/all-MiniLM-L6-v2",
        "dim": 384
    },
    "jina-v3": {
        "model_name": "jinaai/jina-embeddings-v3",
        "dim": 1024
    },
    "KaLM-mini-v1.5": {
        "model_name": "HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5",
        "dim": 896
    },
}

def get_total_rows(db_path):
    """获取满足条件的总行数"""
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    cursor.execute("SELECT COUNT(*) FROM document_pairs;")
    total_rows = cursor.fetchone()[0]
    conn.close()
    return total_rows

def load_data_from_db(db_path, batch_size, offset):
    """从 SQLite3 数据库中批量加载满足条件的数据"""
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    cursor.execute(f"""
        SELECT pid, query, doc, score 
        FROM document_pairs 
        LIMIT {batch_size} OFFSET {offset} ;
    """)
    rows = cursor.fetchall()
    conn.close()
    return rows

def generate_embeddings(model, texts, prompt_name="passage"):
    """调用模型对文本生成 embeddings"""
    embeddings = model.encode(
            texts,
            show_progress_bar=False,
            batch_size=1,
            convert_to_numpy=True, 
            prompt_name=prompt_name
    )
    return embeddings

def is_hdf5_valid(file_path):
    """检查HDF5文件是否有效"""
    if not os.path.exists(file_path):
        return False
    
    try:
        with h5py.File(file_path, 'r') as h5:
            # 检查必要的数据集是否存在
            if not all(key in h5 for key in ['pid', 'qemb', 'demb', 'score']):
                return False
        return True
    except (IOError, RuntimeError, OSError, KeyError):
        return False

def backup_corrupted_file(file_path):
    """备份损坏的文件"""
    if os.path.exists(file_path):
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        backup_path = f"{file_path}.corrupted_{timestamp}"
        shutil.copy2(file_path, backup_path)
        print(f"已备份损坏文件: {backup_path}")
        return backup_path
    return None

def extract_embeddings_from_hdf5(file_path):
    """从HDF5文件中提取嵌入向量，返回pid到嵌入的映射"""
    embeddings_dict = {}
    
    if not os.path.exists(file_path):
        return embeddings_dict
    
    try:
        with h5py.File(file_path, 'r') as h5:
            # 确认文件中有必要的数据集
            if all(key in h5 for key in ['pid', 'qemb', 'demb', 'score']):
                # 获取所有记录数量
                n = min(len(h5['pid']), len(h5['qemb']), len(h5['demb']), len(h5['score']))
                
                # 读取数据
                try:
                    for i in range(n):
                        try:
                            pid = h5['pid'][i]
                            qemb = h5['qemb'][i]
                            demb = h5['demb'][i]
                            score = h5['score'][i]
                            # 将bytes转换为字符串
                            if isinstance(pid, bytes):
                                pid = pid.decode('utf-8')
                            
                            embeddings_dict[pid] = {
                                'qemb': qemb,
                                'demb': demb,
                                'score': score
                            }
                        except Exception as e:
                            print(f"从文件 {file_path} 读取数据时出错: {e}")
                            traceback.print_exc()
                            # 忽略损坏的记录
                            continue
                except Exception as e:
                    print(f"从文件 {file_path} 读取数据时出错: {e}")
                    traceback.print_exc()
                    pass  # 如果读取过程中出现错误，返回已恢复的部分
    except Exception as e:
        print(f"从文件 {file_path} 读取数据时出错: {e}")
        traceback.print_exc()
    
    print(f"从文件 {file_path} 中读取了 {len(embeddings_dict)} 条记录")
    return embeddings_dict

def main():
    db_path = "new.db"
    total_rows = get_total_rows(db_path)
    print(f"满足条件的总行数：{total_rows}")
    
    # 随机划分训练集和测试集
    train_count = int(total_rows * 0.8)
    test_count = total_rows - train_count
    print(f"训练集样本数：{train_count}；测试集样本数：{test_count}")
    
    # 生成全局随机索引集合
    all_indices = np.arange(total_rows)
    np.random.seed(42)
    np.random.shuffle(all_indices)
    train_indices_set = set(all_indices[:train_count])
    
    batch_size = 64  # 每批次处理的行数，可根据内存情况调整
    num_batches = math.ceil(total_rows / batch_size)
    
    # 定义 HDF5 中 pid 字符串的数据类型 (UTF-8 编码)
    pid_dtype = h5py.string_dtype(encoding='utf-8')
    
    for model_key, config in models_info.items():
        print(f"\n正在处理模型：{model_key}")
        
        train_hdf5_file = f"{model_key}_train.hdf5"
        test_hdf5_file = f"{model_key}_test.hdf5"
        
        # 检查文件是否有效，如果无效则备份
        train_valid = is_hdf5_valid(train_hdf5_file)
        test_valid = is_hdf5_valid(test_hdf5_file)
        
        # 从现有文件中提取嵌入向量
        embeddings_dict = {}
        
        # 处理训练集文件
        if os.path.exists(train_hdf5_file):
            print(f"训练集文件 {train_hdf5_file} 恢复数据...")
            backup_path = backup_corrupted_file(train_hdf5_file)
            train_embeddings = extract_embeddings_from_hdf5(backup_path)
            embeddings_dict.update(train_embeddings)
        
        # 处理测试集文件
        if os.path.exists(test_hdf5_file):
            print(f"测试集文件 {test_hdf5_file} 恢复数据...")
            backup_path = backup_corrupted_file(test_hdf5_file)
            test_embeddings = extract_embeddings_from_hdf5(backup_path)
            embeddings_dict.update(test_embeddings)
        
        # 如果两个文件都有效且所有记录都已处理，则跳过
        if train_valid and test_valid and len(embeddings_dict) >= total_rows:
            print(f"模型 {model_key} 已处理完毕，跳过...")
            continue
        
        # 加载模型，用于生成新的嵌入向量
        print(f"加载模型: {config['model_name']}")
        model = SentenceTransformer(config["model_name"], trust_remote_code=True)
        if model_key in ["nomic-v2", "gte-Qwen2-1.5B"]:
            prompt_name_q = "query"
            if model_key in ["nomic-v2"]:
                prompt_name_d = "passage"
            else:
                prompt_name_d = None
        else:
            prompt_name_q = None
            prompt_name_d = None
        
        # 创建新的HDF5文件
        with h5py.File(train_hdf5_file, 'w') as h5_train, h5py.File(test_hdf5_file, 'w') as h5_test:
            # 为训练集创建数据集
            pid_train = h5_train.create_dataset('pid', shape=(0,), maxshape=(None,), dtype=pid_dtype, chunks=True)
            qemb_train = h5_train.create_dataset('qemb', shape=(0, config['dim']), maxshape=(None, config['dim']),
                                                 dtype='float32', chunks=True)
            demb_train = h5_train.create_dataset('demb', shape=(0, config['dim']), maxshape=(None, config['dim']),
                                                 dtype='float32', chunks=True)
            score_train = h5_train.create_dataset('score', shape=(0,), maxshape=(None,), dtype='float32', chunks=True)
            
            # 为测试集创建数据集
            pid_test = h5_test.create_dataset('pid', shape=(0,), maxshape=(None,), dtype=pid_dtype, chunks=True)
            qemb_test = h5_test.create_dataset('qemb', shape=(0, config['dim']), maxshape=(None, config['dim']),
                                               dtype='float32', chunks=True)
            demb_test = h5_test.create_dataset('demb', shape=(0, config['dim']), maxshape=(None, config['dim']),
                                               dtype='float32', chunks=True)
            score_test = h5_test.create_dataset('score', shape=(0,), maxshape=(None,), dtype='float32', chunks=True)
            
            train_current_size = 0
            test_current_size = 0
            
            # 从头遍历所有数据
            for batch_num in range(num_batches):
                offset = batch_num * batch_size
                current_batch_rows = min(batch_size, total_rows - offset)
                
                if current_batch_rows <= 0:
                    break
                
                print(f"正在处理第 {batch_num + 1}/{num_batches} 批次，行 {offset} - {offset + current_batch_rows}")
                
                data = load_data_from_db(db_path, batch_size, offset)
                if not data:
                    break  # 数据读取完毕
                
                # 分别获取各列数据
                pids, queries, docs, scores = zip(*data)
                
                # 划分训练集与测试集
                global_indices = np.arange(offset, offset + current_batch_rows)
                train_mask = [idx in train_indices_set for idx in global_indices]
                test_mask = [not mask for mask in train_mask]
                
                # 准备训练集数据
                train_indices = [i for i, is_train in enumerate(train_mask) if is_train]
                if train_indices:
                    # 收集训练集记录
                    pids_train = [pids[i] for i in train_indices]
                    queries_train = [queries[i] for i in train_indices] 
                    docs_train = [docs[i] for i in train_indices]
                    scores_train = np.array([scores[i] for i in train_indices], dtype='float32')
                    
                    # 收集需要新生成嵌入的记录
                    new_queries = []
                    new_docs = []
                    new_q_indices = []
                    new_d_indices = []
                    
                    # 检查现有嵌入
                    q_embs_train = np.zeros((len(train_indices), config['dim']), dtype='float32')
                    d_embs_train = np.zeros((len(train_indices), config['dim']), dtype='float32')
                    
                    for i, pid in enumerate(pids_train):
                        if pid in embeddings_dict:
                            # 使用现有嵌入
                            q_embs_train[i] = embeddings_dict[pid]['qemb']
                            d_embs_train[i] = embeddings_dict[pid]['demb']
                            # scores不需要重新生成，直接使用新的分数
                        else:
                            # 需要生成新的嵌入
                            new_queries.append(queries_train[i].strip() if queries_train[i].strip() else "none")
                            new_docs.append(docs_train[i].strip() if docs_train[i].strip() else "none")
                            new_q_indices.append(i)
                            new_d_indices.append(i)
                    
                    # 生成新的嵌入向量
                    if new_q_indices:
                        new_q_embs = generate_embeddings(model, new_queries, prompt_name=prompt_name_q).astype('float32')
                        for j, idx in enumerate(new_q_indices):
                            q_embs_train[idx] = new_q_embs[j]
                    if new_d_indices:
                        new_d_embs = generate_embeddings(model, new_docs, prompt_name=prompt_name_d).astype('float32')
                        for j, idx in enumerate(new_d_indices):
                            d_embs_train[idx] = new_d_embs[j]
                    
                    # 写入训练集数据
                    new_train_size = train_current_size + len(pids_train)
                    pid_train.resize((new_train_size,))
                    qemb_train.resize((new_train_size, config['dim']))
                    demb_train.resize((new_train_size, config['dim']))
                    score_train.resize((new_train_size,))
                    
                    pid_train[train_current_size:new_train_size] = pids_train
                    qemb_train[train_current_size:new_train_size, :] = q_embs_train
                    demb_train[train_current_size:new_train_size, :] = d_embs_train
                    score_train[train_current_size:new_train_size] = scores_train
                    
                    train_current_size = new_train_size
                
                # 准备测试集数据
                test_indices = [i for i, is_test in enumerate(test_mask) if is_test]
                if test_indices:
                    # 收集测试集记录
                    pids_test = [pids[i] for i in test_indices]
                    queries_test = [queries[i] for i in test_indices]
                    docs_test = [docs[i] for i in test_indices]
                    scores_test = np.array([scores[i] for i in test_indices], dtype='float32')
                    
                    # 收集需要新生成嵌入的记录
                    new_queries = []
                    new_docs = []
                    new_q_indices = []
                    new_d_indices = []
                    
                    # 检查现有嵌入
                    q_embs_test = np.zeros((len(test_indices), config['dim']), dtype='float32')
                    d_embs_test = np.zeros((len(test_indices), config['dim']), dtype='float32')
                    
                    for i, pid in enumerate(pids_test):
                        if pid in embeddings_dict:
                            # 使用现有嵌入
                            q_embs_test[i] = embeddings_dict[pid]['qemb']
                            d_embs_test[i] = embeddings_dict[pid]['demb']
                        else:
                            # 需要生成新的嵌入
                            new_queries.append(queries_test[i].strip() if queries_test[i].strip() else "none")
                            new_docs.append(docs_test[i].strip() if docs_test[i].strip() else "none")
                            new_q_indices.append(i)
                            new_d_indices.append(i)
                    
                    # 生成新的嵌入向量
                    if new_q_indices:
                        new_q_embs = generate_embeddings(model, new_queries, prompt_name=prompt_name_q).astype('float32')
                        for j, idx in enumerate(new_q_indices):
                            q_embs_test[idx] = new_q_embs[j]
                    if new_d_indices:
                        new_d_embs = generate_embeddings(model, new_docs, prompt_name=prompt_name_d).astype('float32')
                        for j, idx in enumerate(new_d_indices):
                            d_embs_test[idx] = new_d_embs[j]
                    
                    # 写入测试集数据
                    new_test_size = test_current_size + len(pids_test)
                    pid_test.resize((new_test_size,))
                    qemb_test.resize((new_test_size, config['dim']))
                    demb_test.resize((new_test_size, config['dim']))
                    score_test.resize((new_test_size,))
                    
                    pid_test[test_current_size:new_test_size] = pids_test
                    qemb_test[test_current_size:new_test_size, :] = q_embs_test
                    demb_test[test_current_size:new_test_size, :] = d_embs_test
                    score_test[test_current_size:new_test_size] = scores_test
                    
                    test_current_size = new_test_size
                
                # 每5批次强制写入磁盘
                if batch_num % 5 == 0:
                    h5_train.flush()
                    h5_test.flush()
                    #print(f"已处理 {train_current_size + test_current_size}/{total_rows} 条记录")
            
            print(f"模型 {model_key} 处理完成")
            print(f"训练集记录数: {train_current_size}，测试集记录数: {test_current_size}")
            print(f"模型 {model_key} 的训练集保存在 {train_hdf5_file}")
            print(f"模型 {model_key} 的测试集保存在 {test_hdf5_file}")

if __name__ == "__main__":
    main()
