import pandas as pd
from typing import Dict, List, Any, Optional, Tuple
import json
import re
import difflib
import numpy as np
from transformers import AutoTokenizer
from tqdm import tqdm
tokenizer = AutoTokenizer.from_pretrained("/mnt/shared-storage-user/p1-shared/Qwen/Qwen3-4B-Base")
data = pd.read_parquet("/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/data/l1/dapo-math-17k_qwen3-add1k.parquet")
jsonl_file = "/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/verl-041-result/verl-qwen3-8b-new/Qwen3-4B-Base-dapo-17k/valid/0_32000.jsonl/0_32000.jsonl"

def calculate_distinct_ngram_ratio(text, n=10):
    """
    使用哈希算法计算 text 中唯一 n-gram 占总 n-gram 的比例。
    参考 repetition.py 的逻辑。
    """
    if not text:
        return 1.0

    # 遵循 repetition.py 的分词逻辑
    words = []
    for segment in text.split():
        words.extend(segment.split('_'))
    
    if len(words) < n:
        return 1.0
    
    # 使用 set 记录唯一的哈希值
    distinct_hashes = set()
    total_ngrams = 0
    
    for i in range(len(words) - n + 1):
        # 提取窗口并计算哈希
        window = tuple(words[i:i+n])
        window_hash = hash(window)
        
        distinct_hashes.add(window_hash)
        total_ngrams += 1
    
    if total_ngrams == 0:
        return 1.0
        
    # 比例 = 唯一哈希数 / 总 n-gram 数
    return len(distinct_hashes) / total_ngrams


def load_jsonl(file_path: str) -> List[Dict]:
    """加载JSONL文件"""
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line.strip()))
    return data


def _extract_prompt_from_jsonl_item(jsonl_item: Dict[str, Any]) -> str:
    """从jsonl条目里提取prompt文本（兼容 prompt / input 字段）。"""
    prompt_text = jsonl_item.get('prompt', '') or jsonl_item.get('input', '')
    
    prompt_text = prompt_text.split('user')[-1].split('assistant')[0]
    prompt_text = prompt_text.split('<|im_start|>user')[-1].split('<|im_end|> <|im_start|>assistant')[0]
    prompt_text = prompt_text.split("A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. User: You must put your answer inside \\boxed{} and Your final answer will be extracted automatically by the \\boxed{} tag. \n")[-1].split("Assistant:")[0].strip()
    prompt_text = prompt_text.replace("<|endoftext|>", "")
    return prompt_text

def _extract_responses_from_jsonl_item(jsonl_item: Dict[str, Any]) -> List[str]:
    out = jsonl_item.get('output', None)
    if out is None or out == "":
        out = jsonl_item.get('generated_text', None)

    if out is None:
        return [""]
    if isinstance(out, list):
        # 重要：很多生成脚本里一条jsonl代表“一次生成”，即使返回了多个候选，
        # 这里也只取第一个，避免把一行jsonl计成多条response导致“8遍”对不齐。
        if len(out) == 0:
            return [""]
        first = out[0]
        s = str(first).strip() if first is not None else ""
        s = s.replace("<|endoftext|>", "")
        return [s] if s != "" else [""]
    if isinstance(out, str):
        out = out.replace("<|endoftext|>", "")
        return [out] if out.strip() != "" else [""]
    s = str(out).strip()
    s = s.replace("<|endoftext|>", "")
    return [s] if s != "" else [""]

def _approx_prompt_match(a: str, b: str, min_len: int = 20, ratio_threshold: float = 0.8) -> bool:
    """
    “大致匹配”两个 prompt：
    - 先做完全匹配
    - 再做互为子串（避免很短字符串误匹配）
    - 最后用 difflib 的相似度兜底
    - 如果都失败，尝试去掉所有非字母数字字符后再对比
    """
    # a = _normalize_prompt_text(a)
    # b = _normalize_prompt_text(b)
    if not a or not b:
        return False
    if a == b:
        return True
    if len(a) >= min_len and len(b) >= min_len:
        if a in b or b in a:
            return True
    # 相似度兜底（对小的格式差异更稳）
    r = difflib.SequenceMatcher(None, a, b).ratio()
    if r >= ratio_threshold:
        return True

    # 最后的保底：只对比字母和数字
    def _clean(s):
        return re.sub(r'[^a-zA-Z0-9]', '', s).lower()
    
    a_clean, b_clean = _clean(a), _clean(b)
    if len(a_clean) > min_len and len(b_clean) > min_len:
        if a_clean == b_clean or a_clean in b_clean or b_clean in a_clean:
            return True
            
    return False

jsonl_data = load_jsonl(jsonl_file)
distinct_ratio_list = []
response_length_list = []
for i in tqdm(range(len(data))):
    prompt = data.at[i, 'prompt'][0]['content']
    j = i
    jsonl_prompt = _extract_prompt_from_jsonl_item(jsonl_data[j])
    while not _approx_prompt_match(prompt, jsonl_prompt):
        j += 1
        jsonl_prompt = _extract_prompt_from_jsonl_item(jsonl_data[j])
    if j == len(jsonl_data):
        print(f"Warning: prompt {prompt} not found in jsonl")
        breakpoint()
    jsonl_response = _extract_responses_from_jsonl_item(jsonl_data[j])
    response_text = jsonl_response[0] if jsonl_response else ""
    cal_response = tokenizer.decode(tokenizer.encode(response_text, add_special_tokens=False)[:8192], skip_special_tokens=True)
    distinct_ratio = calculate_distinct_ngram_ratio(cal_response, n=10)
    distinct_ratio_list.append(distinct_ratio)
    curr_reward_model = data.at[i, 'reward_model'].copy()
    curr_reward_model['distinct_ratio'] = distinct_ratio

    
    if distinct_ratio == 1.0:
        response_length_list.append(data.iloc[i]['reward_model']['num_tokens']-1000)
    data.at[i, 'reward_model'] = curr_reward_model

print(len(distinct_ratio_list), len(data), "\nmean: ", np.mean(distinct_ratio_list), "\nmax: ", np.max(distinct_ratio_list), "\nmin: ", np.min(distinct_ratio_list))
print("Count of 1.0:", np.sum(np.array(distinct_ratio_list) == 1.0))
print("mean: ", np.mean(response_length_list), "\nmax: ", np.max(response_length_list), "\nmin: ", np.min(response_length_list))
data.to_parquet("/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/data/l1/dapo-math-17k_qwen3-add1k_distinct_ratio.parquet")
