import gradio as gr
import os
import json
from tqdm import tqdm

def write_json(fn, data):
    with open(fn, "w") as f:
        f.write(json.dumps(data, ensure_ascii=False, indent=2))

def load_json(fn):
    with open(fn) as f:
        return json.loads(f.read())

CONFIG = load_json("config.json")['annotation_config']

auth_list = list(CONFIG['auth_list'].items())
fn = CONFIG["input_fns"]["conv_data_fn"]  # Your data file

output_dir = CONFIG["output_dir"]
os.makedirs(output_dir, exist_ok=True)

save_dir = os.path.join(output_dir, fn.split('/')[-1].split('.')[0])
os.makedirs(save_dir, exist_ok=True)

# Consistency rating definitions
CONSISTENCY_RATINGS = {
    0: "No conflict: Conversation content is completely consistent with current states. No contradictory information.",
    1: "Minor inconsistency: Slight expression inconsistencies but no substantial conflict. E.g., state says 'basic_awareness_sessions' but conversation mentions 'some additional activities'",
    2: "Major conflict: Serious contradictions that cannot be reconciled. E.g., state says only doing 'awareness sessions' but conversation says 'we already have a comprehensive college counseling program'"
}

# Load annotation data
def write_json(fn, data):
    with open(fn, "w") as f:
        f.write(json.dumps(data, ensure_ascii=False, indent=2))

def load_json(fn):
    with open(fn) as f:
        return json.loads(f.read())

data = load_json(fn)


def get_max_states():
    """Get the maximum number of states across all items"""
    max_states = 0
    for item in data:
        conversation_data = item[1]
        num_states = len(conversation_data['unchanged_states'])
        max_states = max(max_states, num_states)
    return max_states

max_states = get_max_states()

""" Defining functions and layouts """

def get_result_fn(username):
    result_fn = os.path.join(save_dir, f"{username}.json")
    return result_fn

def get_header(username, idx, annotation):
    global data
    idx = int(idx)
    n_finished = sum([item is not None for item in annotation])
    progress_bar = str(tqdm(total=len(data), initial=n_finished,
                            ascii=" .oO0", bar_format='{l_bar}{bar:30}|Finished:{n_fmt}|Total:{total_fmt}')).strip().replace(" ", "+")
    return f"``` shell\n{progress_bar}\n\nCurrent id: {idx+1}\nYour annotation: {annotation[idx]}\n```"

def format_conversation(user_turns):
    """Format conversation turns for display"""
    conversation_text = ""
    for i, turn in enumerate(user_turns):
        conversation_text += f"**Turn {i+1}:**\n{turn['content']}\n\n"
    return conversation_text


def initialize(request: gr.Request):
    username = request.username
    result_fn = get_result_fn(username)

    if os.path.exists(result_fn):   # start from the first not filled id
        annotation = load_json(result_fn)
        for i, ann in enumerate(annotation):
            if ann is None:
                break
        idx = i
    else:   # start from 0
        annotation = [None for _ in range(len(data))]
        idx = 0

    greeting = f"Thanks for logging in, {request.username}. For each state in the conversation, evaluate the consistency level (0-2) based on how well the conversation content aligns with that specific state. Click **Submit** to save your annotation."

    return [greeting] + load_item(username, idx, annotation)


def load_item(username, idx, annotation):
    assert isinstance(idx, int)
    idx = max(0, min(idx, len(data) - 1))
    
    current_item = data[idx]
    conversation_id = current_item[0]
    conversation_data = current_item[1]

    # Get state updates for this item
    # state_updates = update_state_visibility(idx)
    current_states = conversation_data['unchanged_states']
    state_keys = list(current_states.keys())
    current_annotation = annotation[idx]
    if current_annotation is None:
        state_key_to_rating = {state_key: {"consistency_rating": None, "state_value": current_states[state_key]} for state_key in state_keys}
        comments = ""
    else:
        state_key_to_rating = current_annotation['state_ratings']
        comments = current_annotation['comments']

    rating_choices = [(f"{k}: {v.split(':')[0]}", k) for k, v in CONSISTENCY_RATINGS.items()]
    state_updates = []
    for state_key in state_keys:
        state_updates.append(gr.update(
            visible = True,
            label = f"{state_key}: `{current_states[state_key]}`",
            info = "Rate consistency between conversation and this state",
            choices = rating_choices,
            value = state_key_to_rating[state_key]["consistency_rating"]
        ))
    for _ in range(len(state_keys), max_states):
        state_updates.append(gr.update(visible=False, value=None))

    return [
        conversation_id,
        format_conversation(conversation_data['user_turns']),
        get_header(username, idx, annotation),
        username, idx, annotation
    ] + state_updates + [comments]


def load_prev(username, idx, annotation):
    idx -= 1
    return load_item(username, idx, annotation)


def load_next(username, idx, annotation):
    idx += 1
    return load_item(username, idx, annotation)


def submit(*args):
    # Extract arguments: state_ratings (first max_states), comments, username, idx, annotation
    state_ratings = args[:max_states]
    comments = args[max_states]
    username = args[max_states + 1]
    idx = int(args[max_states + 2])
    annotation = args[max_states + 3]
    
    if idx >= len(data):
        gr.Error(f"Index {idx} out of range. Total items: {len(data)}")
        return

    current_item = data[idx]
    conversation_id = current_item[0]
    conversation_data = current_item[1]
    current_states = conversation_data['unchanged_states']
    
    # Get the actual number of states for this item
    state_keys = list(current_states.keys())
    state_key_to_rating = {state_key: rating for state_key, rating in zip(state_keys, state_ratings)}

    annotation[idx] = {
        "conversation_id": conversation_id,
        "state_ratings": {
            state_key: {
                "state_value": current_states[state_key],
                "consistency_rating": state_key_to_rating[state_key]
            } for state_key in state_keys
        },
        "comments": comments
    }

    for state_key, rating in state_key_to_rating.items():
        if rating is None:
            gr.Warning(f"Submission NOT saved: please provide consistency ratings for all states ({state_key} empty)!")
            return load_item(username, idx, annotation)
    write_json(get_result_fn(username), annotation)
    gr.Info("Annotation recorded.")
    return load_item(username, idx, annotation)


""" Defining layouts """
with gr.Blocks(title="State-Conversation Consistency Evaluation") as demo:
    # Session states
    username = gr.State(None)
    idx = gr.State(0)
    annotation = gr.State([None for _ in range(len(data))])

    # Header content
    top_instruction = gr.Markdown("")
    header = gr.Markdown()

    # Data display section
    with gr.Column():
        gr.Markdown("## Current Item")
        conversation_id_display = gr.Textbox(label="Conversation ID", interactive=False)
        conversation_display = gr.Markdown(label="Conversation")

    # Evaluation section
    gr.Markdown("## State Consistency Evaluation")
    
    # Consistency rating guidelines
    with gr.Accordion("Rating Guidelines", open=False):
        guidelines_text = ""
        for rating, description in CONSISTENCY_RATINGS.items():
            guidelines_text += f"**{rating}**: {description}\n\n"
        gr.Markdown(guidelines_text)

    # Dynamic state rating components
    gr.Markdown("### Rate Each State")
    state_rating_components = []
    
    max_states = get_max_states()
    for i in range(max_states):
        state_rating = gr.Radio(
            [],
            label=f"State {i+1}",
            info="Rate consistency for this state",
            visible=False  # Initially hidden
        )
        state_rating_components.append(state_rating)

    # Comments section
    comments = gr.Textbox(
        label="Additional Comments",
        placeholder="Any additional observations, specific examples of conflicts, or reasoning for your ratings...",
        lines=3
    )

    # Define component lists
    display_list = [conversation_id_display, conversation_display, header]
    state_list = [username, idx, annotation]
    input_list = state_rating_components + [comments]
    all_outputs = display_list + state_list + input_list

    # Initialize page
    demo.load(initialize, inputs=None, outputs=[top_instruction] + all_outputs)

    # Control buttons
    with gr.Row():
        prev_btn = gr.ClearButton(state_rating_components + [comments], value="Previous")
        submit_btn = gr.Button("Submit", variant="primary")
        next_btn = gr.ClearButton(state_rating_components + [comments], value="Next")

    # Button actions - note the different output lists to account for greeting in initialize
    prev_btn.click(fn=load_prev,
                   inputs=state_list,
                   outputs=all_outputs)

    next_btn.click(fn=load_next,
                   inputs=state_list,
                   outputs=all_outputs)

    submit_btn.click(fn=submit,
                     inputs=input_list + state_list,
                     outputs=all_outputs)

""" Launching the demo """
demo.launch(auth=auth_list, share=False)
