import re
from multiprocessing import Pool, cpu_count

from openai import OpenAI
import json
import uuid
import mimetypes
from datetime import datetime
import os
from tqdm import tqdm
import prompt_adapter
from zai import ZaiClient
import argparse


with open("env.json",'r') as f:
    api_keys = json.load(f)

def encode_image(image_path):
    if image_path.startswith("http"):
        user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0"
        request_kwargs = {
            "headers": {"User-Agent": user_agent},
            "stream": True,
        }

        # Send a HTTP request to the URL
        response = requests.get(image_path, **request_kwargs)
        response.raise_for_status()
        content_type = response.headers.get("content-type", "")

        extension = mimetypes.guess_extension(content_type)
        if extension is None:
            extension = ".download"

        fname = str(uuid.uuid4()) + extension
        download_path = os.path.abspath(os.path.join("downloads", fname))

        with open(download_path, "wb") as fh:
            for chunk in response.iter_content(chunk_size=512):
                fh.write(chunk)

        image_path = download_path

    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')


headers = {
    "Content-Type": "application/json",
    "Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}"
}

clients = {
    'openai':OpenAI(
    base_url="openai_base_url",
    api_key=api_keys['OPENAI_API_KEY'],
    timeout=120),
    'glm': ZaiClient(api_key=api_keys['ZAI_API_KEY']),
    'judge':OpenAI(
    base_url="openai_base_url",
    api_key=api_keys['OPENAI_API_KEY'],
    timeout=120)
}


import base64
import requests
from PIL import Image
from io import BytesIO


def pil_image_to_base64(image):
    buffered = BytesIO()
    image.save(buffered, format='PNG')
    img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
    return img_str


def resize_image(image_path):
    img = Image.open(image_path)
    width, height = img.size
    img = img.resize((int(width / 2), int(height / 2)))
    new_image_path = f"resized_{image_path}"
    img.save(new_image_path)
    return new_image_path


model_name = "o4-mini"
model_family = 'gemini'
prompt_type = "wo"
judge_model_name = "None"
thinking = "Yes"
def llm_as_judge(sample, model_response):

    prompt = prompt_adapter.llm_as_judge_prompt(sample, model_response)
    response = clients['judge'].chat.completions.create(
        model=judge_model_name,
        messages=prompt,
        # temperature=0,
        seed=42,
        top_p=1,
    )
    response = response.choices[0].message.content
    judge = re.findall(r'<judge>(.*?)</judge>', response, re.DOTALL)
    if len(judge) > 0:
        judge = judge[0]
        # raise Exception(f"LLM judgement without format. :{response}")
    score = 0
    correct = False
    if judge == 'correct':
        correct = True
        score = 1
    elif judge == 'typical':
        score = -1
    return correct, score, response



def run_model_sample(sample):
    if model_family == 'glm':
        client = clients['glm']
    elif model_family == 'internvl3.5':
        client = clients['intern']
    else:
        client = clients['openai']
    prompt = prompt_adapter.get_prompt_qa(model_family, sample)
    if thinking == 'False':
        if model_name == 'o3':
            response = client.chat.completions.create(
                model=model_name,
                temperature=0,
                seed=42,
                top_p=1,
                messages=prompt,
                max_tokens=10002,
                reasoning_effort="low",
            )
        elif model_name == 'glm-4.5v':
            response = client.chat.completions.create(
                model=model_name,
                thinking={
                    "type": "enabled"
                },
                temperature=0,
                top_p=1,
                messages=prompt,
                max_tokens=8192,
            )
        elif model_name == 'gemini-2.5-flash':
            response = client.chat.completions.create(
                model=model_name + "-nothinking",
                temperature=0,
                seed=42,
                top_p=1,
                messages=prompt,
                max_tokens=10002,
            )
    else:
        response = client.chat.completions.create(
            model=model_name,
            temperature=0,
            seed=42,
            top_p=1,
            messages=prompt,
            max_tokens=10002,
        )
    # print('END')

    response = response.choices[0].message.content
    score = 0
    correct, score, judge_detail = llm_as_judge(sample, response)
    answer = {
        **sample,
        "response":response,
        "score":score,
        "corrects":correct,
        "llm_judge_detail": judge_detail
    }
    return answer

def automatic_model_family(model_name):
    if ("o3" in model_name) or ("o4" in model_name) or ("4o" in model_name) or ("gpt" in model_name):
        return "gemini"
    elif "qwen" in model_name:
        return "gemini"
    elif "claude" in model_name:
        return "claude"
    elif "grok" in model_name:
        return "grok"
    elif "glm" in model_name:
        return "glm"
    elif "internvl3.5" in model_name:
        return "internvl3.5"

    return "gemini"

def run_model(args):
    global model_name,model_family,prompt_type, judge_model_name, thinking
    judge_model_name = args.judge_model_name
    model_name = args.model_name
    prompt_type = args.prompt_type
    thinking = args.thinking
    if args.model_family == "None":
        model_family = automatic_model_family(model_name)
    else:
        model_family = args.model_family

    cnt_save_dir = args.save_dir
    with open(args.qa_json_file,'r',encoding="utf-8") as f:
        dataset = json.load(f)
    category = {
        "counting":0,
        "detection":0,
        "shape":0,
        "color":0,
        "logic":0,
        "ocr":0,
        "position":0,
        "attribute":0,
    }
    for i in dataset:
        if type(i['category']) is list:
            for k in i['category']:
                category[k] += 1
            if 'color' in i['category'] or 'shape' in i['category']:
                category['attribute'] += 1
        else:
            category[i['category']] += 1
    print(category)
    json_answer = []
    with Pool(processes=args.parallel_num) as pool:
        json_answer = list(tqdm(pool.imap(run_model_sample,dataset), total=len(dataset)))
    hallu = 0
    correct = 0
    for i in json_answer:
        if i['score'] == 1:
            correct += 1
        elif i['score'] == -1:
            hallu += 1
    print("HALLU:", hallu, ' CORRECT:',correct)
    if thinking == 'False':
        save_dir = os.path.join(cnt_save_dir,
                                f"no_think_{model_name.replace('/', '_')}_qa_{datetime.now().strftime("%Y%m%d_%H%M%S")}.json")
    else:
        save_dir = os.path.join(cnt_save_dir,f"{model_name.replace('/','_')}_qa_{datetime.now().strftime("%Y%m%d_%H%M%S")}.json")
    with open(save_dir,'w',encoding='utf-8') as f:
        json.dump(json_answer, f, ensure_ascii=False, indent=2)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--judge_model_name', type=str, default='gpt-5-mini-2025-08-07')
    parser.add_argument('--model_name', type=str, default='gemini-2.5-flash', help='Model name for evaluating')
    parser.add_argument('--model_family', type=str, default='None', help='Model name for result save')
    parser.add_argument('--parallel_num', type=int, default=int(cpu_count() * 0.8),help='Number of parallel processes')
    parser.add_argument('--save_dir',type=str,default='./final_result_qa/',help='')
    parser.add_argument('--prompt_type', type=str, default='wo',help='wo / normal / sbs')
    parser.add_argument('--thinking', type=str, default='True')
    parser.add_argument('--qa_json_file', type=str, default='./generated_dataset/dataset_qa.json')
    parser.add_argument('--images_dir', type=str, default='./generated_dataset/final_part2')

    args = parser.parse_args()
    run_model(args)