import re
import numpy as np
from typing import List

question_pattern = re.compile(r'(?<=[0-9]\. ).+\?')

def extract_questions(responses: List[str]) -> List[List[str]]:
    """
    Extract checklist questions from the given responses.
    """
    questions_list = []
    for response in responses:
        response = response.split("##")[-1]
        questions = question_pattern.findall(response) if question_pattern.findall(response) else [""]
        questions_list.append(questions)
    return questions_list


def extract_info(row: dict) -> dict:
    """
    Extract the necessary information from the given row (WildBench format).
    """
    history = chat_history(row["conversation_input"])
    instruction = row["conversation_input"][-1]["content"]
    ref_response = row["references"]["gpt-4"]
    return {
        "history": history,
        "instruction": instruction,
        "reference_response": ref_response
    }

def chat_history(conversation_input) -> dict:
    """
    Convert the conversation input to chat history as the context.
    """
    history = ""
    if len(conversation_input) > 0: 
        for x in conversation_input[:-1]:
            if x["role"] == "user":
                history += "USER: " + x["content"] + "\n\n"
            elif x["role"] == "assistant":
                history += "ASSISTANT: " + x["content"] + "\n\n"
    return history

def robust_logprobs(logprobs: list) -> dict:
    """
    Extract the logprobs of target tokens from the given logprobs list.
    """
    prob = {
        "Yes": -100,
        "No": -100
    }
    for lobprob in logprobs:
        word = lobprob['token'].strip().capitalize()
        if word in prob:
            prob[word] = max(prob[word], lobprob["logprob"])
    return prob

def get_confidence(prob: dict, default: int | float = np.nan) -> str:
    """
    Derive the conditional normalized probability from the judgment response.
    """
    if prob["Yes"] == prob["No"] == -100:
        return default
    else:
        confidence = np.exp(prob["Yes"]) / (np.exp(prob["Yes"]) + np.exp(prob["No"]))
        return confidence

def get_judgment(score: float) -> str:
    """
    Derive the judgment from the score.
    """
    if np.isnan(score) or abs(score - 0.5) < 1e-3:
        return "Unsure"
    else:
        judgment = "Yes" if score > 0.5 else "No"
        return f"{judgment} ({score * 100:.1f})%"

def get_all_judgments(checklist: dict) -> list:
    """
    Derive all judgment from the checklist.
    """
    judgment_list = ["Unsure"] * (max(checklist.keys()) + 1)
    for key in checklist.keys():
        judgment_list[key] = get_judgment(checklist[key])
    return judgment_list

def get_confidence_score(item: dict) -> float:
    """
    Derive the confidence score from the judgment response.
    """
    if item["response"]["body"]:
        logprob = item["response"]["body"]["choices"][0]["logprobs"]["content"][0]["top_logprobs"]
    else:
        logprob = []
    choice = robust_logprobs(logprob)
    choice = get_confidence(choice)
    return choice