import concurrent.futures
import json
import logging
import re
import time

import boto3
from botocore.exceptions import ClientError

logger = logging.getLogger(__name__)
MAX_ERROR = 5


def call_llm_func(
    input_str,
    model_id="",
    max_tokens=512,
    temperature=1.0,
    top_p=0.999,
    top_k=40,
):
    """Call LLM API.
    :param input_str: user query
    :param model_id: the LLM to call (only support Claude and LLAMA now)
    :return: text generated by the LLM.
    """
    # Initialize the Amazon Bedrock runtime client
    client = boto3.client(service_name="bedrock-runtime", region_name="us-east-1")
    body_json = _construct_input(client, input_str, model_id, max_tokens, temperature, top_p, top_k)
    model_type = model_id.split(".")[0]
    error = 0
    while True:
        try:
            response = client.invoke_model(
                modelId=model_id,
                body=body_json,
            )
            result = json.loads(response.get("body").read())
            if model_type == "anthropic":
                output_list = result.get("content", [])
                output = output_list[0]["text"]
            elif model_type == "meta":
                output = result["generation"]
            else:
                raise NotImplementedError("Please check your model id!")

        except ClientError as err:
            if error < MAX_ERROR:
                error += 1
                logger.error(
                    f"Coundn't invoke {model_id}. Here's why: %s: %s",
                    err.response["Error"]["Code"],
                    err.response["Error"]["Message"],
                )
                logger.info("Sleep 20 secs and retry...")
                time.sleep(20)
                continue
            else:
                raise RecursionError("Exceed maximum API calls!")
        break
    return output


def _construct_input(
    client,
    input_str,
    model_id,
    max_tokens,
    temperature,
    top_p,
    top_k,
):
    """Construct input to LLM
    :param client: boto3.client
    :return: generated output with the format of json body.
    """
    prompt = f"You are a helpful assistant. <s>[INST] {input_str} [/INST]"
    model_type = model_id.split(".")[0]
    if model_type == "anthropic":
        body_json = json.dumps(
            {
                "anthropic_version": "bedrock-2023-05-31",
                "max_tokens": max_tokens,
                "messages": [
                    {
                        "role": "user",
                        "content": [{"type": "text", "text": prompt}],
                    }
                ],
                "temperature": temperature,
                "top_p": top_p,
                "top_k": top_k,
            }
        )
    elif model_type == "meta":
        body_json = json.dumps(
            {
                "prompt": prompt,
                "max_gen_len": max_tokens,
                "temperature": temperature,
                "top_p": top_p,
                "top_k": top_k,
            }
        )
    else:
        raise NotImplementedError("Please check your model id!")
    return body_json


def parallel_call_llm(
    prompts_grouped_by_batch_size,
    model_id,
    max_tokens=512,
    temperature=1.0,
    top_p=0.999,
    top_k=40,
):
    """Call LLM in parallel to facilitate the speed.
    :param prompts_grouped_by_batch_size: raw prompts grouped by batch size.
    :return raw outputs generated by LLM.
    """
    raw_outputs = []
    for i, prompts_single_batch in enumerate(prompts_grouped_by_batch_size):
        batch_time_start = time.time()
        print(f"Batch {i+1} is under inference...")
        batch_outputs = ["" for _ in range(len(prompts_single_batch))]
        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = [
                executor.submit(
                    call_llm_func,
                    input_str=_input,
                    model_id=model_id,
                    max_tokens=max_tokens,
                    temperature=temperature,
                    top_p=top_p,
                    top_k=top_k,
                )
                for _input in prompts_single_batch
            ]
            for p_id, future in enumerate(futures):
                try:
                    data = future.result()
                    batch_outputs[p_id] = data
                except Exception:
                    batch_outputs[p_id] = ""
        print(f"Batch time duration: {time.time()-batch_time_start}")
        raw_outputs += batch_outputs
    return raw_outputs


def group_batch(
    dataset,
    data_index,
    meta_prompt,
    batch_size,
    is_boolean,
):
    """Group a batch of queries that could be further fed into LLMs in parallel.
    :param dataset: datasets consist of input and target.
    :param data_index: indicate a subset of samples in the dataset.
    :param meta_prompt: a prompt used in the current iteration.
    :param batch_size: batch size used in parallel.
    :param is_boolean: whether the task is boolean.
    :return: prompts grouped by batch size.
    """
    num_examples = len(data_index)
    if batch_size >= num_examples:
        batch_size = num_examples
    prompts_grouped_by_batch_size = []
    prompts_single_batch = []
    true_answers = []
    i = 0
    while i < num_examples:
        if is_boolean:
            prompt = (
                dataset[data_index[i]]["input"]
                + " Put your answer choosing from the previously stated options within tag <Ans> and </Ans>. The answer should be necassarily between the tags.\n"
                + meta_prompt
            )
        else:
            prompt = (
                dataset[data_index[i]]["input"]
                + "\n"
                + meta_prompt
                + " First provide the reasoning path, then put your answer choosing from the previously stated options within tag <Ans> and </Ans>. The answer should be necassarily between the tags.\n"
            )
        prompts_single_batch.append(prompt)
        true_answer = dataset[data_index[i]]["target"]
        true_answers.append(true_answer)
        i += 1
        if i % batch_size == 0:
            prompts_grouped_by_batch_size.append(prompts_single_batch)
            prompts_single_batch = []
    if prompts_single_batch:
        prompts_grouped_by_batch_size.append(prompts_single_batch)
    return prompts_grouped_by_batch_size, true_answers


def extract_output(raw_output, tag):
    """Extract the output in the raw generated text.
    :param tag: the tag used to format the output.
    :return: extracted outputs.
    """
    matches = re.findall(r"<" + tag + r">(.+?)</" + tag + r">", raw_output, flags=re.DOTALL)
    if len(matches) == 0:
        return ""
    output = matches[0]
    if tag == "Ans":
        output = output.lower()
    return output
