import gradio as gr
import os
import json
from tqdm import tqdm

def write_json(fn, data):
    with open(fn, "w") as f:
        f.write(json.dumps(data, ensure_ascii=False, indent=2))

def load_json(fn):
    with open(fn) as f:
        return json.loads(f.read())

CONFIG = load_json("config.json")['annotation_config']

auth_list = list(CONFIG['auth_list'].items())
fn = CONFIG["input_fns"]["env_data_fn"]  # Your data file

output_dir = CONFIG["output_dir"]
os.makedirs(output_dir, exist_ok=True)

save_dir = os.path.join(output_dir, fn.split('/')[-1].split('.')[0])
os.makedirs(save_dir, exist_ok=True)

# Load your data - adapt this based on your actual data structure
data = load_json(fn)
max_states = max(len(item[1]['exposed_states']) for item in data)


def get_result_fn(username):
    result_fn = os.path.join(save_dir, f"{username}.json")
    return result_fn


def get_header(username, idx, annotation):
    global data
    idx = int(idx)
    n_finished = sum([item is not None for item in annotation])
    progress_bar = str(tqdm(total=len(data), initial=n_finished,
                            ascii=" .oO0", bar_format='{l_bar}{bar:30}|Finished:{n_fmt}|Total:{total_fmt}')).strip().replace(" ", "+")
    return f"``` shell\n{progress_bar}\n\nCurrent id: {idx+1}\nYour annotation: {annotation[idx]}\n```"


def format_exposed_states(exposed_states):
    """Format exposed states for display with enhanced context information"""
    states_text = "**Exposed States to Annotate:**\n\n"
    for state_name, state_info in exposed_states.items():
        states_text += f"### **{state_name}**\n"
        states_text += f"- **Current Value**: `{state_info['this_state_value']}`\n"
        states_text += f"- **Previous Value**: `{state_info['past_state_value']}`\n"
        states_text += f"- **All Possible Values**: `{', '.join(state_info['all_state_values'])}`\n\n"
    return states_text


def get_annotation_scale_info():
    """Return the annotation scale description"""
    return """
## Annotation Scale (0-2 points):

**2 points (Fully Implied)**: User naturally reveals complete state information. State can be determined without additional reasoning. Information exposure is reasonable and natural.

**1 point (Partly Implied)**: Most information is exposed, but may lack some details. Requires reasoning to determine complete state. 
Example: Query mentions "limited budget" but doesn't specify exact range.

**0 points (Not Reflected)**: Query is completely unrelated to this state. Cannot infer any relevant information from the query.
"""


def initialize(request: gr.Request):
    username = request.username
    result_fn = get_result_fn(username)

    if os.path.exists(result_fn):
        annotation = load_json(result_fn)
        for i, ann in enumerate(annotation):
            if ann is None:
                break
        idx = i
    else:
        annotation = [None for _ in range(len(data))]
        idx = 0

    greeting = f"Thanks for logging in, {request.username}. Rate how well each exposed state is reflected in the user query (0-2 points). Click **Submit** to save your annotation."
    outputs = load_item(username, idx, annotation)

    return [greeting] + outputs


def load_item(username, idx, annotation):
    assert isinstance(idx, int)
    idx = max(0, min(idx, len(data) - 1))
    current_item = data[idx]
    item_id = current_item[0]
    item_data = current_item[1]
    query = item_data['query']
    exposed_states = item_data['exposed_states']

    current_annotation = annotation[idx]
    if current_annotation is None:
        state_anno = {state_name: {"score": None} for state_name in exposed_states}
        comments_anno = ""
    else:
        state_anno = current_annotation["state_scores"]
        comments_anno = current_annotation["comments"]
    
    updates = []
    for state_name in list(exposed_states.keys()):
        state_score = state_anno[state_name]['score']
        state_info = exposed_states[state_name]
        current_value = state_info['this_state_value']
        past_value = state_info['past_state_value']
        all_values = ', '.join(state_info['all_state_values'])
        
        info_text = f"Current: '{current_value}' | Past: '{past_value}' | All: [{all_values}] | Rate: 0=Not reflected, 1=Partly implied, 2=Fully implied"
        
        updates.append(gr.update(
            visible=True,
            label=state_name,
            info=info_text,
            value=state_score
        ))
    
    for _ in range(len(state_anno), max_states):
        updates.append(gr.update(visible=False))

    outputs = [
        item_id,
        query,
        format_exposed_states(exposed_states),
        get_annotation_scale_info(),
        get_header(username, idx, annotation),
        username, idx, annotation
    ] + updates + [comments_anno]

    return outputs


def load_prev(user_name, idx, annotation):
    idx -= 1
    return load_item(user_name, idx, annotation)


def load_next(username, idx, annotation):
    idx += 1
    return load_item(username, idx, annotation)


def submit(*args):
    # Last 3 arguments: comments, username, idx, annotation
    username = args[-3]
    idx = int(args[-2])
    annotation = args[-1]
    
    if idx >= len(data):
        gr.Error(f"Index {idx} out of range. Total items: {len(data)}")
        return
    
    state_anno_input = args[:max_states]
    comments_anno_input = args[max_states]

    current_item = data[idx]
    item_data = current_item[1]
    exposed_states = item_data['exposed_states']
    state_names = list(exposed_states.keys())
    state_anno_values = {}
    for state_name, value in zip(state_names, state_anno_input):
        state_anno_values[state_name] = value

    annotation[idx] = {
        "item_id": current_item[0],
        "state_scores": {
            state_name: {
                "score": state_anno_values[state_name],
                "expected_value": state_info['this_state_value'],
                "past_value": state_info['past_state_value'],
                "all_possible_values": state_info['all_state_values']
        } for state_name, state_info in exposed_states.items()},
        "comments": comments_anno_input
    }
    
    for state_name, score in state_anno_values.items():
        if score is None:
            gr.Warning(f"Submission NOT saved: please complete all state scoring fields ({state_name} empty)!")
            return load_item(username, idx, annotation)
    write_json(get_result_fn(username), annotation)
    gr.Info("Annotation recorded.")
    return load_item(username, idx, annotation)


# Gradio Interface
with gr.Blocks(title="State Exposure Quality Evaluation") as demo:
    # Session states
    username = gr.State(None)
    idx = gr.State(0)
    annotation = gr.State([None for _ in range(len(data))])

    # Header content
    top_instruction = gr.Markdown("")
    header = gr.Markdown()

    # Data display section
    with gr.Column():
        gr.Markdown("## Current Item")
        item_id_display = gr.Textbox(label="Item ID", interactive=False)
        query_display = gr.Textbox(label="User Query", interactive=False, lines=3)
        states_display = gr.Markdown(label="Exposed States")
        scale_info_display = gr.Markdown()

    # Evaluation section
    gr.Markdown("## State Exposure Evaluation")
    
    # Dynamic state scoring components
    state_components = []
    for i in range(max_states):
        state_score = gr.Radio(
            choices=[0, 1, 2],
            label=f"State {i+1}",
            info="Rate how well this state is reflected in the query (0-2)",
            visible=False
        )
        state_components.append(state_score)
    # Comments section
    comments = gr.Textbox(
        label="Additional Comments",
        placeholder="Any additional observations about state exposure...",
        lines=2
    )

    # Aggregate lists
    input_list = state_components + [comments]
    display_list = [item_id_display, query_display, states_display, scale_info_display, header]
    state_list = [username, idx, annotation]

    # Initialize page
    demo.load(initialize, inputs=None, outputs=[top_instruction] + display_list + state_list + input_list)

    # Control buttons
    with gr.Row():
        prev_btn = gr.ClearButton(input_list, value="Previous")
        submit_btn = gr.Button("Submit", variant="primary")
        next_btn = gr.ClearButton(input_list, value="Next")

    prev_btn.click(fn=load_prev,
                   inputs=state_list,
                   outputs=display_list + state_list + input_list)

    next_btn.click(fn=load_next,
                   inputs=state_list,
                   outputs=display_list + state_list + input_list)

    submit_btn.click(fn=submit,
                     inputs=input_list + state_list,
                     outputs=display_list + state_list + input_list)

# Launch the demo
if __name__ == "__main__":
    demo.launch(auth=auth_list, share=False)
