import json
from collections import Counter
from typing import Dict, List, Set




def load_json_file(filepath: str) -> List[Dict]:
    """加载JSON文件"""
    with open(filepath, 'r', encoding='utf-8') as f:
        return json.load(f)

def save_json_file(filepath: str, data: List[Dict]):
    """保存JSON文件"""
    with open(filepath, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=2)
        
def extract_between(text: str, tag: str) -> str | None:
    open_tag, close_tag = f"<{tag}>", f"</{tag}>"
    start = text.find(open_tag) + len(open_tag)
    end = text.find(close_tag)
    if start > len(open_tag)-1 and end > -1:
        return text[start:end].strip()
    return None




def extract_correct_thinking(dataset):
    
    processed_dataset = []
    for d in dataset:
        # print(d)
        answers = d['generated_answer']
        judgment = d['judgments']
        
        for i in range(4):
            try:
                if judgment[i]['score'] == "Yes" and extract_between(answers[i]['answer'], 'think'):
                    # have at least one extractable correct CoT
                    d['cot'] = extract_between(answers[i]['answer'], 'think')
                    # d['cot'] = answers[i]['answer']
                    d['summary'] = answers[i]['answer'].split("</think>")[-1].strip()
                    d.pop('generated_answer')
                    d.pop('judgments')
                    processed_dataset.append(d)
                    
                    break
            except:
                pass
    return processed_dataset




def get_category_distribution(data: List[Dict]) -> Dict[str, float]:
    """计算category分布比例"""
    total = len(data)
    category_counts = Counter(item['category'] for item in data)
    return {cat: count / total for cat, count in category_counts.items()}




def get_category_counts(data: List[Dict]) -> Dict[str, int]:
    """计算每个category的数量"""
    return Counter(item['category'] for item in data)




def select_minimal_data(
    file_a_path: str,
    file_b_path: str,
    file_c_path: str,
    output_path: str,
    tolerance: float = 0.001  # 容差，允许的比例误差
):
    """
    从文件B中选择最少的数据，只补充比例不足的category
    
    Args:
        file_a_path: 文件A路径（提供目标分布）
        file_b_path: 文件B路径（数据源）
        file_c_path: 文件C路径（已选数据，是B的子集）
        output_path: 输出文件路径
        tolerance: 允许的比例误差
    """
    
    # 1. 加载所有文件
    print("正在加载文件...")
    data_a = load_json_file(file_a_path)
    
    data_b_raw = load_json_file(file_b_path)
    data_b = extract_correct_thinking(data_b_raw)
    
    data_c = load_json_file(file_c_path)
    
    # 2. 获取文件A的category分布
    target_distribution = get_category_distribution(data_a)
    print(f"\n文件A的category分布:")
    for cat, ratio in sorted(target_distribution.items()):
        print(f"  {cat}: {ratio:.2%}")
    
    # 3. 获取文件C中已选数据的ID集合
    existing_ids: Set[str] = {item['id'] for item in data_c}
    print(f"\n文件C中已有 {len(data_c)} 条数据")
    
    # 4. 构建文件B的索引（创建ID到数据的映射）
    data_b_dict: Dict[str, Dict] = {item['id']: item for item in data_b}
    print(f"文件B处理后有 {len(data_b)} 条数据")
    
    # 5. 从文件B中提取已选的数据（根据文件C的ID）
    selected_data = []
    for item_id in existing_ids:
        if item_id in data_b_dict:
            selected_data.append(data_b_dict[item_id])
        else:
            print(f"警告: ID {item_id} 在文件C中但不在处理后的文件B中")
    
    print(f"从文件B中匹配到已选数据: {len(selected_data)} 条")
    
    # 统计已选数据的category分布
    existing_counts = get_category_counts(selected_data)
    print("\n已有数据的category分布:")
    current_total = len(selected_data)
    for cat, count in sorted(existing_counts.items()):
        current_ratio = count / current_total if current_total > 0 else 0
        target_ratio = target_distribution.get(cat, 0)
        status = "✓" if current_ratio >= target_ratio - tolerance else "✗ 不足"
        print(f"  {cat}: {count} 条 ({current_ratio:.2%}) 目标: {target_ratio:.2%} {status}")
    
    # 6. 构建未选数据池（按category分类）
    available_data_by_category: Dict[str, List[Dict]] = {}
    for item in data_b:
        if item['id'] not in existing_ids:
            category = item['category']
            if category not in available_data_by_category:
                available_data_by_category[category] = []
            available_data_by_category[category].append(item)
    
    print(f"\n文件B中可用数据（排除已选）:")
    for cat, items in sorted(available_data_by_category.items()):
        print(f"  {cat}: {len(items)} 条")
    
    # 7. 计算需要补充的数据 - 只补充不足的category
    current_counts = existing_counts.copy()
    
    # 初始化所有可能的category计数为0
    for cat in target_distribution.keys():
        if cat not in current_counts:
            current_counts[cat] = 0
    
    print("\n开始智能选择补充数据...")
    print("策略：只补充比例不足的category，已达标的不再添加\n")
    
    iteration = 0
    max_iterations = 100000
    
    while iteration < max_iterations:
        iteration += 1
        current_total = sum(current_counts.values())
        
        # 找出所有比例不足的category
        deficit_categories = []
        
        for cat, target_ratio in target_distribution.items():
            current_count = current_counts.get(cat, 0)
            current_ratio = current_count / current_total if current_total > 0 else 0
            
            if current_ratio < target_ratio - tolerance:
                # 检查是否还有可用数据
                if cat in available_data_by_category and len(available_data_by_category[cat]) > 0:
                    # 计算添加一条数据后的比例变化效果
                    new_count = current_count + 1
                    new_total = current_total + 1
                    new_ratio = new_count / new_total
                    improvement = target_ratio - new_ratio  # 还差多少
                    deficit_categories.append((improvement, cat))
        
        # 如果没有不足的category，结束
        if not deficit_categories:
            break
        
        # 选择最需要补充的category（还差最多的）
        deficit_categories.sort(reverse=True)  # 按deficit降序
        best_category = deficit_categories[0][1]
        
        # 添加一条该category的数据
        item = available_data_by_category[best_category].pop(0)
        selected_data.append(item)
        current_counts[best_category] += 1
        
        # 每100次迭代显示一次进度
        if iteration % 100 == 0:
            added = len(selected_data) - len(data_c)
            print(f"  已添加 {added} 条数据，当前总数: {len(selected_data)}")
    
    # 8. 输出结果
    print(f"\n✓ 完成！最终选择了 {len(selected_data)} 条数据")
    print(f"  - 原有数据: {len(data_c)} 条")
    print(f"  - 新增数据: {len(selected_data) - len(data_c)} 条")
    
    print("\n最终category分布对比:")
    final_counts = get_category_counts(selected_data)
    final_total = len(selected_data)
    
    print(f"{'Category':<20} {'目标比例':<12} {'实际比例':<12} {'数量':<10} {'状态'}")
    print("-" * 70)
    for cat in sorted(target_distribution.keys()):
        target_ratio = target_distribution[cat]
        actual_count = final_counts.get(cat, 0)
        actual_ratio = actual_count / final_total if final_total > 0 else 0
        diff = actual_ratio - target_ratio
        
        if abs(diff) <= tolerance:
            status = "✓ 达标"
        elif diff > 0:
            status = f"↑ 超出 {diff:.2%}"
        else:
            status = f"↓ 不足 {-diff:.2%}"
        
        print(f"{cat:<20} {target_ratio:>10.2%} {actual_ratio:>10.2%} {actual_count:>10} {status}")
    
    # 9. 保存结果
    save_json_file(output_path, selected_data)
    print(f"\n结果已保存到: {output_path}")


