import random
import json
import os
import ast
import re
from teacher_v3_5 import Teacher
from student_4 import Student
from validate_v2_3 import Validate
from typing import List
from stu_simulation import Judgements
from imaging_teacher import Teacher_image
from langchain_community.chat_models.tongyi import ChatTongyi
from langchain_community.chat_models import ChatOpenAI
import os
import argparse


# 设置随机种子
random.seed(66)


def save_checkpoint(output_file, batch_index, round_num, example_batch, step_pointers, batch_completion_flags, 
                   previous_validator_res, previous_predict_lists, previous_strategy_lists,
                   teacher_predict_states, student_see_states, history_full, round_for_step):
    """保存当前状态到检查点文件"""
    checkpoint = {
        'batch_index': batch_index,
        'round_num': round_num,
        'example_batch': example_batch,
        'step_pointers': step_pointers,
        'batch_completion_flags': batch_completion_flags,
        'previous_validator_res': previous_validator_res,
        'previous_predict_lists': previous_predict_lists,
        'previous_strategy_lists': previous_strategy_lists,
        'teacher_predict_states': teacher_predict_states,
        'student_see_states': student_see_states,
        'history_full': history_full,
        'round_for_step': round_for_step 
    }
    with open(f'{output_file}_batch_{batch_index}_round_{round_num-1}checkpoint.json', 'w', encoding='utf-8') as f:
        json.dump(checkpoint, f, ensure_ascii=False, indent=4)

def load_checkpoint(output_file, batch_index, tmp_round):
    # checkpoint_file = f'{output_file}_batch_{batch_index}_checkpoint.json'
    checkpoint_file = f'{output_file}_batch_{batch_index}_round_{tmp_round}checkpoint.json'
    """从检查点文件加载状态"""
    try:
        with open(checkpoint_file, 'r', encoding='utf-8') as f:
            return json.load(f)
    except FileNotFoundError:
        return None
         
def split_steps(text):
    steps = re.split(r'\s*step\s*\d+:\s*', text, flags=re.IGNORECASE)

    steps = [step.strip() for step in steps if step.strip()]
    steps_split = []
    for step in steps:
        if '？' in step:
            # 找到最后一个问号的索引
            split_index = step.rfind('？')
            # 分割步骤，保留问号在前面部分
            before_question = step[:split_index + 1].strip()  # 包含问号
            after_question = step[split_index + 1:].strip()  # 问号后面的部分
            steps_split.append((before_question, after_question))
    
    return steps_split


def extract_result(output_string):
    start_idx = output_string.find('[')
    end_idx = output_string.find(']')
    
    list_str = output_string[start_idx:end_idx + 1]
    
    result_list = eval(list_str)
    
    for i in range(len(result_list)):
        # 转为小写后检查是否为 'yes'
        if "yes" in str(result_list[i]).strip().lower() :
            result_list[i] = "can be solved independently"
        else:
            result_list[i] = "lack of understanding"
    
    return result_list

def extract_json(message) -> List[dict]:
    """Extracts JSON content from a string where JSON is embedded between `json` tags or is plain JSON.

    Parameters:
        message (str): The text containing the JSON content.

    Returns:
        list: A list of extracted JSON objects.
    """
    # Define the regular expression pattern to match JSON blocks with ```json tags
    pattern = r"```json\s*(.*?)\s*```"

    text = message.content  # Assuming message.content contains the string

    # First, try to find JSON blocks with ```json tags
    matches = re.findall(pattern, text, re.DOTALL)

    # If we find matches with the ```json block, return parsed JSON
    if matches:
        try:
            return [ast.literal_eval(match.strip()) for match in matches]
        except json.JSONDecodeError:
            raise ValueError(f"Failed to parse JSON from the content wrapped with ```json: {text}")
    
    # If no matches, try parsing plain JSON (without the ```json block)
    try:
        return [ast.literal_eval(text.strip())]
    except (ValueError, SyntaxError):
        raise ValueError(f"无法解析纯 JSON: {text}")

def load_data(input_file, output_file, batch_size=10, indexa=0, tmp_round=0, use_slow_thinking=True,k_forward=2,k_round_begins=1,k_top=2):
    with open(input_file, 'r', encoding='utf-8') as infile:
        data = json.load(infile)
    
    checkpoint = load_checkpoint(output_file, indexa, tmp_round)
    if checkpoint:
        # 从检查点恢复状态
        batch_index = checkpoint['batch_index']
        example_batch = checkpoint['example_batch']
        step_pointers = checkpoint['step_pointers']
        batch_completion_flags = checkpoint['batch_completion_flags']
        previous_validator_res = checkpoint['previous_validator_res']
        previous_predict_lists = checkpoint['previous_predict_lists']
        previous_strategy_lists = checkpoint['previous_strategy_lists']
        teacher_predict_states = checkpoint['teacher_predict_states']
        student_see_states = checkpoint['student_see_states']
        history_full = checkpoint['history_full']
        round_num = checkpoint['round_num']
        round_for_step = checkpoint['round_for_step']
        is_first_round = False
        print(f"从检查点恢复，继续第 {batch_index} 批次的第 {round_num} 轮")

        steps = [split_steps(example['steps']) for example in example_batch]
        questions = [example['question'] for example in example_batch]
        answers = [example['answer'] for example in example_batch]
        analy = [example['analysis'] for example in example_batch]
        # predict_steps = [extract_result(example['predict_steps']) for example in example_batch]
    else:
        # 初始化默认变量
        batch_index = indexa
        round_num = 0
        example_batch = []
        step_pointers = []
        batch_completion_flags = []
        previous_validator_res = []
        previous_predict_lists = []
        previous_strategy_lists = []
        teacher_predict_states = []
        student_see_states = []
        history_full = []
        round_for_step = []
        is_first_round = True

    tmp_rround = 0
    # 分批处理数据
    for i in range(batch_index, len(data), batch_size):
        current_batch_index = i
        if checkpoint and checkpoint['batch_index'] == current_batch_index:
            # 使用已恢复的状态
            pass
        else:
            # 初始化当前批次的变量
            example_batch = data[i:i + batch_size]
            questions, steps, states, teacher_predict_states, student_see_states, answers, analy = [], [], [], [], [], [], []
            step_pointers = [0] * len(example_batch)
            batch_completion_flags = [False] * len(example_batch)
            previous_step_pointers = [0] * len(example_batch)
            is_first_round = True
            previous_validator_res = [None] * len(example_batch)
            previous_predict_lists = [None] * len(example_batch)
            previous_strategy_lists = [None] * len(example_batch)
            round_for_step = [0] * len(example_batch)
            history_full = ['' for _ in range(len(example_batch))]
            round_num = 0
            
            for example in example_batch:
                answers.append(example['answer'])
                analy.append(example['analysis'])
                questions.append(example['question'])
                steps.append(split_steps(example['steps']))

                states.append(example['state'])
                student_see_states.append(example['state'])
                teacher_predict_states.append(example['predict'])
                example['tutor'] = []
        
        try:
            max_rounds = 19
            
            while not all(batch_completion_flags) and round_num < max_rounds:

                active_indexes = [i for i, done in enumerate(batch_completion_flags) if not done]
                current_steps = [steps[i][step_pointers[i]] for i in active_indexes]

                if active_indexes:
                    branches_info_list = [{}] * len(active_indexes) 
                    if is_first_round:
                        predict_lists = teacher.forward_student_step_predict([questions[i] for i in active_indexes],current_steps, [teacher_predict_states[i] for i in active_indexes])
                        strategy_lists = teacher.generate_strategy_ini([questions[i] for i in active_indexes], current_steps, predict_lists)
                        final_response_list = teacher.generate_final_response([questions[i] for i in active_indexes], current_steps, strategy_lists, [history_full[i] for i in active_indexes], predict_lists)
                        is_first_round = False
                    else:
                        group1_idx = []  # round > 2 且上一轮不是回答正确
                        group2_idx = []  # round <= 2

                        for active_idx in active_indexes:
                            if use_slow_thinking:
                                if (round_for_step[active_idx] > k_round_begins and previous_validator_res[active_idx] != "学生回答正确" and previous_validator_res[active_idx] != "学生回答错误"):
                                    group1_idx.append(active_idx)
                                else:
                                    group2_idx.append(active_idx)
                            else:
                                group2_idx.append(active_idx)

                        # 初始化结果列表
                        predict_lists = [[{'predict': '', 'mastered_knowledge': '', 'missing_knowledge': ''}] for _ in range(len(active_indexes))]
                        strategy_lists = [None] * len(active_indexes)
                        final_response_list = [None] * len(active_indexes)
                        

                        #第一批
                        if group1_idx:
                            group1_strategy, group1_response, group1_branches = teacher_image.teacher_play(k_top, [questions[i] for i in group1_idx], [steps[i][step_pointers[i]] for i in group1_idx], histories=[history_full[i] for i in group1_idx], predict_states = [teacher_predict_states[i][0]['state_tree'] for i in group1_idx], previous_vali = [previous_validator_res[i] for i in group1_idx],iterations=k_forward)
                            for idx, pos in enumerate(active_indexes):
                                if pos in group1_idx:
                                    strategy_lists[idx] = group1_strategy[group1_idx.index(pos)]
                                    final_response_list[idx] = group1_response[group1_idx.index(pos)] # group1不需要predict_lists
                                    branches_info_list[idx] = group1_branches[group1_idx.index(pos)] 
                        if group2_idx:
                            group2_predict = teacher.forward_student_step_predict([questions[i] for i in group2_idx],[steps[i][step_pointers[i]] for i in group2_idx], [teacher_predict_states[i][0]['state_tree'] for i in group2_idx])

                            group2_strategy = teacher.generate_strategy([questions[i] for i in group2_idx], [steps[i][step_pointers[i]] for i in group2_idx], group2_predict, [history_full[i] for i in group2_idx], [previous_strategy_lists[i] for i in group2_idx], [previous_validator_res[i] for i in group2_idx])

                            group2_response = teacher.generate_final_response([questions[i] for i in group2_idx], [steps[i][step_pointers[i]] for i in group2_idx], group2_strategy, [history_full[i] for i in group2_idx], group2_predict)

                            for idx, pos in enumerate(active_indexes):
                                if pos in group2_idx:
                                    predict_lists[idx] = group2_predict[group2_idx.index(pos)]
                                    strategy_lists[idx] = group2_strategy[group2_idx.index(pos)]
                                    final_response_list[idx] = group2_response[group2_idx.index(pos)]

                    # 更新历史记录，并保存每轮的教师对话
                    for idx, response, current_step, predict_list, strategy_list, branches_info in zip(active_indexes, final_response_list, current_steps,predict_lists, strategy_lists, branches_info_list):
                        history_full[idx] += f" Teacher: {response[0]['response']}|EOM|"
                        (step_question, step_analysis) = current_step
                        example_batch[idx]['tutor'].append({
                            'round': round_num,
                            'current_steps':step_question,
                            'student_predict': predict_list[0]['predict'],
                            'student_mastered_knowledge': predict_list[0]['mastered_knowledge'],
                            'student_missing_knowledge': predict_list[0]['missing_knowledge'],
                            'teacher_stra': strategy_list[0]['strategy'],
                            'teacher_goal': strategy_list[0]['purpose'],
                            'teacher_res': response[0]['response'],
                            'stu_res': None,
                            'vali_res': None,
                            'step_pointer': step_pointers[idx],
                            'curret_step' :current_step, 
                            'branches_info': branches_info
                        })

                    tmp2, tmp1, tmp3 = judgements.forward([history_full[i] for i in active_indexes], [student_see_states[i] for i in active_indexes])

                    for idx, (t1, t2) in zip(active_indexes, zip(tmp1, tmp2)):
                        student_see_states[idx] = t2

                    student_reses = student.forward([questions[i] for i in active_indexes], [student_see_states[i] for i in active_indexes], [history_full[i] for i in active_indexes],tmp1, [steps[i][step_pointers[i]] for i in active_indexes])

                    for idx, response, state, need_update in zip(active_indexes, student_reses, tmp1, tmp3):
                        history_full[idx] += f"Student: {response[0]['response']}|EOM|"
                        example_batch[idx]['tutor'][-1]['stu_res'] = f"<state:{state}>  {response[0]['response']}"

                    validator_res = validator.forward([history_full[i] for i in active_indexes], [questions[i] for i in active_indexes], current_steps, [answers[i] for i in active_indexes], [analy[i] for i in active_indexes])

                    for idx, res in zip(active_indexes, validator_res):
                        res_type = str(res[0]['type'])  # Convert to string first
                        if "1" in res_type:
                            previous_validator_res[idx]="The students don't understand what the teacher is explaining."
                        if "2" in res_type:
                            previous_validator_res[idx]="The student answered incorrectly."
                        if "3" in res_type:
                            previous_validator_res[idx]="The student's knowledge is incomplete"
                        if "4" in res_type:
                            previous_validator_res[idx]="The student answered correctly."
                        if "5" in res_type:
                            previous_validator_res[idx]="Other."
                    
                    
                    teacher_predict_student_states = teacher.generate_new_state([questions[i] for i in active_indexes],[teacher_predict_states[i] for i in active_indexes],current_steps, [history_full[i] for i in active_indexes], [previous_validator_res[i] for i in active_indexes], strategy_lists )


                    # 保存当前轮次的变量
                    for idx, predict_list, strategy_list, teacher_predict_student_state in zip(active_indexes, predict_lists, strategy_lists,  teacher_predict_student_states):
                        previous_predict_lists[idx] = predict_list
                        previous_strategy_lists[idx] = strategy_list
                        teacher_predict_states[idx] = teacher_predict_student_state

                # 根据 validator_res 更新步进指针
                for i, res, teacher_predict_student_state, tt2,tt3 in zip(active_indexes, validator_res, teacher_predict_student_states,tmp2,tmp3):
                    if res[0]['is_complete']=="1" and not batch_completion_flags[i]:
                        if "1" in str(res[0]['is_end']):
                            batch_completion_flags[i] = True
                            history_full[i] += f" Teacher: {res[0]['res']}|EOM|"
                            example_batch[i]['tutor'][-1]['is_end']="yes"
                        elif step_pointers[i] < len(steps[i]) - 1:
                            step_pointers[i] += 1
                            round_for_step[i] = 0
                        else:
                            batch_completion_flags[i] = True
                            history_full[i] += f"Teacher: Congratulations, you have answered correctly.！|EOM|"
                            example_batch[i]['tutor'][-1]['is_end']="yes"
                    else:
                        round_for_step[i] += 1
                    tp_need = teacher_predict_student_state[0]['need_update']
                    tp_state = teacher_predict_student_state[0]['state_tree']
                    example_batch[i]['tutor'][-1]['vali_res'] = f"<predict_state: {res[0]['type']}> {res[0]['analysis']} {res[0]['is_complete']}"
                    example_batch[i]['tutor'][-1]['teacher_predict_state'] = f"state_is_changed: {tp_need} state:{tp_state}"
                    example_batch[i]['tutor'][-1]['student_predict_state'] = f"state_is_changed: {tt3} state:{tt2}"

                # 每轮结束后保存检查点
                round_num += 1
                save_checkpoint(output_file, current_batch_index, round_num, example_batch, step_pointers, batch_completion_flags,
                              previous_validator_res, previous_predict_lists, previous_strategy_lists,
                              teacher_predict_states, student_see_states, history_full, round_for_step)
                output_filename = f'{output_file}_batch_{current_batch_index}_round_{round_num-1}.json'
                with open(output_filename, 'w', encoding='utf-8') as outfile:
                    json.dump(example_batch, outfile, ensure_ascii=False, indent=4)
                
                
                
        except Exception as e:
            print(f"发生异常: {str(e)}")
            raise



if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Process some parameters.")
    parser.add_argument('--input_dir', type=str, default='', help="Input directory")
    parser.add_argument('--json_files', type=str, default=['math.json'], help="List of JSON files")
    parser.add_argument('--k_forward', type=int, default=2, help="k_forward parameter")
    parser.add_argument('--use_slow_thinking', type=bool, default=True, help="Whether to use MCTS")
    parser.add_argument('--k_top', type=int, default=2, help="k_top parameter")

    args = parser.parse_args()

    input_dir = args.input_dir
    json_files = args.json_files


    llm1 = ChatTongyi(model='qwen2.5-32b-instruct')
    llm2 = ChatOpenAI(
        model_name="gpt-4o",
        temperature=0,
        openai_api_base="https://api.xty.app/v1",
        max_tokens=4096
    )
    teacher = Teacher(llm1)
    teacher_image = Teacher_image(llm1)
    validator = Validate(llm2)
    student = Student(llm2)
    judgements = Judgements(llm2)

    for json_file in json_files:
        file_path = os.path.join(input_dir, json_file)
        base_name = os.path.splitext(json_file)[0]
        output_file_path = f"/{base_name}/"
        os.makedirs(os.path.dirname(output_file_path), exist_ok=True)

        load_data(file_path, output_file_path, indexa=0, tmp_round=0, use_slow_thinking=args.use_slow_thinking, 
            k_forward=args.k_forward, k_top=args.k_top)
        
        


