# Score charts and select high-scoring charts

import argparse
import json
import os
import re
from tqdm import tqdm
import time
import requests
import base64
from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import partial
import shutil

from openai import OpenAI
#import anthropic

# from lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig
# from lmdeploy.vl import load_image

##############################################

os.chdir(os.path.dirname(os.path.realpath(__file__)))

def read_prompt_from_file(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        return f.read().strip()

Score_Prompt = read_prompt_from_file("./prompt/Score_Prompt.txt")
##############################################

openai_api_key = "EMPTY"

client = OpenAI(
    api_key=openai_api_key,
)

def extract_rating(input_string):
    match = re.search(r"Rating: ([1-5])", input_string)
    if match:
        return int(match.group(1))  # Find the first match value
    else:
        return 0


def batch(iterable, n=1):
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]


def file_to_data_url(file_path: str):
    """
    Convert a local image file to a data URL.
    """
    with open(file_path, "rb") as image_file:
        encoded_string = base64.b64encode(image_file.read()).decode('utf-8')

    _, extension = os.path.splitext(file_path)
    mime_type = f"image/{extension[1:].lower()}"

    return f"data:{mime_type};base64,{encoded_string}"



def create_chat_response_by_messages(
    messages,
):
    prompt, image_url = messages
    t1 = time.time()
    message = client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {"role": "system", "content": prompt},
            {
                "role": "user",
                "content": [
                    {"type": "image_url", "image_url": {"url": image_url}},
                ]
            },
        ]
    )

    t2 = time.time()
    print('########################### result, time:', t2 - t1)
    print(message.choices[0].message.content)
    return message.choices[0].message.content


def filter_images(
        data_path,
        batch_size=100,
        num_workers=100,
):
    meta_data = []
    with open(os.path.join(data_path, "all_info.jsonl"), 'r') as file:
        for line in file:
            item = json.loads(line.strip())
            meta_data.append(item)

    full_image_path_list = [os.path.join(data_path, data['image']) for data in meta_data]
    all_outputs = []
    prompts = [(Score_Prompt, file_to_data_url(img_url)) for img_url in full_image_path_list]
    func = partial(create_chat_response_by_messages)
    batch_list = []
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        for item in tqdm(prompts, desc="Processing Batches"):
            batch_list.append(item)
            if len(batch_list) >= batch_size:
                for result in tqdm(executor.map(func, batch_list), total=len(batch_list), desc="filter image"):
                    all_outputs.append(result)
                batch_list = []

        if batch_list:
            for result in tqdm(executor.map(func, batch_list), total=len(batch_list), desc="filter image"):
                all_outputs.append(result)

    model_name = 'Pixtral-Large-Instruct-2411'

    for item, output in zip(meta_data, all_outputs):
        if 'rating' not in item:
            item['rating'] = {}
        item['rating'][model_name] = int(extract_rating(output))

    with open(os.path.join(data_path, "all_info.jsonl"), "w") as f:
        for item in meta_data:
            f.write(json.dumps(item) + "\n")


def filter_images_flow(
        data_path,
        batch_size=100,
        num_workers=100,
):
    meta_data = []
    all_outputs = []
    model_name = ''
    open_file = os.path.join(data_path, "all_info.jsonl")
    save_file = os.path.join(data_path, "all_info_filter.jsonl")
    total_lines = sum(1 for _ in open(open_file, 'r'))
    with open(open_file, 'r') as file, open(save_file, "w") as f:
        for line in tqdm(file, desc="filter image", total=total_lines):
            item = json.loads(line.strip())
            if 'rating' in item:
                f.write(json.dumps(item) + "\n")
                continue

            meta_data.append(item)
            if len(meta_data) >= batch_size:

                full_image_path_list = [os.path.join(data_path, data['image']) for data in meta_data]
                prompts = [(Score_Prompt, file_to_data_url(img_url)) for img_url in full_image_path_list]
                func = partial(create_chat_response_by_messages)
                with ThreadPoolExecutor(max_workers=num_workers) as executor:

                    for result in executor.map(func, prompts):
                        all_outputs.append(result)

                for item, output in zip(meta_data, all_outputs):
                    if 'rating' not in item:
                        item['rating'] = {}
                    item['rating'][model_name] = int(extract_rating(output))

                for item in meta_data:
                    f.write(json.dumps(item) + "\n")

                meta_data = []
                all_outputs = []

        if meta_data:

            full_image_path_list = [os.path.join(data_path, data['image']) for data in meta_data]
            prompts = [(Score_Prompt, file_to_data_url(img_url)) for img_url in full_image_path_list]
            func = partial(create_chat_response_by_messages)
            with ThreadPoolExecutor(max_workers=num_workers) as executor:
                for result in executor.map(func, prompts):
                    all_outputs.append(result)

                for item, output in zip(meta_data, all_outputs):
                    if 'rating' not in item:
                        item['rating'] = {}
                    item['rating'][model_name] = int(extract_rating(output))

                for item in meta_data:
                    f.write(json.dumps(item) + "\n")

    print(f'all done, cover {open_file} from {save_file}')
    shutil.move(save_file, open_file)


def arg_parser():
    parser = argparse.ArgumentParser()

    #parser.add_argument("--model_path", type=str, default='/dfs/data/InternVL2-40B')
    parser.add_argument("--data_path", type=str)
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--top_p", type=float, default=0.95)
    parser.add_argument("--num_workers", type=int, default=10)
    parser.add_argument("--batch_size", type=int, default=20)

    return parser.parse_args()


if __name__ == "__main__":
    args = arg_parser()
    print(args)

    # backend_config = TurbomindEngineConfig(
    #     tp=args.num_gpus,
    #     session_len=8192,
    # )
    # gen_config = GenerationConfig(
    #     top_p=args.top_p,
    #     temperature=args.temperature,
    #     max_new_tokens=1024,
    # )
    #
    # pipe = pipeline(args.model_path, backend_config=backend_config)

    filter_images_flow(args.data_path, args.batch_size, args.num_workers)

