import numpy as np
import pandas as pd
import re
from tqdm.auto import tqdm
import time
import copy
import ast
import random
from scipy.stats import bernoulli

import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset

import openai
import asyncio
from aiohttp import ClientSession
from tqdm.asyncio import tqdm as async_tqdm
import nest_asyncio
nest_asyncio.apply()

Meta_Prompt_data = pd.read_csv("PATH/DATA_FILE.csv")
tqdm.pandas()
Meta_Prompt_data['Z1-Z2-Embedding'] = Meta_Prompt_data['Z1-Z2-Embedding'].progress_apply(ast.literal_eval)
Meta_Prompt_data.head()

def set_random_seed(seed=0):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  if torch.cuda.is_available():
      torch.cuda.manual_seed(seed)

  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False

class Multi_Armed_Bandit:
  def __init__(self, delta):
    self.delta = delta
    self.optima = 0.5 + self.delta/2
    self.suboptima = self.optima - self.delta
    self.button_color = ["blue", "green", "red", "yellow"]
    self.bandit = self.bernoulli()
    self.sample_list = []

  def bernoulli(self):
    return {color: bernoulli(self.optima if i == 0 else self.suboptima) for i, color in enumerate(self.button_color)}

  def sample(self, color):
    sample = self.bandit[color].rvs()
    self.sample_list.append((color, sample))
    return sample

  def count_and_mean(self):
    color_stats = {color: [0, 0] for color in self.button_color}
    for color, reward in self.sample_list:
        color_stats[color][0] += 1
        color_stats[color][1] += reward

    for color in color_stats:
        if color_stats[color][0] > 0:
            color_stats[color][1] /= color_stats[color][0]

    return color_stats

class Network(nn.Module):
  def __init__(self, input_dim, hidden_size=100, depth=1, init_params=None):
    super(Network, self).__init__()

    self.activate = nn.ReLU()
    self.layer_list = nn.ModuleList()
    self.layer_list.append(nn.Linear(input_dim, hidden_size))
    for i in range(depth-1):
      self.layer_list.append(nn.Linear(hidden_size, hidden_size))
    self.layer_list.append(nn.Linear(hidden_size, 1))

    if init_params is None:
      for i in range(len(self.layer_list)):
        torch.nn.init.normal_(self.layer_list[i].weight, mean=0, std=1.0)
        torch.nn.init.normal_(self.layer_list[i].bias, mean=0, std=1.0)
    else:
      for i in range(len(self.layer_list)):
        self.layer_list[i].weight.data = init_params[i*2]
        self.layer_list[i].bias.data = init_params[i*2+1]

  def forward(self, x):
    y = x
    for i in range(len(self.layer_list)-1):
      y = self.activate(self.layer_list[i](y))
    y = self.layer_list[-1](y)
    y = y.squeeze(-1)
    return y

class Template_Padding:
  def __init__(self, delta):
    self.Multi_Armed_Bandit = Multi_Armed_Bandit(delta)
    self.SYSTEM_template = """{}You must provide your final answer within the tags <Answer>DIST</Answer> where DIST is the distribution in the format specified above.\n"""
    self.SYSTEM = ""
    self.exemplar_template_not_none = """{} button: pressed {} times with average reward {}\n"""
    self.exemplar_template_none = """{} button: pressed {} times\n"""
    self.input_exemplars = None
    self.USER = """So far you have played {} times with your past choices and rewards summarized as follows:

{}
Which button will you choose next? Remember, YOU MUST provide your final answer within the tags <Answer>DIST</Answer> where DIST is formatted like "blue:a,green:b,red:c,yellow:d"."""
    self.Input_USER = ""
    self.padding_exemplar()

  def padding_exemplar(self, start_index=0, step=0):
    summary_dict = self.Multi_Armed_Bandit.count_and_mean()
    input_exemplars = ""
    for i in range(len(self.Multi_Armed_Bandit.button_color)):
      color = self.Multi_Armed_Bandit.button_color[(i + start_index)%len(self.Multi_Armed_Bandit.button_color)]
      if summary_dict[color][0] == 0:
        input_exemplars += self.exemplar_template_none.format(color, 0)
      else:
        input_exemplars += self.exemplar_template_not_none.format(color, summary_dict[color][0], summary_dict[color][1])
    self.input_exemplars = input_exemplars
    self.Input_USER = self.USER.format(step, self.input_exemplars)

  def padding_shuffle(self, start_index=0, step=0):
    summary_dict = self.Multi_Armed_Bandit.count_and_mean()
    input_exemplars = ""
    for i in range(len(self.Multi_Armed_Bandit.button_color)):
      color = self.Multi_Armed_Bandit.button_color[(i + start_index)%len(self.Multi_Armed_Bandit.button_color)]
      if summary_dict[color][0] == 0:
        input_exemplars += self.exemplar_template_none.format(color, 0)
      else:
        input_exemplars += self.exemplar_template_not_none.format(color, summary_dict[color][0], summary_dict[color][1])
    return self.USER.format(step, input_exemplars)

  def padding_system(self, z1_z2):
    self.SYSTEM = self.SYSTEM_template.format(z1_z2)

class RequestLLM:
  def __init__(self, delta, model='gpt-4-turbo', steps=100):
    self.API_key = '[API_KEY]'
    self.model = model
    self.steps = steps
    self.Template = Template_Padding(delta)
    self.stable_score = 0
    self.Suffix_Failure_Frequency = []
    self.MinFrac_t = []
    self.Time_Averaged_Reward = []
    self.Time_Mu = []
    self.Time_Averaged_Mu = []
    self.cumulative_reward_sum = 0
    self.cumulative_mu_sum = 0
    self.fract_best_count = 0
    self.frac_best = []
    self.regret = []


  def Prompt_LLM_Epoch(self, step):
    openai.api_key = self.API_key
    valid_distribution = False
    print("[SYSTEM]" + self.Template.SYSTEM)
    print("[USER]" + self.Template.Input_USER + "\n")

    while not valid_distribution:
      response = openai.ChatCompletion.create(
        model=self.model,
        messages=[
          {"role": "system", "content": self.Template.SYSTEM},
          {"role": "user", "content": self.Template.Input_USER}
        ],
        temperature=0
      )

      response_content = response.choices[0].message['content'].strip()
      print(response_content)

      matches = re.findall(r'\b(blue|green|red|yellow):([0-9]+(?:\.[0-9]*)?)\b', response_content)

      colors_order = ['blue', 'green', 'red', 'yellow']
      color_prob_dict = {color: 0.0 for color in colors_order}

      for color, probability in matches:
        color_prob_dict[color] = float(probability)

      colors = list(color_prob_dict.keys())
      probabilities = list(color_prob_dict.values())

      total_prob_sum = sum(probabilities)
      if total_prob_sum > 0:
        probabilities = [p / total_prob_sum for p in probabilities]
        color_prob_dict = {color: prob / total_prob_sum for color, prob in color_prob_dict.items()}

      if np.isclose(sum(probabilities), 1.0) and len(probabilities) == 4:
        valid_distribution = True

        index = np.random.choice(len(probabilities), p=probabilities)

        color_stats = self.Template.Multi_Armed_Bandit.count_and_mean()
        expected_reward = sum(color_prob_dict[color] * color_stats[color][1] for color in self.Template.Multi_Armed_Bandit.button_color)

        selected_color = colors[index]
        direct_reward = self.Template.Multi_Armed_Bandit.sample(selected_color)
        print(f"The selected button : {selected_color}, reward: {direct_reward}")

        self.Time_Mu.append(self.Template.Multi_Armed_Bandit.optima if selected_color == "blue" else self.Template.Multi_Armed_Bandit.suboptima)

        self.Template.padding_exemplar(step + 1)
        self.cumulative_reward_sum += direct_reward
        self.cumulative_mu_sum += self.Template.Multi_Armed_Bandit.optima if selected_color == "blue" else self.Template.Multi_Armed_Bandit.suboptima
        self.fract_best_count += 1 if selected_color == "blue" else 0
        self.regret.append(0 if selected_color == "blue" else 0.5)

        time_averaged_reward = self.cumulative_reward_sum / (step + 1)
        time_averaged_mu = self.cumulative_mu_sum / (step + 1)
        print(f"Averaged reward : {time_averaged_reward}")
        print(f"Averaged mu : {time_averaged_mu}")
        accumulative_regret = sum(np.array(self.regret))
        print(f"Accumulative Regret: {accumulative_regret}")
        self.Time_Averaged_Reward.append(time_averaged_reward)
        self.Time_Averaged_Mu.append(time_averaged_mu)
        self.frac_best.append(self.fract_best_count / (step + 1))

        print(f"\nExpected reward based on distribution: {expected_reward}\n")

        return expected_reward
      else:
        print("Invalid distribution. Please try again.")

  def calculate_median_reward(self):
    rewards = [reward for _, reward in self.Template.Multi_Armed_Bandit.sample_list]
    print(f"MedianReward : {np.median(rewards)}")
    return np.median(rewards)

  def calculate_suff_fail_freq(self, optimal_button):
    failures = sum(1 for i, (color, _) in enumerate(self.Template.Multi_Armed_Bandit.sample_list[:self.steps//2]) if color != optimal_button)
    print(f"\nSuffFailFreq(T/2) : {failures / (self.steps / 2)}")
    self.Suffix_Failure_Frequency.append(failures / (self.steps / 2))

  def calculate_k_min_frac(self):
    color_counts = {color: 0 for color in self.Template.Multi_Armed_Bandit.button_color}
    for color, _ in self.Template.Multi_Armed_Bandit.sample_list:
        color_counts[color] += 1
    min_count = min(color_counts.values())
    min_frac = min_count / len(self.Template.Multi_Armed_Bandit.sample_list)
    print(f"K*MinFrac : {len(self.Template.Multi_Armed_Bandit.button_color) * min_frac}")
    self.MinFrac_t.append(len(self.Template.Multi_Armed_Bandit.button_color) * min_frac)

  def calculate_greedy_frac(self):
    color_rewards = {color: 0 for color in self.Template.Multi_Armed_Bandit.button_color}
    color_counts = {color: 0 for color in self.Template.Multi_Armed_Bandit.button_color}

    for color, reward in self.Template.Multi_Armed_Bandit.sample_list:
        color_rewards[color] += reward
        color_counts[color] += 1

    avg_rewards = {color: (color_rewards[color] / color_counts[color] if color_counts[color] > 0 else 0)
                    for color in self.Template.Multi_Armed_Bandit.button_color}
    best_color = max(avg_rewards, key=avg_rewards.get)

    greedy_count = sum(1 for color, _ in self.Template.Multi_Armed_Bandit.sample_list if color == best_color)
    print(f"GreedyFrac : {greedy_count / len(self.Template.Multi_Armed_Bandit.sample_list)}")
    return greedy_count / len(self.Template.Multi_Armed_Bandit.sample_list)

class PromptEmbedding:
  def __init__(self, z1_z2_prompt_data, delta, Optimizer_Model='gpt-4-turbo', Embedding_model='text-embedding-3-large'):
    self.API_key = '[API_KEY]'
    self.z1_z2_prompt = z1_z2_prompt_data
    self.Embedding_model = Embedding_model
    self.RequestLLM = RequestLLM(delta, model='gpt-4-turbo', steps=1)
    self.Prompt_Embedding_List = []


  def normalize_l2(self, x):
    x = np.array(x)
    norm = np.linalg.norm(x)
    if norm == 0:
      return x
    return x / norm

  def Get_Single_Embedding(self, text, target_dim=3072):
    openai.api_key = self.API_key
    text = text.replace("\n", " ")

    while True:
      try:
        time.sleep(random.uniform(0.5, 1.5))

        response = openai.Embedding.create(
          input=[text],
          model=self.Embedding_model
        )
        embedding = response['data'][0]['embedding']

        truncated_embedding = embedding[:target_dim]

        normalized_embedding = list(self.normalize_l2(truncated_embedding))

        return normalized_embedding

      except openai.error.RateLimitError:
        print("Rate limit exceeded. Waiting before retrying...")
        time.sleep(10)

      except Exception as e:
        print(f"An error occurred: {e}")
        print("Retrying...")
        time.sleep(5)

  def Get_Batch_Embedding(self, exemplars_embedding, column='PromptEmbedding'):
    openai.api_key = self.API_key
    tqdm.pandas()
    self.z1_z2_prompt[column] = self.z1_z2_prompt['Z1-Z2-Embedding'].progress_apply(
      lambda x: x
    )

  async def fetch_embedding(self, session, text, target_dim=3072):
    url = "https://api.openai.com/v1/embeddings"
    headers = {
      "Authorization": f"Bearer {self.API_key}",
      "Content-Type": "application/json"
    }
    data = {
      "input": [text],
      "model": self.Embedding_model
    }
    for _ in range(5):
      try:
        async with session.post(url, json=data, headers=headers) as response:
          if response.status == 429:
            print("Rate limit exceeded. Retrying...")
            await asyncio.sleep(10)
            continue
          if response.status != 200:
            print(f"Error: {await response.text()}")
            continue

          result = await response.json()
          embedding = result['data'][0]['embedding']
          truncated_embedding = embedding[:target_dim]
          return list(self.normalize_l2(truncated_embedding))
      except Exception as e:
        print(f"Error occurred: {e}. Retrying...")
        await asyncio.sleep(5)

    raise RuntimeError("Failed to fetch embedding after multiple retries.")

  async def get_batch_embedding_async(self, step, column='PromptEmbedding', batch_size=4):
    self.Prompt_Embedding_List = []
    texts = [self.RequestLLM.Template.padding_shuffle(i, step) for i in range(batch_size)]

    async with ClientSession() as session:
      tasks = [self.fetch_embedding(session, text) for text in texts]
      self.Prompt_Embedding_List = await async_tqdm.gather(*tasks, desc="Embedding prompts", unit="prompt")

  def Get_Batch_Embedding_Z3(self, step, column='PromptEmbedding'):
    asyncio.run(self.get_batch_embedding_async(step, column))

  def forward(self, i, z1_z2, z1, z2, z1_z2_embedding):
    exemplars_embedding = self.Get_Single_Embedding(self.RequestLLM.Template.Input_USER)
    self.Get_Batch_Embedding(exemplars_embedding)
    self.Get_Batch_Embedding_Z3(i)
    self.RequestLLM.Template.padding_system(z1_z2)
    reward = self.RequestLLM.Prompt_LLM_Epoch(i)
    completed_embedding = z1_z2_embedding + exemplars_embedding
    return [completed_embedding, reward]

class AdversarialBandit:
  def __init__(self, z1_z2_prompt_data, delta, input_dim=6144, steps=100, NumberofK=10201):
    self.z1_z2_prompt_data = z1_z2_prompt_data
    self.input_dim = input_dim
    self.steps = steps
    self.NumberofK = NumberofK
    self.device = torch.device("cuda")

    self.model = Network(6144, 1536, 1).to('cuda').double()
    self.model_z3 = Network(3072, 512, 1).to('cuda').double()


    self.best_result_list = []

    self.lr = 1e-2
    self.epochs = 6000
    self.optimizer_fn = torch.optim.Adam
    self.optimizer_fn_z3 = torch.optim.Adam
    self.optimizer = self.optimizer_fn(self.model.parameters(), lr=self.lr, weight_decay=1)
    self.optimizer_z3 = self.optimizer_fn_z3(self.model_z3.parameters(), lr=self.lr, weight_decay=1)
    self.loss_fn = nn.MSELoss()
    self.loss_fn_z3 = nn.MSELoss()

    self.init_model_weight = copy.deepcopy(self.model.state_dict())
    self.init_model_weight_z3 = copy.deepcopy(self.model_z3.state_dict())


    self.eta = 10
    self.eta_z3 = 10
    self.prompt_score_trace = []
    self.prompt_score_trace_z3 = []
    self.embedding_score_trace = []



    self.z1_z2 = """You are a bandit algorithm in a room with 4 buttons labeled blue, green, red, yellow. Each button is associated with a Bernoulli distribution with a fixed but unknown mean; the means for the buttons could be different. For each button, when you press it, you will get a reward that is sampled from the button's associated distribution. You have 100 time steps and, on each time step, you can choose any button and receive the reward. Your goal is to maximize the total reward over the 100 time steps.
At each time step, I will show you a summary of your past choices and rewards. Then you must make the next choice. You may output a distribution over the 4 buttons formatted EXACTLY like "blue:a,green:b,red:c,yellow:d". Let's think step by step to make sure we make a good choice.\n"""

    self.z1 = """You are a bandit algorithm in a room with 4 buttons labeled blue, green, red, yellow. Each button is associated with a Bernoulli distribution with a fixed but unknown mean; the means for the buttons could be different. For each button, when you press it, you will get a reward that is sampled from the button's associated distribution. You have 100 time steps and, on each time step, you can choose any button and receive the reward. Your goal is to maximize the total reward over the 100 time steps."""
    self.z2 = """At each time step, I will show you a summary of your past choices and rewards. Then you must make the next choice. You may output a distribution over the 4 buttons formatted EXACTLY like "blue:a,green:b,red:c,yellow:d". Let's think step by step to make sure we make a good choice."""

    self.prompt_embedding = PromptEmbedding(z1_z2_prompt_data, delta, Optimizer_Model='gpt-4-turbo', Embedding_model='text-embedding-3-large')

    self.z1_z2_Embedding = self.prompt_embedding.Get_Single_Embedding(self.z1) + self.prompt_embedding.Get_Single_Embedding(self.z2)

    self.scheduler = None
    self.scheduler_z3 = None

    self.Stop_EXP3 = 0
    self.Stop_EXP3_z3 = 0

    self.best_z3_index = None


  def restart_model(self):
    self.model.load_state_dict(copy.deepcopy(self.init_model_weight))
    self.optimizer = self.optimizer_fn(self.model.parameters(), lr=self.lr, weight_decay=1/self.embedding_score_trace.__len__())
    self.scheduler = torch.optim.lr_scheduler.LambdaLR(
      self.optimizer,
      lr_lambda=lambda epoch: 0.1 ** (epoch // 2000)
    )

  def restart_model_z3(self):
    self.model_z3.load_state_dict(copy.deepcopy(self.init_model_weight_z3))
    self.optimizer_z3 = self.optimizer_fn_z3(self.model_z3.parameters(), lr=self.lr, weight_decay=1/self.embedding_score_trace.__len__())
    self.scheduler_z3 = torch.optim.lr_scheduler.LambdaLR(
      self.optimizer_z3,
      lr_lambda=lambda epoch: 0.1 ** (epoch // 2000)
    )


  def train_model(self, batch_size=64):
    self.restart_model()
    self.model.to('cuda')

    embeddings = [item[0][:6144] for item in self.embedding_score_trace]
    scores = [item[1] for item in self.embedding_score_trace]

    embeddings_tensor = torch.stack([torch.tensor(e, dtype=torch.float64) for e in embeddings])
    scores_tensor = torch.tensor(scores, dtype=torch.float64)

    dataset = TensorDataset(embeddings_tensor, scores_tensor)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    self.model.train()
    for epoch in tqdm(range(self.epochs), desc="Training Model"):
      epoch_loss = 0
      for batch_embeddings, batch_scores in dataloader:
        batch_embeddings = batch_embeddings.to('cuda')
        batch_scores = batch_scores.to('cuda')

        self.optimizer.zero_grad()

        predictions = self.model(batch_embeddings)
        loss = self.loss_fn(predictions, batch_scores)

        loss.backward()
        self.optimizer.step()

        epoch_loss += loss.item()

      self.scheduler.step()

      if epoch % 100 == 0:
        print(f"Epoch {epoch}/{self.epochs}, Loss: {epoch_loss / len(dataloader)}")

    print("Training complete.")

  def train_model_z3(self, batch_size=64):
    self.restart_model_z3()
    self.model_z3.to('cuda')

    embeddings = [item[0][-3072:] for item in self.embedding_score_trace]
    scores = [item[1] for item in self.embedding_score_trace]

    embeddings_tensor = torch.stack([torch.tensor(e, dtype=torch.float64) for e in embeddings])
    scores_tensor = torch.tensor(scores, dtype=torch.float64)

    dataset = TensorDataset(embeddings_tensor, scores_tensor)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    self.model_z3.train()
    for epoch in tqdm(range(self.epochs), desc="Training Z3 Model"):
      epoch_loss = 0
      for batch_embeddings, batch_scores in dataloader:
        batch_embeddings = batch_embeddings.to('cuda')
        batch_scores = batch_scores.to('cuda')

        self.optimizer_z3.zero_grad()

        predictions = self.model_z3(batch_embeddings)
        loss = self.loss_fn_z3(predictions, batch_scores)

        loss.backward()
        self.optimizer_z3.step()

        epoch_loss += loss.item()

      self.scheduler_z3.step()

      if epoch % 100 == 0:
        print(f"Epoch {epoch}/{self.epochs}, Loss: {epoch_loss / len(dataloader)}")

    print("Training Z3 model complete.")


  def select(self, iteration):

    self.z1_z2_prompt_data['Score'] = self.prompt_embedding.z1_z2_prompt['PromptEmbedding'].apply(
        lambda x: self.model(torch.tensor(x, dtype=torch.float64).to('cuda')).item()
    )
    self.prompt_score_trace.append(self.z1_z2_prompt_data['Score'].tolist())
    score_tensor = torch.tensor(self.prompt_score_trace, dtype=torch.float64)

    log_exp_scores = self.eta * torch.sum(score_tensor, dim=0)
    max_log_score = torch.max(log_exp_scores)
    safe_log_exp_scores = log_exp_scores - max_log_score
    exp_scores = torch.exp(safe_log_exp_scores)

    total_score = torch.sum(exp_scores)
    normalized_exp_scores = exp_scores / total_score

    top_scores, top_indices = torch.topk(normalized_exp_scores, 10)
    print("Top 10 scores: ", top_scores)
    print("Indices of top 10 scores: ", top_indices)


    if not torch.isnan(normalized_exp_scores).any() and not torch.isinf(normalized_exp_scores).any():
      best_index = torch.multinomial(normalized_exp_scores, 1).item()
      self.best_index = best_index
    elif not torch.isnan(normalized_exp_scores).any() or not torch.isinf(normalized_exp_scores).any():
      self.Stop_EXP3 = 1

    print("The index of selected prompt is: %d"%(self.best_index))
    print("The selected prompt's cumulative score: ", torch.sum(score_tensor, dim=0)[self.best_index])
    best_prompt = self.z1_z2_prompt_data.iloc[self.best_index]['Merged_Content']
    self.z1_z2 = best_prompt
    self.z1 = self.z1_z2_prompt_data.iloc[self.best_index]['Z1']
    self.z2 = self.z1_z2_prompt_data.iloc[self.best_index]['Z2']
    self.z1_z2_Embedding = self.z1_z2_prompt_data.iloc[self.best_index]['Z1-Z2-Embedding']

  def select_z3(self, iteration):
    prompt_embeddings = self.prompt_embedding.Prompt_Embedding_List
    num_embeddings = len(prompt_embeddings)

    scores = []
    for embedding in prompt_embeddings:
      embedding_tensor = torch.tensor(embedding, dtype=torch.float64).to('cuda')
      score = self.model_z3(embedding_tensor.unsqueeze(0)).item()
      scores.append(score)

    self.prompt_score_trace_z3.append(scores)

    score_tensor = torch.tensor(self.prompt_score_trace_z3, dtype=torch.float64)

    log_exp_scores = self.eta_z3 * torch.sum(score_tensor, dim=0)
    max_log_score = torch.max(log_exp_scores)
    safe_log_exp_scores = log_exp_scores - max_log_score
    exp_scores = torch.exp(safe_log_exp_scores)
    total_score = torch.sum(exp_scores)
    normalized_exp_scores = exp_scores / total_score

    top_scores, top_indices = torch.topk(normalized_exp_scores, 4)
    print("5 scores: ", top_scores)
    print("Indices 4 scores: ", top_indices)

    if not torch.isnan(normalized_exp_scores).any() and not torch.isinf(normalized_exp_scores).any():
      best_index = torch.multinomial(normalized_exp_scores, 1).item()
      self.best_z3_index = best_index
      print(f"The selected order: {self.best_z3_index}")
    elif not torch.isnan(normalized_exp_scores).any() or not torch.isinf(normalized_exp_scores).any():
      self.Stop_EXP3_z3 = 1

  def optimization(self):
    for i in tqdm(range(self.steps), desc="Optimizing the Meta-Prompt " + str(self.steps) + " step"):
      forward_results = self.prompt_embedding.forward(i, self.z1_z2, self.z1, self.z2, self.z1_z2_Embedding)
      self.embedding_score_trace.append(forward_results)
      if not self.Stop_EXP3:
        self.train_model()
        self.select(i)
      if not self.Stop_EXP3_z3:
        self.train_model_z3()
        self.select_z3(i)
      self.prompt_embedding.RequestLLM.Template.padding_exemplar(self.best_z3_index, i + 1)
      self.prompt_embedding.RequestLLM.Template.padding_system(self.z1_z2)

Suffix_Failure_Frequency = []
MinFrac_t = []
Time_Averaged_Reward = []
Time_Mu = []
Time_Average_Mu = []

Time_Averaged_Reward = []
Regret = []

for i in range(6):
  if i < 3:
    set_random_seed(999)
  else:
    set_random_seed(1000)
  z1_z2_prompt_data = copy.deepcopy(Meta_Prompt_data)
  new_bandit = AdversarialBandit(z1_z2_prompt_data, 0.5)
  new_bandit.optimization()
  Time_Averaged_Reward.append(copy.deepcopy(new_bandit.prompt_embedding.RequestLLM.Time_Averaged_Reward))
  Regret.append(new_bandit.prompt_embedding.RequestLLM.regret)
  print(Time_Averaged_Reward)
  print(Regret)