import asyncio
import random
import aiohttp
import re
import math
from aiohttp import ClientSession
import nest_asyncio
import copy
import openai
from tqdm.auto import tqdm
import pickle
import pandas as pd
import numpy as np

import torch
from torch import nn
from datasets import load_dataset

nest_asyncio.apply()

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import re
from tqdm.auto import tqdm
import time
import copy
# import swifter
import math
import ast
import random
import scipy


import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

import openai

import asyncio
from aiohttp import ClientSession
from tqdm.asyncio import tqdm_asyncio

with open("PATH/DATA_FILE.pkl", "rb") as f:
    data = pickle.load(f)

prompt_data = pd.DataFrame(data)
prompt_data.head()

from datasets import load_dataset

def get_sampled_gsm8k(fraction=0.01, seed=42):
    dataset = load_dataset("gsm8k", "main", split="train")
    print(f"Full GSM8K train set: {len(dataset)}")

    sampled_dataset = dataset.train_test_split(test_size=fraction, seed=seed)['test']
    print(f"Sample subset: {len(sampled_dataset)}")

    return sampled_dataset

def get_full_gsm8k_test():
    dataset = load_dataset("gsm8k", "main", split="test")
    print(f"Full GSM8K test set: {len(dataset)}")
    return dataset


async def async_get_single_answer(prompt, session, api_key, temperature=0):
    url = "https://openrouter.ai/api/v1/chat/completions"

    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json"
    }
    payload = {
        "model": "qwen/qwen-2.5-7b-instruct",
        "messages": [
            {
                "role": "user",
                "content": prompt
            }
        ],
        "temperature": temperature
    }
    while True:
        try:
            await asyncio.sleep(random.uniform(0.5, 1.5))
            async with session.post(url, json=payload, headers=headers) as response:
                if response.status == 429:
                    print("Rate limit exceeded. Waiting before retrying...")
                    await asyncio.sleep(10)
                    continue
                response_json = await response.json()
                answer = response_json['choices'][0]['message']['content']
                full_match = re.search(r"<Answer>(.*?)</Answer>", answer, re.DOTALL)
                if full_match:
                    inner_text = full_match.group(1)
                    number_match = re.search(r"([0-9]+(?:\.[0-9]+)?)", inner_text)
                    if number_match:
                        extracted_answer = number_match.group(1)
                        return extracted_answer
                    else:
                        return "WRONG"
                else:
                    return "WRONG"
        except Exception as e:
            print(f"An error occurred: {e}")
            print("Retrying...")
            await asyncio.sleep(5)

async def async_get_batch_answers(prompts, Instruction, api_key, temperature=0, batch_size=64, delay_between_batches=10):
    answers = []
    print(prompts[0].format(Instruction))
    async with ClientSession() as session:
        for i in range(0, len(prompts), batch_size):
            batch = prompts[i:i + batch_size]
            tasks = [
                async_get_single_answer(prompt.format(Instruction), session, api_key, temperature)
                for prompt in batch
            ]
            batch_results = await asyncio.gather(*tasks)
            answers.extend(batch_results)
            if i + batch_size < len(prompts):
                print(f"Batch {i//batch_size + 1} completed, sleeping for {delay_between_batches} seconds...")
                await asyncio.sleep(delay_between_batches)
    return answers

def extract_groundtruth_number(text):
    match = re.search(r"####\s*([+-]?[0-9]+(?:\.[0-9]+)?)", text)
    if match:
        return float(match.group(1))
    return None

def compute_accuracy(predictions, dataset, tol=1e-5):
    correct = 0
    total = len(predictions)

    for pred_str, sample in zip(predictions, dataset):
        try:
            pred_value = float(pred_str.strip())
        except ValueError:
            print("Invalid numeric prediction:", pred_str)
            continue

        gt_value = extract_groundtruth_number(sample["answer"])
        if gt_value is None:
            print("Failed to extract ground truth from sample:", sample)
            continue

        if math.isclose(pred_value, gt_value, rel_tol=tol):
            correct += 1

    return correct / total if total > 0 else 0

class env:
  def __init__(self):
    self.sampled_dataset = get_sampled_gsm8k()
    self.prompts = [
        sample["question"].strip()+ "\nProvide your final result within the tags <Answer>FINAL_NUMERICAL_ANSWER</Answer>.\n" + "\n{}"
        for sample in self.sampled_dataset
    ]
    self.sample_count = 0
    self.test_dataset = get_full_gsm8k_test()
    self.test_prompts = [
        sample["question"].strip()+ "\nProvide your final result within the tags <Answer>FINAL_NUMERICAL_ANSWER</Answer>.\n" + "\n{}"
        for sample in self.test_dataset
    ]

    self.API_key = "[API_KEY]"

  def evaluator(self, Instruction):
    answer = asyncio.run(async_get_batch_answers(self.prompts, Instruction, self.API_key))
    accuracy = compute_accuracy(answer, self.sampled_dataset)
    print(f"Instruction: {Instruction}, score={int(round(accuracy*100))}")
    return int(round(accuracy*100))

  def random_sample_exemplars(self):
    sampled_samples = self.sampled_dataset.shuffle(seed=42 + self.sample_count).select(range(3))
    self.sample_count += 1
    return sampled_samples

  def test_env(self, Instruction):
    answer = asyncio.run(async_get_batch_answers(self.test_prompts, Instruction, self.API_key))
    accuracy = compute_accuracy(answer, self.test_dataset)
    print(f"Instruction: {Instruction}, score={int(round(accuracy*100))}")
    return accuracy*100

class Network(nn.Module):
  def __init__(self, input_dim, hidden_size=100, depth=1, init_params=None):
    super(Network, self).__init__()

    self.activate = nn.ReLU()
    self.layer_list = nn.ModuleList()
    self.layer_list.append(nn.Linear(input_dim, hidden_size))
    for i in range(depth-1):
      self.layer_list.append(nn.Linear(hidden_size, hidden_size))
    self.layer_list.append(nn.Linear(hidden_size, 1))

    if init_params is None:
      for i in range(len(self.layer_list)):
        torch.nn.init.normal_(self.layer_list[i].weight, mean=0, std=1.0)
        torch.nn.init.normal_(self.layer_list[i].bias, mean=0, std=1.0)
    else:
      for i in range(len(self.layer_list)):
        self.layer_list[i].weight.data = init_params[i*2]
        self.layer_list[i].bias.data = init_params[i*2+1]

  def forward(self, x):
    y = x
    for i in range(len(self.layer_list)-1):
      y = self.activate(self.layer_list[i](y))
    y = self.layer_list[-1](y)
    y = y.squeeze(-1)
    return y

class AdversarialBandit:
  def __init__(self, prompt_data):
    self.prompt_data = prompt_data
    self.model = Network(3072, 512, 1).to('cuda').double()
    self.lr = 1e-3
    self.epochs = 5000
    self.optimizer_fn = torch.optim.Adam
    self.optimizer = self.optimizer_fn(self.model.parameters(), lr=self.lr, weight_decay=1)
    self.loss_fn = nn.MSELoss()
    self.init_model_weight = copy.deepcopy(self.model.state_dict())
    self.eta = 50
    self.prompt_score_trace = []
    self.embedding_score_trace = []
    self.current_prompt = """Generate an instruction that is different from all the instructions <INS> above, and has a higher score than all the instructions <INS> above. The instruction should begin with <INS> and end with </INS>. The instruction should be concise, effective, and generally applicable to all problems above."""
    self.prompt_embedding = self.prompt_data.loc[self.prompt_data["Z2"] == self.current_prompt, "Embedding"].iloc[0]
    self.Stop_EXP3 = 0

  def restart_model(self):
    self.model.load_state_dict(copy.deepcopy(self.init_model_weight))
    self.optimizer = self.optimizer_fn(self.model.parameters(), lr=self.lr, weight_decay=1/self.embedding_score_trace.__len__())
    self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda epoch: 0.1 if epoch > 3000 else 1.0)

  def train_model(self):
    self.restart_model()
    self.model.to('cuda')

    embeddings = [item[0] for item in self.embedding_score_trace]
    scores = [(item[1])/100 for item in self.embedding_score_trace]

    embeddings_tensor = torch.stack([torch.tensor(e, dtype=torch.float64) for e in embeddings]).to('cuda')
    scores_tensor = torch.tensor(scores, dtype=torch.float64).to('cuda')

    dataset = torch.utils.data.TensorDataset(embeddings_tensor, scores_tensor)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)

    self.model.train()
    for epoch in tqdm(range(self.epochs), desc="Training Model"):
      for batch_embeddings, batch_scores in dataloader:
        self.optimizer.zero_grad()

        predictions = self.model(batch_embeddings)

        loss = self.loss_fn(predictions, batch_scores)
        loss.backward()

        self.optimizer.step()

      self.scheduler.step()

      if epoch % 100 == 0:
        print(f"Epoch {epoch}/{self.epochs}, Loss: {loss.item()}")

    print("Training complete.")

  def select(self, iteration):
    embeddings = torch.stack(
      [torch.tensor(e, dtype=torch.float64) for e in self.prompt_data['Embedding']]
    ).to('cuda')

    with torch.no_grad():
      scores = self.model(embeddings).squeeze(-1)

    shifted_scores = scores
    self.prompt_data['Score'] = shifted_scores.cpu().tolist()

    self.prompt_score_trace.append(self.prompt_data['Score'])
    score_tensor = torch.tensor(self.prompt_score_trace, dtype=torch.float64)
    exp_scores = torch.exp(self.eta * torch.sum(score_tensor, dim=0))
    total_score = torch.sum(exp_scores)
    normalized_exp_scores = exp_scores / total_score

    top_scores, top_indices = torch.topk(normalized_exp_scores, 10)
    print("Top 10 probabilities: ", top_scores)
    print("Indices of Top 10 probabilities: ", top_indices)

    if not torch.isnan(normalized_exp_scores).any() and not torch.isinf(normalized_exp_scores).any() and torch.sum(normalized_exp_scores) > 0:
      best_index = torch.multinomial(normalized_exp_scores, 1).item()
      self.best_index = best_index
    else:
      print("Encountered NaN or Inf in normalized_exp_scores. Stopping EXP3.")
      self.Stop_EXP3 = 1
      return

    print("The index of selected prompt is: %d"%(self.best_index))
    print("The selected prompt's cumulative score: ", torch.sum(score_tensor, dim=0)[self.best_index])
    best_prompt = self.prompt_data.iloc[self.best_index]['Z2']
    self.current_prompt = best_prompt
    self.prompt_embedding = self.prompt_data.iloc[self.best_index]['Embedding']

class AdversarialBandit_exemplars:
  def __init__(self, inital_point):
    self.model = Network(3072, 512, 1).to('cuda').float()
    self.lr = 1e-3
    self.epochs = 5000
    self.optimizer_fn = torch.optim.Adam
    self.optimizer = self.optimizer_fn(self.model.parameters(), lr=self.lr, weight_decay=1)
    self.loss_fn = nn.MSELoss()
    self.init_model_weight = copy.deepcopy(self.model.state_dict())
    self.eta = 25
    self.prompt_score_trace = []
    self.embedding_score_trace = []
    self.current_prompt = inital_point
    self.prompt_embedding = None
    self.Stop_EXP3 = 0
    self.NN_Para = []

  def restart_model(self):
    self.model.load_state_dict(copy.deepcopy(self.init_model_weight))
    self.optimizer = self.optimizer_fn(self.model.parameters(), lr=self.lr, weight_decay=1/self.embedding_score_trace.__len__())
    self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda epoch: 0.1 if epoch > 3000 else 1.0)

  def train_model(self):
    self.restart_model()
    self.model.to('cuda')

    embeddings = [item[0] for item in self.embedding_score_trace]
    scores = [(item[1])/100 for item in self.embedding_score_trace]

    embeddings_tensor = torch.stack([torch.tensor(e, dtype=torch.float32) for e in embeddings]).to('cuda')
    scores_tensor = torch.tensor(scores, dtype=torch.float32).to('cuda')

    dataset = torch.utils.data.TensorDataset(embeddings_tensor, scores_tensor)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)

    self.model.train()
    for epoch in tqdm(range(self.epochs), desc="Training Model"):
      for batch_embeddings, batch_scores in dataloader:
        self.optimizer.zero_grad()

        predictions = self.model(batch_embeddings)

        loss = self.loss_fn(predictions, batch_scores)
        loss.backward()

        self.optimizer.step()

      self.scheduler.step()

      if epoch % 100 == 0:
        print(f"Epoch {epoch}/{self.epochs}, Loss: {loss.item()}")

    self.NN_Para.append(copy.deepcopy(self.model.state_dict()))
    print("Training complete.")

  def select(self, Prompt_Embedding_List):
    embeddings = torch.stack(
      [torch.tensor(e, dtype=torch.float32) for e in Prompt_Embedding_List]
    ).to('cuda')

    num_models = len(self.NN_Para)
    num_embeddings = embeddings.size(0)

    self.historical_score = [torch.zeros(num_embeddings, dtype=torch.float32) for _ in range(num_models)]

    for model_idx in tqdm(range(num_models), desc="Evaluating Exemplars"):
      self.model.load_state_dict(self.NN_Para[model_idx])
      self.model.eval()

      with torch.no_grad():
        scores = self.model(embeddings).squeeze(-1)

      shifted_scores = scores
      self.historical_score[model_idx] = shifted_scores.cpu()

    score_tensor = torch.stack(self.historical_score)
    eta_sum_score = self.eta * torch.sum(score_tensor, dim=0)
    max_log_score = torch.max(eta_sum_score)
    exp_scores = torch.exp(eta_sum_score - max_log_score)
    total_score = torch.sum(exp_scores)
    normalized_exp_scores = exp_scores / total_score

    top_scores, top_indices = torch.topk(normalized_exp_scores, 10)
    print("Top 10 probabilities: ", top_scores)
    print("Indices of Top 10 probabilities: ", top_indices)

    if not torch.isnan(normalized_exp_scores).any() and not torch.isinf(normalized_exp_scores).any() and torch.sum(normalized_exp_scores) > 0:
      best_index = torch.multinomial(normalized_exp_scores, 1).item()
      self.best_index = best_index
    else:
      print("Encountered NaN or Inf in normalized_exp_scores. Stopping EXP3.")
      self.Stop_EXP3 = 1
      return

    print("The index of selected prompt is: %d"%(self.best_index))
    print("The selected prompt's cumulative score: ", torch.sum(score_tensor, dim=0)[self.best_index])
    self.prompt_embedding = Prompt_Embedding_List[self.best_index]

class Template_Padding:
  def __init__(self, env, prompt_data):
    self.exemplars_template = """text:
{}
score:
{}

"""
    self.exemplars_template_2 = """Problem:
Q: {}
A: <INS>

Ground truth Answer:
{}

"""
    self.env = env
    samples_evaluator = []
    initial_instruction = "Let's think step by step."
    samples_evaluator.append(self.env.evaluator(initial_instruction))
    self.samples = list(zip([initial_instruction], samples_evaluator))
    self.samples.sort(key=lambda x: x[1], reverse=True)

    self.input_exemplars_1 = ""
    self.input_exemplars_2 = ""
    self.Generate_Subset()
    self.subset_list = []
    self.selected_samples = None
    self.mask = None

    self.prompt_1 = """Your task is to generate the instruction <INS>. Below are some previous instructions with their scores. The score ranges from 0 to 100.
"""
    self.prompt_2 = """Below are some problems.
"""
    self.API_key = "[API_KEY]"
    self.result_lst = []

    self.steps=20

    self.bandit = AdversarialBandit(prompt_data=prompt_data)
    self.prompt_3 = self.bandit.current_prompt

    self.exemplars_bandit = AdversarialBandit_exemplars(self.input_exemplars_1)
    self.exemplars_bandit.prompt_embedding = self.Get_Single_Embedding(self.input_exemplars_1)

  def Get_Single_Embedding(self, text, target_dim=3072):
    openai.api_base = "https://api.openai.com/v1"
    openai.api_key = '[API_KEY]'
    text = text.replace("\n", " ")
    while True:
      try:
        time.sleep(random.uniform(0.5, 1.5))
        response = openai.Embedding.create(
          input=[text],
          model='text-embedding-3-large'
        )
        embedding = response['data'][0]['embedding']
        truncated_embedding = embedding[:target_dim]
        normalized_embedding = list(self.normalize_l2(truncated_embedding))
        return normalized_embedding
      except openai.error.RateLimitError:
        print("Rate limit exceeded. Waiting before retrying...")
        time.sleep(10)
      except Exception as e:
        print(f"An error occurred: {e}")
        print("Retrying...")
        time.sleep(5)

  def Add_Padding(self):
    self.selected_samples.sort(key=lambda x: x[1], reverse=False)

    selected_samples = self.selected_samples

    self.input_exemplars_1 = ""
    for sample in selected_samples:
        self.input_exemplars_1 += self.exemplars_template.format(sample[0], sample[1])

    self.input_exemplars_2 = ""
    sampled_exemplars = self.env.random_sample_exemplars()
    for sample in sampled_exemplars:
      self.input_exemplars_2 += self.exemplars_template_2.format(sample["question"], extract_groundtruth_number(sample["answer"]))

  def Padding(self, z3):
    input_exemplars = ""
    for sample in z3:
      input_exemplars += self.exemplars_template.format(sample[0], sample[1])
    return input_exemplars


  def Add_Sample(self, new_instruction, re=False):
    score = self.env.evaluator(new_instruction)
    self.samples.append((new_instruction, score))
    if re == True:
      return score

  def Add_All(self, Meta_Prompt):
    prompt_template =  self.prompt_1 + "\n{}" + self.prompt_2 + "\n{}" + Meta_Prompt
    return prompt_template.format(self.input_exemplars_1, self.input_exemplars_2)

  async def Prompt_LLM_Epoch(self, Meta_Prompt, epoch=7):
    async def fetch_one():
      max_retries = 3
      for attempt in range(max_retries):
        response = await openai.ChatCompletion.acreate(
          model="openai/gpt-4-turbo",
          messages=[{"role": "user", "content": self.Add_All(Meta_Prompt)}],
          temperature=1,
        )
        content = response.choices[0].message['content'].strip()
        match = re.search(r"<INS>(.*?)</INS>", content, re.DOTALL)
        if match:
            extracted_content = match.group(1).strip()
        else:
            extracted_content = content

        print(extracted_content)
        self.Add_Sample(extracted_content)
        return extracted_content

      print("[Warning] Failed after 3 attempts.")

    tasks = [fetch_one() for _ in range(epoch)]
    answers = await asyncio.gather(*tasks)
    return answers

  def Prompt_LLM_Epoch_Temp_0(self, Meta_Prompt):
    max_retries = 10
    print(self.Add_All(Meta_Prompt))
    for attempt in range(max_retries):
      try:
        response = openai.ChatCompletion.create(
          model="openai/gpt-4-turbo",
          messages=[{"role": "user", "content": self.Add_All(Meta_Prompt)}],
          temperature=0
        )
        content = response.choices[0].message['content'].strip()
        match = re.search(r"<INS>(.*?)</INS>", content, re.DOTALL)
        if match:
          extracted_content = match.group(1).strip()
        else:
          extracted_content = content

        score = self.Add_Sample(extracted_content, re=True)
        return score
      except Exception as e:
        print(f"[Retry {attempt+1}] Exception during LLM call: {e}")

    print("[Error] Failed after 10 attempts. Returning None.")
    return ""

  def generate_01_low_discrepancy_sequences(self, num_sequences, dim, num_ones):
    sampler = scipy.stats.qmc.Sobol(d=dim-1, scramble=True)
    samples = sampler.random(n=num_sequences)

    binary_sequences = []
    for sample in samples:
      threshold_indices = np.argsort(sample)[:num_ones-1]
      binary_sequence = np.zeros(dim)
      binary_sequence[threshold_indices] = 1
      binary_sequence[-1] = 1
      binary_sequences.append(binary_sequence)

    additional_mask = np.zeros(dim)
    additional_mask[-20:] = 1
    binary_sequences.append(additional_mask)

    return np.array(binary_sequences)

  def get_subset(self, samples):
    num_to_select = min(len(samples), 30)
    best_samples = samples[-num_to_select:]

    binary_sequences = self.generate_01_low_discrepancy_sequences(256, num_to_select, 20)

    full_length = len(samples)
    padded_binary_sequences = []
    for binary_seq in binary_sequences:
      padded_binary_sequence = np.zeros(full_length, dtype=int)
      padded_binary_sequence[-num_to_select:] = binary_seq
      padded_binary_sequences.append(padded_binary_sequence)

    self.mask = np.array(padded_binary_sequences)

    all_selected_samples = []
    for binary_seq in self.mask:
      stack = []
      for index, value in enumerate(binary_seq):
        if value == 1:
          stack.append(samples[index])
      all_selected_samples.append(stack)

    self.subset_list = all_selected_samples
    return

  def Generate_Subset(self):
    self.samples.sort(key=lambda x: x[1], reverse=False)
    samples_copy = copy.deepcopy(self.samples)
    sample_set = list(set(samples_copy))
    sample_set.sort(key=lambda x: x[1], reverse=False)
    print("The Best Performance, Instruction=%s, value=%f"%(sample_set[-1][0], sample_set[-1][1]))
    if len(sample_set) <= 20:
      self.subset_list = [sample_set]
      self.selected_samples = sample_set
      self.Add_Padding()
    else:
      self.get_subset(sample_set)

  def normalize_l2(self, x):
    x = np.array(x)
    norm = np.linalg.norm(x)
    if norm == 0:
      return x
    return x / norm

  async def async_get_single_embedding(self, text, session, target_dim=3072):
    # text = text.replace("\n", " ").strip()
    api_key = '[API_KEY]'
    headers = {
      "Authorization": f"Bearer {api_key}",
      "Content-Type": "application/json",
    }
    url = "https://api.openai.com/v1/embeddings"
    payload = {
      "input": text,
      "model": 'text-embedding-3-large'
    }
    while True:
      try:
        await asyncio.sleep(random.uniform(0.5, 1.5))
        async with session.post(url, json=payload, headers=headers) as response:
          if response.status == 429:
            print("Rate limit exceeded. Waiting before retrying...")
            await asyncio.sleep(10)
            continue
          response_json = await response.json()
          embedding = response_json['data'][0]['embedding']
          truncated_embedding = embedding[:target_dim]
          normalized_embedding = list(self.normalize_l2(truncated_embedding))
          return normalized_embedding
      except Exception as e:
        print(f"An error occurred: {e}")
        print("Retrying...")
        await asyncio.sleep(5)

  async def get_batch_embedding_z3_async(self):
    self.Prompt_Embedding_List = []
    tasks = []
    async with ClientSession() as session:
      for i in tqdm(range(0, len(self.subset_list), 128), desc="Batch Embedding ... "):
        batch = self.subset_list[i:i + 128]
        for prompt in batch:
          tasks.append(self.async_get_single_embedding(self.Padding(prompt), session))
        batch_results = await asyncio.gather(*tasks)
        self.Prompt_Embedding_List.extend(batch_results)
        tasks = []
        if i + 128 < len(self.subset_list):
          print("Batch completed, sleeping for 15 seconds...")
          await asyncio.sleep(15)

  def Get_Batch_Embedding_z3(self):
    asyncio.run(self.get_batch_embedding_z3_async())

  def Prompt_LLM_Step(self):
    for i in tqdm(range(self.steps), desc="Prompting LLM " + str(self.steps) + " steps: " + "gpt-4-turbo"):
      openai.api_key = self.API_key
      openai.api_base = "https://openrouter.ai/api/v1"
      reward = self.Prompt_LLM_Epoch_Temp_0(self.prompt_3)
      asyncio.run(self.Prompt_LLM_Epoch(self.prompt_3))
      print("Step %d: The Best Performance, INSTRUCTION=%s, score=%f"%(i + 1, self.samples[0][0], self.samples[0][1]))
      sorted_samples = sorted(self.samples, key=lambda x: x[1], reverse=True)
      self.result_lst.append(sorted_samples[0][1])
      print(self.result_lst)
      self.bandit.embedding_score_trace.append([self.bandit.prompt_embedding, reward])

      async def _inner():
        async with aiohttp.ClientSession() as session:
          new_embedding_exemplars = await self.async_get_single_embedding(self.input_exemplars_1, session)
          return new_embedding_exemplars
      new_embedding_exemplars = asyncio.run(_inner())
      self.exemplars_bandit.embedding_score_trace.append([new_embedding_exemplars, reward])
      if not self.bandit.Stop_EXP3:
        self.bandit.train_model()
        self.bandit.select(i)
      if not self.exemplars_bandit.Stop_EXP3:
        self.exemplars_bandit.train_model()
      self.Generate_Subset()
      if len(self.subset_list) != 1:
        self.Get_Batch_Embedding_z3()
        self.exemplars_bandit.select(self.Prompt_Embedding_List)
        exemplars_index = self.exemplars_bandit.best_index
        self.selected_samples = self.subset_list[exemplars_index]
        self.Add_Padding()
        print(self.mask[exemplars_index])
      self.prompt_3 = self.bandit.current_prompt
    print(f"Final Instruction: {self.samples[-1][0]}")
    return reward

new_env = env()
new_Template_Padding = Template_Padding(new_env, copy.deepcopy(prompt_data))
new_Template_Padding.Prompt_LLM_Step()