import base64
import mimetypes

import tempfile
import uuid
from datetime import datetime

from PIL import Image
import tempfile
import traceback
import torch
from tqdm import tqdm
import json
import json
import os
from openai import OpenAI
import math
from io import BytesIO
import argparse
import requests
from multiprocessing import Pool, cpu_count
os.environ["TOKENIZERS_PARALLELISM"] = "true"
parser = argparse.ArgumentParser()
parser.add_argument('--counterfactual', type=str, default="True", help='With counterfactual or original image.')
parser.add_argument('--model_name', type=str, default='qwen2.5-vl-32b', help='Model name for result save')
parser.add_argument('--api_key', type=str, default='EMPTY', help='API key')
parser.add_argument('--api_url', type=str, default='', help='API URL')
parser.add_argument('--save_path', type=str, default='./result/', help='Path to save the results')
parser.add_argument('--num_workers', type=int, default=8)
args = parser.parse_args()

openai_api_key = args.api_key
openai_api_base = args.api_url

client = OpenAI(
    api_key=openai_api_key,
    base_url=openai_api_base,
)


eval_model_name = args.model_name

import base64
import requests
from PIL import Image
from io import BytesIO


def get_option_probs(logprobs_entry):
    option_tokens = ['A','B', 'C', 'D']
    raw_probs = {}
    for opt in option_tokens:
        for candidate in logprobs_entry.top_logprobs:
            if candidate.token == opt:
                raw_probs[opt] = math.exp(candidate.logprob)
                break
        else:
            raw_probs[opt] = 0.0
    total = sum(raw_probs.values())
    if total > 0:
        return {opt: p / total for opt, p in raw_probs.items()}
    else:
        return {opt: 0.0 for opt in option_tokens}


def pil_image_to_base64(image):
    buffered = BytesIO()
    image.save(buffered, format='PNG')
    img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
    return img_str


def resize_image(image_path):
    img = Image.open(image_path)
    width, height = img.size
    img = img.resize((int(width / 2), int(height / 2)))
    new_image_path = f"resized_{image_path}"
    img.save(new_image_path)
    return new_image_path
def encode_image(image_path):
    if image_path.startswith("http"):
        user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0"
        request_kwargs = {
            "headers": {"User-Agent": user_agent},
            "stream": True,
        }

        # Send a HTTP request to the URL
        response = requests.get(image_path, **request_kwargs)
        response.raise_for_status()
        content_type = response.headers.get("content-type", "")

        extension = mimetypes.guess_extension(content_type)
        if extension is None:
            extension = ".download"

        fname = str(uuid.uuid4()) + extension
        download_path = os.path.abspath(os.path.join("downloads", fname))

        with open(download_path, "wb") as fh:
            for chunk in response.iter_content(chunk_size=512):
                fh.write(chunk)

        image_path = download_path

    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')


permutation_list = [
    [0, 1, 2, 3],
    [0, 2, 1, 3],
    [1, 0, 2, 3],
    [1, 2, 0, 3],
    [2, 0, 1, 3],
    [2, 1, 0, 3]
]
permutation_list2 = [
    ['A', 'B', 'C', 'D'],
    ['A', 'C', 'B', 'D'],
    ['B', 'A', 'C', 'D'],
    ['C', 'A', 'B', 'D'],
    ['B', 'C', 'A', 'D'],
    ['C', 'B', 'A', 'D']
]


def get_prompt(sample,perm,perm_rev,image_path):
    return [
            {"role": "system", "content": """
You are a helpful agent.Here is an image with a multiple choice question about the image content. You should reply the question according to the image faithfully. Please note that the question maybe confusing or the image content might be uncommon, you should answer the question ONLY with the correct choice letter.
Here is an example:
#########
[IMAGE]
Question:Does the Teapot in the picture have a handle? If so, where is it located?
Choices:
A. Not visible / Can't see.
B. Yes, on the side.
C. Yes, arched over the top.
D. The correct answer is not listed.

Your answer: A.
#########
Now please answer the question following the above format STRICTLY.
            """},
            {"role": "user", "content": [
                {
                    "type": "image_url",
                    "image_url": {
                        "url": f"data:image/jpeg;base64,{encode_image(image_path)}"}
                },
                {
                    "type":"text",
                    "text":f"""
Question:{sample['question']}
Choices:
A. {sample['options'][perm[0]]}
B. {sample['options'][perm[1]]}
C. {sample['options'][perm[2]]}
D. {sample['options'][perm[3]]}
Your answer:"""
                },
             ]}
        ]

def run_model_sample(sample):
    counterfactual = sample['counterfactual']
    # print(counterfactual)
    responses = []
    scores = []
    corrects = []
    choices = []
    probs = []
    if counterfactual == 'False':
        image_path = "./generated_dataset/subset/" + sample['image_path'][:4] + "_o.jpeg"
    else:
        image_path = "./generated_dataset/final_part2/" + sample['image_path'][:4] + ".jpeg"
    for permutation,permutation_reverse in zip(permutation_list,permutation_list2):

        prompt = get_prompt(sample, permutation, permutation_reverse,image_path)
        response = client.chat.completions.create(
            model=eval_model_name,
            messages=prompt,
            max_tokens=6000,
            temperature=0,
            top_p=1,
            logprobs=True,
            top_logprobs=20
        )
        # print(response)
        logprobs_list = response.choices[0].logprobs.content
        probs_dict = get_option_probs(logprobs_list[0])
        response = response.choices[0].message.content

        score = 0
        correct = False
        response = response.strip()
        response = response.replace('*','')
        if response.startswith("Your answer:"):
            response = response[12:]
        choice = response
        if choice.startswith('A.') or choice.startswith('B.') or choice.startswith('C.') or choice.startswith('D.'):
            choice = choice[0]
        elif choice == 'A' or choice == 'B' or choice == 'C' or choice == 'D':
            choice = choice[0]
        else:
            if 'A.' in choice:
                choice = 'A'
            elif 'B.' in choice:
                choice = 'B'
            elif 'C.' in choice:
                choice = 'C'
            elif 'D.' in choice:
                choice = 'D'
            else:
                choice = 'D'
                print(sample['id'])
        correct = False
        if choice == permutation_reverse[0]:
            score = 1
            correct = True
        elif choice == (permutation_reverse[1]):
            score = -1
        responses.append(response)
        choices.append(choice)
        corrects.append(correct)
        scores.append(score)
        probs.append(probs_dict)
    answer = {
        **sample,
        "responses":responses,
        "choices":choices,
        "scores":scores,
        "corrects":corrects,
        "options_probs":probs
    }
    return answer

def run_model(counterfactual):
    with open("subset.json",'r',encoding="utf-8") as f:
        dataset = json.load(f)
    for i in dataset:
        i['counterfactual'] = counterfactual
    category = {
        "counting":0,
        "detection":0,
        "shape":0,
        "color":0,
        "logic":0,
        "ocr":0,
        "position":0
    }
    for i in dataset:
        if type(i["category"]) is list:
            for k in i["category"]:
                category[k] += 1
        else:
            category[i['category']] += 1
    print(category)
    print(int(cpu_count() * 0.9))
    json_answer = []
    for i in tqdm(dataset):
        json_answer.append(run_model_sample(i))
    score = 0
    correct = 0
    wrong = 0
    for i in json_answer:
        wrong += 6 - sum(i['corrects'])
        correct += sum(i['corrects'])
        score += sum(i['scores'])
    print("SCORE:",score / (6 * len(json_answer)), ' CORRECT:',correct)
    with open(f"EXP2_{args.model_name}_{counterfactual}_shuffle_{datetime.now().strftime("%Y%m%d_%H%M%S")}.json",'w',encoding='utf-8') as f:
        json.dump(json_answer, f, ensure_ascii=False, indent=2)

if __name__ == "__main__":
    run_model("True")
    run_model("False")
