import json
import typing as tp


def mbpp_single_testcase(
  *,
  test_list: tp.Sequence[str],
  prompt: str | None = None,
  text: str | None = None,
):
  """Adapted from https://github.com/bigcode-project/bigcode-evaluation-harness/blob/f09c68161480839eb10fc238dcfe4ce3d056fc06/lm_eval/tasks/mbpp.py#L51"""
  description = text or prompt
  assert description is not None, "prompt or code must be provided"
  return f'"""\n{description}\n{test_list[0]}\n"""\n'


def mbpp_prompt_all_testcases(
  *,
  test_list: tp.Sequence[str],
  prompt: str | None = None,
  text: str | None = None,
):
  """Adapted from https://github.com/bigcode-project/bigcode-evaluation-harness/blob/f09c68161480839eb10fc238dcfe4ce3d056fc06/lm_eval/tasks/mbpp.py#L51"""
  description = text or prompt
  assert description is not None, "prompt or code must be provided"
  test_cases = "\n".join(test_list)
  return f'"""\n{description}\n{test_cases}\n"""\n'


# ---------------------------------------------------------------------------- #
#                                 APPS Prompts                                 #
# ---------------------------------------------------------------------------- #
APPS_PROMPTS = dict(
  code_llama="""[INST] Write a python code to solve the following coding problem that obeys the constraints and
passes the example test cases. The output code needs to {call_format}. Please wrap your code answer using ```:
{question}
[/INST]""",
  default="\nQUESTION:\n{question}\n{call_format}\nAnswer:\n",
  phi2_instruct="Instruction: Write a python code to solve the following coding problem that obeys the constraints and passes the example test cases. {call_format}.\nQuestion:\n{question}\n\nPython Solution:\n",
)


def create_apps_question(question: str, input_output: str):
  try:
    input_output_parsed = json.loads(input_output)
    fn_name = input_output_parsed.get("fn_name", None)
  except ValueError:
    fn_name = None

  call_format = "Use Standard Input format" if not fn_name else "Use Call-Based format"

  return question, call_format


def apps_instruct(*, question: str, input_output: str, starter_code=None):
  question, call_format = create_apps_question(question, input_output)
  p = APPS_PROMPTS["default"].format(question=question, call_format=call_format)
  if starter_code:
    p += f"\n{starter_code}"
  return p


def apps_codellama_instruct(*, question: str, input_output: str, starter_code=None):
  question, call_format = create_apps_question(question, input_output)
  p = APPS_PROMPTS["code_llama"].format(question=question, call_format=call_format)
  if starter_code:
    p += f"\n{starter_code}"
  return p


def apps_phi2_instruct(*, question: str, input_output: str, starter_code=None):
  question, call_format = create_apps_question(question, input_output)
  p = APPS_PROMPTS["phi2_instruct"].format(question=question, call_format=call_format)
  if starter_code:
    p += f"\n{starter_code}"
  return p


def passthrough_prompt(*, prompt: str):
  return prompt


# ---------------------------------------------------------------------------- #
#                                    GSM-8k                                    #
# ---------------------------------------------------------------------------- #
def gsm8k_cot_zeroshot(*, question: str):
  return f"Q: {question}\nA: Let's think step by step."


def gsm8k_cot_zeroshot_gemma(*, question: str):
  PREAMBLE = (
    "As an expert problem solver solve step by step the following mathematical questions. "
    "Format your answer in plaintext (do not use any latex or markdown). "
    "In your response, include only your reasoning and final answer, and nothing else."
  )

  return "\n".join([PREAMBLE, gsm8k_cot_zeroshot(question=question)])


# ---------------------------------------------------------------------------- #
#                                   Trivia QA                                  #
# ---------------------------------------------------------------------------- #


def trivia_qa_zeroshot(*, question: str):
  return f"Q: {question}\nA:"


def trivia_qa_5_shot(*, question: str):
  # The following examples are from the TriviaQA dataset,
  # sampled from the training set.
  fewshot_examples = [
    {
      "idx": 136062,
      "question": "Which former major league baseball pitcher, known as The Big Unit, now pitches Geico?",
      "answer": "Randy Johnson",
      "hf_split": "train",
      "hf_name": "rc.nocontext",
    },
    {
      "idx": 84467,
      "question": "The Philippines were named after which king of Spain?",
      "answer": "King Philip II",
      "hf_split": "train",
      "hf_name": "rc.nocontext",
    },
    {
      "idx": 126280,
      "question": "US Vice-President Joe Biden represents which state?",
      "answer": "DELAWARE",
      "hf_split": "train",
      "hf_name": "rc.nocontext",
    },
    {
      "idx": 7010,
      "question": "Which, now defunct, political party was founded by Declan Ganley in April 2009?",
      "answer": "Libertas Ireland",
      "hf_split": "train",
      "hf_name": "rc.nocontext",
    },
    {
      "idx": 75734,
      "question": "Sept 30, 1966 saw the public unveiling of which popular model of Boeing aircraft?",
      "answer": "747",
      "hf_split": "train",
      "hf_name": "rc.nocontext",
    },
  ]
  prompt_strs = []
  for example in fewshot_examples:
    prompt_strs.append(f"Q: {example['question']}\nA: {example['answer']}")

  prompt_strs.append(trivia_qa_zeroshot(question=question))
  return "\n".join(prompt_strs)


def trivia_qa_5_shot_w_instruction(*, question: str):
  prompt = trivia_qa_5_shot(question=question)

  return f"Answer these questions:\n{prompt}"


def math_zeroshot(*, problem: str):
  return f"Problem: {problem}\nAnswer:"


math_fewshot_examples = [
  {
    "problem": "Grandma gave Bryce and Carter some raisins. Bryce received 6 more raisins than Carter, and Carter received half the number of raisins Bryce received. How many raisins did Bryce receive?",
    "solution": "Let the number of raisins Bryce received be $x$. Since Bryce received 6 more raisins than Carter, Carter received $x-6$ raisins. Since Carter received half the number of raisins Bryce did, Carter also received $x/2$ raisins. We have two ways of expressing the number of raisins Carter received, so we have the equation $x-6=x/2$, or $x=12$. Thus, Bryce received $\\boxed{12}$ raisins.",
  },
  {
    "problem": "Evaluate the expression $\\left\\lceil{\\frac54}\\right\\rceil+\\left\\lfloor{-\\frac54}\\right\\rfloor$.",
    "solution": "$1<\\frac54<2$, so the smallest integer greater than or equal to $\\frac54$ is $2$. Similarly, $-2<-\\frac54<-1$, so the largest integer less than or equal to $-\\frac54$ is $-2$. The original expression, $\\left\\lceil{\\frac54}\\right\\rceil+\\left\\lfloor{-\\frac54}\\right\\rfloor$, is equal to the sum of the two, which is just $2+(-2)=\\boxed{0}$.",
  },
  {
    "problem": "Each of the numbers $a_1,$ $a_2,$ $\\dots,$ $a_{95}$ is $\\pm 1.$  Find the smallest possible positive value of\n\\[\\sum_{1 \\le i < j \\le 95} a_i a_j.\\]",
    "solution": "Let $m$ and $n$ denote the number of 1's and $-1$'s among the $a_i,$ respectively.  Then $m + n = 95$ and\n\\[a_1^2 + a_2^2 + \\dots + a_{95}^2 = 95.\\]Let\n\\[S = \\sum_{1 \\le i < j \\le 95} a_i a_j.\\]Then\n\\[2S + 95 = (a_1 + a_2 + \\dots + a_{95})^2 = (m - n)^2.\\]Note that $m - n = m + n - 2n = 95 - 2n$ is odd, so $(m - n)^2$ is an odd perfect square.  To minimize $S,$ while still keeping it positive, we take $(m - n)^2$ as the smallest odd perfect square greater than 95, which is 121.  Then $S = \\frac{121 - 95}{2} = 13.$\n\nEquality occurs when $m = 53$ and $n = 42,$ so the smallest possible positive value of $S$ is $\\boxed{13}.$",
  },
  {
    "problem": "Jo adds up all the positive integers from 1 to 50. Kate does a similar thing with the first 50 positive integers; however, she first rounds every integer to its nearest multiple of 10 (rounding 5s up) and then adds the 50 values. What is the positive difference between Jo's sum and Kate's sum?",
    "solution": "Consider the numbers $1, 2, 3,\\dots, 10$. Jo would add these integers up as is, while Kate would round the first four down to 0, decreasing her sum by $1+2+3+4=10$, and would round the last six up to 10, increasing her sum by $5+4+3+2+1+0=15$. Thus, her sum is $-10+15=5$ more than Jo's sum for the numbers $1, 2, 3,\\dots, 10$. This same logic applies to the numbers $11, 12, 13,\\dots, 20$ also, and in general it applies to every ten numbers greater than 20. Since there are five sets of ten numbers from 1 to 50, Kate's sum is $5 \\cdot 5 = \\boxed{25}$ more than Jo's sum.",
  },
]


def math_4shot(*, problem: str):
  prompt_strs = []
  for example in math_fewshot_examples:
    prompt_strs.append(f"Problem: {example['problem']}\nAnswer: {example['solution']}")

  prompt_strs.append(math_zeroshot(problem=problem))
  return "\n".join(prompt_strs)


MINERVA_FEWSHOT_EXAMPLES = [
  {
    "problem": "Find the domain of the expression  $\\frac{\\sqrt{x-2}}{\\sqrt{5-x}}$.}",
    "solution": "The expressions inside each square root must be non-negative. Therefore, $x-2 \\ge 0$, so $x\\ge2$, and $5 - x \\ge 0$, so $x \\le 5$. Also, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$. Therefore, the domain of the expression is $\\boxed{[2,5)}$.\nFinal Answer: The final answer is $[2,5)$. I hope it is correct.",
    "few_shot": "1",
  },
  {
    "problem": "If $\\det \\mathbf{A} = 2$ and $\\det \\mathbf{B} = 12,$ then find $\\det (\\mathbf{A} \\mathbf{B}).$",
    "solution": "We have that $\\det (\\mathbf{A} \\mathbf{B}) = (\\det \\mathbf{A})(\\det \\mathbf{B}) = (2)(12) = \\boxed{24}.$\nFinal Answer: The final answer is $24$. I hope it is correct.",
    "few_shot": "1",
  },
  {
    "problem": "Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight?",
    "solution": "If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\\cdot 12\\cdot20=480$ pounds of weight.  If he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\\cdot15\\cdot n=30n$ pounds of weight.  Equating this to 480 pounds, we can solve for $n$:\n\\begin{align*}\n30n&=480\\\n\\Rightarrow\\qquad n&=480/30=\\boxed{16}\n\\end{align*}\nFinal Answer: The final answer is $16$. I hope it is correct.",
    "few_shot": "1",
  },
  {
    "problem": "If the system of equations\n\n\\begin{align*}\n6x-4y&=a,\\\n6y-9x &=b.\n\\end{align*}has a solution $(x, y)$ where $x$ and $y$ are both nonzero,\nfind $\\frac{a}{b},$ assuming $b$ is nonzero.",
    "solution": "If we multiply the first equation by $-\\frac{3}{2}$, we obtain\n\n$$6y-9x=-\\frac{3}{2}a.$$Since we also know that $6y-9x=b$, we have\n\n$$-\\frac{3}{2}a=b\\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}.$$\nFinal Answer: The final answer is $-\\frac{2}{3}$. I hope it is correct.",
    "few_shot": "1",
  },
]


def math_minerva_zeroshot(*, problem: str):
  return f"Problem\n{problem}\n\nSolution:"


def math_minerva_4shot(*, problem: str):
  prompt_strs = []
  for example in MINERVA_FEWSHOT_EXAMPLES:
    prompt_strs.append(
      f"Problem:\n{example['problem']}\n\nSolution: {example['solution']}"
    )

  prompt_strs.append(math_zeroshot(problem=problem))
  return "\n\n".join(prompt_strs)


def gemma_math_4shot(*, problem: str):
  default_minerva_prompt = math_minerva_4shot(problem=problem)

  return f"Answer the following questions. Your final answer should always follow the same format: 'Final Answer: The final answer is [answer]. I hope it is correct.'\n{default_minerva_prompt}"


# ---------------------------------------------------------------------------- #
#                                   registry                                   #
# ---------------------------------------------------------------------------- #
PROMPTS: dict[str, tp.Callable] = {
  "mbpp_single_testcase": mbpp_single_testcase,
  "mbpp_all_testcases": mbpp_prompt_all_testcases,
  "apps_instruct": apps_instruct,
  "apps_codellama_instruct": apps_codellama_instruct,
  "apps_phi2_instruct": apps_phi2_instruct,
  "passthrough": passthrough_prompt,
  "gsm8k_cot_zeroshot": gsm8k_cot_zeroshot,
  "trivia_qa_zeroshot": trivia_qa_zeroshot,
  "trivia_qa_5_shot": trivia_qa_5_shot,
  "trivia_qa_5_shot_with_description": trivia_qa_5_shot_w_instruction,
  "math-zeroshot": math_minerva_zeroshot,
  "math-4-shot": math_minerva_4shot,
  "gemma-math-4-shot": gemma_math_4shot,
  "gsm8k_cot_zeroshot_gemma": gsm8k_cot_zeroshot_gemma,
}
