import json
import yaml
import sys
import os
import aiofiles
import asyncio
import uuid
import re
import csv
from tqdm import tqdm
import aiohttp
import random
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from tqdm.asyncio import tqdm_asyncio

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from tools.tifa import generate_and_evaluate_image
from tools.diffusion_model import async_openjourney_v4, async_sdxl, async_stable_diffusion, async_stable_diffusion_3, async_flux_pro, async_flux_1_1_pro, async_dalle_3, async_stable_diffusion_3_5
from tools.lvm_pool import async_gpt4o, async_gemini_1_5_flash, async_claude_3_5_sonnet,async_claude_3_haiku, async_glm_4v_plus, async_gpt4o_mini, async_llama_3_2, async_gemini_1_5_pro, async_qwen_2_vl, async_gemma_3, async_llava, async_qwen_2_5

async_diffusion_function = {
    "dall_e_3": async_dalle_3,
    "openjourney_v4": async_openjourney_v4,
    "sdxl": async_sdxl,
    "stable_diffusion": async_stable_diffusion,
    "stable_diffusion_3": async_stable_diffusion_3,
    "stable_diffusion_3_5": async_stable_diffusion_3_5,
    "flux_pro": async_flux_pro,
    "flux_1_1_pro": async_flux_1_1_pro
}

async_lvm_function = {
    "gpt-4o": async_gpt4o,
    "gpt4o_mini": async_gpt4o_mini,
    "gemini_1_5_flash": async_gemini_1_5_flash,
    "gemini_1_5_pro": async_gemini_1_5_pro,
    "claude_3_5_sonnet": async_claude_3_5_sonnet,
    "claude_3_haiku": async_claude_3_haiku,
    "glm_4v_plus": async_glm_4v_plus,
    "qwen2_vl": async_qwen_2_vl,
    "llama_3_2": async_llama_3_2,
    # "gemma_3": async_gemma_3,
    # "phi_4": async_phi,
    # "llava-v1.6-13b": async_llava,
    # "qwen2.5_vl": async_qwen_2_5,
}

async_examiner_function = {
    "gpt-4o": async_gpt4o,
    "claude_3_5_sonnet": async_claude_3_5_sonnet,
    "gemini_1_5_pro": async_gemini_1_5_pro,
}

lvm_func_to_name = {
    "async_gpt4o": "GPT-4o",
    "async_gpt4o_mini": "GPT-4o-Mini",
    "async_gemini_1_5_flash": "Gemini-1.5-Flash",
    "async_gemini_1_5_pro": "Gemini-1.5-Pro",
    "async_claude_3_5_sonnet": "Claude-3.5-Sonnet",
    "async_claude_3_haiku": "Claude-3-Haiku",
    "async_glm_4v_plus": "GLM-4v-Plus",
    "async_llama_3_2":"Llama-3.2-90B-Vision",
    "async_qwen_2_vl":"Qwen2-VL"
}

async def load_config(config_file_path):
    async with aiofiles.open(config_file_path, 'r', encoding='utf-8') as file:
        content = await file.read()
    return yaml.safe_load(content)

async def load_json(file_path):
    async with aiofiles.open(file_path, 'r') as file:
        content = await file.read()
    return json.loads(content)

async def save_json(data, file_path):
    async with aiofiles.open(file_path, 'w') as file:
        await file.write(json.dumps(data, indent=4))

def save_plot(fig, path, dpi=300):
    fig.savefig(path, format='png', dpi=dpi)

async def download_image(session, url, save_path):
    try:
        async with session.get(url, ssl=False) as response:
            response.raise_for_status()
            content = await response.read()
            async with aiofiles.open(save_path, 'wb') as file:
                await file.write(content)
    except Exception as e:
        print(f"Failed to download image from {url}: {e}")
        raise

async def save_aspects(aspects, aspects_file_path):
    os.makedirs(os.path.dirname(aspects_file_path), exist_ok=True)
    async with aiofiles.open(aspects_file_path, 'w') as file:
        await file.write(json.dumps(aspects, indent=4))

async def generate_fine_grained_aspects(user_input):
    config_file_path = 'config/config.yaml'
    config = await load_config(config_file_path)
    aspect_prompt = config.get(f'{user_input}_prompt')
    # message = 
    aspects_response = await async_lvm_function['gpt-4o'](aspect_prompt)
    print(aspects_response)

    aspects = []
    current_aspect = None
    for line in aspects_response.split('\n'):
        line = line.strip()  # Remove leading and trailing whitespace
        if line.startswith("Fined-grained Aspect:"):
            current_aspect = {"aspect": line[len("Fined-grained Aspect:"):].strip()}
        elif line.startswith("Introduction:") and current_aspect:
            current_aspect["introduction"] = line[len("Introduction:"):].strip()
            aspects.append(current_aspect)  # Save the current aspect once complete
            current_aspect = None  # Reset for the next aspect

    aspects_file_path = f'./document/{user_input}/{user_input}_aspects.json'
    await save_aspects(aspects, aspects_file_path)

    print("Fine-grained aspects generated and saved successfully!")

async def generate_guidance(user_input):
    with open(f'./document/{user_input}/{user_input}_aspects.json', 'r') as file:
        data = json.load(file)
    config = await load_config('config/config.yaml')
    aspect_prompt = config.get('guidance_prompt')
    guidance = []
    for aspect in data:
        prompt = aspect_prompt.format(aspect=aspect['aspect'], introduction=aspect['introduction'])
        response = await async_lvm_function['gpt-4o'](prompt)
        current_content = None
        for line in response.split('\n'):
            line = line.strip()
            if line.startswith("Aspect:"):
                current_content = {"aspect": line[len('Aspect:'):].strip()}
            if line.startswith("Introduction:"):
                current_content["introduction"] = line[len('Introduction:'):].strip()
            if line.startswith("Guidance:") and current_content:
                current_content["guidance"] = line[len("Guidance:"):].strip()
                guidance.append(current_content)
                current_content = None
    with open(f'./document/{user_input}/{user_input}_guidance.json', 'w') as file:
        json.dump(guidance, file, indent=4)
    # return guidance

async def generate_prompt_with_topic_words(aspects, image_prompt_template, level, num_prompts_per_aspect):
    prompts = []
    all_topic_word_degrees = []  # List to store topic words and their degrees
    
    for aspect, introduction in aspects:
    # aspect = item['aspect']
    # introduction = item['introduction']
    # guidance = item['guidance']
        G = nx.Graph()
        used_words = set()
        degrees_over_4 = []
        degrees_over_5 = []
        degrees_over_6 = []

        for round_num in range(num_prompts_per_aspect):    
            model_name = random.choice(list(async_examiner_function.keys()))
            used_words_str = ', '.join(used_words)
            image_description = image_prompt_template.format(aspect=aspect, introduction=introduction, level=level, used_words_str=used_words_str)
            prompt_response = await async_examiner_function[model_name](image_description)

            prompt = None
            topic_word = None
            key_words = None
            retry_times = 3
            try:
                prompt_lines = prompt_response.split('\n')
            except:
                for i in range(retry_times):
                    prompt_response = await async_lvm_function['gpt-4o'](image_description)
                    prompt_lines = prompt_response.split('\n')
                    if prompt_lines:
                        break

            for line in prompt_lines:
                if line.startswith("Prompt:"):
                    prompt = line[len("Prompt:"):].strip()
                if line.startswith("Topic word:"):
                    topic_word = line[len("Topic word:"):].strip().lower()
                if line.startswith("Key word:") or line.startswith("Key words:"):
                    key_words = line[len("Key words:"):].strip().lower()
                    key_words_list = [word.strip() for word in key_words.split(',')]
                    break
            
            if prompt and topic_word and key_words:
                prompts.append({
                    "aspect": aspect,
                    "prompt": prompt,
                    "topic_word": topic_word,
                    "key_words": key_words,
                    "model": model_name
                })
                G.add_node(topic_word)
                for key_word in key_words_list:
                    G.add_node(key_word)
                    G.add_edge(topic_word, key_word)

                degree_dict = dict(G.degree())
                degrees_over_4.append(sum(deg > 4 for deg in degree_dict.values()))
                degrees_over_5.append(sum(deg > 5 for deg in degree_dict.values()))
                degrees_over_6.append(sum(deg > 6 for deg in degree_dict.values()))

                top_nodes = [node for node, degree in sorted(degree_dict.items(), key=lambda item: item[1], reverse=True)[:round_num + 1]]
                used_words.update(top_nodes)

                all_topic_word_degrees.append((topic_word, degree_dict[topic_word]))

                print(f"Round {round_num + 1} - Top node(s) selected: {top_nodes}")

    print(f"Final degrees for aspect '{aspect}': {dict(G.degree())}")

    return prompts, all_topic_word_degrees
    # return prompts

async def generate_prompts(user_input):
    config_file_path = './config/config.yaml'
    config = await load_config(config_file_path)

    for level in ['easy','medium','hard']:
        image_prompt_template = config.get('difficulty_control_image_prompt')

        aspects_file_path = f'./document/{user_input}/{user_input}_aspects.json'
        aspects_data = await load_json(aspects_file_path)


        aspects = [(aspect_data['aspect'], aspect_data['introduction']) for aspect_data in aspects_data]

        generated_prompts, topic_word_degrees = await generate_prompt_with_topic_words(aspects, image_prompt_template, level, 10)
        
        if not os.path.exists(f'./document/{user_input}/prompts'):
            os.makedirs(f'./document/{user_input}/prompts')
        save_file = f'./document/{user_input}/prompts/{level}_image_prompts.json'
        csv_file_path = f'./document/{user_input}/{level}_topic_word_degrees.csv'

        # save_data = await tqdm_asyncio.gather(*tasks)
        # save_data = [data for data in save_data if data is not None]
        await save_json(generated_prompts, save_file)

        print(f"{user_input} {level} prompts generated and saved successfully!")

async def generate_single_image_1(item, level, image_prompt_folder, retry_attempts=3):
    prompt = item['prompt']
    aspect = item['aspect']
    id = str(uuid.uuid4())
    for attempt in range(retry_attempts):
        try:
            prompt = "please generate a picture from the perspective of an observer " + prompt
            image_url = await async_diffusion_function['flux_1_1_pro'](prompt)
            print(image_url)
            async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=60)) as session:
                image_folder = f'{image_prompt_folder}/extracted_images/{level}'
                os.makedirs(image_folder, exist_ok=True)
                image_path = f'{image_folder}/{id}.png'
                await download_image(session, image_url, image_path)
                return {
                    "id": id,
                    "aspect": aspect,
                    "prompt": prompt,
                    "image_url": image_url,
                    "image_path": os.path.abspath(image_path),
                    'level': level,
                    'model': 'flux_1_1_pro'
                }
        except Exception as e:
            print(f"Attempt {attempt + 1} failed: {e}")
            if attempt == retry_attempts - 1:
                return None
            
async def generate_single_image(item, level, image_prompt_folder, retry_attempts=3):
    prompt = item['prompt']
    aspect = item['aspect']
    id = str(uuid.uuid4())
    for attempt in range(retry_attempts):
        try:
            prompt = "please generate a picture from the perspective of an observer, " + prompt + " In oil paint style."
            image_folder = f'{image_prompt_folder}/extracted_images_oil_paint/{level}'
            os.makedirs(image_folder, exist_ok=True)
            image_path = f'{image_folder}/{id}.png'
            image_url = await async_diffusion_function['flux_1_1_pro'](prompt)
            print(f"Image URL: {image_url}")
            async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=60)) as session:
                await download_image(session, image_url, image_path)
            return {
                "id": id,
                "aspect": aspect,
                "prompt": prompt,
                "image_url": image_url,
                "image_path": os.path.abspath(image_path),
                'level': level,
                'model': 'flux 1.1 pro'
            }
        except Exception as e:
            print(f"Attempt {attempt + 1} failed: {e}")
            if attempt == retry_attempts - 1:
                return None
            await asyncio.sleep(2)
            
async def generate_images(user_input):
    image_prompt_folder = f'./document/{user_input}/prompts/'
    store_folder = f'./document/{user_input}/'
    os.makedirs(image_prompt_folder, exist_ok=True)
    os.makedirs(store_folder, exist_ok=True)

    semaphores = asyncio.Semaphore(5)
    
    for level in ['easy','medium','hard']:
        image_prompt_file = f'{image_prompt_folder}/{level}_image_prompts.json'
        image_prompt_data = await load_json(image_prompt_file)
        
        image_prompt_data_sample = random.sample(image_prompt_data,60)

        tasks = []

        async def sem_task(item, level):
            async with semaphores:
                return await generate_single_image(item, level, store_folder)
            
        for item in image_prompt_data_sample:
            task = asyncio.create_task(sem_task(item, level))
            tasks.append(task)
        
        save_data = await tqdm_asyncio.gather(*tasks)
        save_data = [data for data in save_data if data is not None]
        save_file = f'{store_folder}/image_json_oil_paint/{level}_images.json'
        if not os.path.exists(f'{store_folder}/image_json_oil_paint'):
            os.makedirs(f'{store_folder}/image_json_oil_paint')
        await save_json(save_data, save_file)
        print(f'{level} photos generated and saved successfully!')

async def align_single_image(item, level, retry_attempts=3, threshold=0.8, align_attempts=3):
    aspect = item['aspect']
    image_path = item['image_path']
    prompt = item['prompt']
    if level == 'easy':
        threshold = 1.0
    for attempt in range(retry_attempts):
        try:
            score, results = await generate_and_evaluate_image(image_path, prompt)
            align_attempt = 0
            while score < threshold and align_attempt < align_attempts:
                score, results = await generate_and_evaluate_image(image_path, prompt)
                align_attempt += 1
            if score >= threshold:
                return {
                    "aspect": aspect,
                    "prompt": prompt,
                    "image_path": image_path,
                    'level': level,
                    'model': 'gpt4o',
                    'score': score,
                    'align_results': results
                }
            else:
                return None
        except Exception as e:
            print(f"Attempt {attempt + 1} failed: {e}.")
            if attempt == retry_attempts - 1:
                return None

async def align_images(user_input):
    image_prompt_folder = f'./document/{user_input}'
    os.makedirs(image_prompt_folder, exist_ok=True)
    for level in ['easy','medium','hard']:
        image_prompt_file = f'{image_prompt_folder}/image_json_oil_paint/{level}_images.json'
        image_prompt_data = await load_json(image_prompt_file)
        
        tasks = []
        semaphore = asyncio.Semaphore(20)

        async def sem_task(item, level):
            async with semaphore:
                return await align_single_image(item, level)

        # image_prompt_data_sample = random.sample(image_prompt_data, 80)
        
        for item in image_prompt_data:
            task = asyncio.create_task(sem_task(item, level))
            tasks.append(task)
        
        save_data = await tqdm_asyncio.gather(*tasks)
        save_data = [data for data in save_data if data is not None]

        save_file = f'{image_prompt_folder}/aligned_image_json_oil_paint/{level}_aligned_images.json'
        if not os.path.exists(f'{image_prompt_folder}/aligned_image_json_oil_paint'):
            os.makedirs(f'{image_prompt_folder}/aligned_image_json_oil_paint')
        await save_json(save_data, save_file)
        print(f'{level} images aligned and saved successfully! Totally {len(save_data)} images aligned. Align rate: {len(save_data) / len(image_prompt_data)}')

async def gen_single_question(model_name, item, level, objective_question_prompt, retry_attempts=3):
    aspect = item['aspect']
    image_path = item['image_path']
    prompt = item['prompt']
    need_elements = False
    elements = None
    # if item['score'] == 1:
    #     elements = "None"
    # else:
    #     need_elements = True
        # for result in item['align_results']['responses']:
        #     if result['llm_answer'] != result['correct_answer']:
        #         elements = f"{result['element_type']}: {result['element']}"
        #         break
    for attempt in range(retry_attempts):
        try:
            objective_prompt = objective_question_prompt.format(aspect=aspect, elements=elements, level=level, prompt=prompt)
            # objective_response = await async_gpt4o(objective_prompt)
            # for model_name in async_examiner_function.keys():
            model_function = async_examiner_function[model_name]
            objective_response = await model_function(objective_prompt)
            objective_response = re.sub(r"```json(.*?)```", r"\1", objective_response, flags=re.DOTALL).strip()
            objective_reference_answer = json.loads(objective_response)['reference_answer']
            # objective_question = json.loads(objective_response)['question'] + '\n' + json.dumps(json.loads(objective_response)['options'])
            objective_question = json.loads(objective_response)['question']
            choice = json.loads(objective_response)['options']
            return {
                "aspect": aspect,
                "prompt": prompt,
                "image_path": image_path,
                'level': level,
                'model': model_name,
                'objective_question': objective_question,
                'choice': choice,
                'objective_reference_answer': objective_reference_answer,
                'need_elements': need_elements,
            }
        except Exception as e:
            print(f"Attemp {attempt + 1} failed: {e}")
            if attempt == retry_attempts - 1:
                return None
            
    return None

async def generate_questions(user_input):
    image_prompt_folder = f'./document/{user_input}'
    os.makedirs(image_prompt_folder, exist_ok=True)
    config_file_path = './config/config.yaml'
    config = await load_config(config_file_path)
    for level in ['easy','medium','hard']:
        image_prompt_file = f'{image_prompt_folder}/aligned_image_json_oil_paint/{level}_aligned_images.json'
        # image_prompt_file = f'document/{user_input}/prompts/{level}_image_prompts.json'
        # image_prompt_file = f'./document/{user_input}/prediction/{level}_pertubation_modified.json'  # for test 
        image_prompt_data = await load_json(image_prompt_file)
        
        tasks = []
        semaphore = asyncio.Semaphore(20)

        async def sem_task(model_name, item, level, objective_question_prompt):
            async with semaphore:
                return await gen_single_question(model_name, item, level, objective_question_prompt)
        
        for item in image_prompt_data:
            model_name = random.choice(list(async_examiner_function.keys()))
            # model_name = 'gemini_1_5_pro'
            objective_question_prompt = config.get('objective_question_prompt')
            task = asyncio.create_task(sem_task(model_name, item, level, objective_question_prompt))
            tasks.append(task)
        
        save_data = await tqdm_asyncio.gather(*tasks)
        save_data = [data for data in save_data if data is not None]
        # save_file = f'{image_prompt_folder}/questions/{level}_questions.json'
        save_file = f'./document/{user_input}/questions_oil_paint/{level}_questions.json' # for test
        if not os.path.exists(f'{image_prompt_folder}/questions_oil_paint'):
            os.makedirs(f'{image_prompt_folder}/questions_oil_paint')
        await save_json(save_data, save_file)
        print(f'{level} questions generated and saved successfully! Totally {len(save_data)} questions generated. Generate rate: {len(save_data) / len(image_prompt_data)}')

async def adjust_questions(user_input, weights=[0.25, 0.25, 0.25, 0.25]):
    questions_folder = f'./document/{user_input}/questions'
    os.makedirs(questions_folder, exist_ok=True)
    for level in ['easy', 'medium', 'hard']:
        questions_file = f'{questions_folder}/{level}_questions_modified.json'
        questions_data = await load_json(questions_file)
        options = ['A', 'B', 'C', 'D']
        random.seed(42)
        answer_sequence = random.choices(options, weights, k=len(questions_data))
        for i, item in enumerate(questions_data):
            question = item["objective_question"]
            correct_answer = item["objective_reference_answer"]
            new_answer = answer_sequence[i]
            # question_text, options_text = question.split("\n", 1)
            # options_dict = json.loads(options_text)
            options_dict = item["choice"]
            correct_answer_text = options_dict[correct_answer]
            options_dict[correct_answer], options_dict[new_answer] = options_dict[new_answer], correct_answer_text
            # new_question = f"{question_text}\n" + json.dumps(options_dict, ensure_ascii=False)
            # item["objective_question"] = new_question
            item["objective_reference_answer"] = new_answer
        save_file = f'{questions_folder}/{level}_questions_balanced.json'
        await save_json(questions_data, save_file)
        print(f'{level} questions adjusted and saved successfully! Weighted answers: {weights}')

def extract_score(text):
    pattern_brackets = r'Rating:\s*\[\[(\d+(\.\d+)?)\]\]'
    pattern_direct = r'Rating:\s*(\d+(\.\d+)?)'
    
    matches_brackets = re.findall(pattern_brackets, text)
    matches_direct = re.findall(pattern_direct, text)
    
    if matches_brackets:
        try:
            return float(matches_brackets[-1][0])
        except:
            return 0.0
    
    if matches_direct:
        try:
            return float(matches_direct[-1][0])
        except:
            return 0.0
    
    return 0.0
    
def extract_choice(text):
    try:
        pattern = r'\[\[(.*?)\]\]'
        matches = re.findall(pattern, text)
        return matches[0]
    except:
        return None

async def generate_single_answer(model_function, objective_answer_prompt, level, item, retry_attempts=8):
    aspect = item['aspect']
    image_path = item['image_path']
    # image_path = None
    # image_url = local_image_to_data_url(image_path)
    prompt = item['prompt']
    objective_question = item['objective_question'] + '\n' + json.dumps(item['choice'])
    objective_reference_answer = item['objective_reference_answer']
    for attempt in range(retry_attempts):
        try:
            objective_answer_prompt = objective_answer_prompt.format(aspect=aspect, question=objective_question)
            # image_input = image_url if model_function in [async_gpt4o, async_gpt4o_mini] else image_path
            objective_answer = await model_function(objective_answer_prompt, image_path)
            objective_choice = extract_choice(objective_answer)
            objective_score = 1 if objective_choice == objective_reference_answer else 0
            return {
                "aspect": aspect,
                "prompt": prompt,
                "image_path": image_path,
                'level': level,
                'model': model_function.__name__,
                'objective_question': objective_question,
                'objective_answer': objective_answer,
                'need_elements': item['need_elements'],
                'objective_choice': objective_choice,
                'objective_score': objective_score,
                'objective_reference_answer': objective_reference_answer
            }
        except Exception as e:
            print(f"Model: {model_function.__name__} Attempt {attempt + 1} failed: {e}")
            if attempt == retry_attempts - 1:
                return None
            await asyncio.sleep(2)

async def generate_answers(user_input):
    image_prompt_folder = f'./document/{user_input}'
    os.makedirs(image_prompt_folder, exist_ok=True)
    config_file_path = './config/config.yaml'
    config = await load_config(config_file_path)
    for level in ['easy','medium','hard']:    
        image_prompt_file = f'{image_prompt_folder}/questions/{level}_questions_modified.json'
        image_prompt_data = await load_json(image_prompt_file)

        # image_prompt_data_sample = random.sample(image_prompt_data, 120)
        
        tasks = []
        semaphore = asyncio.Semaphore(5)
        model_scores = {model_function.__name__: {'objective': []} for model_function in async_lvm_function.values()}
        async def sem_task(model_function, objective_answer_prompt, level, item):
            async with semaphore:
                return await generate_single_answer(model_function, objective_answer_prompt,level, item)
        for item in image_prompt_data:
            objective_answer_prompt = config.get('objective_answer_prompt')
            for model_name in async_lvm_function.keys():
                model_function = async_lvm_function[model_name]
            # model_function = async_lvm_function['gpt-4o']
                task = asyncio.create_task(sem_task(model_function, objective_answer_prompt, level, item))
                tasks.append(task)
        
        results = await tqdm_asyncio.gather(*tasks)
        results = [result for result in results if result is not None]
        
        for result in results:
            model_scores[result['model']]['objective'].append(result['objective_score'])
        
        save_file = f'{image_prompt_folder}/answers/{level}_answers_llava.json'  # for test
        if not os.path.exists(f'{image_prompt_folder}/answers'):
            os.makedirs(f'{image_prompt_folder}/answers')
        await save_json(results, save_file)
        print(f'{level} answers generated and saved successfully!')

        # scores_file = f'{image_prompt_folder}/scores/{level}_scores.json'
        scores_file = f'{image_prompt_folder}/scores/{level}_scores_llava.json'  # for test
        if not os.path.exists(f'{image_prompt_folder}/scores'):
            os.makedirs(f'{image_prompt_folder}/scores')
        
        avg_scores = {}
        for model_name, scores in model_scores.items():
            avg_scores[model_name] = {
                'average_objective_score': sum(scores['objective']) / len(scores['objective']) if scores['objective'] else 0,
                'objective_num': len(scores['objective']),
            }
            print(f'Average objective score for model {model_name} at level {level}: {avg_scores[model_name]["average_objective_score"]:.2f}')
        
        await save_json(avg_scores, scores_file)
        print(f'{level} scores generated and saved successfully!')

async def visualization_scores1(user_input, ablation=False):
    parent_path = os.path.join("document", user_input, "scores_cartoon")
    if ablation:
        parent_path = os.path.join("document", user_input, "ablation_study")
    difficulties = ['easy', 'medium', 'hard']
    files = [os.path.join(parent_path, f"{difficulty}_scores_modified.json") for difficulty in difficulties]
    models = [model_name.__name__ for model_name in async_lvm_function.values()]
    model_names = [lvm_func_to_name[model_name] for model_name in models]
    final_scores = {model: {'subjective': [], 'objective': []} for model in models}
    
    for file in files:
        data = await load_json(file)
        for model in models:
            final_scores[model]['objective'].append(data[model]["average_objective_score"])

    async def plot_scores(score_type, title, ylabel, filename):
        bar_width = 0.2
        index = np.arange(len(models))
        fig, ax = plt.subplots(figsize=(10, 6))
        
        for i, difficulty in enumerate(difficulties):
            scores = [final_scores[model][score_type][i] for model in models]
            ax.bar(index + i * bar_width, scores, bar_width, label=difficulty)
        
        ax.set_xlabel('Models')
        ax.set_ylabel(ylabel)
        ax.set_title(title)
        ax.set_xticks(index + bar_width / 2)
        ax.set_xticklabels(model_names, rotation=45)
        ax.set_ylim(0, 1)  # Set y-axis limit from 0 to 1
        ax.legend()
        
        plt.tight_layout()
        
        visualization_path = os.path.join(f'./document/{user_input}', "visualization", filename)
        save_plot(fig, visualization_path, 300)

    os.makedirs(f"./document/{user_input}/visualization", exist_ok=True)
    if ablation:
        await plot_scores('objective', f'Final Objective Scores for Ablation Study', 'Average Objective Score', f'{user_input}_ablation_objective_scores.png')
    else:
        await plot_scores('objective', f'Final Scores with user input: {user_input} On Cartoon', 'Average Objective Score', f'{user_input}_objective_scores.png')

    print(f"Visualization of scores for user input {user_input} saved successfully!")

async def visualization_scores(user_input, ablation=False):
    import matplotlib.pyplot as plt
    import numpy as np
    import seaborn as sns

    # 设置 Seaborn 风格
    sns.set_theme(style="whitegrid")

    parent_path = os.path.join("document", user_input, "scores_oil_paint")
    if ablation:
        parent_path = os.path.join("document", user_input, "ablation_study")
    difficulties = ['easy', 'medium', 'hard']
    files = [os.path.join(parent_path, f"{difficulty}_scores_modified.json") for difficulty in difficulties]
    models = [model_name.__name__ for model_name in async_lvm_function.values()]
    model_names = [lvm_func_to_name[model_name] for model_name in models]
    final_scores = {model: {'subjective': [], 'objective': []} for model in models}
    
    for file in files:
        data = await load_json(file)
        for model in models:
            final_scores[model]['objective'].append(data[model]["average_objective_score"])

    async def plot_scores(score_type, title, ylabel, filename):
        bar_width = 0.3  # 增加柱子的宽度
        index = np.arange(len(models))
        fig, ax = plt.subplots(figsize=(14, 8))  # 调整图表尺寸
        
        # 使用更柔和的配色方案
        colors = sns.color_palette("pastel", len(difficulties))
        
        for i, difficulty in enumerate(difficulties):
            scores = [final_scores[model][score_type][i] for model in models]
            ax.bar(index + i * bar_width, scores, bar_width, label=difficulty, color=colors[i], edgecolor='black')
        
        # 设置标题和标签
        ax.set_xlabel('Models', fontsize=18, fontweight='bold')
        ax.set_ylabel(ylabel, fontsize=18, fontweight='bold')
        ax.set_title(title, fontsize=20, fontweight='bold', pad=20)
        ax.set_xticks(index + bar_width)
        ax.set_xticklabels(model_names, rotation=45, fontsize=14, fontweight='bold')
        ax.set_ylim(0.1, 1)  # 设置 y 轴范围从 0.1 到 1
        ax.yaxis.set_tick_params(labelsize=14)  # 设置 y 轴刻度字体大小
        ax.legend(title="Difficulty", fontsize=14, title_fontsize=16, loc='upper right', frameon=True, shadow=True)

        # 添加数据标签
        for i, difficulty in enumerate(difficulties):
            scores = [final_scores[model][score_type][i] for model in models]
            for j, score in enumerate(scores):
                ax.text(index[j] + i * bar_width, score + 0.02, f"{score:.2f}", ha='center', va='bottom', fontsize=12, fontweight='bold')

        # 添加网格线
        ax.grid(True, linestyle='--', alpha=0.6, axis='y')

        # 设置背景颜色
        ax.set_facecolor('#f9f9f9')

        # 调整布局
        plt.tight_layout()
        
        visualization_path = os.path.join(f'./document/{user_input}', "visualization", filename)
        fig.savefig(visualization_path, format='pdf', dpi=500)

    os.makedirs(f"./document/{user_input}/visualization", exist_ok=True)
    if ablation:
        await plot_scores('objective', f'Final Objective Scores for Ablation Study', 'Average Objective Score', f'{user_input}_ablation_objective_scores.png')
    else:
        await plot_scores('objective', f'Final Scores with user input: {user_input} On Oil Painting', '', f'{user_input}_objective_scores.pdf')

    print(f"Visualization of scores for user input {user_input} saved successfully!")

async def to_csv(user_input, ablation=False):
    parent_path = os.path.join("document", user_input, "scores_oil_paint")
    if ablation:
        parent_path = os.path.join("document", user_input, "ablation_study")
    difficulties = ['easy', 'medium', 'hard']
    files = [os.path.join(parent_path, f"{difficulty}_scores_modified.json") for difficulty in difficulties]

    output_path = os.path.join(f'./document/{user_input}', "all_scores_oil_paint.csv")
    if ablation:
        output_path = os.path.join(f'./document/{user_input}', "ablation_study_scores.csv")

    with open(output_path, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['User Input', user_input]) if not ablation else writer.writerow(['User Input', user_input, 'Ablation Study'])
        writer.writerow(['Difficulty', 'Model', 'Objective Score', 'Alignment Rate'])

        for difficulty, file in zip(difficulties, files):
            data = await load_json(file)
            for model, scores in data.items():
                writer.writerow([difficulty, lvm_func_to_name[model],scores["average_objective_score"], scores["objective_num"]/240])
    print(f"{user_input} scores saved to csv located at {output_path}")

async def single_ablation_study(level, item, objective_question_prompt, subjective_question_prompt, retry_attempts=3):
    aspect = item['aspect']
    image_path = item['image_path']
    prompt = item['prompt']
    need_elements = False
    elements = "None"
    
    for attempt in range(retry_attempts):
        try:
            subjective_prompt = subjective_question_prompt.format(aspect=aspect, elements=elements, level=level, prompt=prompt)
            objective_prompt = objective_question_prompt.format(aspect=aspect, elements=elements, level=level, prompt=prompt)
            subjective_response = await async_gpt4o(subjective_prompt)
            objective_response = await async_gpt4o(objective_prompt)
            subjective_reference_answer = json.loads(subjective_response)['reference_answer']
            objective_reference_answer = json.loads(objective_response)['reference_answer']
            subjective_question = json.loads(subjective_response)['question']
            objective_question = json.loads(objective_response)['question'] + '\n' + json.dumps(json.loads(objective_response)['options'])
            return {
                "aspect": aspect,
                "prompt": prompt,
                "image_path": image_path,
                'level': level,
                'model': 'gpt4o',
                'subjective_question': subjective_question,
                'subjective_reference_answer': subjective_reference_answer,
                'objective_question': objective_question,
                'objective_reference_answer': objective_reference_answer,
                'need_elements': need_elements,
            }
        except Exception as e:
            print(f"Attempt {attempt + 1} failed: {e}")
            if attempt == retry_attempts - 1:
                return None

async def ablation_study(user_input):
    image_prompt_folder = f'./document/{user_input}/'
    os.makedirs(image_prompt_folder, exist_ok=True)
    config_file_path = './config/config.yaml'
    config = await load_config(config_file_path)
    
    for level in ['easy', 'medium', 'hard']:
        image_prompt_file = f'{image_prompt_folder}/image_json/{level}_images.json'
        image_prompt_data = await load_json(image_prompt_file)
        
        tasks = []
        for item in image_prompt_data:
            objective_question_prompt = config.get('objective_question_prompt')
            subjective_question_prompt = config.get('subjective_question_prompt')
            task = asyncio.create_task(single_ablation_study(level, item, objective_question_prompt, subjective_question_prompt))
            tasks.append(task)
        
        save_data = await tqdm_asyncio.gather(*tasks)
        save_data = [data for data in save_data if data is not None]
        save_file = f'{image_prompt_folder}/ablation_study/{level}_questions.json'
        if not os.path.exists(f'{image_prompt_folder}/ablation_study'):
            os.makedirs(f'{image_prompt_folder}/ablation_study')
        await save_json(save_data, save_file)
        print(f'{level} questions for ablation study generated and saved successfully!')

        # Generate answers for ablation study
        tasks = []
        semaphore = asyncio.Semaphore(20)
        model_scores = {model_function.__name__: {'objective': [], 'subjective': []} for model_function in async_lvm_function.values()}
        
        async def sem_task(model_function, subjective_answer_prompt, objective_answer_prompt, eval_prompt, level, item):
            async with semaphore:
                return await generate_single_answer(model_function, subjective_answer_prompt, objective_answer_prompt, eval_prompt, level, item)
        
        for item in save_data:
            subjective_answer_prompt = config.get('subjective_answer_prompt')
            objective_answer_prompt = config.get('objective_answer_prompt')
            eval_prompt = config.get('eval_model_response_prompt_template')
            for model_name in async_lvm_function.keys():
                model_function = async_lvm_function[model_name]
                task = asyncio.create_task(sem_task(model_function, subjective_answer_prompt, objective_answer_prompt, eval_prompt, level, item))
                tasks.append(task)
        
        results = await tqdm_asyncio.gather(*tasks)
        results = [result for result in results if result is not None]
        
        for result in results:
            model_scores[result['model']]['objective'].append(result['objective_score'])
            model_scores[result['model']]['subjective'].append(result['subjective_score'])
        
        save_file = f'{image_prompt_folder}/ablation_study/{level}_answers.json'
        if not os.path.exists(f'{image_prompt_folder}/ablation_study'):
            os.makedirs(f'{image_prompt_folder}/ablation_study')
        await save_json(results, save_file)
        print(f'{level} answers for ablation study generated and saved successfully!')

        scores_file = f'{image_prompt_folder}/ablation_study/{level}_scores.json'
        if not os.path.exists(f'{image_prompt_folder}/ablation_study'):
            os.makedirs(f'{image_prompt_folder}/ablation_study')
        
        avg_scores = {}
        for model_name, scores in model_scores.items():
            avg_scores[model_name] = {
                'average_objective_score': sum(scores['objective']) / len(scores['objective']) if scores['objective'] else 0,
                'average_subjective_score': sum(scores['subjective']) / len(scores['subjective']) if scores['subjective'] else 0,
                'objective_num': len(scores['objective']),
                'subjective_num': len(scores['subjective'])
            }
            print(f'Average objective score for model {model_name} at level {level}: {avg_scores[model_name]["average_objective_score"]:.2f}')
            print(f'Average subjective score for model {model_name} at level {level}: {avg_scores[model_name]["average_subjective_score"]:.2f}')
        
        await save_json(avg_scores, scores_file)
        print(f'{level} scores for ablation study generated and saved successfully!')

async def modify_single_choice(item):
    with open('config/config.yaml', 'r') as file:
        config = yaml.safe_load(file)
    modify_choice_template = config.get('modify_choice_prompt')
    answer_list = []
    question = item['objective_question']
    answer = item['objective_reference_answer']
    answer_list.append(item['choice'][answer])
    retry_attempts = 3
    i = 0
    for choice in ['A', 'B', 'C', 'D']:
        if choice == answer:
            continue
        i += 1
        model_name = list(async_examiner_function.keys())[i % len(async_examiner_function.keys())]
        # model_name = 'gemini_1_5_pro'
        modify_choice_prompt = modify_choice_template.format(question=question, wrong_answer=answer_list)
        model_function = async_examiner_function[model_name]
        for time in range(retry_attempts):
            try:
                modified_answer = await model_function(modify_choice_prompt)
                modified_answer = re.sub(r"```json(.*?)```", r"\1", modified_answer, flags=re.DOTALL).strip()
                item['choice'][choice] = modified_answer
                answer_list.append(modified_answer)
                break
            except Exception as e:
                print(f"Attempt {time} failed: {e}")
                continue
    return item
        
async def modify_choice(user_input):
    for level in ['easy','medium','hard']:
        with open(f'./document/{user_input}/questions_oil_paint/{level}_questions.json', 'r') as file:
            data = json.load(file)
        tasks = []
        for item in data:
            # model_name = list(async_examiner_function.keys())[i % len(async_examiner_function.keys())]
            # model_name = 'gpt-4o'
            # print(model_name)
            task = asyncio.create_task(modify_single_choice(item))
            tasks.append(task)

        save_data = await tqdm_asyncio.gather(*tasks)
        save_data = [data for data in save_data if data is not None]
        save_file = f'./document/{user_input}/questions_oil_paint/{level}_questions_modified.json'
        await save_json(save_data,save_file)
        print(f'{level} questions for modified and saved successfully!')   
    

if __name__ == '__main__':
    user_input = 'basic_understanding'
    generate_type = "visualization" # "aspect", "prompts", "images", "alignment", "questions", "modify_choice", "answers", "adjust", "visualization", "csv", "ablation"
    excute_function = {
        "aspect":generate_fine_grained_aspects,
        "prompts": generate_prompts,
        "images": generate_images,
        "alignment": align_images,
        "questions": generate_questions,
        "modify_choice": modify_choice,
        "answers": generate_answers,
        "adjust": adjust_questions, # optional
        "visualization": visualization_scores,  # two arguments: user_input, ablation (bool)
        "csv": to_csv, # two arguments: user_input, ablation (bool)
        "ablation": ablation_study # optional
    }
    asyncio.get_event_loop().run_until_complete(excute_function[generate_type](user_input))