#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import json
import glob
import argparse
import random
from rich.progress import track

from src.graph_utils import summary
from src.scripts.generate_neg import path_to_str_with_graph

def load_rerank_results(rerank_dirs):
    """流式加载rerank结果文件，使用生成器避免一次性加载所有数据到内存"""
    # 如果传入的是字符串，转换为列表
    if isinstance(rerank_dirs, str):
        rerank_dirs = [rerank_dirs]

    total_files = 0
    for rerank_dir in rerank_dirs:
        if not os.path.exists(rerank_dir):
            print(f"警告: rerank目录 {rerank_dir} 不存在，跳过")
            continue

        files = glob.glob(os.path.join(rerank_dir, "*.json"))
        total_files += len(files)
        print(f"在目录 {rerank_dir} 中找到 {len(files)} 个rerank结果文件")

        for file_path in track(files, description=f"处理 {os.path.basename(rerank_dir)}"):
            try:
                with open(file_path, 'r', encoding='utf-8') as f:
                    data = json.load(f)
                    yield data
            except Exception as e:
                print(f"加载文件 {file_path} 时出错: {e}")

    print(f"总共找到 {total_files} 个rerank结果文件")

def load_valid_paths(valid_path_files):
    """加载有效路径文件"""
    for valid_path_file in valid_path_files:
        if not os.path.exists(valid_path_file):
            print(f"有效路径文件 {valid_path_file} 不存在")
            exit(1)

    valid_paths_dict = {}
    for valid_path_file in valid_path_files:
        with open(valid_path_file, 'r', encoding='utf-8') as f:
            valid_paths_str = f.readlines()
            for line in track(valid_paths_str, description=f"load {os.path.basename(valid_path_file)}"):
                _data = json.loads(line)
                if 'id' in _data and 'valid_paths' in _data:
                    valid_paths_dict[_data['id']] = _data
                else:
                    print("Warning: valid_paths file is not valid")

    print(f"加载了 {len(valid_paths_dict)} 个有效路径数据")
    return valid_paths_dict

def process_rerank_results(rerank_data_stream, valid_paths, neg_per_pos=5):
    """流式处理rerank结果，边处理边生成对比学习数据"""
    stats = {
        "total_questions": 0,
        "questions_with_pos": 0,
        "total_pos_samples": 0,
        "total_neg_samples": 0,
        "avg_neg_per_pos": 0,
        "pos_score_avg": 0,
        "neg_score_avg": 0,
        "questions_without_pos": 0,
        "questions_without_neg": 0
    }


    processed_ids = set()

    def process_question(item_id, question_data):
        """处理单个问题的数据"""
        nonlocal stats

        if not question_data:
            return None

        question = question_data["question"]

        # 从valid_paths字典中获取正样本路径
        valid_path_item = valid_paths.get(item_id)
        if not valid_path_item:
            print(f"Warning: valid_path_item not found for {item_id}")
            stats["questions_without_pos"] += 1
            return None

        valid_path_item = summary(valid_path_item, with_graph=True)
        if 'truth_paths' not in valid_path_item:
            print(f"Warning: truth_paths not found for {item_id}")
            stats["questions_without_pos"] += 1
            return None

        pos_paths = valid_path_item['truth_paths']

        if not pos_paths:
            # try to get pos_paths from rerank_paths
            pos_paths = [item["path"] for item in question_data["rerank_paths"] if item["is_truth"]]
            if not pos_paths:
                stats["questions_without_pos"] += 1
                return None

        stats["questions_with_pos"] += 1
        stats["total_pos_samples"] += len(pos_paths)

        # 找出负样本并按score排序（从高到低）
        pos_paths = path_to_str_with_graph(valid_path_item['q_entity'], pos_paths, valid_path_item['graph'])

        neg_samples = [item for item in question_data["rerank_paths"] if (
            not item["is_truth"] and
            not any(pos_path in item["path"] for pos_path in pos_paths) and
            item["score"] < 0.5 # 过滤掉得分较高的负样本
        )]
        neg_samples.sort(key=lambda x: x["score"], reverse=True)

        if not neg_samples:
            stats["questions_without_neg"] += 1
            return None

        # 从 5N 个负样本中随机选择前 N 个负样本
        sample_pool = neg_samples[:neg_per_pos * 5 * len(pos_paths)]
        sample_size = min(neg_per_pos * len(pos_paths), len(sample_pool))
        selected_negs = random.sample(sample_pool, sample_size)

        if selected_negs:
            stats["total_neg_samples"] += len(selected_negs)
            neg_paths = [neg['path'] for neg in selected_negs]

            # 收集负样本的分数以计算平均值
            stats["neg_score_avg"] += sum(neg["score"] for neg in selected_negs)

            return {
                "id": item_id,
                "query": question,
                "pos": pos_paths,
                "neg": neg_paths
            }

        return None

    for item in rerank_data_stream:
        item_id = item.get("id", "")
        question_data = item.copy()

        if item_id in processed_ids:
            continue

        processed_ids.add(item_id)

        result = process_question(item_id, question_data)
        if result:
            yield result

    stats["total_questions"] = len(processed_ids)
    # 计算最终统计
    if stats["total_neg_samples"] > 0:
        stats["neg_score_avg"] /= stats["total_neg_samples"]

    if stats["questions_with_pos"] > 0:
        stats["avg_neg_per_pos"] = stats["total_neg_samples"] / stats["total_pos_samples"]

    # 返回统计信息（通过生成器的最后一个值）
    yield {"__stats__": stats}

def print_stats(stats):
    """打印统计信息"""
    print("\n===== 数据集统计信息 =====")
    print(f"总问题数: {stats['total_questions']}")
    print(f"有正样本的问题数: {stats['questions_with_pos']} ({stats['questions_with_pos']/stats['total_questions']*100:.2f}%)")
    print(f"无正样本的问题数: {stats['questions_without_pos']} ({stats['questions_without_pos']/stats['total_questions']*100:.2f}%)")
    print(f"总正样本数: {stats['total_pos_samples']}")
    print(f"总负样本数: {stats['total_neg_samples']}")
    print(f"每个正样本平均负样本数: {stats['avg_neg_per_pos']:.2f}")
    # 由于正样本从valid_paths获取，不再有pos_score_avg
    # print(f"正样本平均分数: {stats['pos_score_avg']:.4f}")
    print(f"负样本平均分数: {stats['neg_score_avg']:.4f}")
    print("=========================")

def save_contrastive_data(data_stream, output_file):
    """流式保存对比学习数据集，边处理边保存"""
    stats = None
    count = 0

    with open(output_file, 'w', encoding='utf-8') as f:
        for item in data_stream:
            # 检查是否是统计信息
            if isinstance(item, dict) and "__stats__" in item:
                stats = item["__stats__"]
                break
            else:
                _json_str = json.dumps(item, ensure_ascii=False)
                f.write(_json_str + '\n')
                count += 1

    print(f"对比学习数据集已保存到: {output_file}，共保存 {count} 条数据")
    return stats

def main():
    default_rerank_dir = "data/rerank_results/2025-03-13_15-53-37_bge-reranker-v2-m3"
    default_valid_path_file = "data/webqsp/test.valid.jsonl"
    default_output_file = "data/webqsp/rerank.test.results.0427.jsonl"
    parser = argparse.ArgumentParser(description="从rerank结果生成对比学习数据集")
    parser.add_argument("--rerank_dir", type=str, nargs='+', default=[default_rerank_dir], help="rerank结果目录，可以指定多个目录，用空格分隔")
    parser.add_argument("--output_file", type=str, default=default_output_file, help="输出文件路径")
    parser.add_argument("--neg_per_pos", type=int, default=5, help="每个正样本的负样本数量")
    parser.add_argument("--no-appendix", action="store_true", help="不添加后缀")
    parser.add_argument("--valid_path_files", type=str, nargs='+', default=[default_valid_path_file], help="有效路径文件路径")
    args = parser.parse_args()

    print(f"从 {args.valid_path_files} 加载有效路径...")
    valid_paths = load_valid_paths(args.valid_path_files)

    print(f"从以下目录流式加载并处理rerank结果: {', '.join(args.rerank_dir)}")

    # 创建流式数据源
    rerank_data_stream = load_rerank_results(args.rerank_dir)

    print(f"开始流式处理并保存，每个正样本选择 {args.neg_per_pos} 个负样本...")

    # 流式处理并保存
    result_stream = process_rerank_results(
        rerank_data_stream,
        valid_paths,
        neg_per_pos=args.neg_per_pos
    )

    if not args.no_appendix:
        args.output_file = args.output_file.replace(".jsonl", f".{args.neg_per_pos}.jsonl")

    stats = save_contrastive_data(result_stream, args.output_file)

    if stats:
        print_stats(stats)
    else:
        print("警告: 未获取到统计信息")

    print("完成!")

if __name__ == "__main__":
    main()
