#!/usr/bin/env python3
"""
Convert results from raw_predictions to predictions format
"""

import json
import os
import glob
from pathlib import Path

def extract_answer_from_solution(solution: str) -> str:
    """Extract final answer from solution"""
    import re
    # Match \boxed{...} format
    pattern = r"\\boxed{((?:[^{}]|{[^{}]*})*)}"
    boxed_matches = re.findall(pattern, solution, re.DOTALL)
    if boxed_matches:
        return boxed_matches[-1].strip()
    
    # If no \boxed found, return the entire solution
    return solution

def normalize_math_answer(answer: str) -> str:
    """Normalize mathematical answer format"""
    if not answer:
        return answer
    
    import re
    
    # Replace common Unicode mathematical symbols
    replacements = {
        'π': 'pi',  # Unicode π to pi
        '√': 'sqrt',  # Unicode square root to sqrt
        '×': '*',  # Unicode multiplication to *
        '÷': '/',  # Unicode division to /
        '±': '+/-',  # Unicode plus-minus to +/-
    }
    
    normalized = answer
    for unicode_char, replacement in replacements.items():
        normalized = normalized.replace(unicode_char, replacement)
    
    # Handle LaTeX format: remove backslashes
    normalized = normalized.replace('\\pi', 'pi')
    normalized = normalized.replace('\\sqrt', 'sqrt')
    
    # Handle implicit multiplication between numbers and variables, e.g., "7pi" -> "7*pi"
    # Match pattern of number followed by letter
    normalized = re.sub(r'(\d+)([a-zA-Z])', r'\1*\2', normalized)
    
    return normalized

def handle_corner_cases(content: str, task_id: str) -> dict:
    """处理GSM8K转换中的corner case"""
    content = content.strip()
    
    # Corner case 1: Pure numeric answer (e.g., "5600")
    if content.isdigit():
        print(f"SUCCESS {task_id} Processing pure numeric format: {content}")
        return {"answer": content}
    
    # Corner case 2: Calculation process format (e.g., "Total = 68 + 73 + 61 + 96 = 298")
    if "=" in content and any(char.isdigit() for char in content):
        # Try to extract the last number after equals sign as answer
        import re
        # Match numeric patterns after equals sign
        matches = re.findall(r'=\s*(\d+)', content)
        if matches:
            answer = matches[-1]  # Take the last matched number
            print(f"SUCCESS {task_id} Extracted answer from calculation process: {answer}")
            return {"answer": answer}
    
    # Corner case 3: Text containing numbers, try to extract the last number
    import re
    numbers = re.findall(r'\b(\d+)\b', content)
    if numbers:
        answer = numbers[-1]  # Take the last number
        print(f"SUCCESS {task_id} Extracted numeric answer from text: {answer}")
        return {"answer": answer}
    
    print(f"ERROR {task_id} Cannot handle corner case")
    return None

def fix_json_format(content: str) -> str:
    """修复常见的 JSON 格式问题，特别处理代码生成任务的JSON"""
    content = content.strip()
    
    # 方法1: 先尝试直接解析
    try:
        json.loads(content)
        return content
    except:
        pass
    
    # 方法2: 处理文件末尾的多余字符
    import re
    # 移除末尾的多余括号和方括号
    content = re.sub(r'[\]\}]+$', '', content)
    # 确保以 } 结尾
    if not content.rstrip().endswith('}'):
        content = content.rstrip() + '}'
    
    # 方法3: 简单粗暴的方法 - 将整个answer字段的内容转义
    # 找到 "answer": "..." 部分并修复
    answer_pattern = r'"answer":\s*"([^"]*(?:\\.[^"]*)*)"'
    
    def fix_answer_content(match):
        """修复answer字段中的内容"""
        answer_content = match.group(1)
        # 转义所有控制字符
        fixed = answer_content.replace('\n', '\\n').replace('\r', '\\r').replace('\t', '\\t')
        fixed = fixed.replace('"', '\\"')
        fixed = fixed.replace('\\', '\\\\')
        return f'"answer": "{fixed}"'
    
    content_fixed = re.sub(answer_pattern, fix_answer_content, content)
    
    try:
        json.loads(content_fixed)
        return content_fixed
    except:
        pass
    
    # 方法4: 更激进的修复 - 找到answer字段的完整内容并重新构建
    # 查找 "answer": 开始到下一个 "}" 或文件结尾的内容
    answer_start = content.find('"answer": "')
    if answer_start != -1:
        # 找到answer字段的开始
        answer_start = content.find('"', answer_start + 9) + 1  # 找到第一个引号后的位置
        
        # 找到answer字段的结束（最后一个 " 的位置）
        # 从后往前找，但要小心转义的引号
        answer_end = content.rfind('"')
        
        if answer_end > answer_start:
            # 提取answer内容
            answer_content = content[answer_start:answer_end]
            # 转义所有特殊字符
            fixed_answer = answer_content.replace('\\', '\\\\').replace('"', '\\"').replace('\n', '\\n').replace('\r', '\\r').replace('\t', '\\t')
            
            # 重建JSON
            new_content = content[:answer_start] + fixed_answer + content[answer_end:]
            
            try:
                json.loads(new_content)
                return new_content
            except:
                pass
    
    # 方法5: 最后的尝试 - 手动构建JSON（智能转义）
    try:
        # 提取status和answer内容
        status_match = re.search(r'"status":\s*"([^"]*)"', content)
        answer_match = re.search(r'"answer":\s*"([^"]*(?:\\.[^"]*)*)"', content)
        
        if status_match and answer_match:
            status = status_match.group(1)
            answer = answer_match.group(1)
            
            # 智能转义：只转义双引号，保留换行符
            answer = answer.replace('\\', '\\\\').replace('"', '\\"')
            
            # 手动构建JSON，使用json.dumps确保正确转义
            import json as json_module
            manual_json = json_module.dumps({"status": status, "answer": answer})
            json.loads(manual_json)
            return manual_json
    except:
        pass
    
    # 方法6: 找到最后一个完整的 JSON 对象
    last_brace_pos = content.rfind('}')
    if last_brace_pos != -1:
        # 从开头到最后一个 '}' 的内容
        potential_json = content[:last_brace_pos + 1]
        # 尝试转义换行符
        potential_json_escaped = potential_json.replace('\n', '\\n').replace('\r', '\\r')
        try:
            json.loads(potential_json_escaped)
            return potential_json_escaped
        except:
            pass
    
    # 方法4: 尝试移除常见的多余后缀
    # 这些后缀通常出现在 JSON 对象结束后
    suffixes_to_remove = [
        '}}]',  # 最常见的多余后缀
        '}}',   # 双大括号
        '}]',   # 大括号+方括号
        ']',    # 单个方括号
        '"',    # 单个引号
    ]
    
    for suffix in suffixes_to_remove:
        if content.endswith(suffix):
            # 移除后缀
            content_clean = content[:-len(suffix)].rstrip()
            # 确保以 } 结尾
            if not content_clean.endswith('}'):
                content_clean = content_clean + '}'
            
            # 尝试转义换行符
            content_escaped = content_clean.replace('\n', '\\n').replace('\r', '\\r')
            
            # 测试修复后的内容
            try:
                json.loads(content_escaped)
                return content_escaped
            except:
                # 如果还是不行，继续尝试下一个后缀
                continue
    
    # 方法5: 如果所有方法都失败，返回原始内容
    return content

def extract_code_from_py_file(py_file_path: str) -> str:
    """从Python文件中提取代码内容"""
    try:
        with open(py_file_path, 'r', encoding='utf-8') as f:
            content = f.read()
        return content
    except Exception as e:
        print(f"ERROR Failed to read Python file {py_file_path}: {e}")
        return ""

def convert_raw_to_predictions(dataset_name="math", raw_predictions_dir="raw_predictions", extract_py: bool = False):
    """转换 raw_predictions 到 predictions 格式"""
    
    # 数据集配置
    dataset_configs = {
        "math": {
            "test_file": "data/datasets/math_test.jsonl",
            "output_file": "predictions/math_predictions.jsonl",
            "task_pattern": "math-*",
            "fields": {
                "problem": "problem",
                "expected": "solution"
            },
            "answer_extractor": extract_answer_from_solution,
            "normalize_answers": True
        },
        "gsm8k": {
            "test_file": "data/datasets/gsm8k_test.jsonl", 
            "output_file": "predictions/gsm8k_predictions.jsonl",
            "task_pattern": "gsm8k-*",
            "fields": {
                "problem": "question",
                "expected": "answer"
            },
            "answer_extractor": None,  # 直接使用 answer 字段
            "normalize_answers": False
        },
        "humaneval": {
            "test_file": "data/datasets/humaneval_test.jsonl",
            "output_file": "predictions/humaneval_predictions.jsonl",
            "task_pattern": "HumanEval_*",
            "fields": {
                "problem": "prompt",
                "expected": "canonical_solution"
            },
            "answer_extractor": None,  # 直接使用 canonical_solution 字段
            "normalize_answers": False,
            "id_field": "task_id",  # 使用 task_id 而不是 id
            "extra_fields": ["entry_point", "test"]  # 需要额外字段用于代码执行测试
        },
        "mbpp": {
            "test_file": "data/datasets/mbpp_test.jsonl",
            "output_file": "predictions/mbpp_predictions.jsonl", 
            "task_pattern": "mbpp-*",  # 保持原有模式，但会在代码中特殊处理
            "fields": {
                "problem": "prompt",
                "expected": "code"
            },
            "answer_extractor": None,  # 直接使用 code 字段
            "normalize_answers": False,
            "id_field": "task_id",  # 使用 task_id 而不是 id
            "extra_fields": ["entry_point", "test"]  # 需要额外字段用于代码执行测试
        },
        "hotpotqa": {
            "test_file": "data/datasets/hotpotqa_test.jsonl",
            "output_file": "predictions/hotpotqa_predictions.jsonl",
            "task_pattern": "hotpotqa-*",
            "fields": {
                "problem": "question", 
                "expected": "answer"
            },
            "answer_extractor": None,  # 直接使用 answer 字段
            "normalize_answers": False
        },
        "drop": {
            "test_file": "data/datasets/drop_test.jsonl",
            "output_file": "predictions/drop_predictions.jsonl",
            "task_pattern": "drop-*",
            "fields": {
                "problem": "context",
                "expected": "ref_text"
            },
            "answer_extractor": None,  # 直接使用 ref_text 字段
            "normalize_answers": False
        }
    }
    
    if dataset_name not in dataset_configs:
        print(f"ERROR Unsupported dataset: {dataset_name}")
        print(f"支持的数据集: {', '.join(dataset_configs.keys())}")
        return 0
    
    config = dataset_configs[dataset_name]
    
    # 读取原始测试数据
    test_file = Path(config["test_file"])
    if not test_file.exists():
        print(f"ERROR Test file does not exist: {test_file}")
        return 0
    
    test_data = {}
    print(f"读取 {dataset_name.upper()} 测试数据...")
    with open(test_file, 'r', encoding='utf-8') as f:
        for line in f:
            data = json.loads(line.strip())
            # 使用配置中的 id_field，默认为 'id'
            id_field = config.get('id_field', 'id')
            test_data[data[id_field]] = data
    
    # 读取 raw_predictions 结果
    raw_predictions_path = Path(raw_predictions_dir)
    if not raw_predictions_path.exists():
        print(f"ERROR raw_predictions directory does not exist: {raw_predictions_path}")
        return 0
    
    predictions = []
    print(f"读取 {raw_predictions_dir} 结果...")
    
    # 特殊处理 MBPP：目录名是纯数字
    if dataset_name == "mbpp":
        task_dirs = [d for d in raw_predictions_path.iterdir() if d.is_dir() and d.name.isdigit()]
    else:
        task_dirs = raw_predictions_path.glob(config["task_pattern"])
    
    for task_dir in task_dirs:
        raw_task_id = task_dir.name
        
        # 处理 HumanEval 和 MBPP 的 task_id 格式转换
        if dataset_name == "humaneval":
            # 将 HumanEval_84 转换为 HumanEval/84
            if raw_task_id.startswith("HumanEval_"):
                task_id = raw_task_id.replace("HumanEval_", "HumanEval/")
            else:
                task_id = raw_task_id
        elif dataset_name == "mbpp":
            # MBPP 目录名是纯数字，需要转换为整数
            task_id = int(raw_task_id)
        else:
            task_id = raw_task_id
            
        # 根据extract_py参数选择处理方式
        if extract_py:
            # 优先查找coderun/文件夹，然后查找coderun.py文件
            code_content = None
            source_info = ""
            
            # 方法1: 查找coderun/文件夹
            coderun_dir = task_dir / "coderun"
            if coderun_dir.exists():
                py_files = list(coderun_dir.glob("*.py"))
                if py_files:
                    py_file = py_files[0]
                    code_content = extract_code_from_py_file(str(py_file))
                    if code_content:
                        source_info = f"coderun/{py_file.name}"
            
            # 方法2: 如果没有找到，查找code_run.py文件
            if not code_content:
                code_run_py_file = task_dir / "code_run.py"
                if code_run_py_file.exists():
                    code_content = extract_code_from_py_file(str(code_run_py_file))
                    if code_content:
                        source_info = "code_run.py"
            
            # 方法3: 兼容原有的code_run目录
            if not code_content:
                code_run_dir = task_dir / "code_run"
                if code_run_dir.exists():
                    py_files = list(code_run_dir.glob("*.py"))
                    if py_files:
                        py_file = py_files[0]
                        code_content = extract_code_from_py_file(str(py_file))
                        if code_content:
                            source_info = f"code_run/{py_file.name}"
            
            if code_content:
                result = {"answer": code_content}
                print(f"SUCCESS {task_id} Successfully extracted code from Python file: {source_info}")
            else:
                print(f"ERROR {task_id} No Python code file found (coderun/, code_run.py, code_run/)")
                continue
        else:
            # 原有的JSON处理逻辑
            result_file = task_dir / "result.json"
            
            if result_file.exists():
                try:
                    with open(result_file, 'r', encoding='utf-8') as f:
                        result = json.load(f)
                except json.JSONDecodeError as e:
                    print(f"警告: {task_id} 的 result.json 格式错误，尝试修复: {e}")
                    # 尝试修复 JSON 格式
                    try:
                        with open(result_file, 'r', encoding='utf-8') as f:
                            content = f.read()
                        
                        # 修复常见的 JSON 格式问题
                        fixed_content = fix_json_format(content)
                        
                        # 尝试解析修复后的内容
                        result = json.loads(fixed_content)
                        print(f"SUCCESS {task_id} JSON format fix successful")
                        
                    except Exception as fix_error:
                        print(f"ERROR {task_id} JSON format fix failed: {fix_error}")
                        # 尝试处理corner case
                        result = handle_corner_cases(content, task_id)
                        if result is None:
                            continue
                
                # 检查result是否为字典
                if not isinstance(result, dict):
                    print(f"ERROR {task_id} result is not dictionary format, skipping")
                    continue
            else:
                print(f"警告: 找不到 {task_id} 的结果文件")
                continue
        
        # 获取原始问题数据
        if task_id in test_data:
            original_data = test_data[task_id]
            
            # 提取答案
            if config["answer_extractor"]:
                expected_answer = config["answer_extractor"](original_data.get(config["fields"]["expected"], ""))
            else:
                expected_answer = original_data.get(config["fields"]["expected"], "")
            
            predicted_answer = result.get("answer", "")
            
            # 标准化答案（如果需要）
            if config["normalize_answers"]:
                expected_answer = normalize_math_answer(expected_answer)
                predicted_answer = normalize_math_answer(predicted_answer)
            
            # 构建 predictions 格式
            prediction_data = {
                config["fields"]["problem"]: original_data[config["fields"]["problem"]],
                "prediction": predicted_answer,
                "expected": expected_answer,
                "cost": 0.0,  # 默认成本
                "id": task_id
            }
            
            # 添加额外字段（如 entry_point, test）
            if config.get("extra_fields"):
                for field in config["extra_fields"]:
                    if field in original_data:
                        prediction_data[field] = original_data[field]
            
            predictions.append(prediction_data)
            print(f"转换 {task_id}: {predicted_answer}")
        else:
            print(f"警告: 找不到 {task_id} 的原始数据")
    
    # 写入 predictions 文件
    output_file = Path(config["output_file"])
    output_file.parent.mkdir(parents=True, exist_ok=True)
    
    print(f"\n写入 {len(predictions)} 个预测结果到 {output_file}")
    
    with open(output_file, 'w', encoding='utf-8') as f:
        for pred in predictions:
            f.write(json.dumps(pred, ensure_ascii=False) + '\n')
    
    print(f"SUCCESS Conversion completed! Total {len(predictions)} prediction results")
    return len(predictions)

def convert_all_datasets(base_raw_dir="raw_predictions"):
    """批量转换所有数据集"""
    datasets = ["math", "gsm8k"]  # 只转换现有的数据集
    total_converted = 0
    
    print("=" * 60)
    print("批量转换所有数据集")
    print("=" * 60)
    
    for dataset in datasets:
        print(f"\n🔄 转换 {dataset.upper()} 数据集...")
        raw_dir = f"{base_raw_dir}_{dataset}" if dataset != "math" else base_raw_dir
        
        count = convert_raw_to_predictions(dataset, raw_dir)
        if count > 0:
            print(f"SUCCESS {dataset.upper()}: Converted {count} prediction results")
            total_converted += count
        else:
            print(f"WARNING {dataset.upper()}: No prediction results found or conversion failed")
    
    print(f"\nSUMMARY Total converted {total_converted} prediction results")
    return total_converted

if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(
        description="Raw Predictions Conversion Script",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Usage Examples:
  # Convert MATH dataset (default)
  python convert_raw_predictions.py

  # Convert other datasets
  python convert_raw_predictions.py --dataset gsm8k
  python convert_raw_predictions.py --dataset humaneval
  python convert_raw_predictions.py --dataset mbpp
  python convert_raw_predictions.py --dataset hotpotqa
  python convert_raw_predictions.py --dataset drop

  # Use custom raw_predictions directory
  python convert_raw_predictions.py --dataset gsm8k --raw-dir raw_predictions_gsm8k
  python convert_raw_predictions.py --dataset humaneval --raw-dir raw_predictions_humaneval

  # Extract code directly from Python code files (supports coderun/, code_run.py, code_run/)
  python convert_raw_predictions.py --dataset humaneval --py
  python convert_raw_predictions.py --dataset mbpp --raw-dir tasks --py
        """
    )
    
    parser.add_argument(
        '--dataset', '-d',
        type=str,
        default='math',
        help='Dataset name to convert (default: math)'
    )
    
    parser.add_argument(
        '--raw-dir', '-r',
        type=str,
        default='raw_predictions',
        help='raw_predictions directory path (default: raw_predictions)'
    )
    
    parser.add_argument(
        '--py', '-p',
        action='store_true',
        help='Extract code directly from Python code files instead of reading from result.json. Search order: coderun/ folder -> code_run.py file -> code_run/ folder'
    )
    
    parser.add_argument(
        '--all',
        action='store_true',
        help='Batch convert all datasets'
    )
    
    args = parser.parse_args()
    
    # 批量转换所有数据集
    if args.all:
        convert_all_datasets(args.raw_dir)
        print("=" * 60)
        print("SUCCESS Batch conversion completed!")
        print("=" * 60)
        exit(0)
    
    print("=" * 60)
    print("Raw Predictions Conversion Script")
    print("=" * 60)
    
    # 转换预测结果
    count = convert_raw_to_predictions(args.dataset, args.raw_dir, args.py)
    

    print("\n" + "=" * 60)
    if count > 0:
        print(f"SUCCESS Conversion completed! Total converted {count} prediction results")
    else:
        print("ERROR Conversion failed or no prediction results found")
    print("=" * 60)
