import mimetypes
import re

import uuid
from datetime import datetime

from tqdm import tqdm
import json
import os
from openai import OpenAI
import argparse
from multiprocessing import cpu_count
import prompt_adapter

parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default='qwen', help='Model name for result save')
parser.add_argument('--judge_model_name', type=str, default='gpt-5-mini-2025-08-07')
parser.add_argument('--parallel_num', type=int, default=int(cpu_count() * 0.8), help='Number of parallel processes')

parser.add_argument('--api_key', type=str, default='EMPTY', help='API key')
parser.add_argument('--api_url', type=str, default='', help='API URL')
# parser.add_argument('--vstar_bench_path', type=str, default=None, help='Path to the V* benchmark')
parser.add_argument('--save_path', type=str, default='./final_result_qa', help='Path to save the results')
parser.add_argument('--num_workers', type=int, default=8)
parser.add_argument('--thinking',type=str,default='False')
parser.add_argument("--judge", type=str, default='True')
parser.add_argument("--prompt", type=str, default='v0')
parser.add_argument('--bot',type=str,default='<think>')
parser.add_argument('--eot',type=str,default='</think>')
parser.add_argument('--qa_json_file', type=str, default='./generated_dataset/dataset_qa.json')
parser.add_argument('--images_dir', type=str, default='./generated_dataset/final_part2')

args = parser.parse_args()

judge = args.judge

openai_api_key = args.api_key
openai_api_base = args.api_url

client = OpenAI(
    api_key=openai_api_key,
    base_url=openai_api_base,
)

eval_model_name = args.model_name
import base64
import requests
from PIL import Image
from io import BytesIO


def pil_image_to_base64(image):
    buffered = BytesIO()
    image.save(buffered, format='PNG')
    img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
    return img_str


def resize_image(image_path):
    img = Image.open(image_path)
    width, height = img.size
    img = img.resize((int(width / 2), int(height / 2)))
    new_image_path = f"resized_{image_path}"
    img.save(new_image_path)
    return new_image_path
def encode_image(image_path):
    if image_path.startswith("http"):
        user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0"
        request_kwargs = {
            "headers": {"User-Agent": user_agent},
            "stream": True,
        }

        # Send a HTTP request to the URL
        response = requests.get(image_path, **request_kwargs)
        response.raise_for_status()
        content_type = response.headers.get("content-type", "")

        extension = mimetypes.guess_extension(content_type)
        if extension is None:
            extension = ".download"

        fname = str(uuid.uuid4()) + extension
        download_path = os.path.abspath(os.path.join("downloads", fname))

        with open(download_path, "wb") as fh:
            for chunk in response.iter_content(chunk_size=512):
                fh.write(chunk)

        image_path = download_path

    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')

with open("env.json",'r') as f:
    api_keys = json.load(f)


clients = {
    'openai':OpenAI(
    base_url="https://yunwu.ai/v1",
    api_key=api_keys['API_SALE'],
    timeout=120)
}

def extract_thinking_and_summary(text: str, bot: str = "<think>", eot: str = "</think>") -> str:

    if eot in text and bot in text:

        return text[text.index(bot) + len(bot):text.index(eot)].strip(), text[text.index(eot) + len(eot) :].strip()
    return "", text


judge_model_name = args.judge_model_name
model_family = "None"
def automatic_model_family(model_name):
    if "o3" or "o4" or "4o" or "gpt" in model_name:
        return "gemini"
    elif "gemini" or "qwen" in model_name:
        return "gemini"
    elif "claude" in model_name:
        return "claude"
    elif "grok" in model_name:
        return "grok"
    elif "glm" in model_name:
        return "glm"
    elif "internvl3.5" in model_name:
        return "internvl3.5"

    return "gemini"

def get_prompt(sample, thinking_mode):
    if thinking_mode and ("kimi" not in eval_model_name) and ("mimo" not in eval_model_name):
        return prompt_adapter.get_prompt_qa_cot("gemini", sample)
    return prompt_adapter.get_prompt_qa("gemini", sample)

def llm_as_judge(sample, model_response):
    global judge
    if judge == "True":
        prompt = prompt_adapter.llm_as_judge_prompt(sample, model_response)
        # print('judge')

        response = clients['openai'].chat.completions.create(
            model=judge_model_name,
            messages=prompt,
            # temperature=0,
            # seed=42,
            # top_p=1,
        )
        response = response.choices[0].message.content
        judge = re.findall(r'<judge>(.*?)</judge>', response, re.DOTALL)
        if len(judge) > 0:
            judge = judge[0]
            # raise Exception(f"LLM judgement without format. :{response}")
        score = 0
        correct = False
        if judge == 'correct':
            correct = True
            score = 1
        elif judge == 'typical':
            score = -1
        # return correct, score, "[empty]"
        return correct, score, response
    else:
        return False, 0, "[empty]"

def run_model_sample(sample,thinking_mode=False,bot_token="<think>",eot_token="</think>"):
        prompt = get_prompt(sample, thinking_mode)
        if not thinking_mode and eval_model_name == 'mimo-vl':
            prompt[-1]['content'][-1]['text'] += '/no_think'
        response = client.chat.completions.create(
            model=eval_model_name,
            messages=prompt,
            max_tokens=5000,
            temperature=0,
            # top_p=1
        )
        response = response.choices[0].message.content
        if thinking_mode:
            thinking, summary = extract_thinking_and_summary(response,bot_token,eot_token)
            choice = summary
        else:
            choice = response
        correct, score, judge_detail = llm_as_judge(sample, choice)

        answer = {
            **sample,
            "response": response,
            "score": score,
            "corrects": correct,
            "llm_judge_detail": judge_detail
        }
        return answer


def run_model():

    with open(args.qa_json_file,'r',encoding="utf-8") as f:
        dataset = json.load(f)

    thinking = False
    if args.thinking == 'True':
        thinking = True
    category = {
        "counting": 0,
        "detection": 0,
        "shape": 0,
        "color": 0,
        "logic": 0,
        "ocr": 0,
        "position": 0
    }
    for i in dataset:
        if type(i['category']) is list:
            for k in i['category']:
                category[k] += 1
        else:
            category[i['category']] += 1
    print(category)
    json_answer = []
    for idx,i in enumerate(tqdm(dataset)):
        json_answer.append(run_model_sample(i,thinking,args.bot,args.eot))
    hallu = 0
    correct = 0
    for i in json_answer:
        if i['score'] == 1:
            correct += 1
        elif i['score'] == -1:
            hallu += 1
    print("HALLU:", hallu, ' CORRECT:',correct)
    save_path = os.path.join(args.save_path,f"{eval_model_name.replace('/','_')}_qa_{datetime.now().strftime("%Y%m%d_%H%M%S")}.json")
    with open(save_path,'w',encoding='utf-8') as f:
        json.dump(json_answer, f, ensure_ascii=False, indent=2)

if __name__ == "__main__":
    # print('PROMPT TYPE:',args.cot, 'URL:',args.api_url)
    run_model()