# 跟踪学生状态的教师
from utils import extract_json
import faiss
import numpy as np
import random
import time

import re
from langchain import PromptTemplate

class Teacher_image(object):
    def __init__(self,llm):
        #生成topn策略,vali_state
        self.teacher_generate_topn_strategy = """Role: You are a teacher guiding a student to answer the following question. This question has been broken down into multiple steps, and each step may require multiple rounds of dialogue to fully answer. The student is stuck at the current step, and your task is to select the **top {n}** most suitable tutoring strategies based on the student's last response.

        ### Question:
        ##
        {question}
        ##

        ### Current Breakdown Step:
        ##
        {ques}
        ##

        ### Breakdown Step Explanation:
        ##
        {ans}
        ##

        ### Conversation History:
        ##
        {history}
        ##

        ### Student's Last Response State:
        ##
        {vali_state}
        ##

        ### Available Tutoring Strategies:
        a) **Explain a concept**: Explain a concept or method to the student without directly providing the answer to the question.  
        b) **Suggest a strategy**: Guide the student to use a specific method or strategy to help them think through the problem.  
        c) **Confirm the student's answer**: Affirm the student's correct answer and encourage them to continue to the next step.  
        d) **Correct the student's answer**: Gently point out the student's mistake and guide them to rethink the problem to self-correct.  
        e) **Ask an open-ended question**: Encourage the student to think deeply or help them complete the next step of the solution.  
        f) **Ask a closed-ended question**: Test the student's understanding and memory, focusing on specific knowledge points.  
        g) **Simplify the question**: Simplify the question to help the student focus on the core issue.  
        h) **Break down the question**: Split the question into smaller parts to help the student solve it step by step.  
        i) **Provide an analogy or example**: Use real-life examples or analogies to help the student understand abstract concepts more easily.  
        j) **Other**.

        ### Task:
        Please consider the above information and select the **top {n}** tutoring strategies that you think are most suitable.

        ### Output Format:
        Output the strategy list in the corresponding content of "strategy" in the following markdown format, including the "```json" markers before and after the code:
        ```json
        {"strategy": "['Strategy1 Number','Strategy2 Number',...]" }

            """
        #根据策略生成具体响应
        self.teacher_response = """Role: You are a teacher guiding a student to answer the following question. The question has been broken down into multiple steps, and each breakdown step may require multiple rounds of dialogue to fully answer. The student is stuck at the current breakdown step, and your task is to respond to the student's last message based on the historical records and the given strategy.

        ### Question:
        ##
        {question}
        ##

        ### Current Breakdown Step:
        ##
        {ques}
        ##

        ### Breakdown Step Explanation:
        ##
        {ans}
        ##

        ### Conversation History:
        ##
        {history}
        ##

        ### Current Strategy:
        ##
        {strategy}
        ##

        ### Current Goal:
        The student is having difficulty with the **current breakdown step**, and your task is to respond to the student's last message based on the current strategy.

        ### Notes:
        1. **Avoid over-intervention**: When explaining concepts or providing strategies, do not directly give the final answer to the current step.
        2. **One question at a time**: Each response should only raise one question (if necessary), and not multiple or compound questions.

        ### Output Format:
        The response should be in the following markdown format, including the "\\\json" markers before and after the code:
        ```json
        { "response": "Your response" }
"""


        #模拟学生响应
        self.play_student = """Role: You are a student, and you are struggling with a question. Your teacher is providing guidance.

        ### Question:
        ##
        {question}
        ##

        ### Your Knowledge State:
        ##
        {state}
        ##

        In the knowledge state, the number of '#' symbols represents the hierarchy of knowledge points, and the number following each knowledge point represents your level of mastery: (1) indicates mastered, (0) indicates not mastered.

        ### Conversation History:
        ##
        {history}
        ##

        ### Task Instructions:
        Based on the teacher's last round of guidance, please consider the following two aspects to judge whether you can understand the teacher's explanation:

        1. **Changes in the conversation history**: Review the previous guidance and judge whether the teacher's explanation this time is clearer or better explained compared to before. If this round of explanation has significantly improved or made things clearer, you may consider that you can understand.
        2. **Impact on your knowledge state**: Based on your current knowledge state, judge whether you have enough foundation to understand this round of explanation. If the explanation involves knowledge points you haven't mastered yet, or the teaching method does not align with your cognitive state, you may find it difficult to understand.

        Based on the conversation history and knowledge state, make a judgment and choose an appropriate response:

        ### Response Status Options:
        1. **You do not understand a certain knowledge point, process, etc., explained by the teacher**:
        - If there are parts of the teacher's explanation that you are unsure of or do not understand, please ask relevant questions about what you don’t understand.
        
        2. **You attempt to answer the teacher's question but give an incorrect answer**:
        - If you try to answer the teacher's question but give the wrong answer, you can generate an incorrect answer to reflect your lack of understanding of the knowledge point.
        
        3. **Your knowledge level is poor, and you cannot answer the teacher's question**:
        - If you feel that you do not have enough knowledge to answer the question, you can admit that you cannot answer and express your confusion or difficulties.
        
        4. **You can correctly answer the teacher's question**:
        - If you can accurately answer the teacher's question, you can simply provide the correct answer.
        
        5. **Other**:
        - If your situation does not fit any of the above statuses, you can select "Other" and briefly explain your situation.

        ### Output Format:
        The output should be in the following markdown format, including the "\\\json" markers before and after the code:
        ```json
        { "can_understand": "yes/no",
            "type": "your selected response status",
            "response": "your response" }
        """
        #教师判断学生状态（是否足以回答当前步骤问题）
        self.validate_student = """Scene Setup: The teacher is guiding the student through a problem that has been broken down into multiple independent steps. Each step contains a target question. You are the Student Status Detection Assistant, responsible for analyzing the student's last response.

        ### Task Description:
        1. Input:
        - **Original Question**: The full background of the question.
        - **Target Question for the Current Step**: The final question that needs to be solved in this step. (The current step is not equivalent to the current dialogue round; one step may require multiple rounds of dialogue to complete.)
        - **Analysis of Learning Objectives**: The explanation related to the question.
        - **Dialogue History**: Contains the historical interaction between the student and the teacher.

        2. Output:
        The output should be in JSON format and include the following fields:
        - **type**: The category corresponding to the student's last response. Choose one from the following categories (only output the category number, do not output the actual content):
            - **Category 1**: The student does not understand the teacher's explanation: The teacher explained something or provided methods/strategies, but the student cannot understand. The student asks a question about the explanation.
            - **Category 2**: The student answers incorrectly: The teacher asked a question in the last round, and the student answered incorrectly.
            - **Category 3**: The student's knowledge level is poor and cannot answer the teacher's question/the student doesn't know a certain knowledge point/the student does not understand the teacher's question.
            - **Category 4**: The student answers correctly.
            - **Category 5**: Other: Any situation that does not fall into the above categories.
        - **analysis** (only for Category 4, otherwise leave empty): First, output the **"Target Question for the Current Step"**. Then, based on the history, output the student's answer to the target question. Finally, judge whether the student's answer fully and accurately addresses the target question. If the student's answer fully answers the target question, output "sufficiently answered", otherwise output "insufficiently answered".
        - **is_complete** (only for Category 4, otherwise leave empty): If the student's answer fully answers the target question, output "1", otherwise output "0".

        ### Output Format:
        The output should be in the following markdown format, including the "\```json" markers before and after the code:
        ```json
        {
        "type": "Response Category Number (1/2/3/4/5)",
        "analysis": "Analysis for Category 4",
        "is_complete": "0/1"
        }
"""

        self.teacher_final_response = """You are a teacher, guiding a student to solve the following problem. The problem has been broken down into multiple steps, and each step may require multiple rounds of dialogue to fully solve. The student is currently stuck on the current step. Your task is to respond to the student's last message based on the historical records and the given strategy.

        **Problem**:
        ##
        {question}
        ##

        **Current Step**:
        ##
        {ques}
        ##

        **Analysis of Current Step**:
        ##
        {ans}
        ##

        **Dialogue History**:
        ##
        {history}
        ##

        **Current Strategy**:
        ##
        {strategy}
        ##

        **Reference Response**:
        ##
        {aim_res}
        ##

        **Current Objective**:
        The reference response is an initial version of the response to the student's last message. Your task is: combining the current teaching strategy and the reference response, respond to the student's last message. If you feel that the reference response is good enough, you can simply choose the reference response.

        ### Notes:
        1. **Avoid over-intervention**: When explaining concepts or providing strategies, do not directly give the final answer to the current step.
        2. **Single question per response**: Each response should only ask one question (if needed), do not ask multiple or compound questions.

        ### Output Format:
        The output should be in the following markdown format, including the "\```json" markers before and after the code:

        ```json
        {
        "response": "Your response"
        }
"""
        


        
        self.strategy_map = {
            "a": "a. Explain a concept: Explain a certain concept or method to the student, without directly providing the answer to the problem.",
            "b": "b. Suggest a strategy: Guide the student to adopt a certain method or strategy to advance their thinking.",
            "c": "c. Confirm the student's answer: Affirm the student's correct answer and encourage them to continue thinking deeply for the next step.",
            "d": "d. Correct the student's answer: Gently point out the student's mistake and guide them to reassess the problem for self-correction.",
            "e": "e. Ask an open-ended question: Trigger deeper thinking or encourage the student to complete the next step of the answer.",
            "f": "f. Ask a closed-ended question: Test the student's understanding and memory, focusing on specific knowledge points.",
            "g": "g. Simplify the problem: Simplify the problem into a more manageable form to help the student grasp the core.",
            "h": "h. Break the problem down: Divide the problem into smaller questions to help the student solve it step by step.",
            "i": "i. Provide analogies or examples: Use real-life examples or analogies to help the student better understand abstract concepts.",
            "j": "j. Other."
        }

        self.llm = llm


    def func1(self, questions, steps, histories, n, vali_states):
        final_prompts = []
        for question, step, history, vali_state in zip(questions, steps, histories, vali_states):
            (step_question, step_analysis) = step
            p1 = PromptTemplate.from_template(self.teacher_generate_topn_strategy).partial(
                question=question, 
                ques=step_question, 
                ans=step_analysis, 
                history=history, 
                n=n,
                vali_state=vali_state
            ).format()
            final_prompts.append(p1)
        tmp = self.llm.batch(final_prompts)
        strategies = extract_json(tmp)
        
        # Map strategy letters to full descriptions
        mapped_strategies = []
        for strategy_group in strategies:
            mapped_group = []
            strr = strategy_group[0]['strategy']
            if not isinstance(strr, list):
                # 如果是列表，将每个元素转换为字符串后连接
                letters = [char for char in strr if char.isalpha()]
                strr = letters
            

            for str_letters in strr:
                # 使用正则表达式精确匹配策略字母（单独出现的a-i）
                match = re.search(r'\b[a-j]\b', str_letters, re.IGNORECASE)
                if match:
                    strategy_key = match.group().lower()
                    mapped_content = self.strategy_map.get(strategy_key, "unknown strategy")
                else:
                    mapped_content = "unknown strategy"
                mapped_group.append(mapped_content)
            mapped_strategies.append(mapped_group)
        return mapped_strategies
    

    def func2(self, questions, steps, strategies, histories):
        finall_prompts = []
        prompt_sources = []
        
        for i, (question, step, strategy_group, history) in enumerate(zip(questions, steps, strategies, histories)):
            (step_question, step_analysis) = step
            
            # 直接使用策略描述，不再需要eval转换
            for j, strategy_desc in enumerate(strategy_group):
                tmp1 = PromptTemplate.from_template(self.teacher_response).partial(
                    question=question,
                    ques=step_question,
                    ans=step_analysis,
                    history=history,
                    strategy=strategy_desc  # 直接使用策略描述
                ).format()
                
                finall_prompts.append(tmp1)
                prompt_sources.append((i, j))
        
        # 批量处理所有prompts
        max_retries = 5  # 最大重试次数
        retry_count = 0
        success = False
        
        while not success and retry_count < max_retries:
            try:
                tmp2 = self.llm.batch(finall_prompts)
                success = True  # 如果成功执行，设置success为True
            except Exception as e:
                retry_count += 1
                if retry_count >= max_retries:
                    raise  # 如果达到最大重试次数，抛出异常
                time.sleep(1)  # 等待1秒后重试
        teacher_res = extract_json(tmp2)
        
        # 将结果按原始问题分组
        grouped_results = []
        current_group = []
        last_index = -1  # 初始化为-1，确保第一次循环时不会匹配
        
        for i, (source_idx, _) in enumerate(prompt_sources):
            if source_idx != last_index and last_index != -1:  # 添加last_index != -1的判断
                grouped_results.append(current_group)
                current_group = []
            current_group.append(teacher_res[i][0]['response'])
            last_index = source_idx
        
        # 添加最后一个group
        if current_group:
            grouped_results.append(current_group)
        
        return grouped_results, prompt_sources
    

    #模拟学生响应
    def func3(self, questions, histories, predict_states):
        final_prompts = []
        for question, state, history in zip(questions, predict_states, histories):
            p1 = PromptTemplate.from_template(self.play_student).partial(
                question=question, 
                state=state,
                history=history, 
            ).format()
            final_prompts.append(p1)
        tmp = self.llm.batch(final_prompts)
        res = extract_json(tmp)
        
        # Map strategy letters to full descriptions
        return res

    #验证者，验证学生状态，返回：type, analysis, is_complete
    def func4(self, questions, steps, histories):
        final_prompts = []
        for question, step, history in zip(questions, steps, histories):
            (step_question, step_analysis) = step
            p1 = PromptTemplate.from_template(self.validate_student).partial(
                question=question, 
                ques = step_question,
                ans = step_analysis,
                history=history, 
            ).format()
            final_prompts.append(p1)
        tmp = self.llm.batch(final_prompts)
        res = extract_json(tmp)
        
        for r in res:
            # 先将类型转换为字符串
            type_str = str(r[0]['type'])
            if "1" in type_str:
                r[0]['type'] = "The student does not understand the teacher's explanation"
            if "2" in type_str:
                r[0]['type'] = "The student answers incorrectly"
            if "3" in type_str:
                r[0]['type'] = "The student's knowledge level is poor"
            if "4" in type_str:
                r[0]['type'] = "The student answers correctly"
            if "5" in type_str:
                r[0]['type'] = "other"
        # Map strategy letters to full descriptions
        return res


    def func5(self, questions, steps, strategies, histories, states, aim_responses):
        finall_prompts = []
        
        for i, (question, step, strategy, history, state, aim_response) in enumerate(zip(questions, steps, strategies, histories, states, aim_responses)):
            (step_question, step_analysis) = step
            
            tmp2 = PromptTemplate.from_template(self.teacher_final_response).partial(
                    question=question,
                    ques=step_question,
                    ans=step_analysis,
                    history=history,
                    strategy=strategy,
                    aim_res=aim_response  # 直接使用策略描述
                ).format()
                
            finall_prompts.append(tmp2)
        
        # 批量处理所有prompts
        tmp2 = self.llm.batch(finall_prompts)
        teacher_res = extract_json(tmp2)
        return teacher_res

   
    

    def calculate_strategy_scores(self, all_branches):
        """计算每个原始分支的第一步策略得分并选择最佳策略"""
        strategy_scores = {}  # 存储每个原始分支的策略得分
        best_data = {}

        # 遍历所有分支
        for branch in all_branches:
            # 获取原始分支索引
            original_index = branch['path'][0]
            # 获取第一步策略
            first_strategy = branch['strategy_path'][0]
            
            # 初始化得分记录
            if original_index not in strategy_scores:
                strategy_scores[original_index] = {}
            if first_strategy not in strategy_scores[original_index]:
                strategy_scores[original_index][first_strategy] = {
                    'scores': [],
                    'response': None  # 存储对应的教师响应
                }
            
            # 计算当前分支得分
            if branch['complete']:
                # 成功分支得分公式：1 - 0.4 * (策略次数 - 1)
                score = max(1 - 0.4 * (len(branch['strategy_path']) - 1), 0)
            else:
                # 失败分支得分为0
                score = 0
                
            strategy_scores[original_index][first_strategy]['scores'].append(score)
            
            # 如果是第一次遇到这个策略，提取对应的教师响应
            if strategy_scores[original_index][first_strategy]['response'] is None:
                # 从历史记录中提取教师响应
                # 从历史记录中提取教师响应
                history_parts = branch['history'].split('|EOM|')
                # 找到对应轮次的教师响应
                target_round = len(branch['strategy_path']) - 1  # 策略路径长度减1就是目标轮次
                teacher_count = 0
                # 从后向前遍历历史记录
                for part in reversed(history_parts):
                    if part.strip().startswith('Teacher:'):
                        if teacher_count == target_round:
                            strategy_scores[original_index][first_strategy]['response'] = part.strip()
                            break
                        teacher_count += 1
        
        # 计算每个原始分支的最佳策略
        best_strategies = {}
        for original_index, strategies in strategy_scores.items():
            # 计算每个策略的总分
            strategy_totals = {
                strategy: sum(data['scores']) 
                for strategy, data in strategies.items()
            }
            
            # 找出最高分
            max_score = max(strategy_totals.values())
            
            # 找出所有达到最高分的策略
            best_candidates = [
                (strategy, strategies[strategy]['response']) 
                for strategy, score in strategy_totals.items() 
                if score == max_score
            ]
            
            # 随机选择一个最佳策略及其响应
            best_strategy, best_response = random.choice(best_candidates)
            best_strategies[original_index] = {
                'strategy': best_strategy,
                'response': best_response
            }
            #并且这个历史记录应该是最原始的历史记录，不附加新的东西。
            matching_branch = None
            for branch in all_branches:
                if branch['path'][0] == original_index:
                    matching_branch = branch
                    break
            best_data[original_index] = {
                'strategy': best_strategy,
                'response': best_response,
                'question': matching_branch['question'],  # 使用匹配分支的问题
                'step': matching_branch['step'],  # 使用匹配分支的步骤
                'history': matching_branch['history'],  # 使用匹配分支的历史
                'predict_state': matching_branch['predict_state']  # 使用匹配分支的预测状态
            }
        
        return best_data




#没有传入学生上一轮的状态
    def teacher_play(self, topn, questions, steps, histories, predict_states, previous_vali, iterations=2):
        # 初始化
        completed_branches = []  # 存储已完成的分支
        active_branches = [{
            'history': history,
            'path': [i],  # 原始分支索引
            'strategy_path': [],  # 策略路径
            'complete': False,
            'question': questions[i],
            'step': steps[i], 
            'predict_state': predict_states[i],
            'previous_vali' : previous_vali[i]
        } for i, history in enumerate(histories)]
        
        for _ in range(iterations):
            if not active_branches:
                break
                
            # 准备批量处理的数据
            active_questions = [branch['question'] for branch in active_branches]
            active_steps = [branch['step'] for branch in active_branches]
            active_histories = [branch['history'] for branch in active_branches]
            active_previous_vali = [branch['previous_vali'] for branch in active_branches]
            # 批量生成策略，把之前的学生状态加进去
            strategies = self.func1(active_questions, active_steps, active_histories, topn, active_previous_vali)
            grouped_results, _ = self.func2(active_questions, active_steps, strategies, active_histories)
            
            # 准备所有新分支的数据
            new_branches_data = [
                {
                    'history': branch['history'] + f" Teacher: {grouped_results[i][j]}|EOM|",
                    'question': branch['question'],
                    'step': branch['step'],
                    'predict_state': branch['predict_state'],
                    'parent_branch': branch,
                    'strategy_index': strategies[i][j]
                }
                for i, branch in enumerate(active_branches)
                for j in range(topn)
            ]

            # 直接解包数据
            new_histories = [data['history'] for data in new_branches_data]
            new_questions = [data['question'] for data in new_branches_data]
            new_steps = [data['step'] for data in new_branches_data]
            new_predict_states = [data['predict_state'] for data in new_branches_data]
            
            # 批量模拟学生响应
            stu_responses = self.func3(new_questions, new_histories, new_predict_states)
            validation_results = self.func4(new_questions, new_steps, [
                history + f" Student: {response[0]['response']}|EOM|"
                for history, response in zip(new_histories, stu_responses)
            ])
            
            # 创建新分支
            new_active_branches = []
            for idx, data in enumerate(new_branches_data):
                parent_branch = data['parent_branch']
                is_complete = validation_results[idx][0]['is_complete'] == "1"
                
                new_branch = {
                    'history': new_histories[idx] + f" Student: {stu_responses[idx]}|EOM|",
                    'path': parent_branch['path'].copy(),
                    'strategy_path': parent_branch['strategy_path']+[data['strategy_index']],  # 完整记录策略轨迹
                    'complete': is_complete,
                    'question': parent_branch['question'],
                    'step': parent_branch['step'],
                    'predict_state': parent_branch['predict_state'],
                    'previous_vali' : validation_results[idx][0]['type']
                }
                
                if is_complete:
                    completed_branches.append(new_branch)
                else:
                    new_active_branches.append(new_branch)
            
            # 更新活跃分支
            active_branches = new_active_branches
        
        # 合并所有分支
        all_branches = completed_branches + active_branches

        #判断原来分支分出的所有分支中。设置奖励。1-0.4*失败次数。最低为0.
        best_res = self.calculate_strategy_scores(all_branches)
        sorted_indices = sorted(best_res.keys())

        # 按顺序提取数据
        questions = [best_res[idx]['question'] for idx in sorted_indices]
        steps = [best_res[idx]['step'] for idx in sorted_indices]
        strategies = [best_res[idx]['strategy'] for idx in sorted_indices]
        ori_histories = [histories[idx] for idx in sorted_indices]
        states = [best_res[idx]['predict_state'] for idx in sorted_indices]
        aim_responses = [best_res[idx]['response'] for idx in sorted_indices]

        
        # 调用func5生成最终响应
        final_responses = self.func5(questions, steps, strategies, ori_histories, states, aim_responses)

        # 整理返回结果
        final_responses_list = []
        branches_info = []
        strategy_lists = []
        
        # 按原始顺序整理
        for idx in range(len(questions)):
            # 获取最终响应
            final_response = final_responses[idx][0]['response']
            final_responses_list.append([{'response': final_response}])
            
            strategy_path = best_res[idx]['strategy']
            strategy_lists.append([{
                'strategy': strategy_path,
                'purpose': ""  # 根据需求添加具体目标
            }])

            # 收集该path对应的所有分支信息
            branch_info = []
            for branch in all_branches:
                if branch['path'][0] == idx:
                    branch_info.append({
                        'strategy_path': branch['strategy_path'],
                        'complete': branch['complete'],
                        'history': branch['history']
                    })
            branches_info.append(branch_info)

        return strategy_lists, final_responses_list, branches_info