import json
import base64
import os
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
import openai
from openai import OpenAI
import time
import re
import time

input_file = '' # input jsonl file
output_file = ''    # output jsonl file
refusal_file_1 = '' # refusal jsonl file1()
refusal_file_2 = '' # refusal jsonl file2()
image_base_path = ""

# thread number and retry times
num_threads = 8  # 
max_retries = 1

write_lock = threading.Lock()
error_limit_times = 0
stop_processing_event = threading.Event()

processed_ids = set()
refusal_ids = set()

# load refusal ids
if os.path.exists(refusal_file_1):
    with open(refusal_file_1, 'r', encoding='utf-8') as refusalfile1:
        for line in refusalfile1:
            try:
                refused_data = json.loads(line)
                refusal_ids.add(refused_data['id'])
            except json.JSONDecodeError:
                continue

# load refusal ids
if os.path.exists(refusal_file_2):
    with open(refusal_file_2, 'r', encoding='utf-8') as refusalfile2:
        for line in refusalfile2:
            try:
                refused_data = json.loads(line)
                refusal_ids.add(refused_data['id'])
            except json.JSONDecodeError:
                continue


def process_line(line):
    global error_limit_times
    if stop_processing_event.is_set():
        return

    try:
        data = json.loads(line)     # id, image, conversations
        entry_id = data['id']
        if entry_id in processed_ids:
            return
        if entry_id in refusal_ids:
            return

        print(f"start processing id: {entry_id}")
        ############ process the image ############
        if 'image' in data:
            image_path = data['image']
            image_full_path = os.path.join(image_base_path, image_path)
            if os.path.exists(image_full_path):
                with open(image_full_path, 'rb') as img_file:
                    base64_image = base64.b64encode(img_file.read()).decode('utf-8')
                    image_type = image_full_path.split('.')[-1]
            else:
                print(f"fail to open image {image_full_path}, skip id {entry_id}")
                return
        else:
            return

        ############ init gpt ############
        model_name="gpt-4o-2024-11-20"
        api_key = ''    # your api key
        client = OpenAI(api_key = api_key)

        conversations = data['conversations']

        if (0 < entry_id and entry_id < 3000) :
            hints = 'Photographing a newspaper page constitutes spoofing.'
        elif (3000 < entry_id and entry_id < 5500) or (8500 < entry_id and entry_id < 11000):
            hints = 'Photographing a poster constitutes spoofing.'
        elif (5500 < entry_id and entry_id < 8500) :
            hints = 'Photographing a printed upper-body image constitutes spoofing.'
        elif (11000 < entry_id and entry_id < 14000) :
            hints = 'Photographing an album page constitutes spoofing.'
        elif (14000 < entry_id and entry_id < 19500) or (22500 < entry_id and entry_id < 25000):
            hints = 'Photographing a printed photo constitutes spoofing.'
        elif (19500 < entry_id and entry_id < 22500) :
            hints = 'Photographing an A4-printed image constitutes spoofing.'
        elif (22500 < entry_id and entry_id < 25000):
            hints = 'Photographing a facial photo constitutes spoofing.'
        elif (50000 < entry_id and entry_id < 55000) :
            hints = 'Photographing a computer screen constitutes spoofing.'
        elif (55000< entry_id and entry_id < 60000) or (65000< entry_id and entry_id < 70000):
            hints = 'Photographing a phone screen constitutes spoofing.'
        elif (60000 < entry_id and entry_id < 65000) or (70000 < entry_id and entry_id < 75000):
            hints = 'Photographing a tablet display constitutes spoofing.'
        elif (75000 < entry_id and entry_id < 100000):
            hints = 'If there is no spoofing, fill ""No"" in the spoofing description'
        else:
            hints = data.get('hints', [])

        for index, convo in enumerate(conversations):
            role = convo['from']
            value = convo['value']

            if role == 'gpt':
                retry_attempts = 0
                while retry_attempts < max_retries:
                    content = []
                    content.append({
                        "type": "image_url",
                        "image_url": {'url': f'data:image/{image_type};base64,' + base64_image}
                    })
                    content.append({
                        "type": "text",
                        "text": (
                            'In a face anti-spoofing detection system, a person has submitted an image for inspection. Now please assume the role of the examiner and respond to the task. I need you to strictly follow the format containing six specific parts: caption, facical description, facial attributes, reasoning, spoofing description, and conclusion. It is crucial that you adhere exactly to the outlined structure, and that the final answer in conclusion matches the standard correct answer exactly. Further explanation: In caption, provide a holistic description of the image. In facial description, focus on describing the facial features of the person in the image. In facial attributes, list the facial attributes in JSON format. In reasoning, perform step-by-step logical reasoning based on the image description, facial description, and facial attributes. In spoofing description, describe the spoofing method and highlight the critical spoofing characteristics in detail. In conclusion, give the final answer “Yes” or “No” that must match the standard answer exactly. Format as follows: <Caption> [Holistic description of the image] </Caption> <Facial Description> [Detailed description of the person’s facial features] </Facial Description> <Facial Attributes> [Describe facial attributes in JSON format: "eyes": ..., "mouth": ..., "nose": ..., "skin": ... etc.] </Facial Attributes> <Reasoning> [Step-by-step logical reasoning based on the above] </Reasoning> <Spoofing Description> [Details of the spoofing method: ..., critical spoofing features: ...] </Spoofing Description> <Conclusion> [Answer “Yes” or “No” matching the stan dard answer exactly] </Conclusion> Please carefully apply this format to analyze the given image and answer the related question, ensuring your conclusion matches the reference answer exactly.'
                        )
                    })

                    assert index > 0 and conversations[index - 1]['from'] == 'human'
                    conversations[index - 1]['value'] = conversations[index - 1]['value'].replace('<image>\n', '').replace('\n<image>', '')
                    question = conversations[index - 1]['value']
                    standard_answer = value
                    
                    content.append({
                        "type": "text",
                        "text": "Question: " + question + "\n"
                    })
                    content.append({
                        "type": "text",
                        "text": "Standard answer: " + standard_answer
                    })

                    if hints:
                        added_hints = "".join([f"\nHint: {hint}" for hint in hints])
                        content.append({
                            "type": "text",
                            "text": added_hints
                        })

                    messages = [
                        {
                            'role': 'user',
                            'content': content
                        }
                    ]

                    try:
                        response = client.chat.completions.create(
                            model=model_name,
                            messages=messages,
                            timeout=60,
                            max_tokens=1024
                        )
                        augmented_answer = response.choices[0].message.content
                        data['conversations'][index]['value'] = augmented_answer
                        pattern = r"<Conclusion>(.*?)</Conclusion>"
                        match = re.search(pattern, augmented_answer, re.DOTALL)
                        
                        # match Conclusion tags
                        if match:
                            conclusion = match.group(1)
                            conclusion_clean = conclusion.replace(' ', '').replace('\n','')
                            # answer check failed
                            if standard_answer != conclusion_clean:
                                retry_attempts += 1
                                print(f'id: {entry_id} failed')
                                time.sleep(1)
                                with write_lock:
                                    with open(refusal_file_1, 'a+', encoding='utf-8') as refusalfile1:
                                        json.dump(data, refusalfile1, ensure_ascii=False)
                                        refusalfile1.write('\n')

                                    refusal_ids.add(entry_id)

                            # answer check success
                            else:
                                print(f"id: {entry_id} success")
                                with write_lock:
                                    with open(output_file, 'a+', encoding='utf-8') as outfile:
                                        json.dump(data, outfile, ensure_ascii=False)
                                        outfile.write('\n')
                                    processed_ids.add(entry_id)
                                break
                        
                        # failed to get Conclusion
                        else:
                            retry_attempts += 1
                            print(f'id: {entry_id} fail to get Conclusion')
                            with write_lock:
                                with open(refusal_file_2, 'a+', encoding='utf-8') as refusalfile2:
                                    json.dump(data, refusalfile2, ensure_ascii=False)
                                    refusalfile2.write('\n')

                                refusal_ids.add(entry_id)
                            time.sleep(1)

                    except Exception as e:
                        print(f"id: {entry_id} error : {e}")
                        print(f"Stop !!!")
                        error_limit_times += 1
                        if error_limit_times == 1000:
                            stop_processing_event.set()
                        return

    except json.JSONDecodeError:
        print("Invalid JSON format, skipping line.")
        return

# if the id is not in the processed_ids set, process the entry
if os.path.exists(output_file):
    with open(output_file, 'r', encoding='utf-8') as outfile:
        for line in outfile:
            try:
                existing_data = json.loads(line)
                processed_ids.add(existing_data['id'])
            except json.JSONDecodeError:
                continue

processed_num = len(processed_ids)
print(f'Orocessed_ids: {processed_num}')

with open(input_file, 'r', encoding='utf-8') as infile:
    lines = infile.readlines()

start_time = time.time()
print('Data to process: ', len(lines))

with ThreadPoolExecutor(max_workers=num_threads) as executor:
    futures = [executor.submit(process_line, line) for line in lines]

    for future in as_completed(futures):
        future.result()
end_time = time.time()
print(f"{len(processed_ids)-processed_num} processed in {end_time - start_time} seconds")

