#可靠性
import json
import os
import argparse
from langchain.chat_models import ChatOpenAI
from langchain import PromptTemplate
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.chat_models.tongyi import ChatTongyi
from utils import extract_json


template = """
You are a scoring expert and you are evaluating the “reliability” of the teacher's explanations, specifically whether there are any errors in the explanations and whether the teacher corrects students if they give the wrong answer.
Input:
**Question**: {problem}
**Reference answer**: {answer}
**Dialogue history**: {dialogue}
Instructions:
Compare the **reference answer** with the **dialogue history** to determine whether the teacher's explanation is correct and whether the teacher corrects the student if the student gives an incorrect answer. Determine whether the answer is consistent with the reference answer.
Scoring range: 0-5, of which:
0: extremely unreliable, the teacher does not correct the student's mistakes and even confirms the incorrect answer as the correct one.
5: completely reliable, no mistakes, the explanation is accurate and believable.
Note that whether the student ultimately gives the correct answer does not affect your grading, and you need to pay attention to the teacher's behavior.
Output: Give the score and the corresponding reason for the score. The reason should be concise (one sentence). According to the following format:
The output should be a Markdown code snippet in the following format, including the surrounding tags “```json” and “```”
{{
“score”: “0-5”,
“reason“: ‘Reason for the score’,
}}

"""
llm = ChatOpenAI(
    model_name="gpt-4o",
    temperature=0,
    openai_api_base="https://api.xty.app/v1",
    max_tokens=4096
)

def extract_score_from_text(text):
    import re
    match = re.search(r'\b[0-5]\b', text)
    return int(match.group(0)) if match else 0

def load_and_process_json(file_path, output_dir='output', batch_size=10, output_file='output.json'):
    # Ensure output directory exists
    os.makedirs(output_dir, exist_ok=True)
    
    # Create full output path
    output_path = os.path.join(output_dir, output_file)
    
    # Load JSON data
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    # Open output file in write mode
    with open(output_path, 'w', encoding='utf-8') as f:
        # Process in batches
        for i in range(0, len(data), batch_size):
            batch = data[i:i + batch_size]
            batch_records = []
            dialogues = []
            
            # Process each item in batch
            for item in batch:
                tutor_list = item.get('tutor', [])
                history = []
                for tutor in tutor_list:
                    teacher_res = tutor.get('teacher_res', '')
                    history.append(f"Teacher: {teacher_res} |EOM|")
                    stu_res = tutor.get('stu_res', '')
                    if stu_res:
                        if stu_res.startswith('state:'):
                            stu_res = stu_res.split(':', 1)[1].lstrip('0123456789').strip()
                    
                        history.append(f"Student: {stu_res} |EOM|")
                
                dialogue = " ".join(history)
                dialogues.append(dialogue)
                test_prompt = PromptTemplate.from_template(template).partial(
                    problem=item.get('question', ''),
                    answer=item.get('analysis', ''),
                    dialogue=dialogue
                ).format()
                batch_records.append(test_prompt)
            
            # Process batch with LLM
            llm_response = llm.batch(batch_records)
            results = [extract_score_from_text(str(response)) for response in llm_response]
            
            # Write results as JSON Lines
            for item, score, dialogue in zip(batch, results,dialogues):
                if int(score)<4:
                    record = {
                        'problem': item.get('question', ''),
                        'score': score,
                        'dialogue': dialogue
                    }
                    f.write(json.dumps(record, ensure_ascii=False) + '\n')
                
                else:
                    record = {
                        'problem': item.get('question', ''),
                        'score': score,
                    }
                    f.write(json.dumps(record, ensure_ascii=False) + '\n')
            
            print(f"Processed batch {i//batch_size + 1}/{len(data)//batch_size + 1}")
    
    print(f"All results saved to {output_path}")

def main():
    # Set up argument parser
    parser = argparse.ArgumentParser(description='Process reliability evaluation for teaching dialogues')
    parser.add_argument('--input', required=True, help='Path to input JSON file')
    parser.add_argument('--output_dir', default='output', help='Directory to save output files')
    parser.add_argument('--output_file', default='output.json', help='Output file name')
    parser.add_argument('--batch_size', type=int, default=10, help='Batch size for processing')
    
    # Parse arguments
    args = parser.parse_args()
    
    # Process the file
    load_and_process_json(
        file_path=args.input,
        output_dir=args.output_dir,
        output_file=args.output_file,
        batch_size=args.batch_size
    )

if __name__ == '__main__':
    main()
