import io
import os
from collections import defaultdict
from multiprocessing.pool import ThreadPool
from typing import Any

import jinja2
import numpy as np
import requests
from tqdm import tqdm

from .types import EvalResult, Message, SamplerBase, SingleEvalResult

QUERY_TEMPLATE_MULTICHOICE = """
Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.

{Question}

A) {A}
B) {B}
C) {C}
D) {D}
""".strip()

QUERY_TEMPLATE_MULTICHOICE_SELF_PLAY = """
Please reason step by step, and put your final answer within \\boxed{{}}. Your final answer should be of the following format: \\boxed{{LETTER}} where LETTER is one of ABCD.

Question: {question}

Options:
A) {A}
B) {B}
C) {C}
D) {D}
""".strip()

QUERY_TEMPLATE_MULTICHOICE_PRO = """
Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCDEFGHIJ. Think step by step before answering.

{question}

A) {options[0]}
B) {options[1]}
C) {options[2]}
D) {options[3]}
E) {options[4]}
F) {options[5]}
G) {options[6]}
H) {options[7]}
I) {options[8]}
J) {options[9]}
""".strip()

QUERY_TEMPLATE_MULTICHOICE_PRO_SELF_PLAY = """
Please reason step by step, and put your final answer within \\boxed{{}}. Your final answer should be of the following format: \\boxed{{LETTER}} where LETTER is one of ABCD.

{question}

A) {options[0]}
B) {options[1]}
C) {options[2]}
D) {options[3]}
E) {options[4]}
F) {options[5]}
G) {options[6]}
H) {options[7]}
I) {options[8]}
J) {options[9]}
""".strip()

QUERY_TEMPLATE_MULTICHOICE_KNOWLOGIC = (
    "Answer the following multiple choice question with one or more correct answers. \n"
    "The last line of your response should be of the following format: 'Answer: [LETTERS]' (without quotes) where place answer choices in []. LETTERS is one or more of the letters ABCD, separated by commas. \n"
    "No partial credit will be given for incorrect or incomplete answers. \n"
    "Answer choices must exactly match the standard answer to be considered correct. \n"
    "{text} \n\n"
    "Question: \n"
    "{question} \n\n"
    "Options: \n"
    "A) {options[A]} \n"
    "B) {options[B]} \n"
    "C) {options[C]} \n"
    "D) {options[D]} \n"
).strip()

QUERY_TEMPLATE_MULTICHOICE_SELF_PLAY_KNOWLOGIC = """
Please reason step by step, and put your final answer within \\boxed{{}}.
Answer the following multiple choice question with one or more correct answers. Put your final answer in the following format: \\boxed{{LETTERS}} where LETTERS is one or more of the letters ABCD, separated by commas.
{text} 

Question: 
{question}

Options:
A) {options[A]}
B) {options[B]}
C) {options[C]}
D) {options[D]}
""".strip()

ANSWER_PATTERN_MULTICHOICE = r"(?i)Answer[ \t]*:[ \t]*\$?([A-D])\$?"
ANSWER_PATTERN_MULTICHOICE_SELF_PLAY = r"(?i)\\boxed{([A-D])}"
ANSWER_PATTERN = r"(?i)Answer\s*:\s*([^\n]+)"
ANSWER_PATTERN_KNOWLOGIC = r"(?i)Answer\s*:\s*(?:\[)?([A-D](?:\s*,\s*[A-D])*)(?:\])?"
ANSWER_PATTERN_MULTICHOICE_SELF_PLAY_KNOWLOGIC = (
    r"(?i)\\boxed{([A-D](?:\s*,\s*[A-D])*)}"
)

MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = (
    "(?i){}[ \t]*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[Ａ]|[Ｂ]|[Ｃ]|[Ｄ])"
)
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE_PRO = "(?i){}[ \t]*([A-J]|[أ-ي]|[অ-ঝ]|[Ａ-Ｊ])"
ANSWER_PATTERN_MULTICHOICE_SELF_PLAY_PRO = r"(?i)\\boxed{([A-J]|[أ-ي]|[অ-ঝ]|[Ａ-Ｊ])}"
# All the different ways "Answer" is written in different languages
MULTILINGUAL_ANSWER_REGEXES = [
    "Answer\s*:",
    "Answer\s*:​​​​​​",  # Korean invisible character
    "উত্তর\s*:",
    "उत्तर\s*:",
    "উত্তরঃ",
    "উত্তর\s*:",
    "Antwort\s*:",
    "답변\s*:",
    "정답\s*:",
    "답\s*:",
    "答案\s*：",
    "答案\s*:",
    "答\s*：",
    "答\s*:",
    "答复\s*：",
    "答曰\s*：",
    "الإجابة:",
    "الجواب:",
    "إجابة:",
    "الإجابة النهائية:",
    "الإجابة الصحيحة:",
    "الإجابة الصحيحة هي:",
    "الإجابة هي:",
    "الجواب النهائي:",
    "Respuesta\s*:",
    "Risposta\s*:",
    "答え\s*:",
    "答え\s*：",
    "回答\s*:",
    "回答\s*：",
    "解答\s*:",
    "Jawaban\s*:",
    "Réponse\s*:",
    "Resposta\s*:",
    "Jibu\s*:",
    "Idahun\s*:",
    "Ìdáhùn\s*:",
    "Idáhùn\s*:",
    "Àmọ̀nà\s*:",
    "Àdáhùn\s*:",
    "Ànúgọ\s*:",
    "Àṣàyàn\s*:",
]


EQUALITY_TEMPLATE = r"""
Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications

Examples:

    Expression 1: $2x+3$
    Expression 2: $3+2x$

Yes

    Expression 1: 3/2
    Expression 2: 1.5

Yes

    Expression 1: $x^2+2x+1$
    Expression 2: $y^2+2y+1$

No

    Expression 1: $x^2+2x+1$
    Expression 2: $(x+1)^2$

Yes

    Expression 1: 3245/5
    Expression 2: 649

No
(these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications)

    Expression 1: 2/(-3)
    Expression 2: -2/3

Yes
(trivial simplifications are allowed)

    Expression 1: 72 degrees
    Expression 2: 72

Yes
(give benefit of the doubt to units)

    Expression 1: 64
    Expression 2: 64 square feet

Yes
(give benefit of the doubt to units)

---

YOUR TASK


Respond with only "Yes" or "No" (without quotes). Do not include a rationale.

    Expression 1: %(expression1)s
    Expression 2: %(expression2)s
""".strip()


HTML_JINJA = """
<h3>Prompt conversation</h3>
{% for message in prompt_messages %}
{{ message_to_html(message) | safe }}
{% endfor %}
<h3>Sampled message</h3>
{{ message_to_html(next_message) | safe }}
<h3>Results</h3>
<p>Correct Answer: {{ correct_answer }}</p>
<p>Extracted Answer: {{ extracted_answer }}</p>
<p>Score: {{ score }}</p>
"""


def format_multichoice_question(row):
    # return QUERY_TEMPLATE_MULTICHOICE.format(**row)
    question = QUERY_TEMPLATE_MULTICHOICE_SELF_PLAY.format(**row)
    if ("few_shot_examples" in row) and row["few_shot_examples"]:
        question = f"{row['few_shot_examples']}\n\n{question}"
    return question

def format_multichoice_question_few_shot(row):
    question = QUERY_TEMPLATE_MULTICHOICE_SELF_PLAY.format(**row)

    return f"""
{question}

{row["explanation"]}

Answer: \\boxed{{{row["correct_answer"]}}}
""".strip()

def format_multichoice_question_pro(row):
    """Format a multiple choice question with a variable number of options."""
    options = row["options"]
    option_count = len(options)

    # Generate option lines dynamically based on available options
    option_letters = "ABCDEFGHIJ"
    option_lines = []
    for i in range(option_count):
        option_lines.append(f"{option_letters[i]}) {options[i]}")

    formatted_options = "\n".join(option_lines)

    template = f"""
Please reason step by step, and put your final answer within \\boxed{{}}. Your final answer should be of the following format: \\boxed{{LETTER}} where LETTER is one of {option_letters[:option_count]}.
    
{row["question"]}
    
{formatted_options}

""".strip()

    if "few_shot_examples" in row:
        template = f"{row['few_shot_examples']}" + "\n\n" + template

    return template

def format_multichoice_question_pro_with_answers(row):
    """Format a multiple choice question with a variable number of options and answers."""
    options = row["options"]
    option_count = len(options)

    answer = row["answer"]
    cot_content = row["cot_content"]

    # Generate option lines dynamically based on available options
    option_letters = "ABCDEFGHIJ"
    option_lines = []
    for i in range(option_count):
        option_lines.append(f"{option_letters[i]}) {options[i]}")

    formatted_options = "\n".join(option_lines)

    template = f"""
Please reason step by step, and put your final answer within \\boxed{{}}. Your final answer should be of the following format: \\boxed{{LETTER}} where LETTER is one of {option_letters[:option_count]}.

{row["question"]}

{formatted_options}

{cot_content}
Answer: \\boxed{{{answer}}}
""".strip()
    
    return template

def check_equality(sampler: SamplerBase, expr1: str, expr2: str):
    prompt = EQUALITY_TEMPLATE % {"expression1": expr1, "expression2": expr2}
    response = sampler([dict(content=prompt, role="user")])
    return response.lower().strip() == "yes"


def _compute_stat(values: list, stat: str):
    if stat == "mean":
        return np.mean(values)
    elif stat == "std":
        return np.std(values)
    elif stat == "min":
        return np.min(values)
    elif stat == "max":
        return np.max(values)
    else:
        raise ValueError(f"Unknown {stat =}")


def aggregate_results(
    single_eval_results: list[SingleEvalResult],
    default_stats: tuple[str] = ("mean", "std"),
    name2stats: dict[str, tuple[str]] | None = None,
) -> EvalResult:
    """
    Aggregate results from multiple evaluations into a single EvalResult.
    """
    name2stats = name2stats or {}
    name2values = defaultdict(list)
    htmls = []
    convos = []
    for single_eval_result in single_eval_results:
        for name, value in single_eval_result.metrics.items():
            name2values[name].append(value)
        if single_eval_result.score is not None:
            name2values["score"].append(single_eval_result.score)
        htmls.append(single_eval_result.html)
        convos.append(single_eval_result.convo)
    final_metrics = {}
    for name, values in name2values.items():
        stats = name2stats.get(name, default_stats)
        for stat in stats:
            key = name if stat == "mean" else f"{name}:{stat}"
            final_metrics[key] = _compute_stat(values, stat)
    return EvalResult(
        score=final_metrics.pop("score", None),
        metrics=final_metrics,
        htmls=htmls,
        convos=convos,
    )


def map_with_progress(f: callable, xs: list[Any], num_threads: int = 100):
    """
    Apply f to each element of xs, using a ThreadPool, and show progress.
    """
    if os.getenv("debug"):
        return list(map(f, tqdm(xs, total=len(xs))))
    else:
        with ThreadPool(min(num_threads, len(xs))) as pool:
            return list(tqdm(pool.imap(f, xs), total=len(xs)))


jinja_env = jinja2.Environment(
    loader=jinja2.BaseLoader(),
    undefined=jinja2.StrictUndefined,
    autoescape=jinja2.select_autoescape(["html", "xml"]),
)
_message_template = """
<div class="message {{ role }}">
    <div class="role">
    {{ role }}
    {% if variant %}<span class="variant">({{ variant }})</span>{% endif %}
    </div>
    <div class="content">
    <pre>{{ content }}</pre>
    </div>
</div>
"""


def message_to_html(message: Message) -> str:
    """
    Generate HTML snippet (inside a <div>) for a message.
    """
    return jinja_env.from_string(_message_template).render(
        role=message["role"],
        content=message["content"],
        variant=message.get("variant", None),
    )


jinja_env.globals["message_to_html"] = message_to_html


_report_template = """<!DOCTYPE html>
<html>
    <head>
        <style>
            .message {
                padding: 8px 16px;
                margin-bottom: 8px;
                border-radius: 4px;
            }
            .message.user {
                background-color: #B2DFDB;
                color: #00695C;
            }
            .message.assistant {
                background-color: #B39DDB;
                color: #4527A0;
            }
            .message.system {
                background-color: #EEEEEE;
                color: #212121;
            }
            .role {
                font-weight: bold;
                margin-bottom: 4px;
            }
            .variant {
                color: #795548;
            }
            table, th, td {
                border: 1px solid black;
            }
            pre {
                white-space: pre-wrap;
            }
        </style>
    </head>
    <body>
    {% if metrics %}
    <h1>Metrics</h1>
    <table>
    <tr>
        <th>Metric</th>
        <th>Value</th>
    </tr>
    <tr>
        <td><b>Score</b></td>
        <td>{{ score | float | round(3) }}</td>
    </tr>
    {% for name, value in metrics.items() %}
    <tr>
        <td>{{ name }}</td>
        <td>{{ value }}</td>
    </tr>
    {% endfor %}
    </table>
    {% endif %}
    <h1>Examples</h1>
    {% for html in htmls %}
    {{ html | safe }}
    <hr>
    {% endfor %}
    </body>
</html>
"""


def make_report(eval_result: EvalResult) -> str:
    """
    Create a standalone HTML report from an EvalResult.
    """
    return jinja_env.from_string(_report_template).render(
        score=eval_result.score,
        metrics=eval_result.metrics,
        htmls=eval_result.htmls,
    )


def make_report_from_example_htmls(htmls: list[str]):
    """
    Create a standalone HTML report from a list of example htmls
    """
    return jinja_env.from_string(_report_template).render(
        score=None, metrics={}, htmls=htmls
    )


def normalize_response(response: str) -> str:
    """
    Normalize the response by removing markdown and LaTeX formatting that may prevent a match.
    """

    return (
        response.replace("**", "")
        .replace("$\\boxed{", "")
        .replace("}$", "")
        .replace("\\$", "")
        .replace("$\\text{", "")
        .replace("$", "")
        .replace("\\mathrm{", "")
        .replace("\\{", "")
        .replace("\\text", "")
        .replace("\\(", "")
        .replace("\\mathbf{", "")
        .replace("{", "")
        .replace("\\boxed", "")
    )


def normalize_extracted_answer(extracted_answer: str) -> str:
    return (
        # In arabic these are the letters used for A-D in multiple choice questions
        extracted_answer.replace("أ", " A")
        .replace("ب", " B")
        .replace("ج", " C")
        .replace("د", " D")
        # In Bengali these are the letters used for A-D in multiple choice questions
        .replace("অ", " A")
        .replace("ব", " B")
        .replace("ড", " C")
        .replace("ঢ", " D")
        # In Japanese these are the letters sometimes used for A-D in multiple choice questions
        .replace("Ａ", " A")
        .replace("Ｂ", " B")
        .replace("Ｃ", " C")
        .replace("Ｄ", " D")
        .strip()
    )


def url_to_fileobj(url: str, binary=False) -> Any:
    response = requests.get(url)
    response.raise_for_status()
    return io.BytesIO(response.content) if binary else io.StringIO(response.text)
