import mimetypes
import re

import uuid
from datetime import datetime

from tqdm import tqdm
import json
import os
from openai import OpenAI
import argparse
from supplymentary_materials.inference.mcq import prompt_adapter, prompt_adapter_CoT

parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default='qwen', help='Model name for result save')
parser.add_argument('--api_key', type=str, default='EMPTY', help='API key')
parser.add_argument('--api_url', type=str, default='', help='API URL')
# parser.add_argument('--vstar_bench_path', type=str, default=None, help='Path to the V* benchmark')
parser.add_argument('--save_path', type=str, default='./', help='Path to save the results')
parser.add_argument('--num_workers', type=int, default=8)
parser.add_argument('--cot',type=str,default='False')
parser.add_argument('--thinking',type=str,default='False')
# parser.add_argument("--prompt", type=str, default='v0')
parser.add_argument('--bot',type=str,default='<think>')
parser.add_argument('--eot',type=str,default='</think>')
parser.add_argument('--mcq_json_file', type=str, default='./generated_dataset/dataset.json')
parser.add_argument('--images_dir', type=str, default='./generated_dataset/final_part2')

args = parser.parse_args()


openai_api_key = args.api_key
openai_api_base = args.api_url

client = OpenAI(
    api_key=openai_api_key,
    base_url=openai_api_base,
)

eval_model_name = args.model_name
import base64
import requests
from PIL import Image
from io import BytesIO


def pil_image_to_base64(image):
    buffered = BytesIO()
    image.save(buffered, format='PNG')
    img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
    return img_str


def resize_image(image_path):
    img = Image.open(image_path)
    width, height = img.size
    img = img.resize((int(width / 2), int(height / 2)))
    new_image_path = f"resized_{image_path}"
    img.save(new_image_path)
    return new_image_path
def encode_image(image_path):
    if image_path.startswith("http"):
        user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0"
        request_kwargs = {
            "headers": {"User-Agent": user_agent},
            "stream": True,
        }

        # Send a HTTP request to the URL
        response = requests.get(image_path, **request_kwargs)
        response.raise_for_status()
        content_type = response.headers.get("content-type", "")

        extension = mimetypes.guess_extension(content_type)
        if extension is None:
            extension = ".download"

        fname = str(uuid.uuid4()) + extension
        download_path = os.path.abspath(os.path.join("downloads", fname))

        with open(download_path, "wb") as fh:
            for chunk in response.iter_content(chunk_size=512):
                fh.write(chunk)

        image_path = download_path

    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')


permutation_list = [
    [0, 1, 2, 3],
    [0, 2, 1, 3],
    [1, 0, 2, 3],
    [1, 2, 0, 3],
    [2, 0, 1, 3],
    [2, 1, 0, 3]
]
permutation_list2 = [
    ['A', 'B', 'C', 'D'],
    ['A', 'C', 'B', 'D'],
    ['B', 'A', 'C', 'D'],
    ['C', 'A', 'B', 'D'],
    ['B', 'C', 'A', 'D'],
    ['C', 'B', 'A', 'D']
]

def extract_thinking_and_summary(text: str, bot: str = "<think>", eot: str = "</think>") -> str:

    if eot in text and bot in text:

        return text[text.index(bot) + len(bot):text.index(eot)].strip(), text[text.index(eot) + len(eot) :].strip()
    return "", text


def get_prompt(sample,perm,perm_rev,cot="False"):
    model_family = "internvl"
    image_path = os.path.join(args.images_dir,sample['image_path'])
    if cot == "True":
        # print("hello")
        return prompt_adapter_CoT.get_prompt(model_family, sample, perm, image_path)
    else:
        # Step by step thinking:
        # return prompt_adapter.get_prompt_v3_icl(model_family,sample,perm,image_path)
        return prompt_adapter.get_prompt_v4(model_family, sample, perm, image_path)

        # Normal prompt
        # return prompt_adapter.get_prompt(model_family,sample,perm,image_path)


def run_model_sample(sample,cot,thinking_mode=False,bot_token="<think>",eot_token="</think>"):
    counterfactual = True
    responses = []
    scores = []
    corrects = []
    choices = []
    for permutation,permutation_reverse in zip(permutation_list,permutation_list2):
        prompt = get_prompt(sample, permutation, permutation_reverse,cot)
        response = client.chat.completions.create(
            model=eval_model_name,
            messages=prompt,
            max_tokens=6000,
            temperature=0,
            # top_p=1
        )
        response = response.choices[0].message.content
        if thinking_mode:
            thinking, summary = extract_thinking_and_summary(response,bot_token,eot_token)
            choice = summary
        else:
            choice = response
        score = 0
        correct = False
        response = response.strip()

        pattern = r'<answer>(.*?)</answer>'
        choice1 = re.findall(pattern, choice, re.DOTALL)
        if len(choice1) >= 1:
            choice = choice1[0]
        if choice == 'A' or choice == 'B' or choice == 'C' or choice == 'D':
            choice = choice[0]
        else:
            if choice[-1] == 'A':
                choice = 'A'
            elif choice[-1] == 'B' in choice:
                choice = 'B'
            elif choice[-1] == 'C' in choice:
                choice = 'C'
            elif choice[-1] == 'D' in choice:
                choice = 'D'
            else:
                if 'A.' in choice or 'is A' in choice or 'is: A' in choice or "{A}" in choice:
                    choice = 'A'
                elif 'B.' in choice or 'is B' in choice or 'is: B' in choice or "{B}" in choice:
                    choice = 'B'
                elif 'C.' in choice or 'is C' in choice or 'is: C' in choice or "{C}" in choice:
                    choice = 'C'
                elif 'D.' in choice or 'is D' in choice or 'is: D' in choice or "{D}" in choice:
                    choice = 'D'
                else:
                    choice = 'D'
        correct = False
        if choice == permutation_reverse[0]:
            score = 1
            correct = True
        elif choice == (permutation_reverse[1]):
            score = -1
        responses.append(response)
        choices.append(choice)
        corrects.append(correct)
        scores.append(score)
    answer = {
        **sample,
        "responses":responses,
        "choices":choices,
        "scores":scores,
        "corrects":corrects
    }
    return answer

def run_model():
    with open(args.mcq_json_file,'r',encoding="utf-8") as f:
        dataset = json.load(f)
    category = {
        "counting":0,
        "detection":0,
        "shape":0,
        "color":0,
        "logic":0,
        "ocr":0,
        "position":0
    }
    for i in dataset:
        if type(i["category"]) is list:
            for k in i["category"]:
                category[k] += 1
        else:
            category[i['category']] += 1
    thinking = False
    if args.thinking == 'True':
        thinking = True
    print(category)
    json_answer = []
    # if os.path.isfile(f"{eval_model_name}_shuffle_0828_{args.cot}.json"):
    #     with open(f"{eval_model_name}_shuffle_0828_{args.cot}.json",'r',encoding="utf-8") as f:
    #     # with open(f"{eval_model_name}_probe1_{datetime.now().strftime("%Y%m%d")}.json",'r',encoding="utf-8") as f:
    #         json_answer = json.load(f)
    # else:
    #     json_answer = []
    for idx,i in enumerate(tqdm(dataset)):
        json_answer.append(run_model_sample(i,args.cot,thinking,args.bot,args.eot))
    score = 0
    correct = 0
    wrong = 0
    for i in json_answer:
        wrong += 6 - sum(i['corrects'])
        correct += sum(i['corrects'])
        score += sum(i['scores'])
    print("SCORE:",score / (6 * len(json_answer)), ' CORRECT:',correct)
    with open(f"{eval_model_name}_v4_{datetime.now().strftime("%Y%m%d_%H%M%S")}_{args.cot}.json",'w',encoding='utf-8') as f:
        json.dump(json_answer, f, ensure_ascii=False, indent=2)

if __name__ == "__main__":
    print('PROMPT TYPE:',args.cot, 'URL:',args.api_url)
    run_model()