import re
from multiprocessing import Pool, cpu_count

from openai import OpenAI
import json
import uuid
import mimetypes
from datetime import datetime
import os
from tqdm import tqdm
from supplymentary_materials.inference import prompt_adapter, prompt_adapter_CoT
from zai import ZaiClient
import argparse
with open("env.json",'r') as f:
    api_keys = json.load(f)

def encode_image(image_path):
    if image_path.startswith("http"):
        user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0"
        request_kwargs = {
            "headers": {"User-Agent": user_agent},
            "stream": True,
        }

        # Send a HTTP request to the URL
        response = requests.get(image_path, **request_kwargs)
        response.raise_for_status()
        content_type = response.headers.get("content-type", "")

        extension = mimetypes.guess_extension(content_type)
        if extension is None:
            extension = ".download"

        fname = str(uuid.uuid4()) + extension
        download_path = os.path.abspath(os.path.join("downloads", fname))

        with open(download_path, "wb") as fh:
            for chunk in response.iter_content(chunk_size=512):
                fh.write(chunk)

        image_path = download_path

    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')


headers = {
    "Content-Type": "application/json",
    "Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}"
}

clients = {
    'openai':OpenAI(
    base_url="openai_base_url",
    api_key=api_keys['OPENAI_API_KEY'],
    timeout=120),
    'glm': ZaiClient(api_key=api_keys['ZAI_API_KEY']),
}

import base64
import requests
from PIL import Image
from io import BytesIO


def pil_image_to_base64(image):
    buffered = BytesIO()
    image.save(buffered, format='PNG')
    img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
    return img_str


def resize_image(image_path):
    img = Image.open(image_path)
    width, height = img.size
    img = img.resize((int(width / 2), int(height / 2)))
    new_image_path = f"resized_{image_path}"
    img.save(new_image_path)
    return new_image_path


model_name = "o4-mini"
model_family = 'gemini'
prompt_type = "wo"

permutation_list = [
    [0,1,2,3],
    [0,2,1,3],
    [1,0,2,3],
    [1,2,0,3],
    [2,0,1,3],
    [2,1,0,3]
]
permutation_list2 = [
    ['A','B','C','D'],
    ['A','C','B','D'],
    ['B','A','C','D'],
    ['C','A','B','D'],
    ['B','C','A','D'],
    ['C','B','A','D']
]

def process_response(id,response):
    """

    :param response:  text(str)
    :return:
    """
    response = response.strip()
    response = response.replace('*', '')
    if response.startswith("Your answer:"):
        response = response[12:]
    response = response.strip()
    choice = response
    # print(response)
    pattern = r'<answer>(.*?)</answer>'
    choice1 = re.findall(pattern, response, re.DOTALL)
    # print(choice1)
    if len(choice1) >= 1:
        # print(choice1,'HELLO')
        choice = choice1[0]
    if choice == 'A' or choice == 'B' or choice == 'C' or choice == 'D':
        choice = choice[0]
    else:
        if len(choice) < 1:
            choice = 'D'

        elif choice[-1] == 'A':
            choice = 'A'
        elif choice[-1] == 'B' in choice:
            choice = 'B'
        elif choice[-1] == 'C' in choice:
            choice = 'C'
        elif choice[-1] == 'D' in choice:
            choice = 'D'
        else:
            # print(choice)
            if 'A.' in choice or 'A:' in choice or '[A]' in choice or '{A}' in choice or '*A*' in choice:
                # print('笨哦')
                choice = 'A'
            elif 'B.' in choice or 'B:' in choice or '[B]' in choice or '{B}' in choice or '*B*' in choice:
                choice = 'B'
            elif 'C.' in choice or 'C:' in choice or '[C]' in choice or '{C}' in choice or '*C*' in choice:
                choice = 'C'
            elif 'D.' in choice or 'D:' in choice or '[D]' in choice or '{D}' in choice or '*D*' in choice:
                choice = 'D'
            else:
                choice = 'D'
    return choice

def run_model_sample(sample):
    if model_family == 'glm':
        client = clients['glm']
    elif model_family == 'internvl3.5':
        client = clients['intern']
    else:
        client = clients['openai']

    responses = []
    scores = []
    corrects = []
    choices = []
    for permutation,permutation_reverse in zip(permutation_list,permutation_list2):
        if prompt_type == 'wo':
            prompt = prompt_adapter.get_prompt(model_family, sample, permutation)
        elif prompt_type == 'sbs':
            prompt = prompt_adapter.get_prompt_v3(model_family, sample, permutation)
        elif prompt_type == 'cot':
            prompt = prompt_adapter_CoT.get_prompt(model_family, sample, permutation)
        # print(prompt[-1]['content'][-1])
        # print('HELLO')
        # print(prompt)
        response = client.chat.completions.create(
            model=model_name,
            temperature=0,
            seed=42,
            top_p=1,
            messages=prompt,
            reasoning_effort="low",
            max_tokens=10002,
        )
        # print('END')

        response = response.choices[0].message.content
        # print(response)
        score = 0
        choice = process_response(sample['id'],response)
        correct = False
        if choice == permutation_reverse[0]:
            score = 1
            correct = True
        elif choice == (permutation_reverse[1]):
            score = -1
        responses.append(response)
        choices.append(choice)
        corrects.append(correct)
        scores.append(score)
    answer = {
        **sample,
        "responses":responses,
        "choices":choices,
        "scores":scores,
        "corrects":corrects
    }
    return answer

def automatic_model_family(model_name):
    if "o3" or "o4" or "4o" or "gpt" in model_name:
        return "gemini"
    elif "gemini" or "qwen" in model_name:
        return "gemini"
    elif "claude" in model_name:
        return "claude"
    elif "grok" in model_name:
        return "grok"
    elif "glm" in model_name:
        return "glm"
    elif "internvl3.5" in model_name:
        return "internvl3.5"

    return "gemini"

def run_model(args):
    global model_name,model_family,prompt_type
    model_name = args.model_name
    prompt_type = args.prompt_type

    if args.model_family == "None":
        model_family = automatic_model_family(model_name)
    else:
        model_family = args.model_family
    cnt_save_dir = args.save_dir
    with open(args.mcq_json_file,'r',encoding="utf-8") as f:
        dataset = json.load(f)
    category = {
        "counting":0,
        "detection":0,
        "shape":0,
        "color":0,
        "logic":0,
        "ocr":0,
        "position":0
    }
    for i in dataset:
        if type(i['category']) is list:
            for k in i['category']:
                category[k] += 1
        else:
            category[i['category']] += 1
    print(category)
    json_answer = []
    # for i in tqdm(dataset):
    #     json_answer.append(run_model_sample(i))
    with Pool(processes=args.parallel_num) as pool:
        json_answer = list(tqdm(pool.imap(run_model_sample,dataset), total=len(dataset)))
    score = 0
    correct = 0
    wrong = 0
    for i in json_answer:
        wrong += 6 - sum(i['corrects'])
        correct += sum(i['corrects'])
        score += sum(i['scores'])
    print("SCORE:",score / (6 * len(json_answer)), ' CORRECT:',correct)
    with open(f"{cnt_save_dir}/{model_name.replace('/','_')}_shuffle_{datetime.now().strftime("%Y%m%d_%H%M%S")}.json",'w',encoding='utf-8') as f:
        json.dump(json_answer, f, ensure_ascii=False, indent=2)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str, default='gemini-2.5-flash', help='Model name for evaluating')
    parser.add_argument('--model_family', type=str, default='None', help='Model name for result save')
    parser.add_argument('--parallel_num', type=int, default=int(cpu_count() * 0.8),help='Number of parallel processes')
    parser.add_argument('--save_dir',type=str,default='./result_exp1/',help='')
    parser.add_argument('--prompt_type', type=str, default='wo',help='wo / cot / sbs')
    parser.add_argument('--mcq_json_file', type=str, default='./generated_dataset/dataset.json')
    parser.add_argument('--images_dir', type=str, default='./generated_dataset/final_part2')

    args = parser.parse_args()
    run_model(args)