"""Contains classes for querying large language models."""
import os
import time
from tqdm import tqdm
from abc import ABC, abstractmethod
from transformers import LlamaForCausalLM, LlamaTokenizer
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import openai
import torch
import asyncio
from typing import Any
import boto3
import json
import time


gpt_costs_per_thousand = {
    'davinci': 0.0200,
    'curie': 0.0020,
    'babbage': 0.0005,
    'ada': 0.0004
}


async def dispatch_openai_requests(
    messages_list: list[list[dict[str,Any]]],
    model: str,
    temperature: float,
    max_tokens: int,
    frequency_penalty: int,
    presence_penalty: int
) -> list[str]:
    """Dispatches requests to OpenAI API asynchronously.
    
    Args:
        messages_list: List of messages to be sent to OpenAI ChatCompletion API.
        model: OpenAI model to use.
        temperature: Temperature to use for the model.
        max_tokens: Maximum number of tokens to generate.
        top_p: Top p to use for the model.
    Returns:
        List of responses from OpenAI API.
    """
    # for x in messages_list:
        # try:
    async_responses = [openai.ChatCompletion.acreate(
            model=model,
            messages=x,
            temperature=temperature,
            max_tokens=max_tokens,
            frequency_penalty=frequency_penalty,
            presence_penalty=presence_penalty) for x in messages_list]

    return await asyncio.gather(*async_responses)


def model_from_config(config, disable_tqdm=True):
    """Returns a model based on the config."""
    model_type = config["name"]
    
    ###
    if model_type in ["Claude_2", "Claude_3", "Claude_3_5", "Llama_3_70B", "Claude_3_haiku", "Claude_3_opus"]:
        return Claude_Forward(config, version=model_type, disable_tqdm=disable_tqdm)
    raise ValueError(f"Unknown model type: {model_type}")


class LLM(ABC):
    """Abstract base class for large language models."""

    @abstractmethod
    def generate_text(self, prompt):
        """Generates text from the model.
        Parameters:
            prompt: The prompt to use. This can be a string or a list of strings.
        Returns:
            A list of strings.
        """
        pass

    @abstractmethod
    def log_probs(self, text, log_prob_range):
        """Returns the log probs of the text.
        Parameters:
            text: The text to get the log probs of. This can be a string or a list of strings.
            log_prob_range: The range of characters within each string to get the log_probs of. 
                This is a list of tuples of the form (start, end).
        Returns:
            A list of log probs.
        """
        pass

class Llama_Forward(LLM):
    """Wrapper for llama."""

    def __init__(self, config, needs_confirmation=False, disable_tqdm=True):
        """Initializes the model."""
        SIZE=13
        
        # MODEL_DIR = '/data/bobchen/llama4trans/llama-{}b'.format(SIZE)
        # TOKENIZER_DIR = '/data/bobchen/llama4trans/tokenizer'
        MODEL_DIR = '/home/collabml/anaconda3/xl/Neural-bandits-for-prompt-searching/models/vicuna-13b'
        # TOKENIZER_DIR = '/data/bobchen/llama4trans/tokenizer'
        
        self.config = config
        self.needs_confirmation = needs_confirmation
        self.disable_tqdm = disable_tqdm
        kwargs={'torch_dtype': torch.float16}
        self.device = torch.device('cuda:1')
        self.model=LlamaForCausalLM.from_pretrained(MODEL_DIR, low_cpu_mem_usage=True, **kwargs).to(self.device)
        self.tokenizer=AutoTokenizer.from_pretrained(
                                MODEL_DIR,
                                model_max_length=512,
                                padding_side='left',
                                use_fast=False,
                            )

    def auto_reduce_n(self, fn, prompt, n):
        """Reduces n by half until the function succeeds."""
        try:
            return fn(prompt, n)
        except BatchSizeException as e:
            if n == 1:
                raise e
            return self.auto_reduce_n(fn, prompt, n // 2) + self.auto_reduce_n(fn, prompt, n // 2)

    def generate_text(self, prompts, n):
        if not isinstance(prompts, list):
            prompts = [prompts]
        text = []
        for prompt in prompts:
            input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
            # Generate
            generate_ids = self.model.generate(input_ids, max_new_tokens=32)
            text.append(self.tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0])
        return text

    def complete(self, prompt, n):
        """Generates text from the model and returns the log prob data."""
        if not isinstance(prompt, list):
            prompt = [prompt]
        batch_size = self.config['batch_size']
        prompt_batches = [prompt[i:i + batch_size]
                          for i in range(0, len(prompt), batch_size)]
        if not self.disable_tqdm:
            print(
                f"[{self.config['name']}] Generating {len(prompt) * n} completions, " 
                f"split into {len(prompt_batches)} batches of size {batch_size * n}")
        res = []
        for prompt_batch in tqdm(prompt_batches, disable=self.disable_tqdm):
            res += self.__complete(prompt_batch, n)
        return res

    def log_probs(self, text, log_prob_range=None):
        """Returns the log probs of the text."""
        if not isinstance(text, list):
            text = [text]
        if self.needs_confirmation:
            self.confirm_cost(text, 1, 0)
        batch_size = self.config['batch_size']
        text_batches = [text[i:i + batch_size]
                        for i in range(0, len(text), batch_size)]
        if log_prob_range is None:
            log_prob_range_batches = [None] * len(text)
        else:
            assert len(log_prob_range) == len(text)
            log_prob_range_batches = [log_prob_range[i:i + batch_size]
                                      for i in range(0, len(log_prob_range), batch_size)]
        if not self.disable_tqdm:
            print(
                f"[{self.config['name']}] Getting log probs for {len(text)} strings, "
                f"split into {len(text_batches)} batches of (maximum) size {batch_size}")
        log_probs = []
        tokens = []
        for text_batch, log_prob_range in tqdm(list(zip(text_batches, log_prob_range_batches)),
                                               disable=self.disable_tqdm):
            log_probs_batch, tokens_batch = self.__log_probs(
                text_batch, log_prob_range)
            log_probs += log_probs_batch
            tokens += tokens_batch
        return log_probs, tokens


class Flan_T5(LLM):
    """Wrapper for llama."""

    def __init__(self, config, needs_confirmation=False, disable_tqdm=True):
        """Initializes the model."""
        self.device="cuda:1"
        self.config = config
        self.needs_confirmation = needs_confirmation
        self.disable_tqdm = disable_tqdm
        self.model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-xxl",
                                                    torch_dtype=torch.float16).to(device=self.device)
        self.tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xxl")

    def auto_reduce_n(self, fn, prompt, n):
        """Reduces n by half until the function succeeds."""
        try:
            return fn(prompt, n)
        except BatchSizeException as e:
            if n == 1:
                raise e
            return self.auto_reduce_n(fn, prompt, n // 2) + self.auto_reduce_n(fn, prompt, n // 2)

    def generate_text(self, prompts, n):
        if not isinstance(prompts, list):
            prompts = [prompts]
        text = []
        batch_size=10
        for i in range(len(prompts) // batch_size):
            tmp_prompts = prompts[i*batch_size:(i+1)*batch_size]
            input_ids = self.tokenizer(tmp_prompts, padding='longest', return_tensors="pt").input_ids.to(device=self.device)
            outputs = self.model.generate(input_ids, max_new_tokens=32)
            text += self.tokenizer.batch_decode(outputs, skip_special_tokens=True) 
            
        return text

    def complete(self, prompt, n):
        """Generates text from the model and returns the log prob data."""
        if not isinstance(prompt, list):
            prompt = [prompt]
        batch_size = self.config['batch_size']
        prompt_batches = [prompt[i:i + batch_size]
                          for i in range(0, len(prompt), batch_size)]
        if not self.disable_tqdm:
            print(
                f"[{self.config['name']}] Generating {len(prompt) * n} completions, "
                f"split into {len(prompt_batches)} batches of size {batch_size * n}")
        res = []
        for prompt_batch in tqdm(prompt_batches, disable=self.disable_tqdm):
            res += self.__complete(prompt_batch, n)
        return res

    def log_probs(self, text, log_prob_range=None):
        """Returns the log probs of the text."""
        if not isinstance(text, list):
            text = [text]
        if self.needs_confirmation:
            self.confirm_cost(text, 1, 0)
        batch_size = self.config['batch_size']
        text_batches = [text[i:i + batch_size]
                        for i in range(0, len(text), batch_size)]
        if log_prob_range is None:
            log_prob_range_batches = [None] * len(text)
        else:
            assert len(log_prob_range) == len(text)
            log_prob_range_batches = [log_prob_range[i:i + batch_size]
                                      for i in range(0, len(log_prob_range), batch_size)]
        if not self.disable_tqdm:
            print(
                f"[{self.config['name']}] Getting log probs for {len(text)} strings, "
                f"split into {len(text_batches)} batches of (maximum) size {batch_size}")
        log_probs = []
        tokens = []
        for text_batch, log_prob_range in tqdm(list(zip(text_batches, log_prob_range_batches)),
                                               disable=self.disable_tqdm):
            log_probs_batch, tokens_batch = self.__log_probs(
                text_batch, log_prob_range)
            log_probs += log_probs_batch
            tokens += tokens_batch
        return log_probs, tokens





class GPT_Forward(LLM):
    """Wrapper for GPT-3."""

    def __init__(self, config, needs_confirmation=False, disable_tqdm=True):
        """Initializes the model."""
        self.config = config
        self.needs_confirmation = needs_confirmation
        self.disable_tqdm = disable_tqdm

    def confirm_cost(self, texts, n, max_tokens):
        total_estimated_cost = 0
        for text in texts:
            total_estimated_cost += gpt_get_estimated_cost(
                self.config, text, max_tokens) * n
        print(f"Estimated cost: ${total_estimated_cost:.2f}")
        # Ask the user to confirm in the command line
        if os.getenv("LLM_SKIP_CONFIRM") is None:
            confirm = input("Continue? (y/n) ")
            if confirm != 'y':
                raise Exception("Aborted.")

    def auto_reduce_n(self, fn, prompt, n):
        """Reduces n by half until the function succeeds."""
        try:
            return fn(prompt, n)
        except BatchSizeException as e:
            if n == 1:
                raise e
            return self.auto_reduce_n(fn, prompt, n // 2) + self.auto_reduce_n(fn, prompt, n // 2)

    def generate_text(self, prompt, n):
        if not isinstance(prompt, list):
            prompt = [prompt]
        if self.needs_confirmation:
            self.confirm_cost(
                prompt, n, self.config['gpt_config']['max_tokens'])
        batch_size = self.config['batch_size']
        prompt_batches = [prompt[i:i + batch_size]
                          for i in range(0, len(prompt), batch_size)]
        if not self.disable_tqdm:
            print(
                f"[{self.config['name']}] Generating {len(prompt) * n} completions, "
                f"split into {len(prompt_batches)} batches of size {batch_size * n}")
        text = []

        for prompt_batch in tqdm(prompt_batches, disable=self.disable_tqdm):
            # text += self.auto_reduce_n(self.__generate_text, prompt_batch, n)
            text += self.__async_generate(prompt_batch, n)
        return text

    def complete(self, prompt, n):
        """Generates text from the model and returns the log prob data."""
        if not isinstance(prompt, list):
            prompt = [prompt]
        batch_size = self.config['batch_size']
        prompt_batches = [prompt[i:i + batch_size]
                          for i in range(0, len(prompt), batch_size)]
        if not self.disable_tqdm:
            print(
                f"[{self.config['name']}] Generating {len(prompt) * n} completions, "
                f"split into {len(prompt_batches)} batches of size {batch_size * n}")
        res = []
        for prompt_batch in tqdm(prompt_batches, disable=self.disable_tqdm):
            res += self.__complete(prompt_batch, n)
        return res

    def log_probs(self, text, log_prob_range=None):
        """Returns the log probs of the text."""
        if not isinstance(text, list):
            text = [text]
        if self.needs_confirmation:
            self.confirm_cost(text, 1, 0)
        batch_size = self.config['batch_size']
        text_batches = [text[i:i + batch_size]
                        for i in range(0, len(text), batch_size)]
        if log_prob_range is None:
            log_prob_range_batches = [None] * len(text)
        else:
            assert len(log_prob_range) == len(text)
            log_prob_range_batches = [log_prob_range[i:i + batch_size]
                                      for i in range(0, len(log_prob_range), batch_size)]
        if not self.disable_tqdm:
            print(
                f"[{self.config['name']}] Getting log probs for {len(text)} strings, "
                f"split into {len(text_batches)} batches of (maximum) size {batch_size}")
        log_probs = []
        tokens = []
        for text_batch, log_prob_range in tqdm(list(zip(text_batches, log_prob_range_batches)),
                                               disable=self.disable_tqdm):
            log_probs_batch, tokens_batch = self.__log_probs(
                text_batch, log_prob_range)
            log_probs += log_probs_batch
            tokens += tokens_batch
        return log_probs, tokens

    def __generate_text(self, prompt, n):
        """Generates text from the model."""
        if not isinstance(prompt, list):
            text = [prompt]
        config = self.config['gpt_config'].copy()
        config['n'] = n
        answer = []
        # If there are any [APE] tokens in the prompts, remove them
        for i in range(len(prompt)):
            prompt_single = prompt[i].replace('[APE]', '').strip()
            response = None

            while response is None:
                try:
                    response = openai.ChatCompletion.create(
                        model="gpt-3.5-turbo",
                        messages=[{"role": "user", "content": prompt_single}],
                        temperature=0.0,
                        max_tokens=256,
                        frequency_penalty=0,
                        presence_penalty=0)

                except Exception as e:
                    if 'is greater than the maximum' in str(e):
                        raise BatchSizeException()
                    print(e)
                    print('Retrying...')
                    time.sleep(5)
                try:
                    # print(response['choices'][0]["message"]["content"])
                    answer.append(response['choices'][0]["message"]["content"])
                except Exception:
                    answer.append('do not have reponse from chatgpt')

        return answer
    
    def __async_generate(self, prompt, n):
        ml = [[{"role": "user", "content": p.replace('[APE]', '').strip()}] for p in prompt]
        answer = None
        # print(self.config)
        if "text" in self.config['gpt_config']['model']:
            raise ValueError
        else:
            model = self.config['gpt_config']['model']
        print(model)
        while answer is None:
            try:
                predictions = asyncio.run(asyncio.wait_for(dispatch_openai_requests(
                    messages_list = ml,
                    model=model,
                    temperature=0,
                    max_tokens=256,
                    frequency_penalty=0,
                    presence_penalty=0
                    ), timeout=25)
                )
            except asyncio.TimeoutError:
                print("The task exceeded the time limit 25 s.")
            except Exception as e:
                # if 'is greater than the maximum' in str(e):
                #     raise BatchSizeException()
                print(e)
                print("Retrying....")
                time.sleep(20)

            try:
                answer = [x['choices'][0]['message']['content'] for x in predictions]
            except Exception:
                print("Please Wait!")

        return answer
    

    def __complete(self, prompt, n):
        """Generates text from the model and returns the log prob data."""
        if not isinstance(prompt, list):
            text = [prompt]
        config = self.config['gpt_config'].copy()
        config['n'] = n
        # If there are any [APE] tokens in the prompts, remove them
        for i in range(len(prompt)):
            prompt[i] = prompt[i].replace('[APE]', '').strip()
        response = None
        while response is None:
            try:
                response = openai.Completion.create(
                    **config, prompt=prompt)
            except Exception as e:
                print(e)
                print('Retrying...')
                time.sleep(5)
        return response['choices']

    def __log_probs(self, text, log_prob_range=None):
        """Returns the log probs of the text."""
        if not isinstance(text, list):
            text = [text]
        if log_prob_range is not None:
            for i in range(len(text)):
                lower_index, upper_index = log_prob_range[i]
                assert lower_index < upper_index
                assert lower_index >= 0
                assert upper_index - 1 < len(text[i])
        config = self.config['gpt_config'].copy()
        config['logprobs'] = 1
        config['echo'] = True
        config['max_tokens'] = 0
        if isinstance(text, list):
            text = [f'\n{text[i]}' for i in range(len(text))]
        else:
            text = f'\n{text}'
        response = None
        while response is None:
            try:
                response = openai.Completion.create(
                    **config, prompt=text)
                # import pdb;pdb.set_trace()

            except Exception as e:
                print(e)
                print('Retrying...')
                time.sleep(5)
        log_probs = [response['choices'][i]['logprobs']['token_logprobs'][1:]
                     for i in range(len(response['choices']))]
        tokens = [response['choices'][i]['logprobs']['tokens'][1:]
                  for i in range(len(response['choices']))]
        offsets = [response['choices'][i]['logprobs']['text_offset'][1:]
                   for i in range(len(response['choices']))]

        # Subtract 1 from the offsets to account for the newline
        for i in range(len(offsets)):
            offsets[i] = [offset - 1 for offset in offsets[i]]

        if log_prob_range is not None:
            # First, we need to find the indices of the tokens in the log probs
            # that correspond to the tokens in the log_prob_range
            for i in range(len(log_probs)):
                lower_index, upper_index = self.get_token_indices(
                    offsets[i], log_prob_range[i])
                log_probs[i] = log_probs[i][lower_index:upper_index]
                tokens[i] = tokens[i][lower_index:upper_index]

        return log_probs, tokens

    def get_token_indices(self, offsets, log_prob_range):
        """Returns the indices of the tokens in the log probs that correspond to the tokens in the log_prob_range."""
        # For the lower index, find the highest index that is less than or equal to the lower index
        lower_index = 0
        for i in range(len(offsets)):
            if offsets[i] <= log_prob_range[0]:
                lower_index = i
            else:
                break

        upper_index = len(offsets)
        for i in range(len(offsets)):
            if offsets[i] >= log_prob_range[1]:
                upper_index = i
                break

        return lower_index, upper_index


class GPT_Insert(LLM):

    def __init__(self, config, needs_confirmation=False, disable_tqdm=True):
        """Initializes the model."""
        self.config = config
        self.needs_confirmation = needs_confirmation
        self.disable_tqdm = disable_tqdm

    def confirm_cost(self, texts, n, max_tokens):
        total_estimated_cost = 0
        for text in texts:
            total_estimated_cost += gpt_get_estimated_cost(
                self.config, text, max_tokens) * n
        print(f"Estimated cost: ${total_estimated_cost:.2f}")
        # Ask the user to confirm in the command line
        if os.getenv("LLM_SKIP_CONFIRM") is None:
            confirm = input("Continue? (y/n) ")
            if confirm != 'y':
                raise Exception("Aborted.")

    def auto_reduce_n(self, fn, prompt, n):
        """Reduces n by half until the function succeeds."""
        try:
            return fn(prompt, n)
        except BatchSizeException as e:
            if n == 1:
                raise e
            return self.auto_reduce_n(fn, prompt, n // 2) + self.auto_reduce_n(fn, prompt, n // 2)

    def generate_text(self, prompt, n):
        if not isinstance(prompt, list):
            prompt = [prompt]
        if self.needs_confirmation:
            self.confirm_cost(
                prompt, n, self.config['gpt_config']['max_tokens'])
        batch_size = self.config['batch_size']
        assert batch_size == 1
        prompt_batches = [prompt[i:i + batch_size]
                          for i in range(0, len(prompt), batch_size)]
        if not self.disable_tqdm:
            print(
                f"[{self.config['name']}] Generating {len(prompt) * n} completions, split into {len(prompt_batches)} batches of (maximum) size {batch_size * n}")
        text = []
        for prompt_batch in tqdm(prompt_batches, disable=self.disable_tqdm):
            text += self.auto_reduce_n(self.__generate_text, prompt_batch, n)
        return text

    def log_probs(self, text, log_prob_range=None):
        raise NotImplementedError

    def __generate_text(self, prompt, n):
        """Generates text from the model."""
        config = self.config['gpt_config'].copy()
        config['n'] = n
        # Split prompts into prefixes and suffixes with the [APE] token (do not include the [APE] token in the suffix)
        prefix = prompt[0].split('[APE]')[0]
        suffix = prompt[0].split('[APE]')[1]
        response = None
        while response is None:
            try:
                response = openai.ChatCompletion.create(
                    **config, prompt=prefix, suffix=suffix)
            except Exception as e:
                print(e)
                print('Retrying...')
                time.sleep(5)
        # Remove suffix from the generated text
        texts = [response['choices'][i]['text'].replace(suffix, '') for i in range(len(response['choices']))]
        return texts
    

def gpt_get_estimated_cost(config, prompt, max_tokens):
    """Uses the current API costs/1000 tokens to estimate the cost of generating text from the model."""
    # Get rid of [APE] token
    prompt = prompt.replace('[APE]', '')
    # Get the number of tokens in the prompt
    n_prompt_tokens = len(prompt) // 4
    # Get the number of tokens in the generated text
    total_tokens = n_prompt_tokens + max_tokens
    engine = config['gpt_config']['model'].split('-')[1]
    costs_per_thousand = gpt_costs_per_thousand
    if engine not in costs_per_thousand:
        # Try as if it is a fine-tuned model
        engine = config['gpt_config']['model'].split(':')[0]
        costs_per_thousand = {
            'davinci': 0.1200,
            'curie': 0.0120,
            'babbage': 0.0024,
            'ada': 0.0016
        }
    price = costs_per_thousand[engine] * total_tokens / 1000
    return price


class BatchSizeException(Exception):
    pass




# ==============================================================================================================
# ==============================================================================================================
# ==============================================================================================================


class Claude_Forward(LLM):
    """Wrapper for Claude 2 and 3 from bedrock."""

    def __init__(self, config, version="Claude_3", needs_confirmation=False, disable_tqdm=True):
        """Initializes the model."""
        self.config = config
        self.needs_confirmation = needs_confirmation
        self.disable_tqdm = disable_tqdm
        print("!!! Evaluating with: ", version)

        #
        self.print_each_token_usage_flag = True

        #
        from transformers import GPT2TokenizerFast
        self.claude_tokenizer = GPT2TokenizerFast.from_pretrained('Xenova/claude-tokenizer')
        # tokenizer.encode('hello world')

        #
        self.claude_ver = version
        if version in ["Claude_2", "Claude_3", "Claude_3_5", "Claude_3_haiku", "Claude_3_opus"]:
            print("Claude ver: ", version)
            if version == "Claude_2":
                self.model_id = 'anthropic.claude-v2:1'
            elif version == "Claude_3":
                self.model_id = 'anthropic.claude-3-sonnet-20240229-v1:0'
            elif version == "Claude_3_haiku":
                self.model_id = 'anthropic.claude-3-haiku-20240307-v1:0'
            elif version == "Claude_3_opus":
                self.model_id = 'anthropic.claude-3-opus-20240229-v1:0'
            elif version == "Claude_3_5":
                self.model_id = 'anthropic.claude-3-5-sonnet-20240620-v1:0'        
        elif version == 'Llama_3_70B':
            self.model_id = 'meta.llama3-70b-instruct-v1:0'
        else:
            raise NotImplementedError

        #
        self.bedrock_runtime = boto3.client(service_name='bedrock-runtime', region_name='us-east-1')

        print(f"----- Model ID: {self.model_id} -----")

    # Method for calling Claude 2
    # Can consider passing max_tokens, temperature, top_k etc as parameters
    def call_claude2_batch(self, bedrock_runtime, model_id, prompt, max_tokens=1000):
        response_num_count = 0
        #
        response_list = []
        for i in range(len(prompt)):
            prompt_single = prompt[i].replace('[APE]', '').strip()
            #
            body = json.dumps(
                {
                    "prompt":f"Human: {prompt_single}\nAssistant:",
                    "max_tokens_to_sample":max_tokens,
                    "temperature":0.2,
                    "top_k":250,
                    "top_p":0.999,
                    "stop_sequences":[
                        "\n\nHuman:"
                    ],
                    "anthropic_version":"bedrock-2023-05-31"
                })
            accept = 'application/json'
            contentType = 'application/json'
            response = bedrock_runtime.invoke_model(body=body, modelId=model_id)
            #
            response_body = json.loads(response.get('body').read())
            text = response_body['content'][0]['text']
            response_num_count += len(self.claude_tokenizer.encode(text))
            response_list.append(text)
        
        return response_list, response_num_count

    # Method for calling Claude 2 and 3
    def call_claude3_batch(self, bedrock_runtime, model_id, system_prompt, messages, max_tokens=1000, sleep_time=5, threshold=15):
        response_num_count = 0
        #
        def call_with_info(body, model_id):
            try:
                response = self.bedrock_runtime.invoke_model(body=body, modelId=model_id)
                return response
            except Exception:
                return None
        
        # : Split the system prompt and user prompt apart.
        response_list = []
        for i in range(len(messages)):
            prompt_single = messages[i]
            prompt_single["content"] = prompt_single["content"].replace('[APE]', '').strip()
    
            #
            # : change to batch settings
            body = json.dumps(
                {
                    "messages": [prompt_single],
                    "max_tokens": max_tokens,
                    # "stop_sequences":[
                    #     "\n\nHuman:"
                    # ],
                    "anthropic_version":"bedrock-2023-05-31"
                })
    
            #
            call_counter = 0
            while True:
                response = call_with_info(body, model_id)
                if response is not None:
                    break
                else:
                    if self.print_each_token_usage_flag:
                        print(f"Sleeping on bedrock interface for {sleep_time} seconds...")
                    time.sleep(sleep_time)
                    if call_counter >= threshold:
                        if self.print_each_token_usage_flag:
                            print(f"Refreshing the bedrock interface...")
                        self.bedrock_runtime = boto3.client(service_name='bedrock-runtime', region_name='us-east-1')
                        call_counter = 0
                    else:
                        call_counter += 1
                    continue
            #
            response_body = json.loads(response.get('body').read())
            text = response_body['content'][0]['text']
            response_num_count += len(self.claude_tokenizer.encode(text))
            response_list.append(text)
        
        return response_list, response_num_count
    
    # Method for calling Claude 2 and 3
    def call_claude_3_5_batch(self, bedrock_runtime, model_id, system_prompt, messages, max_tokens=1000, sleep_time=5, threshold=15):
        response_num_count = 0
        #
        def call_with_info(body, model_id):
            try:
                response = self.bedrock_runtime.invoke_model(body=body, modelId=model_id)
                return response
            except Exception:
                return None
        
        # : Split the system prompt and user prompt apart.
        response_list = []
        for i in range(len(messages)):
            prompt_single = messages[i]
            prompt_single["content"] = prompt_single["content"].replace('[APE]', '').strip()
    
            #
            # : change to batch settings
            body = json.dumps(
                {
                    "anthropic_version": "bedrock-2023-05-31",
                    "max_tokens": max_tokens,
                    # "system": system_prompt,
                    "messages": [prompt_single]
                })
    
            #
            call_counter = 0
            while True:
                response = call_with_info(body, model_id)
                if response is not None:
                    break
                else:
                    if self.print_each_token_usage_flag:
                        print(f"Sleeping on bedrock interface for {sleep_time} seconds...")
                    time.sleep(sleep_time)
                    if call_counter >= threshold:
                        if self.print_each_token_usage_flag:
                            print(f"Refreshing the bedrock interface...")
                        self.bedrock_runtime = boto3.client(service_name='bedrock-runtime', region_name='us-east-1')
                        call_counter = 0
                    else:
                        call_counter += 1
                    continue
            #
            response_body = json.loads(response.get('body').read())
            text = response_body['content'][0]['text']
            response_num_count += len(self.claude_tokenizer.encode(text))
            response_list.append(text)
        
        return response_list, response_num_count

    # Method for calling Claude 2 and 3
    def call_Llama_3_batch(self, bedrock_runtime, model_id, system_prompt, messages, max_tokens=1000, sleep_time=5, threshold=15):
        response_num_count = 0
        #
        def call_with_info(body, model_id):
            try:
                response = self.bedrock_runtime.invoke_model(body=body, modelId=model_id)
                return response
            except Exception:
                return None
        
        # : Split the system prompt and user prompt apart.
        response_list = []
        for i in range(len(messages)):
            prompt_single = messages[i]
            prompt_single["content"] = prompt_single["content"].replace('[APE]', '').strip()
    
            #
            body = json.dumps(
                {
                    # "prompt": messages[0]['content'],
                    "prompt": f"<|begin_of_text|><|start_header_id|>user<|end_header_id|> {prompt_single['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>",
                    "max_gen_len": max_tokens,
                    "temperature": 0.2,
                    # "top_p": 0.9
                }
            )
    
            #
            call_counter = 0
            while True:
                response = call_with_info(body, model_id)
                if response is not None:
                    break
                else:
                    if self.print_each_token_usage_flag:
                        print(f"Sleeping on bedrock interface for {sleep_time} seconds...")
                    time.sleep(sleep_time)
                    if call_counter >= threshold:
                        if self.print_each_token_usage_flag:
                            print(f"Refreshing the bedrock interface...")
                        self.bedrock_runtime = boto3.client(service_name='bedrock-runtime', region_name='us-east-1')
                        call_counter = 0
                    else:
                        call_counter += 1
                    continue
            #
            response_body = json.loads(response.get('body').read())
            text = response_body['generation']
            response_num_count += len(self.claude_tokenizer.encode(text))
            response_list.append(text)
        
        return response_list, response_num_count

    #
    def generate_text(self, prompt, n=1, max_tokens=1000):
        start_time = time.time()
        if not isinstance(prompt, list):
            prompt = [prompt]
        if self.needs_confirmation:
            # self.confirm_cost(prompt, n, self.config['gpt_config']['max_tokens'])
            self.confirm_cost(prompt, n, max_tokens)
        batch_size = self.config['batch_size']
        prompt_batches = [prompt[i:i + batch_size]
                          for i in range(0, len(prompt), batch_size)]
        if not self.disable_tqdm:
            print(
                f"[{self.config['name']}] Generating {len(prompt) * n} completions, "
                f"split into {len(prompt_batches)} batches of size {batch_size * n}")
        text = []

        input_token_num, output_token_num = 0, 0

        for prompt_batch in tqdm(prompt_batches, disable=self.disable_tqdm):

            messages = []
            for prompt in prompt_batch:
                input_token_num += len(self.claude_tokenizer.encode(prompt))
                #
                user_message =  {"role": "user", "content": prompt}
                messages.append(user_message)


            ##############################
            if self.claude_ver in ["Claude_3", "Claude_3_haiku", "Claude_3_opus"]: 
                this_response_list, response_num_count = \
                    self.call_claude3_batch(self.bedrock_runtime, self.model_id, system_prompt=None, messages=messages, max_tokens=max_tokens)
            elif self.claude_ver == "Claude_3_5":
                this_response_list, response_num_count = \
                    self.call_claude_3_5_batch(self.bedrock_runtime, self.model_id, system_prompt=None, messages=messages, max_tokens=max_tokens)
            elif self.claude_ver == 'Llama_3_70B':
                this_response_list, response_num_count = \
                    self.call_Llama_3_batch(self.bedrock_runtime, self.model_id, system_prompt=None, messages=messages, max_tokens=max_tokens)
            else:
                this_response_list, response_num_count = \
                    self.call_claude2_batch(self.bedrock_runtime, self.model_id, prompt=messages, max_tokens=max_tokens)
            text += this_response_list
            #
            output_token_num += response_num_count
        
        ###
        if self.print_each_token_usage_flag:
            print(f"----- Claude input token num: {input_token_num}, output token num: {output_token_num} -----")
            print(f"----- This call elapsed time: {time.time() - start_time} -----")
            return text
        else:
            return text, input_token_num, output_token_num
    
    # ================================================================
    def confirm_cost(self, texts, n, max_tokens):
        total_estimated_cost = 0
        for text in texts:
            total_estimated_cost += gpt_get_estimated_cost(self.config, text, max_tokens) * n
        print(f"Estimated cost: ${total_estimated_cost:.2f}")
        
        # Ask the user to confirm in the command line
        if os.getenv("LLM_SKIP_CONFIRM") is None:
            confirm = input("Continue? (y/n) ")
            if confirm != 'y':
                raise Exception("Aborted.")

    def auto_reduce_n(self, fn, prompt, n):
        """Reduces n by half until the function succeeds."""
        try:
            return fn(prompt, n)
        except BatchSizeException as e:
            if n == 1:
                raise e
            return self.auto_reduce_n(fn, prompt, n // 2) + self.auto_reduce_n(fn, prompt, n // 2)

    def log_probs(self, text, log_prob_range=None):
        raise NotImplementedError


