from openai import OpenAI
import base64
import re
import numpy as np
# from PIL import Image
import os, json
from tqdm import tqdm
import random

api_source = "gpt-4o"
client = OpenAI(
    base_url="https://api2.aigcbest.top/v1",
    api_key="sk-xxxx"
)


def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')    
    
def get_messages(example, dataset, zero_shot):
    
    SYSTEM_PROMPT = "You are a assistant with context reasoning capabilities. Your task is to extract relevant information from the context and provide a reasonable answer for each question based on the context. Make sure to focus on the key details that relate to the question and ignore irrelevant or extraneous information.\nThe reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>"
    QUESTION_TEMPLATE = "{Question} Output the thinking process in <think> </think> and final answer in <answer> </answer> tags."
    
    
    messages = [{"role": "system", "content": SYSTEM_PROMPT}]

    if dataset == "image_operation":
        messages.append({"role": "user", "content": []})
        for i in range(len(example["demonstration"])):
            images = example["demonstration"][i]["images"][0]
            messages[-1]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(images)}"}})
        
        # user query
        query_images = example["images"][0]
        messages[-1]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(query_images)}"}})
        messages[-1]["content"].append({"type": "text", "text": QUESTION_TEMPLATE.format(Question=f'{example["question"]}\n')})
    
        
    if dataset in ["caption_match", 'image_rotate', "textocr"]:
        # task_description= 'According to the few-shot examples, induce what operation to do and determine the output for the new inputs.'
        if dataset == "caption_match":
            task_description = 'According to the few-shot examples, induce what operation to do and determine the output for the new inputs.'
        if dataset == "image_rotate":
            task_description = "According to the few-shot examples, induce what operation to do and determine the output for the new inputs. Hint: The images on either side of the arrow show the before and after of the transformation"
        if dataset == "textocr":
            task_description = "Induce the transformation rule related to the text in the red box from the in-context examples and determine the output for the new inputs. Please note the case of the letters."
        
        # few-shot
        messages.append({"role": "user", "content": [{"type": "text", "text": task_description}]})
        demonstration_list = example["demonstration"]
        for item in demonstration_list:
            demo_image = item["images"][0]
            demo_answer = item["answer"]
            if dataset == "caption_match":
                demo_answer = f"Judgement: {demo_answer}"
            messages[-1]["content"].extend([
                {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(demo_image)}"}},
                {"type": "text", "text": demo_answer}
            ])
                
        # user query
        query_images = example["images"][0]
        question = "Judgement: " if dataset == "caption_match" else ""
        messages[-1]["content"].extend([
            {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(query_images)}"}},
            {"type": "text", "text": QUESTION_TEMPLATE.format(Question='')}
        ])

    
    # recognize_dog
    if dataset in ['recognize_dog']:
        question = example["question"]         
        # task_description = "Induce the concept from the in-context examples. Answer the question with a single word or phase."   # default in VL_ICL
        task_description = "Induce the features of different types from the in-context examples."
        
        messages.append({"role": "user", "content": [{"type": "text", "text": task_description}]})
        demonstration_list = example["demonstration"]
        for item in demonstration_list:
            answer = item["answer"]
            for image in item["images"]:
                messages[-1]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(image)}"}})
            messages[-1]["content"].append({"type": "text", "text": answer})
                
        # user query
        query_images = example["images"][0]
        messages[-1]["content"].extend([
            {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(query_images)}"}},
            {"type": "text", "text": QUESTION_TEMPLATE.format(Question=f"{question}")}
        ])
    
    if dataset in ['inductive_transform', "logicvista_knowledge"]:    
        image_path = example["image"]
        question = example["question"]
        messages.append({
            "role": "user",
            "content": [
                {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(image_path)}"}},
                {"type": "text", "text": QUESTION_TEMPLATE.format(Question=question)}
            ]
        })
    
    if dataset in ["daily_operation"]:
        demonstration_list = example["demonstration"]
        question = example["question"]
        query_image = example["images"]
        
        messages.append({
            "role": "user",
            "content": [
                {"type": "text", "text": example["description"]},     
            ]
        })
        # demo
        for i in range(len(demonstration_list)):
            image = demonstration_list[i]['images']
            step_desc = demonstration_list[i]["step_desc"]
            messages[-1]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(image)}"}})
            messages[-1]["content"].append({"type": "text", "text": step_desc})
            
        # question  
        messages[-1]["content"].extend([
            {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(query_image)}"}},
            {"type": "text", "text": QUESTION_TEMPLATE.format(Question=question)}
        ])
                    
        
    if dataset in ["handsign_decryption", "alphabet_understand"]:
        if dataset == "alphabet_understand":
            question = "What English word can be formed from the objects shown in the image?"
        if dataset == "handsign_decryption":
            question = example['question']
        messages.append({
            "role": "user",
            "content": [
                {"type": "text", "text": example['task_description']},
                {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(example['support_image'])}"}},
                {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(example['query_image'])}"}},
                {"type": "text", "text": QUESTION_TEMPLATE.format(Question=question)}
            ]
        })

        
    if dataset in ["caption_preference"]:      
        task_description = "Given examples of inputs paired with different captions and their ratings, infer the user's underlying preference patterns and apply them to answer the question."
        question = "Please select the option from the following descriptions that best matches the user's preferences."
        messages.append({
            "role": "user",
            "content": [
                {"type": "text", "text": task_description},
                {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(example['demonstration'][0]['images'])}"}},
                {"type": "text", "text": example['demonstration'][0]['captions']},
                {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(example['images'])}"}},
                {"type": "text", "text": QUESTION_TEMPLATE.format(Question=f"{example['question']}\n{example['choices']}")}
            ]
        })
        
    if dataset in ["image_style"]:
        task_description = "You will be provided with images and their corresponding preference selections. Please infer the user's underlying preference patterns from the examples and apply them to answer the question."
        messages.append({"role": "user", "content": [{"type": "text", "text": task_description}]})
        for item in example["demonstration"]:
            messages[-1]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(item['images'])}"}})
            messages[-1]["content"].append({"type": "text", "text": item["answer"]})
        messages[-1]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(example['images'])}"}})
        messages[-1]["content"].append({"type": "text", "text": QUESTION_TEMPLATE.format(Question="")})
        
        
    if dataset in ["manual_understand"]:
        messages.append({
            "role": "user",
            "content": [
                {"type": "text", "text": example['task_description']},
                {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(example['images'])}"}},
                {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example['question'])}
            ]
        })
        
    if dataset in ["lego_build"]:
        question_content = example["question"].split("<image>")
        assert len(question_content) == len(example["images"]) + 1, "too much question content"
        
        # demo
        messages.append({"role": "user", "content": [
            {"type": "text", "text": example['task_description']},
            {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(example['demonstration_image'])}"}}
        ]})

        # question
        for i in range(len(example["images"])):
            messages[-1]["content"].extend([
                {"type": "text", "text": question_content[i]},
                {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(example['images'][i])}"}}
            ])
        messages[-1]["content"].append({"type": "text", "text": question_content[-1]})

            
    if dataset in ["anomaly_detection"]:
        
        messages.append({"role": "user", "content": [{"type": "text", "text": f"{example['task_description']}\n{example['demo_prompt']}"}]})
        for item in example["demonstration"]:
            messages[-1]["content"].extend([
                {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(item['images'])}"}},
                {"type": "text", "text": item['anomaly_desc']}
            ])
            
        messages[-1]["content"].extend([
            {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(example['images'])}"}},
            {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example['choice'])}
        ])
            
    if dataset in ["user_preference_microlens", "user_preference_amazon"]:
        
        messages.append({"role": "user", "content": [
            {"type": "text", "text": example['task_description']}
        ]})
        for item in example["demonstration"]:
            if dataset == "user_preference_microlens":
                obj_desc = item['title']
            if dataset == "user_preference_amazon":
                obj_desc = item['description']
            messages[-1]["content"].extend([
                {"type": "text", "text": obj_desc},
                {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(item['images'])}"}},
                {"type": "text", "text": item['user_behavior']}
            ])
        
        messages.append({"role": "user", "content": [
            {"type": "text", "text": example['question']},
            {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(example['images'])}"}}
        ]})
        

    return messages


def inference_single_example(example, dataset, model, max_tokens, zero_shot):
    messages = get_messages(example, dataset, zero_shot)

    response = client.chat.completions.create(
        model=model,
        messages=messages,
        max_tokens=max_tokens,
    )
    response = response.choices[0].message.content
    return response


def inference(data_file, dataset, model, max_tokens, zero_shot):

    with open(data_file, 'r', encoding='utf-8') as f:
        all_data = json.load(f)
    filename = os.path.basename(data_file)
    filename = filename.replace(".json", "_inference_results.json")
    shot_tag = "_0shot" if zero_shot else ""
    result_path = f"inference_results/{model}{shot_tag}"
    os.makedirs(result_path, exist_ok=True)
    result_file = os.path.join(result_path, filename)
    
    if os.path.exists(result_file):
        with open(result_file, 'r', encoding='utf-8') as f:
            inference_results = json.load(f)
    else:
        inference_results = []
    all_data = all_data[len(inference_results):]
    
    for data in tqdm(all_data):
        response = inference_single_example(data, dataset, model, max_tokens, zero_shot)

        data["prediction"] = response
        inference_results.append(data)
        
        with open(result_file, "w", encoding="utf-8") as f:
            json.dump(inference_results, f, ensure_ascii=False, indent=4)
        
    return result_file
        

def extract_answer_withgpt(example, generated_answer, dataset, retry=5, model="gpt-4o", max_token=256):

    task_description = "Please extract the answer from the model response and type it at the end of the prompt.\n"

    prompt = task_description
    
    if dataset in ["caption_match", "instance_match"]:
        prompt_example = '''Example 1:\ninput:No, the image and the text do not describe the same content.\noutput:No\nExample 2:\ninput:Yes, the image and the text describe the same content.\noutput:Yes\nExample 3:\ninput:The text describes a man on top of a train that is steaming. However, the image does not show a train or steam.\noutput:No'''
        prompt += prompt_example

    if dataset == "anomaly_detection":
        prompt += "Is the objects in the image normal? If any abnormalities are present, select the appropriate type of anomaly.\n"
    else:
        prompt += example["question"] + "\n"
    prompt += "Model respone: " + generated_answer + "\n"
    prompt += "Extracted answer:"
    
    messages = [
        {"role": "user", "content": prompt},
    ]
    for i in range(retry):
        try:
            completion = client.chat.completions.create(    
                model=model, messages=messages, temperature=0.5 * i, max_tokens=max_token
            )
            prediction = completion.choices[0].message.content.strip()
            if prediction != "" and prediction is not None:
                return prediction
            else:
                continue
        except Exception as e:
            print(e)
    return ""

    

def analyze_results(results_file, dataset, eval_model=None):

    all_results = json.load(open(results_file, 'r', encoding='utf-8'))
    
    all_answers = []
    for result in tqdm(all_results):
        
        prediction, true_answer = result["prediction"], result["answer"]     # .strip删除两端的\n\r\t
        prediction = prediction.replace('\n', '').replace('\r', '')     # 删除回车换行
        match = re.search(r'<answer>(.*?)</answer>', prediction)
        
        if match:
            predict_answer = match.group(1)
            if len(predict_answer.split(" ")) > 1:
                extract_answer = extract_answer_withgpt(result, predict_answer, dataset)
            else:
                extract_answer = ''
        else:
            predict_answer = prediction.split("<answer>")[-1]
            extract_answer = extract_answer_withgpt(result, predict_answer, dataset)    
        
        if "index" in result:
            data_id = result["index"]
        elif "id" in result:
            data_id = result["id"]
        else:
            raise KeyError(f"dataset: {dataset}")
        
        result["predict_answer"] = predict_answer
        result["extract_answer"] = extract_answer
        result["true_answer"] = true_answer
        all_answers.append(result)
        
    save_filename = os.path.basename(results_file)
    save_filename = save_filename.replace("inference_results.json", "answers.json")
    save_path = os.path.dirname(results_file).replace("inference_results", "eval_results")
    
    if eval_model:
        save_path = save_path.replace("eval_results", f"eval_results/{eval_model}")
    
    os.makedirs(save_path, exist_ok=True)
    save_file = os.path.join(save_path, save_filename)
    
    print(f"answers are saved in {save_file}")

    with open(save_file, "w", encoding="utf-8") as f:
        json.dump(all_answers, f, ensure_ascii=False, indent=4)
    return save_file


    
    
    
    
def answer_verify_with_gpt(answer_dict, eval_model, merge=False, model="gpt-4o", max_token=512):

    system_prompt = "You are an expert in answer evaluation. Given a model-generated answer and the correct reference answer, your task is to determine whether the generated answer is correct. Your evaluation should focus on the semantic meaning and whether the key information matches, rather than requiring exact wording."

    
    all_acc = {}
    for dataset, result_file in answer_dict.items():
        extract_answer_file = analyze_results(result_file, dataset, eval_model)
        all_results = json.load(open(extract_answer_file, 'r', encoding='utf-8'))
        acc = []
        judge_results = []
        for result in all_results:
            predict_answer = result["extract_answer"] if result["extract_answer"] != "" else result["predict_answer"]
            true_answer = result["answer"]

            question_str = ""
            if dataset == "anomaly_detection":
                question_str = "Question:\nIs the objects in the image normal? If any abnormalities are present, select the appropriate type of anomaly.\n"
            else:
                question_str = f"Question:\n{result['question']}\n"
                
            query_content = f"{question_str}Model-generated answer:\n{predict_answer}\nReference answer:\n{true_answer}\nPlease output True if the generated answer is correct; otherwise, output False."

            messages = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": query_content}
            ]
            for i in range(5):
                try:
                    completion = client.chat.completions.create(    
                        model=model, messages=messages, temperature=0.5 * i, max_tokens=max_token
                    )
                    prediction = completion.choices[0].message.content.strip()
                    print(f"Model-generated answer: {predict_answer}\tReference answer: {true_answer}\tJudge:{prediction}")
                    if prediction != "" and prediction is not None:
                        judge_result = prediction
                        break
                    else:
                        continue
                except Exception as e:
                    print(e)
            result["judge_result"] = judge_result
            judge_results.append(result)
            
            if dataset in ['alphabet_understand', 'handsign_decryption']:
                pred_clean = ''.join(predict_answer.upper().split())
                true_clean = ''.join(true_answer.upper().split())

                correct_letters = 0
                true_len = len(true_clean)

                compare_len = min(len(pred_clean), true_len)
                
                for i in range(compare_len):
                    if pred_clean[i] == true_clean[i]:
                        correct_letters += 1

                letter_score = correct_letters / true_len if true_len > 0 else 0
                acc.append(letter_score)
                
            elif dataset == 'textocr':
                if str(true_answer) in predict_answer:
                    acc.append(1)
                else:
                    acc.append(0)
                    
            else:

                if judge_result == "True":
                    acc.append(1)
                else:
                    acc.append(0)
                
        avg_acc = np.average(acc)
        print("model_name:", eval_model, "dataset:", dataset, "avg_acc:", f"{np.sum(acc)} / {len(acc)}", avg_acc)
        all_acc[dataset] = {
            "true_num": f"{np.sum(acc)} / {len(acc)}",
            "avg_acc": avg_acc
        }
        
        with open(extract_answer_file, "w", encoding="utf-8") as f:
            json.dump(judge_results, f, ensure_ascii=False, indent=4) 
            
    all_results_file = os.path.join(os.path.dirname(extract_answer_file), "all_results.json")
    if os.path.exists(all_results_file):
        his_acc = json.load(open(all_results_file))
        for dataset, results in his_acc.items():
            all_acc[dataset] = results
            
    with open(all_results_file, "w", encoding="utf-8") as f:
        json.dump(all_acc, f, ensure_ascii=False, indent=4) 
        
    
    

if __name__ == "__main__":

    answer_verify_with_gpt(data_dict, model_name)
    