"""
pip install gradio    # proxy_on first
python vis_geochat_data.py
# browse data in http://127.0.0.1:10064
"""

import os
import io
import json
import copy
import time
import gradio as gr
import base64
from PIL import Image
from io import BytesIO
from argparse import Namespace
from llava import conversation as conversation_lib
from typing import Sequence
from vlmeval import *
from vlmeval.dataset import SUPPORTED_DATASETS, build_dataset

SYS = "You are a helpful assistant. Your job is to faithfully translate all provided text into Chinese faithfully. "

# Translator = SiliconFlowAPI(model='Qwen/Qwen2.5-7B-Instruct', system_prompt=SYS)
Translator = OpenAIWrapper(model='gpt-4o-mini', system_prompt=SYS)


def image_to_mdstring(image):
    return f"![image](data:image/jpeg;base64,{image})"


def images_to_md(images):
    return '\n\n'.join([image_to_mdstring(image) for image in images])


def mmqa_display(question, target_size=768):
    question = {k.lower() if len(k) > 1 else k: v for k, v in question.items()}
    keys = list(question.keys())
    keys = [k for k in keys if k not in ['index', 'image']]

    idx = question.pop('index', 'XXX')
    text = f'\n- INDEX: {idx}\n'

    images = question.pop('image')
    if images[0] == '[' and images[-1] == ']':
        images = eval(images)
    else:
        images = [images]

    qtext = question.pop('question', None)
    if qtext is not None:
        text += f'- QUESTION: {qtext}\n'

    if 'A' in question:
        text += f'- Choices: \n'
        for k in string.ascii_uppercase:
            if k in question:
                text += f'\t-{k}: {question.pop(k)}\n'
    answer = question.pop('answer', None)
    
    for k in question:
        if not pd.isna(question[k]):
            text += f'- {k.upper()}. {question[k]}\n'
    
    if answer is not None:
        text += f'- ANSWER: {answer}\n'

    image_md = images_to_md(images)

    return text, image_md


def parse_args():
    parser = argparse.ArgumentParser()
    # Essential Args, Setting the Names of Datasets and Models
    parser.add_argument('--port', type=int, default=7860)
    args = parser.parse_args()
    return args


def gradio_app_vis_dataset(port=7860):
    data, loaded_obj = None, {}

    def btn_submit_click(filename, ann_id):
        if filename not in loaded_obj:
            return filename_change(filename, ann_id)
        nonlocal data
        data_desc = gr.Markdown(f'Visualizing {filename}, {len(data)} samples in total. ')
        if ann_id < 0 or ann_id >= len(data):
            return filename, ann_id, data_desc, gr.Markdown('Invalid Index'), gr.Markdown(f'Index out of range [0, {len(data) - 1}]')
        item = data.iloc[ann_id]
        text, image_md = mmqa_display(item)
        return filename, ann_id, data_desc, image_md, text

    def btn_next_click(filename, ann_id):
        return btn_submit_click(filename, ann_id + 1)

    # def translate_click(anno_en):
    #     return gr.Markdown(Translator.generate(anno_en))

    def filename_change(filename, ann_id):
        nonlocal data, loaded_obj

        def legal_filename(filename):
            LMURoot = LMUDataRoot()
            if filename in SUPPORTED_DATASETS:
                return build_dataset(filename).data
            elif osp.exists(filename):
                data = load(filename)
                assert 'index' in data and 'image' in data
                image_map = {i: image for i, image in zip(data['index'], data['image'])}
                for k, v in image_map.items():
                    if (not isinstance(v, str) or len(v) < 64) and v in image_map:
                        image_map[k] = image_map[v]
                data['image'] = [image_map[k] for k in data['index']]
                return data
            elif osp.exists(osp.join(LMURoot, filename)):
                filename = osp.join(LMURoot, filename)
                return legal_filename(filename)
            else:
                return None

        data = legal_filename(filename)
        if data is None:
            return filename, 0, gr.Markdown(''), gr.Markdown("File not found"), gr.Markdown("File not found")
        
        loaded_obj[filename] = data
        return btn_submit_click(filename, 0)

    with gr.Blocks() as app:
        
        filename = gr.Textbox(
            value='Dataset Name (supported by VLMEvalKit) or TSV FileName (Relative under `LMURoot` or Real Path)', 
            label='Dataset', 
            interactive=True,
            visible=True)
            
        with gr.Row():
            ann_id = gr.Number(0, label='Sample Index (Press Enter)', interactive=True, visible=True)
            btn_next = gr.Button("Next")
            # btn_translate = gr.Button('CN Translate')

        with gr.Row():
            data_desc = gr.Markdown('Dataset Description', label='Dataset Description')
        
        with gr.Row():
            image_output = gr.Markdown('Image PlaceHolder', label='Image Visualization')
            anno_en = gr.Markdown('Image Annotation', label='Image Annotation')
            # anno_cn = gr.Markdown('Image Annotation (Chinese)', label='Image Annotation (Chinese)')

        input_components = [filename, ann_id]
        all_components = [filename, ann_id, data_desc, image_output, anno_en]

        filename.submit(filename_change, input_components, all_components)
        ann_id.submit(btn_submit_click, input_components, all_components)
        btn_next.click(btn_next_click, input_components, all_components)
        # btn_translate.click(translate_click, anno_en, anno_cn)

    # app.launch()
    app.launch(server_name='0.0.0.0', debug=True, show_error=True, server_port=port)


if __name__ == "__main__":
    args = parse_args()
    gradio_app_vis_dataset(port=args.port)

