import copy
import json
import os
import random
import re
import sys
import pandas as pd
from mathruler.grader import extract_boxed_content, grade_answer
import traceback
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from call_python_tool import extract_code_blocks, multi_process_batch_call_python_code, reassemble_response,reassemble_response_real_interpreter_output
from tool.qwen3_32b_judge_vllm import judge_two_texts_same
from collections import defaultdict

def end_code_close_normalize(response: str) -> str:

    def _replacer(match: re.Match) -> str:
        code = match.group(1)
        code = re.sub(r'\s*```$', '', code.rstrip())
        code = code.rstrip()
        code += "\n```\n"
        return f"<code>{code}</code>"

    return re.sub(r"<code>(.*?)</code>", _replacer, response, flags=re.DOTALL)

def has_code_block(text):
    pattern = r"<code>.*?</code>"
    return re.search(pattern, text, re.DOTALL) is not None

def count_code_blocks(text: str) -> int:

    pattern = r"<code>(.*?)</code>"
    matches = re.findall(pattern, text, flags=re.S | re.I)
    return len(matches)

def count_interpreter_blocks(text: str) -> int:

    pattern = r"<interpreter>(.*?)</interpreter>"
    matches = re.findall(pattern, text, flags=re.S | re.I)
    return len(matches)

def count_code_blocks_interpreter(text: str) -> int:

    pattern = r"<code>(.*?)</code>(.*?)<interpreter>(.*?)</interpreter>"
    matches = re.findall(pattern, text, flags=re.S | re.I)
    return len(matches)

def extract_solution(solution_str):
    answer_pattern = r'<answer>(.*?)</answer>'
    match = re.search(answer_pattern, solution_str, re.DOTALL)

    if match:
        return match.group(1).strip()
    return None


def all_code_blocks_are_python(text: str) -> bool:

    blocks = re.findall(r"<code>(.*?)</code>", text, flags=re.S | re.I)

    if not blocks:
        return True

    return all("```python" in block.lower() for block in blocks)


def has_unwrapped_python_block(text: str) -> bool:

    code_pat = re.compile(r"```python\n.*?```", re.DOTALL)

    for match in code_pat.finditer(text):
        start = match.start()

        last_open  = text.rfind("<code>\n", 0, start)
        last_close = text.rfind("\n</code>", 0, start)

        inside_wrapper = (last_open != -1) and (last_open > last_close)

        if not inside_wrapper:
            return True

    return False


def wrap_unwrapped_code_python_blocks(text: str) -> str:

    code_pat = re.compile(r"```python\n.*?```", re.DOTALL)
    result   = []
    last_end = 0

    for m in code_pat.finditer(text):
        start, end = m.span()

        result.append(text[last_end:start])

        last_open  = text.rfind("<code>\n", 0, start)
        last_close = text.rfind("\n</code>", 0, start)
        inside_wrapper = (last_open != -1) and (last_open > last_close)

        snippet = m.group(0)
        if inside_wrapper:
            result.append(snippet)
        else:
            result.append("<code>\n" + snippet + "\n</code>")

        last_end = end

    result.append(text[last_end:])
    return "".join(result)


def normalize_python_blocks_old(text: str) -> str:
    shell_pat = re.compile(
        r"<code>\s*(```python\n.*?\n```)\s*</code>",
        re.DOTALL | re.IGNORECASE,
    )
    text = shell_pat.sub(r"\1", text)

    py_pat = re.compile(r"```python\n.*?\n```", re.DOTALL)

    def wrap(m: re.Match) -> str:
        return f"<code>\n{m.group(0)}\n</code>"

    return py_pat.sub(wrap, text)

def normalize_python_blocks(text: str) -> str:

    shell_pat = re.compile(
        r"<code>\s*(```python\n.*?```)\s*</code>",
        re.DOTALL | re.IGNORECASE,
    )
    text = shell_pat.sub(r"\1", text)

    py_pat = re.compile(r"```python\n.*?```", re.DOTALL)

    def wrap(m: re.Match) -> str:
        return f"<code>\n{m.group(0)}\n</code>"

    return py_pat.sub(wrap, text)

import re

def code_block_line_counts(text: str, count_blank: bool = False):

    pattern = r"""
        <code>          
        \s*             
        ```python      
        [ \t]*\n       
        (.*?)         
        \n```         
        \s*          
        </code>     
    """
    blocks = re.findall(pattern, text, flags=re.S | re.I | re.X)

    def _count_lines(block: str) -> int:
        lines = block.splitlines()
        if not count_blank:
            lines = [ln for ln in lines if ln.strip()]
        return len(lines)

    blocks_lines = [_count_lines(b) for b in blocks]
    total_lines  = sum(blocks_lines)
    return blocks_lines, total_lines

def ensure_python_fenced(text: str) -> str:

    def _wrap(match: re.Match) -> str:
        inner = match.group(1).strip()
        if inner.lower().startswith("```python"):
            return match.group(0)

        inner_clean = re.sub(r"^```.*?\n", "", inner, flags=re.I | re.S)
        inner_clean = re.sub(r"```$", "", inner_clean, flags=re.S).strip()

        return f"<code>\n```python\n{inner_clean}\n```\n</code>"

    return re.sub(r"<code>(.*?)</code>", _wrap, text, flags=re.S | re.I)

def compute_score_format(solution_str):

    try:

        assistant_blocks = re.findall(r'<revised_thinking_process>\n(.*?)</revised_thinking_process>', solution_str, re.DOTALL)

        format_reward = 0.0

        # If no blocks found, return 0
        if not assistant_blocks or len(assistant_blocks) == 0:
            return 0.0

        last_assistant_block = assistant_blocks[-1]
        format_reward = 1.0
    except Exception as e:
        print(f"[DEBUG] Error in compute_score_format: {e}")
        return 0.0

    return format_reward


def math_verify_reward_function(solution_str, ground_truth):
    from math_verify import parse, verify
    ground_truth = [ground_truth] if isinstance(ground_truth, str) else ground_truth

    # 0 in case parsing cannot be completed
    try:
        math_verify_parsed = parse(solution_str, parsing_timeout=5)
    except Exception:
        return 0.0

    # 0 if parsing is problematic
    if len(math_verify_parsed) < 2:
        return 0.0

    # We perform a quick string match first
    if math_verify_parsed[1] in ground_truth:
        return 1.0

    # We now fallback to semantic verification
    for gt in ground_truth:
        try:
            if verify(
                    parse(f"\\boxed{{{gt}}}", parsing_timeout=5),
                    math_verify_parsed,
                    timeout_seconds=5,
            ):
                return 1.0
        except Exception:
            continue

    # Very unlikely to be correct after the above matches
    return 0.0

def compute_score_mathverify_judge(solution_str: str, ground_truth: str):
    if "</think>" in solution_str:
        solution_str = solution_str.split("</think>")[1]
    else:
        solution_str = solution_str

    try:
        return math_verify_reward_function(solution_str, ground_truth)
    except:
        traceback.print_exc(10)
        return False

def compute_score_answer(solution_str: str, ground_truth: str) -> float:
    import time
    assistant_blocks = re.findall(r'<revised_thinking_process>\n(.*?)</revised_thinking_process>', solution_str, re.DOTALL)
    if len(assistant_blocks) == 0:
        return 0.0
    last_assistant_block = assistant_blocks[-1]

    answer = extract_solution(last_assistant_block)

    if answer is None:
        return 0.0

    answer = extract_boxed_content(answer)
    verify_1 = 1.0 if grade_answer(answer, ground_truth) else 0.0
    return verify_1


def statistic(hendrycks_math_ins, hendrycks_math_ins_tokenizer, k=4000):

    from collections import defaultdict
    BIN_SIZE = k
    bins = defaultdict(list)

    for length, prompt in zip(hendrycks_math_ins_tokenizer, hendrycks_math_ins):
        start = (length // BIN_SIZE) * BIN_SIZE
        end = start + BIN_SIZE
        bucket = f"{start}-{end}"
        bins[bucket].append(prompt)

    for bucket in sorted(bins, key=lambda x: int(x.split('-')[0])):
        print(f"{bucket} token: {len(bins[bucket])} prompt")
    return bins

def get_extra_datas(extra_path_lists):
    all_extra_datas = []
    uids_dict = defaultdict(list)
    for extra_path in extra_path_lists:
        with open(extra_path, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                data = json.loads(line)
                uid = data['uid']
                uids_dict[uid].append(data)

    for k,v in uids_dict.items():
        #
        total_turns = v[0]['total_turns']
        if len(v) == total_turns:
            all_extra_datas.extend(v)

    return all_extra_datas

def recall_code_fail_datas(code_fail_data_lists):
    code_fail_recall_system = """The following is your Response 1 based on Instruction 1. But there was a code interpreter execution error during the process. Please do the following:
    
Please:

1. Based on the interpreter's failed execution output, identify the exact code segment that caused the error and explain the reason for the failure.
2. Immediately after the interpreter's failed output, add a transition sentence , such as: "Oops, the code above appears to be throwing an error. I need to fix this to ensure it runs successfully."
3. Correct the erroneous code to ensure it runs successfully.
4. Continue the process from where you left off in response 1, completing the remaining steps as planned.
5. Wrap the final output in `<output></output>` tags.

**Instruction 1:**
{call_dv_instruction}
**Response 1:**
<revised_thinking_process>
{dv_first_call_response}
</revised_thinking_process>
"""
    calls_datas = []
    for item in code_fail_data_lists:
        temp_item = copy.deepcopy(item)
        call_dv_instruction = temp_item['call_dv3_prompt']
        dv_first_call_response = temp_item['formal_ans_code']
        recall_dv3_ins = code_fail_recall_system.format(call_dv_instruction=call_dv_instruction, dv_first_call_response=dv_first_call_response)
        temp_item['only_question'] = temp_item['question']
        temp_item['question'] = recall_dv3_ins
        calls_datas.append(temp_item)
    return calls_datas

data_path = "/data_path/call_dv3_data_systhesis_response.jsonl"
all_datas = []


extra_path_lists = []

all_extra_datas = get_extra_datas(extra_path_lists)

all_datas.extend(all_extra_datas)
print(f"all_datas ===== {len(all_datas)}")


valid_datas = []
invalid_datas = []
invalid_dict = defaultdict(int)
invalid_datas_2 = []
invalid_code_inter = []
less_5_list = []
formal_datas = []

dv3_data_dict = defaultdict(list)
for item in all_datas:

    question = item['instruction']
    uid = item['uid']
    turn = item['turn']
    text_solution_part = item['question'].split("<original_thinking_process>\n")[-1].split("\n</original_thinking_process>")[0]
    if len(item['answer']) == 0 or len(item['answer'][0]) < 500:
        invalid_datas_2.append(item)
        dv3_response = text_solution_part
    else:
        dv3_response = item['answer'][0]

    is_code = has_code_block(dv3_response)
    format_score = compute_score_format(dv3_response)
    temp_dict = {}
    temp_dict['uid'] = uid
    temp_dict['turn'] = turn
    temp_dict['question'] = question
    temp_dict['dv3_response'] = dv3_response
    temp_dict['format_score'] = format_score
    temp_dict['call_dv3_prompt'] = item['question']
    formal_ans = dv3_response.split("</revised_thinking_process>")[0].split("<revised_thinking_process>")[-1]
    formal_ans = normalize_python_blocks(formal_ans)
    is_all_code = all_code_blocks_are_python(formal_ans)
    if is_all_code == False:
        formal_ans = ensure_python_fenced(formal_ans)
    has_no_code = has_unwrapped_python_block(formal_ans)
    if has_no_code == True:
        formal_ans = wrap_unwrapped_code_python_blocks(formal_ans)

    formal_ans = end_code_close_normalize(formal_ans)
    #
    formal_ans = formal_ans.replace("<original_thinking_process>", "")
    formal_ans = formal_ans.replace("</original_thinking_process>", "")
    formal_ans = formal_ans.replace("<code>\n```python\n\n```\n</code>python", "")
    formal_ans = formal_ans.replace("<code>\n<code>", "<code>")
    formal_ans = formal_ans.replace("<code><code>", "<code>")
    formal_ans = formal_ans.replace("</code>\n</code>", "</code>")
    formal_ans = formal_ans.replace("</code></code>", "</code>")
    formal_ans = formal_ans.replace("<code>", "\n\nOkay, below is the Python code to implement it.\n<code>")
    formal_ans = formal_ans.replace("</interpreter>", "</interpreter>\n\nOkay, the above is the result of using Python code.\n")

    if format_score == 1.0:
        temp_dict['formal_ans'] = formal_ans
    else:
        temp_dict['formal_ans'] = text_solution_part
    temp_dict['text_solution_part'] = text_solution_part
    code_count = count_code_blocks(temp_dict['formal_ans'])
    code_interpreter_count = count_code_blocks_interpreter(temp_dict['formal_ans'])
    count_interpreter = count_interpreter_blocks(temp_dict['formal_ans'])
    blocks_lines, total_lines = code_block_line_counts(temp_dict['formal_ans'])
    temp_dict['blocks_lines'] = blocks_lines
    temp_dict['total_lines'] = total_lines
    if code_count != count_interpreter:
        temp_dict['formal_ans'] = text_solution_part
        invalid_code_inter.append(temp_dict)
    if len(blocks_lines) != 0 and code_count == 1 and min(blocks_lines) < 5:
        temp_dict['formal_ans'] = text_solution_part
        less_5_list.append(temp_dict)
        #
    if format_score == 1.0:
        valid_datas.append(temp_dict)
    else:
        invalid_datas.append(temp_dict)
        invalid_dict[uid] += 1

    formal_datas.append(temp_dict)
    dv3_data_dict[question].append(temp_dict)

all_code_response_lists = []
all_code_response_item_lists = []
text_item_lists = []
for item in formal_datas:
    formal_ans = item['formal_ans']
    code_dict_info = extract_code_blocks(formal_ans)
    if len(code_dict_info) == 0:
        text_item_lists.append(item)
    else:
        for kk in code_dict_info:
            all_code_response_lists.append(kk['code_format'])
            temp_it = copy.deepcopy(item)
            temp_it['extra_info_python'] = kk
            all_code_response_item_lists.append(temp_it)

ip_lists="node1_ip,node2_ip,node3_ip....." #sandbox server ip lists

all_tool_responses, all_new_active_masks, all_tool_successes, all_tool_images = multi_process_batch_call_python_code(all_code_response_lists, ip_lists)


for code_item_dict, flag in zip(all_code_response_item_lists, all_tool_successes):
    if len(flag) == 0:
        code_item_dict['extra_info_python']['code_flag'] = False
    else:
        if flag[0] == False:
            code_item_dict['extra_info_python']['code_flag'] = False
        else:
            code_item_dict['extra_info_python']['code_flag'] = True

for code_item_dict, python_res in zip(all_code_response_item_lists, all_tool_responses):
    code_item_dict['extra_info_python']['interpreter_output_real_result'] = f"\n{python_res}\n"




is_dv3_python_judge = True

if is_dv3_python_judge:
    model = "Qwen/Qwen3-32B"

    dv3_results_lists = [item['extra_info_python']['interpreter_output'] for item in all_code_response_item_lists]

    dv3_python_consistency_datas = {"dv3_results_lists": dv3_results_lists, "all_tool_responses": all_tool_responses}

    tensor_parallel_size = 4
    

    save_vllm_results_path = "/root/data_path/model_python_response_consistency/all_results.json"

    if os.path.exists(save_vllm_results_path):
        with open(save_vllm_results_path, 'r', encoding='utf-8') as f:
            python_exe_results_dict = json.load(f)

        res_completions = python_exe_results_dict['res_completions']
        judge_res_list = python_exe_results_dict['judge_res_list']
    else:
        res_completions, judge_res_list = judge_two_texts_same(model=model, code_results_lists=all_tool_responses,
                                                               dv3_results_lists=dv3_results_lists, start=0,
                                                               end=10000000,
                                                               batch_size=1000000,
                                                               tensor_parallel_size=tensor_parallel_size)
        python_exe_results_dict = {"res_completions": res_completions, "judge_res_list": judge_res_list}
        with open(save_vllm_results_path, 'w', encoding='utf-8') as f:
            json.dump(python_exe_results_dict, f, ensure_ascii=False)

    mm = [(p,v,k) for p,v,k in zip(dv3_results_lists, all_tool_responses,judge_res_list) if k==False]

    for code_item_dict, state, python_res in zip(all_code_response_item_lists, judge_res_list, all_tool_responses):
        if state == True:
            code_item_dict['extra_info_python']['interpreter_output'] = f"\n{python_res}\n"
            code_item_dict['extra_info_python']['dv3_python_consistency'] = state
        else:

            code_item_dict['extra_info_python']['interpreter_output'] = f"\n{python_res}\n"
            code_item_dict['extra_info_python']['dv3_python_consistency'] = state

    

else:
    for code_item_dict, python_res in zip(all_code_response_item_lists, all_tool_responses):
        code_item_dict['extra_info_python']['interpreter_output'] = f"\n{python_res}\n"
        code_item_dict['extra_info_python']['dv3_python_consistency'] = True



gather_extra_python_info_dict = defaultdict(list)
gather_lists_dict = {}

for ga_item in all_code_response_item_lists:
    key = f"question-{ga_item['question']}-turn-{ga_item['turn']}"
    extra_info_python = ga_item['extra_info_python']
    gather_extra_python_info_dict[key].append(extra_info_python)
    gather_lists_dict[key] = ga_item

new_dv3_data_dict = defaultdict(list)
for item_text in text_item_lists:
    question = item_text['question']
    new_dv3_data_dict[question].append(item_text)

Fail_code_items_list = []
dv3_python_no_consistency = []
for key in gather_extra_python_info_dict.keys():
    mt_item = gather_lists_dict[key]
    extra_info_python_lists = gather_extra_python_info_dict[key]
    sorted_extra_info_python_lists = sorted(extra_info_python_lists, key=lambda d: d["block_idx"])
    code_flags = []
    dv3_python_consistency_lists = []
    for m_flag in sorted_extra_info_python_lists:
        code_flags.append(m_flag['code_flag'])
        dv3_python_consistency_lists.append(m_flag['dv3_python_consistency'])

    formal_ans_ori = mt_item['formal_ans']
    text_solution_part = mt_item['text_solution_part']
    new_formal_code = reassemble_response(original_text=formal_ans_ori, processed_blocks=sorted_extra_info_python_lists)
    new_formal_code_real_interpreter_output = reassemble_response_real_interpreter_output(original_text=formal_ans_ori, processed_blocks=sorted_extra_info_python_lists)
    if False in code_flags:
        mt_item['formal_ans'] = text_solution_part
        mt_item['formal_ans_code'] = new_formal_code_real_interpreter_output
        mt_item['is_code_fail'] = False
        Fail_code_items_list.append(mt_item)
    else:
        mt_item['formal_ans'] = new_formal_code
        mt_item['formal_ans_code'] = new_formal_code_real_interpreter_output
        mt_item['is_code_fail'] = True

    if False in dv3_python_consistency_lists:
        mt_item['formal_ans'] = text_solution_part
        dv3_python_no_consistency.append(mt_item)

    question = mt_item['question']
    new_dv3_data_dict[question].append(mt_item)


recall_code_fail_datas_lists = recall_code_fail_datas(Fail_code_items_list)
recall_code_fail_save_path = f'/root/data_path/recall_fail_code.json'

save_temp_code_results_train_data_item = f'/root/data_path/save_temp_code_results_train_data.json'


with open(recall_code_fail_save_path, 'w', encoding='utf-8') as f:
    for item in recall_code_fail_datas_lists:
        f.write(json.dumps(item, ensure_ascii=False) + '\n')

with open(save_temp_code_results_train_data_item, "w", encoding="utf-8") as f:
    json.dump(new_dv3_data_dict, f, ensure_ascii=False, indent=4)


recongnize_data = []

for k,v in new_dv3_data_dict.items():
    question = v[0]['question']
    uid = v[0]['uid']
    sorted_v = sorted(v, key=lambda d: d["turn"])
    tool_response_lists = []
    raw_text_response_lists = []
    for itm in sorted_v:
        formal_ans = itm['formal_ans']
        text_solution_part = itm['text_solution_part']
        tool_response_lists.append(formal_ans)
        raw_text_response_lists.append(text_solution_part)
    tool_response_str = "\n".join(tool_response_lists)
    raw_text_response_str = "\n".join(raw_text_response_lists)
    temp = {}
    temp['uid'] = uid
    temp['question'] = question
    temp['tool_response_str'] = tool_response_str
    temp['raw_text_response_str'] = raw_text_response_str
    recongnize_data.append(temp)


text_path = "/root/data/AM-Thinking-DS-Distill-0528.json"
all_text_datas_dict = {}
with open(text_path, 'r') as file:
    for idx, line in enumerate(file):
        item = json.loads(line)
        ques = item['instruction']
        all_text_datas_dict[ques] = item

train_datas = []
valid_code = []
invalid_code = []
invalid_code_mm = []
less_all_5 = []
code_count_lists = []
from collections import defaultdict
bins = defaultdict(list)
group_size = 1
prompt_template = "<think>\n{thought}\n</think>\n<answer>\n{solution}\n</answer>"
for item in recongnize_data:
    ques = item['question']
    formal_ans = item['tool_response_str']
    formal_ans = formal_ans.replace("<code_snippet>", "").replace("</code_snippet>", "")
    formal_ans = formal_ans.replace("<code>\n```python\n\n```\n</code>python", "")
    raw_text_response_str = item['raw_text_response_str']
    formal_ans = normalize_python_blocks(formal_ans)
    is_all_code = all_code_blocks_are_python(formal_ans)
    if is_all_code == False:
        formal_ans = ensure_python_fenced(formal_ans)

    has_no_code = has_unwrapped_python_block(formal_ans)
    if has_no_code == True:
        formal_ans = wrap_unwrapped_code_python_blocks(formal_ans)
    is_all_code_up = all_code_blocks_are_python(formal_ans)
    has_no_code_up = has_unwrapped_python_block(formal_ans)

    code_count = count_code_blocks(formal_ans)
    code_interpreter_count = count_code_blocks_interpreter(formal_ans)
    count_interpreter = count_interpreter_blocks(formal_ans)
    blocks_lines, total_lines = code_block_line_counts(formal_ans)

    if code_count != count_interpreter:
        invalid_code.append(item)

    if len(blocks_lines) != 0 and code_count == 1 and min(blocks_lines) < 5:
        less_all_5.append(item)

    r1_response = all_text_datas_dict[ques]['output']
    summary_answer = r1_response.split('</think>')[-1].replace("<answer>", "").replace("</answer>", "").strip()
    tool_r1_solution = prompt_template.format(thought=formal_ans.strip(), solution=summary_answer)

    messages = [{'content': ques, 'role': 'user'},
                {'content': tool_r1_solution, 'role': 'assistant'}]
    extra_info = {}
    save_items = {
        "messages": messages,
        'extra_info': extra_info
    }
    train_datas.append(save_items)
    is_code = has_code_block(formal_ans)

    if code_count == 0:
        mean_code_line = 0
    else:
        mean_code_line = total_lines // code_count
    save_items["extra_info"]['code_count'] = code_count
    save_items["extra_info"]['blocks_lines'] = blocks_lines
    save_items["extra_info"]['mean_code_line'] = mean_code_line
    if code_count == 0:
        bucket = "0"
        bins[bucket].append(formal_ans)
    else:
        start = (code_count // group_size) * group_size
        end = start + group_size
        bucket = f"{start}-{end}"
        bins[bucket].append(formal_ans)

    code_count_lists.append(code_count)
    if has_no_code_up == False and is_all_code_up == True:
        valid_code.append(save_items)
    else:
        invalid_code_mm.append(save_items)

bins_sorted_keys = sorted(bins.keys())
for k in bins_sorted_keys:
    print(k,len(bins[k]))


import tiktoken
enc = tiktoken.get_encoding("cl100k_base")
questions_responses = [item["messages"][1]['content'] for item in valid_code]
all_tokens = enc.encode_batch(questions_responses, disallowed_special=())

token_lengths = [len(tokens) for tokens in all_tokens]

len_bins = statistic(questions_responses, token_lengths, k=4000)

for i, item in enumerate(valid_code):
    item['extra_info']['token_length'] = token_lengths[i]
    if item['extra_info']['code_count'] > 0:
        item['extra_info']['code_density'] = token_lengths[i] // item['extra_info']['code_count']
    else:
        item['extra_info']['code_density'] = 0

print('\n=============after filter===============\n')


train_datas = []
for item in valid_code:
    instruction = item['messages'][0]['content']
    output = item['messages'][1]['content']
    temp_dict = {"instruction": instruction, "input": "", "output": output}
    train_datas.append(temp_dict)

df_20k = pd.DataFrame(valid_code)


with open("/data/root/save_process_systhesis_data.json", "w+") as fw:
    json.dump(train_datas, fw, indent=4, ensure_ascii=False)


