"""
为 intern VL 模型生成回答的非异步函数

参数:
prompt (str): 提示模板
image_path (str): 图片路径
level (str): 难度级别
item (dict): 包含问题数据的字典
retry_attempts (int): 重试次数

返回:
dict: 包含生成回答的结果
"""
import yaml
import json
import re
import time
import os
import sys
import dotenv
import random
dotenv.load_dotenv()

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from tools.lvm_pool import intern_VL

def generate_single_answer_intern_vl(prompt, image_path, level, item, retry_attempts=5):
    
    aspect = item['aspect']
    image_path = item['image_path']
    prompt_text = item['prompt']
    objective_question = item['objective_question'] + '\n' + json.dumps(item['choice'])
    objective_reference_answer = item['objective_reference_answer']
    formatted_prompt = prompt.format(aspect=aspect,question=objective_question)
    
    def extract_choice(text):
        try:
            pattern = r'\[\[(.*?)\]\]'
            matches = re.findall(pattern, text)
            return matches[0]
        except:
            return None
    
    for attempt in range(retry_attempts):
        try:
            # 获取回答文本
            objective_answer = intern_VL(formatted_prompt,image_path)
            
            # 提取选择
            objective_choice = extract_choice(objective_answer)
            
            # 计算得分
            objective_score = 1 if objective_choice == objective_reference_answer else 0
            
            # 返回结果
            return {
                "aspect": aspect,
                "prompt": prompt_text,
                "image_path": image_path,
                'level': level,
                'model': 'intern_VL',
                'objective_question': objective_question,
                'objective_answer': objective_answer,
                'need_elements': item.get('need_elements', False),
                'objective_choice': objective_choice if objective_choice else "未能提取选项",
                'objective_score': objective_score,
                'objective_reference_answer': objective_reference_answer
            }
            
        except Exception as e:
            print(f"InternVL Attempt {attempt + 1} failed: {e}")
            if attempt == retry_attempts - 1:
                return None
            time.sleep(3)  # 在重试之前等待

def generate_answers_intern_vl(config, user_input, level='medium'):
    """
    使用 intern VL 模型生成回答的主函数
    
    参数:
    config (dict): 配置信息
    user_input (str): 用户输入，用于确定文件路径
    level (str): 难度级别，默认为 'medium'
    
    返回:
    None: 结果保存到文件
    """
    import json
    import os
    
    # 创建必要的目录
    image_prompt_folder = f'./document/{user_input}'
    os.makedirs(image_prompt_folder, exist_ok=True)
    
    # 加载问题文件
    questions_file = f'{image_prompt_folder}/questions/{level}_questions_modified.json'
    with open(questions_file, 'r', encoding='utf-8') as f:
        questions_data = json.load(f)
    # 获取回答提示模板
    objective_answer_prompt = config.get('objective_answer_prompt')
    
    questions_data_sample = random.sample(questions_data, 150)

    # 处理每个问题
    results = []
    total = len(questions_data_sample)
    model_scores = {'intern_VL': {'objective': []}}
    
    print(f"开始处理 {total} 个问题...")
    
    for i, item in enumerate(questions_data_sample):
        print(f"处理问题 {i+1}/{total}...")
        # 生成回答
        result = generate_single_answer_intern_vl(
            objective_answer_prompt, 
            None,  # 这里不需要使用提示模板中的图片路径
            level, 
            item
        )
        
        if result:
            results.append(result)
            model_scores['intern_VL']['objective'].append(result['objective_score'])
    
    # 保存结果
    answers_dir = f'{image_prompt_folder}/answers'
    os.makedirs(answers_dir, exist_ok=True)
    save_file = f'{answers_dir}/{level}_answers_intern_vl.json'
    
    with open(save_file, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=4)
    
    print(f"保存了 {len(results)} 个回答到 {save_file}")
    
    # 计算并保存分数
    scores_dir = f'{image_prompt_folder}/scores'
    os.makedirs(scores_dir, exist_ok=True)
    scores_file = f'{scores_dir}/{level}_scores_intern_vl.json'
    
    avg_scores = {}
    for model_name, scores in model_scores.items():
        avg_scores[model_name] = {
            'average_objective_score': sum(scores['objective']) / len(scores['objective']) if scores['objective'] else 0,
            'objective_num': len(scores['objective']),
        }
        print(f'InternVL 平均客观分数 ({level}): {avg_scores[model_name]["average_objective_score"]:.2f}')
    
    with open(scores_file, 'w', encoding='utf-8') as f:
        json.dump(avg_scores, f, indent=4)
    
    print(f"分数已保存到 {scores_file}")

# 使用示例
if __name__ == "__main__":
    
    # 加载配置
    with open('./config/config.yaml', 'r', encoding='utf-8') as f:
        config = yaml.safe_load(f)
    
    # 生成回答
    for level in ['easy', 'medium', 'hard']:
        generate_answers_intern_vl(config, 'atmospheric_understanding', level)