import openai
import pandas as pd
import os
import json
import tiktoken
import random
from dotenv import load_dotenv
from tqdm import tqdm
import re
import logging
from datetime import datetime
import numpy as np
import ast
import argparse
import sys

def parse_args():
    parser = argparse.ArgumentParser(description="Synonym replacement for process_evaluation in a DataFrame.")
    parser.add_argument('--input_path', type=str, required=True, help='Input parquet file path')
    parser.add_argument('--output_dir', type=str, required=True, help='Output directory for the result parquet')
    parser.add_argument('--log_dir', type=str, default="log", help='Directory to save logs')
    parser.add_argument('--openai_api_key', type=str, default=None, help='OpenAI API key (optional, else use env)')
    parser.add_argument('--openai_api_base', type=str, default=None, help='OpenAI API base url (optional, else use env)')
    return parser.parse_args()

# 加载环境变量
load_dotenv()

def setup_logging(log_dir):
    os.makedirs(log_dir, exist_ok=True)
    run_time = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_file = os.path.join(log_dir, f'synonym_{run_time}.log')
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s %(levelname)s %(message)s',
        handlers=[
            logging.FileHandler(log_file, encoding='utf-8'),
            logging.StreamHandler()
        ]
    )
    return log_file

# 固定使用 cl100k_base 编码器
encoding = tiktoken.get_encoding("cl100k_base")

def count_gpt_tokens(text):
    """使用 cl100k_base 统计文本的 token 数"""
    return len(encoding.encode(text))

def escape_illegal_backslashes(json_str):
    return re.sub(r'\\(?![\\\"/bfnrtu])', r'\\\\', json_str)

def sanitize_for_ast_literal_eval(s: str) -> str:
    s = re.sub(r'\\u([0-9a-fA-F]{0,4})', r'\\\\u\1', s)
    s = s.replace(r'\(', r'\\(').replace(r'\)', r'\\)')
    s = s.replace(r'\[', r'\\[').replace(r'\]', r'\\]')
    s = re.sub(r'\\(unit|degree|angle|vec|mathrm|dots|text|overrightarrow|bar)', r'\\\\\1', s)
    s = re.sub(r'\\(?![\\\'\"abfnrt])', r'\\\\', s)
    s = re.sub(r'<image[^>]*>', '[IMAGE]', s)
    s = re.sub(r'[\x00-\x1f\x7f]', '', s)   
    s = s.rstrip()
    if s.endswith("\\") and not s.endswith("\\\\"):
        s = s[:-1].rstrip()
    s = re.sub(r"(\\\\)\s*$", "", s)
    s = re.sub(r",\s*$", "", s)
    return s
    
def clean_and_parse_json(raw_output):
    match = re.search(r'\{[\s\S]*\}', raw_output)
    if not match:
        logging.error("未找到合法 JSON 块")
        print("未找到合法 JSON 块")
        return None

    json_str = match.group(0)

    # 尝试直接使用 ast.literal_eval
    try:
        result = ast.literal_eval(json_str)
        return result
    except (ValueError, SyntaxError):
        logging.error("直接解析失败，尝试修复后再解析")
        print("直接解析失败，尝试修复后再解析")
        pass

    # 尝试修复
    fixed_str = sanitize_for_ast_literal_eval(json_str)
    try:
        result = ast.literal_eval(fixed_str)
        return result
    except (ValueError, SyntaxError) as e:
        logging.error(f"ast.literal_eval 解析失败：{e}")
        logging.error(f"修复后字符串：{repr(fixed_str)}")
        print(f"ast.literal_eval 解析失败：{e}")
        print(f"修复后字符串：{repr(fixed_str)}")
        # 新增：尝试用json.loads解析
        import json
        try:
            fixed_str = escape_illegal_backslashes(fixed_str)
            result = json.loads(fixed_str)
            return result
        except Exception as e2:
            logging.error(f"json.loads 解析也失败：{e2}")
            print(f"json.loads 解析也失败：{e2}")
            return None

def build_messages(input_text) -> list:
    system_message = {
        "role": "system",
        "content": (
            "Task Description:\n"
            "You are a synonym replacement assistant. Given an input sentence, your task is to generate five distinct rewrites. "
            "In each version, you must replace at least one non-technical term with an appropriate synonym, and should replace as many non-technical terms as possible. "
            "Use different combinations of synonyms while keeping the original sentence structure and meaning intact. "
            "All outputs must be grammatically correct and sound natural.\n\n"
            "Definition:\n"
            "Technical terms refer to specialized vocabulary that is specific to a particular field or discipline and should remain unchanged. "
            "These include, but are not limited to: mathematical symbols, scientific terminology, programming syntax, technical jargon, and domain-specific abbreviations.\n\n"
            "Key Constraints:\n"
            "- Do not modify any structural elements (e.g., markdown tags like \"1. **Title:**\", section headers like \"### Section\").\n"
            "- Do not alter any numbers, numerical values, or mathematical expressions, including both plain numbers (e.g., 1, 100, 2%) and LaTeX formulas (e.g., \\( a \\), \\( m < n^2 - n \\)).\n"
            "- Do not change list symbols, bullet points, or any other sequence markers (e.g., “-” in a list, “1.” as an item number, etc.). These structural symbols should be preserved exactly as they are.\n"
            "- Replace only the natural language content—do not alter formatting, technical terms, or domain-specific vocabulary.\n"
            "- Ensure all rewritten sentences are grammatically correct, natural, and maintain the original meaning.\n"
            "- Each rewritten version must replace at least one non-technical word, and should replace as many non-technical words as reasonably possible.\n\n"
            "Output Format:\n"
            "Provide your output in the following JSON structure:\n"
            "{\n"
            "  \"Original Sentence\": \"The original sentence\",\n"
            "  \"Synonym_replacements\": [\n"
            "    \"Synonym_replacement 1\",\n"
            "    \"Synonym_replacement 2\",\n"
            "    \"Synonym_replacement 3\",\n"
            "    \"Synonym_replacement 4\",\n"
            "    \"Synonym_replacement 5\"\n"
            "  ]\n"
            "}"
        )
    }

    current_user = {
        "role": "user",
        "content": (input_text)
    }

    messages = [
        system_message,
        current_user
    ]

    return messages

def is_skippable(text):
    """
    判断文本是否不适合进行同义词替换：
    - 空行
    - 单独的 LaTeX 符号，如 \[ 或 \]
    - 图片占位符 <image_1>
    - markdown 短标题（如 ### Step 1），但保留较长标题进行改写
    - 其他提示性结构语句
    """
    stripped = text.strip()
    
    # 空行
    if stripped == "":
        return True
    
    # 单独的 LaTeX 开始或结束符号
    if stripped in ["\\[", "\\]"]:
        return True

    # 图片占位符
    if re.match(r"^<image_\d+>", stripped) or "<image_" in stripped:
        return True

    # markdown 开头（如 ### 或 1.），判断长度
    if stripped.startswith("###") or stripped.startswith("1.") or stripped.startswith("- "):
        # 去掉 markdown 前缀后判断长度
        content_only = re.sub(r"^(\#{1,6}|[0-9]+\.)\s*", "", stripped)
        if len(content_only.split()) < 5:  # 小于5个词视为短标题
            return True  # 跳过短标题
        else:
            return False  # 长标题可以改写

    # markdown 加粗内容（**...**）且过短
    if "**" in stripped and stripped.count("**") >= 2:
        clean_bold = re.sub(r"\*\*(.*?)\*\*", r"\1", stripped)
        if len(clean_bold.split()) < 5:
            return True
    # 如果字符串长度过短则不改写
    if len(stripped) < 5:  # 这里以8个字符为阈值，可根据实际需求调整
        return True
    
    # if len(stripped.split()) <= 2:
    #     return True

    return False


def call_gpt_api(df, openai_api_key=None, openai_api_base=None):

    client = openai.OpenAI(
        api_key=openai_api_key if openai_api_key is not None else os.getenv("OPENAI_API_KEY", ""),
        base_url=openai_api_base if openai_api_base is not None else os.getenv("OPENAI_API_BASE", "")
    )

    filtered_rewrites = []

    for idx, row in tqdm(df.iterrows(), total=len(df), desc="处理进度"):
        process_evaluation = row['process_evaluation']
        best_rewrites = []
        sample_id = row['id']

        for step_idx, item in enumerate(process_evaluation):
            if isinstance(item, np.ndarray) and len(item) > 0:
                input_text = str(item[0])
                
                # ✅ 检查是否包含 <image_x> 占位符，若包含则跳过改写
                if is_skippable(input_text):
                    # 直接保存原文，不调用改写接口
                    best_rewrites.append(input_text)
                    continue
                
                try:
                    messages = build_messages(input_text)

                    response = client.chat.completions.create(
                        model="gpt-4o",
                        messages=messages,
                        max_tokens=8000,
                        temperature=0,
                        n=1
                    )
                    
                    choices = getattr(response, "choices", [])
                    if not choices:
                        msg = "API 没有返回 choices"
                        logging.error(msg)
                        print(msg)
                        best_rewrites.append(None)
                        log_sen_entry = {
                            "sample_id": sample_id,
                            "step_index": step_idx,
                            "original_text": input_text,
                            "error": msg
                        }
                        logging.error(json.dumps(log_sen_entry, ensure_ascii=False))
                        print(json.dumps(log_sen_entry, ensure_ascii=False))
                        continue

                    raw_output = choices[0].message.content.strip()
                    parsed = clean_and_parse_json(raw_output)
                    
                    if parsed is None:
                        msg = f"解析失败内容：{raw_output}"
                        logging.error(msg)
                        print(msg)
                        # 解析失败时直接保留原句
                        best_rewrites.append(input_text)
                        log_sen_entry = {
                            "sample_id": sample_id,
                            "step_index": step_idx,
                            "original_text": input_text,
                            "error": "解析JSON失败，已保留原句"
                        }
                        logging.error(json.dumps(log_sen_entry, ensure_ascii=False))
                        print(json.dumps(log_sen_entry, ensure_ascii=False))
                        continue


                    rewrites = parsed.get("Synonym_replacements", [])

                    if not rewrites:
                        msg = f"No rewrites returned for: {input_text}"
                        logging.error(msg)
                        print(msg)
                        best_rewrites.append(None)
                        log_sen_entry = {
                            "sample_id": sample_id,
                            "step_index": step_idx,
                            "original_text": input_text,
                            "error": msg
                        }
                        logging.error(json.dumps(log_sen_entry, ensure_ascii=False))
                        print(json.dumps(log_sen_entry, ensure_ascii=False))
                        continue

                    original_tokens = count_gpt_tokens(input_text)

                    exact_matches = []
                    close_matches = []
                    for r in rewrites:
                        rewritten_tokens = count_gpt_tokens(r)
                        token_diff = abs(rewritten_tokens - original_tokens)
                        if token_diff == 0:
                            exact_matches.append(r)
                        elif token_diff <= 2:
                            close_matches.append(r)

                    if exact_matches:
                        selected = random.choice(exact_matches)
                        log_message = {
                            "sample_id": sample_id,
                            "step_index": step_idx,
                            "message": "完全一致的tokens",
                            "original_text": input_text,
                            "selected_rewrite": selected,
                            "original_token_count": original_tokens,
                            "selected_token_count": count_gpt_tokens(selected)
                        }
                        logging.info(json.dumps(log_message, ensure_ascii=False))
                        print(json.dumps(log_message, ensure_ascii=False))
                    elif close_matches:
                        selected = random.choice(close_matches)
                        log_message = {
                            "sample_id": sample_id,
                            "step_index": step_idx,
                            "message": "tokens差别在2以内",
                            "original_text": input_text,
                            "selected_rewrite": selected,
                            "original_token_count": original_tokens,
                            "selected_token_count": count_gpt_tokens(selected)
                        }
                        logging.info(json.dumps(log_message, ensure_ascii=False))
                        print(json.dumps(log_message, ensure_ascii=False))
                    else:
                        selected = min(rewrites, key=lambda x: abs(count_gpt_tokens(x) - original_tokens))
                        selected_tokens = count_gpt_tokens(selected)
                        token_difference = abs(selected_tokens - original_tokens)
                        log_message = {
                            "sample_id": sample_id,
                            "step_index": step_idx,
                            "message": f"tokens差别在2以上，原始token数: {original_tokens}，改写token数: {selected_tokens}，差值: {token_difference}",
                            "original_text": input_text,
                            "selected_rewrite": selected,
                            "token_difference": token_difference,
                            "original_token_count": original_tokens,
                            "selected_token_count": selected_tokens
                        }
                        logging.info(json.dumps(log_message, ensure_ascii=False))
                        print(json.dumps(log_message, ensure_ascii=False))

                    best_rewrites.append(selected)
                    
                except Exception as e:
                    error_msg = f"样本ID {sample_id} 第{step_idx}个步骤处理出错: {e}"
                    logging.error(error_msg)
                    print(error_msg)
                    best_rewrites.append(None)
                    log_sen_entry = {
                        "sample_id": sample_id,
                        "step_index": step_idx,
                        "original_text": input_text,
                        "error": str(e)
                    }
                    logging.error(json.dumps(log_sen_entry, ensure_ascii=False))
                    print(json.dumps(log_sen_entry, ensure_ascii=False))
                    # 不再raise，记录日志后继续
            else:
                logging.warning(f"样本ID {sample_id} 第{step_idx}步不是有效数组，跳过")

        filtered_rewrites.append(best_rewrites)

    df['student_solution_synonym'] = filtered_rewrites #如果这个字段已经存在，则会整体替换原来的内容，即覆盖掉原来的列数据

    # summary统计
    total_steps = sum(len(r['process_evaluation']) for _, r in df.iterrows())
    success_count = sum(1 for sample in filtered_rewrites for r in sample if r is not None)
    logging.info(f"✅ 总处理步骤数: {total_steps}, 成功生成同义改写数: {success_count}")

    return df

def main():
    args = parse_args()
    log_file = setup_logging(args.log_dir)

    input_path = args.input_path
    output_dir = args.output_dir
    openai_api_key = args.openai_api_key
    openai_api_base = args.openai_api_base

    if not os.path.exists(input_path):
        print(f"输入文件不存在: {input_path}")
        sys.exit(1)
    os.makedirs(output_dir, exist_ok=True)

    df = pd.read_parquet(input_path)
    updated_df = call_gpt_api(df, openai_api_key=openai_api_key, openai_api_base=openai_api_base)

    # 自动生成带synonym后缀的新文件名，输出到指定文件夹中
    base, ext = os.path.splitext(os.path.basename(input_path))
    output_path = os.path.join(output_dir, f"{base}_after_synonym{ext}")
    updated_df.to_parquet(output_path)
    print(f"✅ 已保存到: {output_path}")
    print(f"日志文件: {log_file}")

if __name__ == "__main__":
    main()
