import base64
import mimetypes
import re

import tempfile
import uuid
from datetime import datetime

import PIL.Image
from PIL import Image
import tempfile
import traceback
import torch
from tqdm import tqdm
import json
import json
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, AutoModel, AutoTokenizer
import os
from openai import OpenAI
import math
from io import BytesIO
import argparse
import requests
from multiprocessing import Pool, cpu_count
import multiprocessing as mp
from choice_probe import *
from conversations import *
from vllm import LLM, SamplingParams
import logging
import sys
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default='qwen', help='Model name for result save')
parser.add_argument('--api_key', type=str, default='EMPTY', help='API key')
parser.add_argument('--api_url', type=str, default='', help='API URL')
parser.add_argument('--save_path', type=str, default='./result/', help='Path to save the results')
parser.add_argument('--num_workers', type=int, default=8)
parser.add_argument('--prompt_template', type=str, default="v2")
parser.add_argument('--model_path', type=str, default="None")
parser.add_argument('--tokenizer_name', type=str, default="OpenGVLab/InternVL3-38B")
parser.add_argument('--tensor_parallel_size', type=int, default=8)
parser.add_argument('--max_model_len', type=int, default=12000)
parser.add_argument('--gpu_memory_utilization', type=float, default=0.85)
parser.add_argument('--max_num_seqs', type=int, default=32)
parser.add_argument('--temperature', type=float, default=0)
parser.add_argument('--max_tokens', type=int, default=3000)
parser.add_argument('--part', type=int, default=1)
parser.add_argument('--place_holder',type=str,default='<image>')


args = parser.parse_args()


openai_api_key = args.api_key
openai_api_base = args.api_url

client = OpenAI(
    api_key=openai_api_key,
    base_url=openai_api_base,
)

eval_model_name = args.model_name


def get_option_probs(logprobs_entry):
    option_tokens = ['>A', '>B', '>C', '>D']
    # print(logprobs_entry)
    raw_probs = {}
    for opt in option_tokens:
        for tokenized_id,candidate in logprobs_entry.items():
            if candidate.decoded_token == opt:
                raw_probs[opt[1]] = math.exp(candidate.logprob)
                break
        else:
            raw_probs[opt] = 0.0
    total = sum(raw_probs.values())
    if total > 0:
        return {opt: p / total for opt, p in raw_probs.items()}
    else:
        return {opt: 0.0 for opt in option_tokens}


permutation_list = [
    [0, 1, 2, 3],
    [0, 2, 1, 3],
    [1, 0, 2, 3],
    [1, 2, 0, 3],
    [2, 0, 1, 3],
    [2, 1, 0, 3]
]
permutation_list2 = [
    ['A', 'B', 'C', 'D'],
    ['A', 'C', 'B', 'D'],
    ['B', 'A', 'C', 'D'],
    ['C', 'A', 'B', 'D'],
    ['B', 'C', 'A', 'D'],
    ['C', 'B', 'A', 'D']
]

PLACEHOLDER = '<image>'

def run_model_sample(sample):
    counterfactual = True
    responses = []
    scores = []
    corrects = []
    choices = []
    history_probs = []
    history_thinkings = []
    prompt_template = None

    if args.prompt_template == 'v0':
        prompt_template = Conversation("base64",True,prompt_cot_instruction_v0_long_response)
    elif args.prompt_template == 'v1':
        prompt_template = Conversation("base64",True,prompt_cot_instruction_v1)
    elif args.prompt_template == 'v2':
        prompt_template = Conversation("base64",True,prompt_cot_instruction_v2)
    elif args.prompt_template == 'v3':
        prompt_template = Conversation("base64", True, prompt_cot_instruction_v3)

    if not counterfactual:
        image_path = "./generated_dataset/final_part2/" + sample['image_path'][:4] + "_o.jpeg"
    else:
        image_path = "./generated_dataset/final_part2/" + sample['image_path']

    for permutation,permutation_reverse in zip(permutation_list,permutation_list2):
        prompt = prompt_template.format(sample, image_path, permutation)
        response = client.chat.completions.create(
            model=eval_model_name,
            messages=prompt,
            max_tokens=6000,
            temperature=0,
            top_p=1
        )
        response = response.choices[0].message.content
        # print(response)
        probe_prompts = get_probe_prompt(" Based on above analysis, I will stop thinking, and the final answer is <answer>",response,prompt,True,client,eval_model_name)
        # print(probe_prompts)
        score = 0
        response = response.strip()
        response = response.replace('*','')
        if response.startswith("Your answer:"):
            response = response[12:]
        response = response.strip()
        pattern = r'<answer>(.*?)</answer>'
        choice = response
        choice1 = re.findall(pattern, response, re.DOTALL)
        if len(choice1) >= 1:
            # print(choice1,'HELLO')
            choice = choice1[0]
        if choice == 'A' or choice == 'B' or choice == 'C' or choice == 'D':
            choice = choice[0]
        else:
            if len(choice) < 1:
                print(response, ' ', choice1)
            if choice[-1] == 'A':
                choice = 'A'
            elif choice[-1] == 'B' in choice:
                choice = 'B'
            elif choice[-1] == 'C' in choice:
                choice = 'C'
            elif choice[-1] == 'D' in choice:
                choice = 'D'
            else:
                # print(choice)
                if 'A.' in choice or 'A:' in choice or '[A]' in choice:
                    choice = 'A'
                elif 'B.' in choice or 'B:' in choice or '[B]' in choice:
                    choice = 'B'
                elif 'C.' in choice or 'C:' in choice or '[C]' in choice:
                    choice = 'C'
                elif 'D.' in choice or 'D:' in choice or '[D]' in choice:
                    choice = 'D'
                else:
                    choice = 'D'
        correct = False
        if choice == permutation_reverse[0]:
            score = 1
            correct = True
        elif choice == (permutation_reverse[1]):
            score = -1
        history_prob = []
        history_thinking = []
        for probe in probe_prompts:
            # print(probe)
            probe_response = client.chat.completions.create(
                model=eval_model_name,
                messages=probe,
                max_tokens=6000,
                temperature=0,
                top_p=1,
                logprobs=True,
                top_logprobs=20
            )
            history_thinking.append(probe[-1]['content'][0]['text'] + "|" + probe_response.choices[0].message.content)
            # print(prompt[0]['content'])
            # print(" Model output:" + probe_response.choices[0].message.content)
            # print(history_thinking[-1])
            logprobs_list = probe_response.choices[0].logprobs.content
            response_text = probe_response.choices[0].message.content
            # print(response_text)
            if response_text[0] == 'A' or response_text[0] == 'B' or response_text[0] == 'C' or response_text[0] == 'D':
                probs_dict = get_option_probs(logprobs_list[0])
                history_prob.append(probs_dict)
            elif response_text[1] == 'A' or response_text[1] == 'B' or response_text[1] == 'C' or response_text[1] == 'D':
                probs_dict = get_option_probs(logprobs_list[1])
                history_prob.append(probs_dict)
            else:
                print('Error at :',sample['id'])
                print(history_thinking[-1])
        responses.append(response)
        choices.append(choice)
        corrects.append(correct)
        scores.append(score)
        history_probs.append(history_prob)
        history_thinkings.append(history_thinking)
    answer = {
        **sample,
        "responses":responses,
        "choices":choices,
        "scores":scores,
        "corrects":corrects,
        "history_probs":history_probs,
        "history_thinkings":history_thinkings
    }
    return answer
def run_model_sample_wo_deploy(llm, sampling_params, tokenizer, sample, placeholder='image'):
    counterfactual = True
    responses = []
    scores = []
    corrects = []
    choices = []
    history_probs = []
    history_thinkings = []
    conversation = None
    if args.prompt_template == 'v0':
        conversation = Conversation("placeholder", True, prompt_cot_instruction_v0_long_response,placeholder)
    elif args.prompt_template == 'v1':
        conversation = Conversation("placeholder", True, prompt_cot_instruction_v1,placeholder)
    elif args.prompt_template == 'v2':
        conversation = Conversation("placeholder", True, prompt_cot_instruction_v2,placeholder)
    elif args.prompt_template == 'v3':
        conversation = Conversation("placeholder", True, prompt_cot_instruction_v3,placeholder)

    if not counterfactual:
        image_path = "./generated_dataset/final_part2/" + sample['image_path'][:4] + "_o.jpeg"
    else:
        image_path = "./generated_dataset/final_part2/" + sample['image_path']

    for permutation, permutation_reverse in zip(permutation_list, permutation_list2):
        inputs = get_multi_modal_inputs(conversation, sample, image_path, permutation, tokenizer)
        request_response = llm.generate(
            inputs,
            sampling_params=sampling_params
        )[0]
        # print(request_response)
        response = request_response.outputs[0].text
        print(response)
        batch_probe_inputs,add_suffix = get_multimodal_probe_inputs(
            " Based on above analysis, I will stop thinking, and the final answer is <answer", response, inputs['prompt'], image_path, True)
        # print(probe_prompts)
        score = 0
        response = response.strip()
        response = response.replace('*', '')
        if response.startswith("Your answer:"):
            response = response[12:]
        response = response.strip()
        pattern = r'<answer>(.*?)</answer>'
        choice = response
        choice1 = re.findall(pattern, response, re.DOTALL)
        if len(choice1) >= 1:
            # print(choice1,'HELLO')
            choice = choice1[0]
        if choice == 'A' or choice == 'B' or choice == 'C' or choice == 'D':
            choice = choice[0]
        else:
            if len(choice) < 1:
                print('##################',response, ' ', choice1)
            if choice[-1] == 'A':
                choice = 'A'
            elif choice[-1] == 'B' in choice:
                choice = 'B'
            elif choice[-1] == 'C' in choice:
                choice = 'C'
            elif choice[-1] == 'D' in choice:
                choice = 'D'
            else:
                # print(choice)
                if 'A.' in choice or 'A:' in choice or '[A]' in choice:
                    choice = 'A'
                elif 'B.' in choice or 'B:' in choice or '[B]' in choice:
                    choice = 'B'
                elif 'C.' in choice or 'C:' in choice or '[C]' in choice:
                    choice = 'C'
                elif 'D.' in choice or 'D:' in choice or '[D]' in choice:
                    choice = 'D'
                else:
                    choice = 'D'
                    print(sample['id'])
        correct = False
        if choice == permutation_reverse[0]:
            score = 1
            correct = True
        elif choice == (permutation_reverse[1]):
            score = -1
        history_prob = []
        history_thinking = []
        batch_request_response = llm.generate(
            batch_probe_inputs,
            sampling_params=sampling_params
        )
        for idx,probe_response in enumerate(batch_request_response):
            logprobs_list = probe_response.outputs[0].logprobs
            response_text = probe_response.outputs[0].text
            # print(response_text)
            if response_text[:2] == '>A' or response_text[0:2] == '>B' or response_text[:2] == '>C' or response_text[:2] == '>D':
                history_thinking.append(add_suffix[idx] + "|" + probe_response.outputs[0].text)
                # print(logprobs_list[0])
                probs_dict = get_option_probs(logprobs_list[0])
                history_prob.append(probs_dict)
            else:
                print('Error at :', sample['id'])
                print(add_suffix[idx] + "|" + probe_response.outputs[0].text)
        responses.append(response)
        choices.append(choice)
        corrects.append(correct)
        scores.append(score)
        history_probs.append(history_prob)
        history_thinkings.append(history_thinking)
    answer = {
        **sample,
        "responses": responses,
        "choices": choices,
        "scores": scores,
        "corrects": corrects,
        "history_probs": history_probs,
        "history_thinkings": history_thinkings
    }
    return answer

def run_model_wo_deploy():
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name)
    stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
    stop_token_ids = [token_id for token_id in stop_token_ids if token_id is not None]
    print('PARALLEL:',args.tensor_parallel_size)
    llm = LLM(model=args.model_path,
              trust_remote_code=True,
              tensor_parallel_size=args.tensor_parallel_size,
              max_model_len=args.max_model_len,
              gpu_memory_utilization=args.gpu_memory_utilization,
              max_num_seqs=args.max_num_seqs,
              )
    sampling_params=SamplingParams(temperature=args.temperature,
                                   seed=42,
                                   max_tokens=args.max_tokens,
                                   stop_token_ids=stop_token_ids,
                                   logprobs=20)
    print('vLLM model started.')
    with open("generated_dataset/dataset.json",'r',encoding="utf-8") as f:
        dataset = json.load(f)
    start_index = 0 if args.part == 1 else 500
    end_index = 500 if args.part == 1 else 1000
    if args.part == 3:
        end_index = 1000
        start_index = 0
    json_answer = []
    for idx,i in enumerate(tqdm(dataset[start_index:end_index])):
        json_answer.append(run_model_sample_wo_deploy(llm,sampling_params,tokenizer,i,args.place_holder))
    score = 0
    correct = 0
    wrong = 0
    for i in json_answer:
        wrong += 6 - sum(i['corrects'])
        correct += sum(i['corrects'])
        score += sum(i['scores'])
    print("SCORE:",score / (6 * len(json_answer)), ' CORRECT:',correct)
    # with open(f"{eval_model_name}_probe1_{datetime.now().strftime("%Y%m%d")}.json",'w',encoding='utf-8') as f:
    with open(f"{eval_model_name}_probe{args.part}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.json",'w',encoding='utf-8') as f:
        json.dump(json_answer, f, ensure_ascii=False, indent=2)



class LoggerWriter:
    def __init__(self, logger_func, is_tqdm=False):
        self.logger_func = logger_func
        self.is_tqdm = is_tqdm
        self.buffer = []

    def write(self, message):
        if self.is_tqdm:
            sys.__stdout__.write(message)
            sys.__stdout__.flush()
        else:
            if message != '\n':
                self.logger_func(message)

    def flush(self):
        pass

if __name__ == "__main__":
    mp.set_start_method("spawn", force=True)
    logging.basicConfig(
        level=logging.INFO,
        filename=f"run_{eval_model_name}_0822_{args.part}.log",
        filemode="w",
        format="%(asctime)s [%(levelname)s] %(message)s",
    )

    sys.stdout = LoggerWriter(logging.info, is_tqdm=True)
    sys.stderr = LoggerWriter(logging.error)

    logger = logging.getLogger(__name__)

    vllm_logger = logging.getLogger("vllm")
    file_handler = logging.FileHandler(f"vllm_{eval_model_name}_0822_{args.part}.log", mode="w")
    file_handler.setLevel(logging.INFO)
    formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
    file_handler.setFormatter(formatter)
    vllm_logger.addHandler(file_handler)

    run_model_wo_deploy()