import json
import os
import glob
from transformers import AutoTokenizer
from math_verify import parse, verify
from tqdm import tqdm
import signal
from typing import List, Dict
import time
import pandas as pd


def timeout(timeout_seconds: int = 10):
    """超时装饰器"""
    if os.name == "posix":
        def decorator(func):
            def handler(signum, frame):
                raise TimeoutError("verify timed out!")

            def wrapper(*args, **kwargs):
                old_handler = signal.getsignal(signal.SIGALRM)
                signal.signal(signal.SIGALRM, handler)
                signal.alarm(timeout_seconds)
                try:
                    return func(*args, **kwargs)
                finally:
                    signal.alarm(0)
                    signal.signal(signal.SIGALRM, old_handler)
            return wrapper
        return decorator
    else:
        def decorator(func):
            return func
        return decorator


@timeout(timeout_seconds=1000)
def labeling_responses_batch(responses: List[str], golden_answers: List[str]) -> List[bool]:
    """批量验证答案 - 优化版本"""
    try:
        # 批量解析预测答案
        predict_answers = [parse(response) for response in responses]

        # 批量解析正确答案
        golden_parsed = [parse("$" + answer + "$")
                         for answer in golden_answers]

        # 批量验证
        labels = [verify(golden, predict)
                  for golden, predict in zip(golden_parsed, predict_answers)]
        return labels
    except Exception as e:
        print(f"批量验证过程中出错: {e}")
        return [False] * len(responses)


def truncate_text_by_tokens(text: str, tokenizer, max_tokens: int) -> str:
    """按token数量截断文本"""
    try:
        tokens = tokenizer.encode(text, add_special_tokens=False)
        if len(tokens) > max_tokens:
            tokens = tokens[:max_tokens]
        truncated_text = tokenizer.decode(tokens, skip_special_tokens=True)
        return truncated_text
    except Exception as e:
        print(f"截断文本时出错: {e}")
        return text


def get_answer_from_parquet_row(row):
    """从parquet行中获取answer"""
    # 尝试多种可能的字段名
    if 'reward_model' in row and isinstance(row['reward_model'], dict):
        if 'ground_truth' in row['reward_model']:
            return str(row['reward_model']['ground_truth'])
    return None


def check_file_complete(input_file: str, output_file: str, parquet_df: pd.DataFrame, verbose: bool = True) -> bool:
    """检查输出文件是否已经完整处理"""
    if not os.path.exists(output_file):
        return False

    # 读取输入文件的行数
    input_lines = 0
    try:
        with open(input_file, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    input_lines += 1
    except Exception as e:
        if verbose:
            print(f"检查输入文件行数时出错: {e}")
        return False

    # 读取输出文件的行数
    output_lines = 0
    try:
        with open(output_file, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    output_lines += 1
    except Exception as e:
        if verbose:
            print(f"检查输出文件行数时出错: {e}")
        return False

    # 检查行数是否匹配（应该与parquet或输入文件的行数一致）
    expected_lines = min(input_lines, len(parquet_df))

    if output_lines >= expected_lines:
        if verbose:
            print(
                f"  文件已完整处理: 输入 {input_lines} 行, 输出 {output_lines} 行, 期望 {expected_lines} 行")
        return True
    else:
        if verbose:
            print(
                f"  文件未完整: 输入 {input_lines} 行, 输出 {output_lines} 行, 期望 {expected_lines} 行")
        return False


def process_single_file(input_file: str, output_file: str, tokenizer, parquet_df: pd.DataFrame,
                        max_tokens: int = 8192, batch_size: int = 32, resume: bool = True):
    """处理单个文件：截断到指定长度并重新打分，按照parquet的顺序"""
    if not os.path.exists(input_file):
        print(f"文件不存在: {input_file}")
        return

    # 确保输出目录存在
    output_dir = os.path.dirname(output_file)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)

    # 检查是否已完整处理
    if resume and check_file_complete(input_file, output_file, parquet_df):
        print(f"  跳过已完成的文件: {os.path.basename(output_file)}")
        return

    # 读取jsonl文件的所有数据
    jsonl_data = []
    try:
        with open(input_file, 'r', encoding='utf-8') as f_in:
            for line in f_in:
                try:
                    data = json.loads(line.strip())
                    jsonl_data.append(data)
                except json.JSONDecodeError:
                    continue
    except Exception as e:
        print(f"读取文件 {input_file} 时出错: {e}")
        return

    # 检查jsonl和parquet的长度是否一致
    if len(jsonl_data) != len(parquet_df):
        print(
            f"警告: jsonl文件有 {len(jsonl_data)} 条记录，但parquet文件有 {len(parquet_df)} 条记录")
        print(f"将使用较小的长度: {min(len(jsonl_data), len(parquet_df))}")

    # 检查输出文件已处理的行数（用于恢复）
    min_len = min(len(jsonl_data), len(parquet_df))
    start_idx = 0
    if resume and os.path.exists(output_file):
        try:
            with open(output_file, 'r', encoding='utf-8') as f:
                for line in f:
                    if line.strip():
                        start_idx += 1
            if start_idx >= min_len:
                # 文件已完整，应该已经在前面被跳过了，这里不应该到达
                print(f"  警告: 文件应该已完整，但未在检查时跳过")
                return
            elif start_idx > 0:
                print(f"  从第 {start_idx} 行恢复处理（共需处理 {min_len} 行）")
        except Exception as e:
            print(f"  检查已处理行数时出错: {e}，从头开始处理")
            start_idx = 0
            # 如果文件有问题，删除它
            try:
                if os.path.exists(output_file):
                    os.remove(output_file)
                    print(f"  删除有问题的输出文件，重新开始处理")
            except Exception as e2:
                print(f"  删除文件时出错: {e2}")

    batch_data = []
    batch_responses = []
    batch_answers = []

    total_processed = 0
    total_truncated = 0

    try:
        # 如果从断点恢复，使用追加模式；否则覆盖
        mode = 'a' if (resume and start_idx > 0) else 'w'
        with open(output_file, mode, encoding='utf-8') as f_out:
            # 从断点处开始处理（min_len 已在前面定义）
            for idx in tqdm(range(start_idx, min_len), desc=f"处理 {os.path.basename(input_file)}", initial=start_idx, total=min_len):
                try:
                    # 获取jsonl数据（按照parquet的顺序）
                    data = jsonl_data[idx]
                    generated_text = data.get('output', '')

                    # 从parquet中获取answer
                    parquet_row = parquet_df.iloc[idx]
                    answer = get_answer_from_parquet_row(parquet_row)

                    # 检查原始长度
                    try:
                        original_length = data['length']
                    except Exception as e:
                        print(f"获取原始长度时出错: {e}")
                        original_length = len(tokenizer.encode(
                            generated_text, add_special_tokens=False))

                    # 如果超过max_tokens，则截断
                    if original_length > max_tokens:
                        truncated_text = truncate_text_by_tokens(
                            generated_text, tokenizer, max_tokens)
                        total_truncated += 1
                    else:
                        truncated_text = generated_text

                    # 收集批量数据用于重新打分
                    batch_data.append({
                        'data': data,
                        'truncated_text': truncated_text,
                        'answer': answer,
                        'original_length': original_length
                    })
                    batch_responses.append(truncated_text)
                    batch_answers.append(answer)

                    # 当达到批量大小时，处理一批
                    if len(batch_data) >= batch_size:
                        # 批量重新打分
                        labels = labeling_responses_batch(
                            batch_responses, batch_answers)

                        # 写入结果
                        for i, (item, label) in enumerate(zip(batch_data, labels)):
                            new_data = item['data'].copy()
                            new_data['output'] = item['truncated_text']
                            # 使用parquet中的answer
                            new_data['answer'] = item['answer']
                            new_data['score'] = label
                            new_data['original_length'] = item['original_length']

                            f_out.write(json.dumps(
                                new_data, ensure_ascii=False) + '\n')
                            total_processed += 1

                        # 清空批次
                        batch_data = []
                        batch_responses = []
                        batch_answers = []

                except Exception as e:
                    print(f"处理索引 {idx} 时出错: {e}")
                    import traceback
                    traceback.print_exc()
                    continue

            # 处理剩余的批次
            if batch_data:
                labels = labeling_responses_batch(
                    batch_responses, batch_answers)

                for i, (item, label) in enumerate(zip(batch_data, labels)):
                    new_data = item['data'].copy()
                    new_data['output'] = item['truncated_text']
                    new_data['generated_text'] = item['truncated_text']
                    new_data['answer'] = item['answer']
                    new_data['correctness'] = label
                    new_data['length'] = len(tokenizer.encode(
                        item['truncated_text'], add_special_tokens=False))
                    new_data['original_length'] = item['original_length']

                    f_out.write(json.dumps(
                        new_data, ensure_ascii=False) + '\n')
                    total_processed += 1

        print(f"  处理完成: {total_processed} 条记录, 其中 {total_truncated} 条被截断")

    except Exception as e:
        print(f"处理文件 {input_file} 时出错: {e}")
        import traceback
        traceback.print_exc()


def process_directory(input_dir: str, output_dir: str, parquet_file: str,
                      max_tokens: int = 8192, batch_size: int = 32):
    """处理目录下的所有jsonl文件，按照parquet的顺序"""
    # 读取parquet文件
    if not os.path.exists(parquet_file):
        print(f"错误: parquet文件不存在: {parquet_file}")
        return

    print(f"读取parquet文件: {parquet_file}")
    try:
        parquet_df = pd.read_parquet(parquet_file)
        print(f"parquet文件包含 {len(parquet_df)} 条记录")
        print(f"parquet文件列名: {list(parquet_df.columns)}")
    except Exception as e:
        print(f"读取parquet文件时出错: {e}")
        import traceback
        traceback.print_exc()
        return

    # 初始化tokenizer
    try:
        if "llama" in input_dir:
            tokenizer_path = "/mnt/shared-storage-gpfs2/p1-shared-2/wangfuting/OctoThinker-3B-Long-Base"
        elif "polaris" in input_dir:
            tokenizer_path = "/mnt/shared-storage-user/p1-shared/Qwen/Qwen3-4B"
        elif "sft" in input_dir:
            tokenizer_path = "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/LlamaFactory/trainer_output/checkpoint-162"
        else:
            tokenizer_path = "/mnt/shared-storage-user/p1-shared/Qwen/Qwen3-4B-Base"
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
        print(f"成功加载tokenizer: {tokenizer_path}")
    except Exception as e:
        print(f"无法加载Qwen tokenizer: {e}")
        print("使用默认tokenizer: gpt2")
        tokenizer = AutoTokenizer.from_pretrained("gpt2")

    # 查找所有jsonl文件
    jsonl_files = glob.glob(os.path.join(input_dir, "*.jsonl"))
    jsonl_files.sort()  # 排序以确保处理顺序一致

    if not jsonl_files:
        print(f"在目录 {input_dir} 中未找到jsonl文件")
        return

    print(f"找到 {len(jsonl_files)} 个jsonl文件")
    print(f"输出目录: {output_dir}")
    print(f"最大token数: {max_tokens}")
    print("=" * 60)

    # 确保输出目录存在
    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)

    # 检查已处理的文件
    print("\n检查已处理的文件...")
    completed_files = []
    incomplete_files = []
    for jsonl_file in jsonl_files:
        filename = os.path.basename(jsonl_file)
        output_file = os.path.join(output_dir, filename)
        if check_file_complete(jsonl_file, output_file, parquet_df, verbose=False):
            completed_files.append(filename)
        else:
            incomplete_files.append(filename)

    print(f"已完成文件: {len(completed_files)} 个")
    if completed_files:
        print(f"  已完成的文件: {', '.join(completed_files[:5])}" + (
            f" ... (共{len(completed_files)}个)" if len(completed_files) > 5 else ""))
    print(f"待处理文件: {len(incomplete_files)} 个")
    if incomplete_files:
        print(f"  待处理的文件: {', '.join(incomplete_files[:5])}" + (
            f" ... (共{len(incomplete_files)}个)" if len(incomplete_files) > 5 else ""))
    print("=" * 60)

    # 处理每个文件
    start_time = time.time()
    for jsonl_file in jsonl_files:
        filename = os.path.basename(jsonl_file)
        output_file = os.path.join(output_dir, filename)

        print(f"\n处理文件: {filename}")
        process_single_file(jsonl_file, output_file, tokenizer,
                            parquet_df, max_tokens, batch_size, resume=True)

    end_time = time.time()
    print(f"\n所有文件处理完成！总耗时: {end_time - start_time:.2f} 秒")


if __name__ == "__main__":
    # 输入目录
    #     input_dir = "/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/verl-041-result/verl-qwen3-4b-oct/skip-right-skip-limits10-gspo-dapo-math-wo-repetition-redo/valid"
    # ["/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/verl-041-result/verl-qwen3-4b-oct/llama-baseline-gspo-deepmath", "/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/verl-041-result/verl-qwen3-4b-oct/llama-add1k-gspo-deepmath/valid", "/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/verl-041-result/verl-qwen3-4b-oct/qwen3-4b-polaris-add1k-gspo/valid", "/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/verl-041-result/verl-qwen3-4b-oct/qwen3-4b-polaris-baseline-gspo/valid"]
    # ["/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/verl-041-result/verl-qwen3-4b-oct/skip-right-skip-limits10-gspo-dapo-math-wo-repetition-redo/valid"]
    # "/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/verl-041-result/verl-qwen3-4b-oct/ours-gspo-dapo-math-add600/valid",
    for input_dir in ["/mnt/shared-storage-gpfs2/p1-shared-2/wangfuting/LIE/models/verl-qwen3-4b-oct/skip-right-skip-limits10-gspo-dapo-math-redo2/valid"]:

        # 输出目录（在输入目录的父目录下创建新文件夹）
        base_dir = os.path.dirname(input_dir)
        output_dir = os.path.join(base_dir, "valid_8k")

        # parquet文件路径（从环境变量或默认路径获取）
        data_dir = os.environ.get(
            'DATA_DIR', '/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/data')
        parquet_file = os.path.join(
            data_dir, "luffy/valid-polaris-qwen3.parquet")

        # 处理参数
        max_tokens = 8192
        batch_size = 64

        print("开始处理文件...")
        print(f"输入目录: {input_dir}")
        print(f"输出目录: {output_dir}")
        print(f"Parquet文件: {parquet_file}")
        print(f"最大token数: {max_tokens}")
        print("=" * 60)

        process_directory(input_dir, output_dir,
                          parquet_file, max_tokens, batch_size)

        print("\n处理完成！")
