import os
import json
from langchain.chat_models import ChatOpenAI
from langchain import PromptTemplate
from utils import extract_json


class Validate(object):
    def __init__(self,llm):
        self.validata_prompt = """Scenario: The teacher is tutoring the student on a problem that has been broken down into multiple independent steps. Each step contains a target question. You are the student status detection assistant, analyzing the student's responses to the last round.

            Task description:
            1. Input content:
            - **Original question**: the full context of the question.
            - **Original question analysis**: the analysis of the original question.
            - **Current decomposition step**: the problem to be solved in the current step.
            - **Analysis of current decomposition step**: Analysis of the current decomposition step
            - **Conversation history**: A record of the historical interactions between the student and the teacher.
            
            2. Output content: The output is a JSON-formatted response containing the following fields:
            - type: Indicates the **response category** corresponding to the student's last response, and selects one from the following categories (the output category label is not output, and the specific content is not output):
                - **Category 1**: Student did not understand the teacher's explanation: The teacher explained the knowledge or provided methods/strategies, but the student could not understand and asked questions about the teacher's explanation
                - **Category 2**: Student gave an incorrect answer: The teacher asked the student a question in the last round, and the student gave an incorrect answer
                - **Category 3**: Student has a poor grasp of the knowledge and cannot answer the teacher's last round of questions/Student said they did not know a certain piece of knowledge/Student did not understand the teacher's question
                - **Category 4**: Student answers correctly: The teacher asks the student a question in the last round, and the student's answer is completely correct
                - **Category 5**: Other, other situations that do not fall into the above categories
            -analysis (considered for category 4, otherwise the value is an empty string): Please output the content of the **current decomposition step** first, and then judge according to the student's response in the last round: whether the student's answer can fully and accurately answer the question in the **current decomposition step**. If so, please output “sufficiently answered”, otherwise output “not sufficiently answered”.
            -is_complete (considered if the category is 4, otherwise an empty string): if the student's response has fully answered the **current decomposition step**, output “1”; otherwise output “0”.
            **Output format**:
            The output should be a Markdown code snippet in the following format, including the surrounding tags “```json” and “```”
            For example:
            {{
            “type": ‘Response category label (1/2/3/4/5)’,
            “analysis“: ‘Analysis for category 4’,
            ”is_complete": ”0/1”
            }}
            
            **Original title**: {problem}
            **Original title analysis**: {analy}
            **Current decomposition step**: {ques}
            **Current decomposition step analysis**: {ans}
            **Conversation history**:
            #
            {history}
            #
            
            Your output:"""
        
        self.decide_prompt = """Scenario: You are an excellent high school teacher and are tutoring a student through a conversation.
        Input:
        **Question:
        ##
        {problem}
        ##
        **Answer:
        ##
        {analysis}
        ##
        **Answer:
        ##
        {answer}
        ##
        **Conversation history:
        ##
        {history}
        ##
            
        Based on the student's last message and the reference answer, please determine whether the student has clearly arrived at the correct answer in the last round and whether it is the same as the reference answer (the answer is correct). If so, please enter 1 in the type field and reply to the student's last message in the res field. If not, please enter 0 in the type field and leave the res field empty.
        **Output format**:
            The output should be a Markdown code snippet in the following format, including the surrounding tags “```json” and “```”
            For example:
            {{
            "type": '0/1',
            "res": ""
            }}
        
        Your output:"""

        self.llm = llm

    

    def forward(self, histories, problems, current_steps, answers, analysiss):
        final_prompts = []
        for history, problem, current_step, analysis in zip(histories, problems, current_steps, analysiss):
            (ques,ans)=current_step
            validator_prompt = PromptTemplate.from_template(self.validata_prompt).partial(problem=problem, ques=ques, ans=ans, history=history, analy = analysis).format()
            final_prompts.append(validator_prompt)
        
        tmp = self.llm.batch(final_prompts)
        responses = extract_json(tmp)

        decide_indices = []
        decide_data = []
        for i, res in enumerate(responses):
            if 'is_end' not in res[0]:
                res[0]['is_end'] = '0'
            if "1" in str(res[0]['is_complete']):
                decide_indices.append(i)
                decide_data.append((
                    histories[i],
                    problems[i],
                    analysiss[i],
                    answers[i]
                ))
        if decide_data:
            decide_histories, decide_problems, decide_analysiss, decide_answers = zip(*decide_data)
            decide_responses = self.forward_decide(
                decide_histories,
                decide_problems,
                decide_analysiss,
                decide_answers
            )
            
            # Update original responses with decide results
            for idx, decide_response in zip(decide_indices, decide_responses):
                if "1" in str(decide_response[0]['type']):
                    responses[idx][0]['is_end'] = '1'
                    responses[idx][0]['res'] = decide_response[0]['res']
        
        return responses
    
    
    def forward_decide(self, histories, problems,analysiss, answers):
        final_prompts = []
        for history, problem, answer, analysis in zip(histories, problems, answers, analysiss):
            dec_prompt = PromptTemplate.from_template(self.decide_prompt).partial(problem=problem, answer=answer, history=history, analysis=analysis).format()
            final_prompts.append(dec_prompt)
        
        tmp = self.llm.batch(final_prompts)
        responses = extract_json(tmp)
        return responses