import openai
import pandas as pd
import os
import json
import tiktoken
import random
from dotenv import load_dotenv
from tqdm import tqdm
import re
import logging
import ast
import numpy as np
import argparse
import sys

# 加载环境变量
load_dotenv()

# 固定使用 cl100k_base 编码器
encoding = tiktoken.get_encoding("cl100k_base")

def count_gpt_tokens(text):
    return len(encoding.encode(text))

def escape_illegal_backslashes(json_str):
    return re.sub(r'\\(?![\\\"/bfnrtu])', r'\\\\', json_str)

def sanitize_for_ast_literal_eval(s: str) -> str:
    s = re.sub(r'\\u([0-9a-fA-F]{0,4})', r'\\\\u\1', s)
    s = s.replace(r'\(', r'\\(').replace(r'\)', r'\\)')
    s = s.replace(r'\[', r'\\[').replace(r'\]', r'\\]')
    s = re.sub(r'\\(unit|degree|angle|vec|mathrm|dots|text|overrightarrow|bar)', r'\\\\\1', s)
    s = re.sub(r'\\(?![\\\'\"abfnrt])', r'\\\\', s)
    s = re.sub(r'<image[^>]*>', '[IMAGE]', s)
    s = re.sub(r'[\x00-\x1f\x7f]', '', s)   
    s = s.rstrip()
    if s.endswith("\\") and not s.endswith("\\\\"):
        s = s[:-1].rstrip()
    s = re.sub(r"(\\\\)\s*$", "", s)
    s = re.sub(r",\s*$", "", s)
    return s

def clean_and_parse_json(raw_output):
    match = re.search(r'\{[\s\S]*\}', raw_output)
    if not match:
        logging.error("未找到合法 JSON 块")
        print("未找到合法 JSON 块")
        return None

    json_str = match.group(0)
    try:
        result = ast.literal_eval(json_str)
        return result
    except (ValueError, SyntaxError):
        print("直接解析失败，尝试修复非法反斜杠后再解析")
        pass

    fixed_str = sanitize_for_ast_literal_eval(json_str)
    try:
        result = ast.literal_eval(fixed_str)
        return result
    except (ValueError, SyntaxError) as e:
        logging.error(f"JSON 解析失败：{e}")
        logging.error(f"修复后字符串：{repr(fixed_str)}")
        print(f"JSON 解析失败：{e}")
        print(f"修复后字符串：{repr(fixed_str)}")
        return None

def build_messages(input_text) -> list:
    system_message = {
        "role": "system",
        "content": (
            "Task Description:\n"
            "You are a sentence structure rewriting assistant. Your task is to rewrite a given sentence while altering its structure, "
            "ensuring that the original meaning is preserved. For each sentence, you must generate five distinct rewritten versions, "
            "each applying only one syntactic transformation. The goal is to create varied sentence structures while maintaining semantic accuracy and natural grammar.\n\n"
            "Syntactic Transformations (Choose One per Rewrite):\n"
            "1. Voice Change (Active ↔ Passive)\n"
            "2. Adverbial Position Adjustment\n"
            "3. Clause Order or Structure Change\n"
            "4. Phrase Structure Simplification or Expansion\n"
            "5. Inversion or Emphatic Structure\n"
            "6. Conditional / Purpose / Result Structure Transformation\n\n"
            "Key Constraints:\n"
            "- Preserve all steps in multi-step logical reasoning chains.\n"
            "- Do not omit any mathematical derivations, steps, or intermediate expressions.\n"
            "- Do not change numbers or mathematical expressions, including LaTeX formulas.\n"
            "- Preserve meaning, grammar, and naturalness.\n"
            "- Try to keep the length of the rewritten sentence close to the original (within 2–3 words difference). Avoid significant shortening or lengthening unless necessary for syntactic transformation.\n"
            "- Only one syntactic transformation type per rewritten sentence.\n\n"
            "Output Format:\n"
            "{\n"
            '  "Original Sentence": "The original sentence",\n'
            '  "Rewritten Sentences": [\n'
            '    "rewritten sentence 1",\n'
            '    "rewritten sentence 2",\n'
            '    "rewritten sentence 3",\n'
            '    "rewritten sentence 4",\n'
            '    "rewritten sentence 5"\n'
            "  ]\n"
            "}"
        )
    }

    current_user = {
        "role": "user",
        "content": (input_text)
    }

    messages = [
        system_message,
        current_user
    ]

    return messages

def is_skippable(text):
    stripped = text.strip()
    if stripped == "":
        return True
    if stripped in ["\\[", "\\]"]:
        return True
    if re.match(r"^<image_\d+>", stripped) or "<image_" in stripped:
        return True
    if stripped.startswith("###") or stripped.startswith("1.") or stripped.startswith("- "):
        content_only = re.sub(r"^(\#{1,6}|[0-9]+\.)\s*", "", stripped)
        if len(content_only.split()) < 5:
            return True
        else:
            return False
    if "**" in stripped and stripped.count("**") >= 2:
        clean_bold = re.sub(r"\*\*(.*?)\*\*", r"\1", stripped)
        if len(clean_bold.split()) < 5:
            return True
    if len(stripped) < 5:
        return True
    return False

def call_gpt_api(df, api_key=None, api_base=None, log_dir="log"):
    # 设置日志，文件名加上日期和时间
    from datetime import datetime
    os.makedirs(log_dir, exist_ok=True)
    run_time = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_file = os.path.join(log_dir, f'structure_{run_time}.log')
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s %(levelname)s %(message)s',
        handlers=[
            logging.FileHandler(log_file, encoding='utf-8'),
            logging.StreamHandler()
        ]
    )

    client = openai.OpenAI(
        api_key=api_key,
        base_url=api_base
    )

    filtered_rewrites = []

    for idx, row in tqdm(df.iterrows(), total=len(df), desc="处理进度"):
        process_evaluation = row['process_evaluation']
        best_rewrites = []
        sample_id = row['id']

        for step_idx, item in enumerate(process_evaluation):
            if isinstance(item, np.ndarray) and len(item) > 0:   
                input_text = str(item[0])
                if is_skippable(input_text):
                    best_rewrites.append(input_text)
                    continue
                try:
                    messages = build_messages(input_text)
                    response = client.chat.completions.create(
                        model="gpt-4o",
                        messages=messages,
                        max_tokens=5000,
                        temperature=0,
                        n=1
                    )
                    choices = getattr(response, "choices", [])
                    if not choices:
                        error_msg = f"样本ID {sample_id} 第{step_idx}个步骤处理出错: API 没有返回 choices"
                        logging.error(error_msg)
                        print(error_msg)
                        best_rewrites.append(None)
                        log_sen_entry = {
                            "sample_id": sample_id,
                            "step_index": step_idx,
                            "original_text": input_text,
                            "error": "API 没有返回 choices"
                        }
                        logging.error(json.dumps(log_sen_entry, ensure_ascii=False))
                        print(json.dumps(log_sen_entry, ensure_ascii=False))
                        continue

                    raw_output = choices[0].message.content.strip()
                    parsed = clean_and_parse_json(raw_output)

                    if parsed is None:
                        msg = f"解析失败内容：{raw_output}"
                        logging.error(msg)
                        print(msg)
                        best_rewrites.append(input_text)
                        log_sen_entry = {
                            "sample_id": sample_id,
                            "step_index": step_idx,
                            "original_text": input_text,
                            "error": "解析JSON失败，已保留原句"
                        }
                        logging.error(json.dumps(log_sen_entry, ensure_ascii=False))
                        print(json.dumps(log_sen_entry, ensure_ascii=False))
                        continue

                    rewrites = parsed.get("Rewritten Sentences", [])

                    if not rewrites:
                        error_msg = f"样本ID {sample_id} 第{step_idx}个步骤处理出错: No rewrites returned for: {input_text}"
                        logging.error(error_msg)
                        print(error_msg)
                        best_rewrites.append(None)
                        log_sen_entry = {
                            "sample_id": sample_id,
                            "step_index": step_idx,
                            "original_text": input_text,
                            "error": f"No rewrites returned for: {input_text}"
                        }
                        logging.error(json.dumps(log_sen_entry, ensure_ascii=False))
                        print(json.dumps(log_sen_entry, ensure_ascii=False))
                        continue

                    original_tokens = count_gpt_tokens(input_text)

                    exact_matches = []
                    close_matches = []
                    for r in rewrites:
                        rewritten_tokens = count_gpt_tokens(r)
                        token_diff = abs(rewritten_tokens - original_tokens)
                        if token_diff == 0:
                            exact_matches.append(r)
                        elif token_diff <= 10:
                            close_matches.append(r)

                    if exact_matches:
                        selected = random.choice(exact_matches)
                        log_message = {
                            "sample_id": sample_id,
                            "step_index": step_idx,
                            "message": "完全一致的tokens",
                            "original_text": input_text,
                            "selected_rewrite": selected,
                            "original_token_count": original_tokens,
                            "selected_token_count": count_gpt_tokens(selected)
                        }
                        logging.info(json.dumps(log_message, ensure_ascii=False))
                        print(json.dumps(log_message, ensure_ascii=False))
                    elif close_matches:
                        selected = random.choice(close_matches)
                        log_message = {
                            "sample_id": sample_id,
                            "step_index": step_idx,
                            "message": "tokens差别在10以内",
                            "original_text": input_text,
                            "selected_rewrite": selected,
                            "original_token_count": original_tokens,
                            "selected_token_count": count_gpt_tokens(selected)
                        }
                        logging.info(json.dumps(log_message, ensure_ascii=False))
                        print(json.dumps(log_message, ensure_ascii=False))
                    else:
                        selected = min(rewrites, key=lambda x: abs(count_gpt_tokens(x) - original_tokens))
                        selected_tokens = count_gpt_tokens(selected)
                        token_difference = abs(selected_tokens - original_tokens)
                        log_message = {
                            "sample_id": sample_id,
                            "step_index": step_idx,
                            "message": f"tokens差别在10以上，原始token数: {original_tokens}，改写token数: {selected_tokens}，差值: {token_difference}",
                            "original_text": input_text,
                            "selected_rewrite": selected,
                            "token_difference": token_difference,
                            "original_token_count": original_tokens,
                            "selected_token_count": selected_tokens
                        }
                        logging.info(json.dumps(log_message, ensure_ascii=False))
                        print(json.dumps(log_message, ensure_ascii=False))

                    best_rewrites.append(selected)
                except Exception as e:
                    error_msg = f"样本ID {sample_id} 第{step_idx}个步骤处理出错: {e}"
                    logging.error(error_msg)
                    print(error_msg)
                    best_rewrites.append(None)
                    log_sen_entry = {
                        "sample_id": sample_id,
                        "step_index": step_idx,
                        "original_text": input_text,
                        "error": str(e)
                    }
                    logging.error(json.dumps(log_sen_entry, ensure_ascii=False))
                    print(json.dumps(log_sen_entry, ensure_ascii=False))
            else:
                logging.warning(f"样本ID {sample_id} 第{step_idx}步不是有效数组，跳过")        
        filtered_rewrites.append(best_rewrites)

    df['student_solution_structure'] = filtered_rewrites

    total_steps = sum(len(r['process_evaluation']) for _, r in df.iterrows())
    success_count = sum(1 for sample in filtered_rewrites for r in sample if r is not None)
    logging.info(f"✅ 总处理步骤数: {total_steps}, 成功生成同义改写数: {success_count}")
    print(f"✅ 总处理步骤数: {total_steps}, 成功生成同义改写数: {success_count}")

    return df

def main():
    parser = argparse.ArgumentParser(description="结构改写批量处理")
    parser.add_argument('--input', '-i', type=str, required=True, help='输入parquet文件路径')
    parser.add_argument('--output', '-o', type=str, default=None, help='输出parquet文件路径（可选，默认自动生成）')
    parser.add_argument('--api_key', type=str, default=None, help='OpenAI API Key（可选）')
    parser.add_argument('--api_base', type=str, default=None, help='OpenAI API Base URL（可选）')
    parser.add_argument('--log_dir', type=str, default="log", help='日志目录（默认log）')
    args = parser.parse_args()

    input_path = args.input
    output_path = args.output
    api_key = args.api_key
    api_base = args.api_base
    log_dir = args.log_dir

    if not os.path.exists(input_path):
        print(f"输入文件不存在: {input_path}")
        sys.exit(1)

    df = pd.read_parquet(input_path)
    updated_df = call_gpt_api(df, api_key=api_key, api_base=api_base, log_dir=log_dir)

    if output_path is None:
        base, ext = os.path.splitext(os.path.basename(input_path))
        adv_dir = "your_output_dir"
        os.makedirs(adv_dir, exist_ok=True)
        output_path = os.path.join(adv_dir, f"{base}_after_structure{ext}")

    updated_df.to_parquet(output_path)
    print(f"已保存到: {output_path}")

if __name__ == "__main__":
    main()
