import gradio as gr
import os
import json
from tqdm import tqdm

def write_json(fn, data):
    with open(fn, "w") as f:
        f.write(json.dumps(data, ensure_ascii=False, indent=2))

def load_json(fn):
    with open(fn) as f:
        return json.loads(f.read())

CONFIG = load_json("config.json")['annotation_config']

auth_list = list(CONFIG['auth_list'].items())
fn = CONFIG["input_fns"]["qa_data_fn"]  # Your data file

output_dir = CONFIG["output_dir"]
os.makedirs(output_dir, exist_ok=True)

save_dir = os.path.join(output_dir, fn.split('/')[-1].split('.')[0])
os.makedirs(save_dir, exist_ok=True)

# Load your data
data = load_json(fn)
max_choices = max(len(item['choices']) for item in data)

def get_result_fn(username):
    result_fn = os.path.join(save_dir, f"{username}.json")
    return result_fn

def get_header(username, idx, annotation):
    global data
    idx = int(idx)
    n_finished = sum([item is not None for item in annotation])
    progress_bar = str(tqdm(total=len(data), initial=n_finished,
                            ascii=" .oO0", bar_format='{l_bar}{bar:30}|Finished:{n_fmt}|Total:{total_fmt}')).strip().replace(" ", "+")
    return f"``` shell\n{progress_bar}\n\nCurrent id: {idx+1}\nYour annotation: {annotation[idx]}\n```"

def format_required_state(required_state):
    """Format required state information for display"""
    state_text = "**Required State Information:**\n\n"
    for state_name, state_value in required_state.items():
        state_text += f"- **{state_name.replace('_', ' ').title()}**: `{state_value}`\n"
    return state_text

def get_annotation_instructions():
    """Return the annotation instructions"""
    return """**Task**: Select the ONE choice that best matches the given query considering the required state information."""

def initialize(request: gr.Request):
    username = request.username
    result_fn = get_result_fn(username)

    if os.path.exists(result_fn):
        annotation = load_json(result_fn)
        for i, ann in enumerate(annotation):
            if ann is None:
                break
        idx = i
    else:
        annotation = [None for _ in range(len(data))]
        idx = 0

    greeting = f"Thanks for logging in, {request.username}. Select the appropriate choice for each query based on the required state information. Click **Submit** to save your annotation."
    outputs = load_item(username, idx, annotation)

    return [greeting] + outputs

def load_item(username, idx, annotation):
    assert isinstance(idx, int)
    idx = max(0, min(idx, len(data) - 1))
    current_item = data[idx]
    
    query = current_item['query']
    choices = current_item['choices']
    is_answer = current_item['is_answer']
    required_state = current_item['required_state']

    current_annotation = annotation[idx]
    if current_annotation is None:
        selected_choice = None
        comments_anno = ""
    else:
        selected_choice = current_annotation["selected_choice"]
        comments_anno = current_annotation["comments"]
    
    # Create radio button choices for current item
    radio_choices = [f"Choice {i+1}: {choice}" for i, choice in enumerate(choices)]
    radio_value = f"Choice {selected_choice+1}: {choices[selected_choice]}" if selected_choice is not None else None
    
    outputs = [
        f"Item {idx + 1}",
        query,
        format_required_state(required_state),
        get_annotation_instructions(),
        get_header(username, idx, annotation),
        gr.update(choices=radio_choices, value=radio_value),
        comments_anno,
        username, idx, annotation
    ]

    return outputs

def load_prev(user_name, idx, annotation):
    idx -= 1
    return load_item(user_name, idx, annotation)

def load_next(username, idx, annotation):
    idx += 1
    return load_item(username, idx, annotation)

def submit(selected_choice_text, comments_anno_input, username, idx, annotation):
    if idx >= len(data):
        gr.Error(f"Index {idx} out of range. Total items: {len(data)}")
        return load_item(username, idx, annotation)
    
    # Check if a choice is selected
    if selected_choice_text is None:
        gr.Warning("Please select a choice before submitting!")
        return load_item(username, idx, annotation)
    
    # Extract choice index from selected text
    selected_choice = int(selected_choice_text.split(":")[0].replace("Choice ", "")) - 1
    
    current_item = data[idx]
    
    annotation[idx] = {
        "query": current_item['query'],
        "required_state": current_item['required_state'],
        "selected_choice": selected_choice,
        "original_answers": current_item['is_answer'],
        "choices": current_item['choices'],
        "comments": comments_anno_input
    }
    
    write_json(get_result_fn(username), annotation)
    gr.Info("Annotation recorded.")
    return load_item(username, idx, annotation)

# Gradio Interface
with gr.Blocks(title="Single Choice QA Annotation") as demo:
    # Session states
    username = gr.State(None)
    idx = gr.State(0)
    annotation = gr.State([None for _ in range(len(data))])

    # Header content
    top_instruction = gr.Markdown("")
    header = gr.Markdown()

    # Data display section
    with gr.Column():
        gr.Markdown("## Current Item")
        item_id_display = gr.Textbox(label="Item ID", interactive=False)
        query_display = gr.Textbox(label="Query", interactive=False, lines=3)
        required_state_display = gr.Markdown()
        instructions_display = gr.Markdown()

    # Choice selection section
    gr.Markdown("## Select the Best Choice")
    choice_radio = gr.Radio(
        label="Available Choices",
        choices=[],
        value=None,
        type="value"
    )

    # Comments section
    comments = gr.Textbox(
        label="Additional Comments",
        placeholder="Any additional observations about the choice or annotation decision...",
        lines=2
    )

    # Define input and output lists
    input_list = [choice_radio, comments]
    display_list = [item_id_display, query_display, required_state_display, instructions_display, header]
    state_list = [username, idx, annotation]
    all_outputs = display_list + [choice_radio, comments] + state_list

    # Initialize page
    demo.load(initialize, inputs=None, outputs=[top_instruction] + all_outputs)

    # Control buttons
    with gr.Row():
        prev_btn = gr.Button("Previous")
        submit_btn = gr.Button("Submit", variant="primary")
        next_btn = gr.Button("Next")

    prev_btn.click(fn=load_prev,
                   inputs=state_list,
                   outputs=all_outputs)

    next_btn.click(fn=load_next,
                   inputs=state_list,
                   outputs=all_outputs)

    submit_btn.click(fn=submit,
                     inputs=input_list + state_list,
                     outputs=all_outputs)

# Launch the demo
if __name__ == "__main__":
    demo.launch(auth=auth_list, share=False)