from __future__ import annotations

import random
from typing import Callable, Dict, Optional, Tuple


FormatterOutput = Optional[Tuple[str, str]]


def _format_choices(choices):
    lines = []
    for idx, choice in enumerate(choices):
        letter = chr(ord("A") + idx)
        lines.append(f"{letter}. {choice}")
    return "\n".join(lines)


def format_mmlu(example) -> FormatterOutput:
    question = example.get("question")
    choices = example.get("choices")
    answer = example.get("answer")
    if question is None or choices is None or answer is None:
        return None

    choice_block = _format_choices(choices)
    prompt = (
        "You will be given a multiple-choice question. "
        "Select the single best answer and respond with the corresponding letter.\n\n"
        f"Question: {question}\n"
        f"{choice_block}\nAnswer:"
    )
    response = f"The correct answer is {answer}."
    return prompt, response


def format_arc(example) -> FormatterOutput:
    question = example.get("question")
    choices = example.get("choices", {}).get("text")
    answer = example.get("answerKey")
    if question is None or choices is None or answer is None:
        return None

    choice_block = _format_choices(choices)
    prompt = (
        "Answer the science question by selecting the single best option.\n\n"
        f"Question: {question}\n"
        f"{choice_block}\nAnswer:"
    )
    response = f"The correct answer is {answer}."
    return prompt, response


def format_piqa(example) -> FormatterOutput:
    goal = example.get("goal")
    sol1 = example.get("sol1")
    sol2 = example.get("sol2")
    label = example.get("label")
    if goal is None or sol1 is None or sol2 is None or label is None:
        return None

    options = [sol1, sol2]
    choice_block = _format_choices(options)
    correct_letter = "A" if int(label) == 0 else "B"
    prompt = (
        "Choose the option that best completes the goal.\n\n"
        f"Goal: {goal}\n"
        f"{choice_block}\nAnswer:"
    )
    response = f"The correct answer is {correct_letter}."
    return prompt, response


def format_hellaswag(example) -> FormatterOutput:
    context = example.get("ctx") or example.get("context")
    endings = example.get("endings")
    label = example.get("label")
    if context is None or endings is None or label is None:
        return None

    choice_block = _format_choices(endings)
    correct_letter = chr(ord("A") + int(label))
    prompt = (
        "Complete the story by choosing the most plausible continuation.\n\n"
        f"Story: {context}\n"
        f"{choice_block}\nAnswer:"
    )
    response = f"The correct answer is {correct_letter}."
    return prompt, response


def format_gsm8k(example) -> FormatterOutput:
    question = example.get("question")
    answer = example.get("answer")
    if question is None or answer is None:
        return None

    prompt = (
        "Solve the math word problem. Provide the reasoning and the final answer.\n\n"
        f"Problem: {question}\nAnswer:"
    )
    response = answer.strip()
    return prompt, response


def format_alpaca(example) -> FormatterOutput:
    instruction = example.get("instruction")
    output = example.get("output")
    if instruction is None or output is None:
        return None

    input_text = example.get("input") or ""
    prompt_parts = [instruction.strip()]
    if input_text.strip():
        prompt_parts.append("\n\nInput:\n")
        prompt_parts.append(input_text.strip())

    prompt = "".join(prompt_parts)
    response = output.strip()
    if not prompt or not response:
        return None
    return prompt, response


def format_humaneval(example) -> FormatterOutput:
    prompt = example.get("prompt")
    solution = example.get("canonical_solution") or example.get("solution")
    if prompt is None or solution is None:
        return None

    user_prompt = (
        "Write a Python function that satisfies the following specification. "
        "Return only valid Python code.\n\n"
        f"{prompt}\nAnswer:"
    )
    response = solution.strip()
    return user_prompt, response


TASK_FORMATTERS: Dict[str, Callable[[dict], FormatterOutput]] = {
    "mmlu": format_mmlu,
    "piqa": format_piqa,
    "arc": format_arc,
    "hellaswag": format_hellaswag,
    "gsm8k": format_gsm8k,
    "alpaca": format_alpaca,
    "humaneval": format_humaneval,
}


def format_example(task: str, example: dict) -> FormatterOutput:
    formatter = TASK_FORMATTERS.get(task.lower())
    if formatter is None:
        return None
    return formatter(example)


def sample_with_weight(indices, weight: float):
    if weight <= 0:
        return []
    count = int(weight)
    remainder = weight - count
    selected = []
    for _ in range(count):
        selected.extend(indices)
    if remainder > 0:
        fraction = int(round(len(indices) * remainder))
        selected.extend(random.sample(indices, max(1, fraction)))
    random.shuffle(selected)
    return selected

