import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import re
from tqdm.auto import tqdm
import time
import copy
import ast
import random
import scipy


import torch
from torch import nn
import openai
import asyncio
from aiohttp import ClientSession

import nest_asyncio
nest_asyncio.apply()

Meta_Prompt_data = pd.read_csv("PATH/DATA_FILE.csv")
tqdm.pandas()
Meta_Prompt_data['Z1-Z2-Embedding'] = Meta_Prompt_data['Z1-Z2-Embedding'].progress_apply(ast.literal_eval)
Meta_Prompt_data.head()

class OPRO4LinearRegressionData:
  def __init__(self, w, b):

    np.random.seed(3407)
    x = np.linspace(-1, 1, 50).reshape(-1, 1)
    # noise = np.random.randn(50, 1) # (36, -1)
    noise = np.random.randn(50, 1)/10
    y = x * w + b + noise
    data=np.hstack((x, y))

    self.true_parameter = torch.tensor([w, b])
    self.whole_data = torch.tensor(data)
    self.data = torch.stack([self.whole_data[:, 0], torch.ones_like(self.whole_data[:, 0])], dim=1)
    self.label = self.whole_data[:, 1]

  def evaluator(self, generated_w, generated_b):
    inference_parameter = torch.tensor([generated_w, generated_b], dtype=torch.double)
    return torch.mean((torch.matmul(self.data, inference_parameter) - self.label)**2)

  def mat_for_data(self):
    plt.scatter(self.whole_data[:, 0], self.whole_data[:, 1], color='blue', label='x1 vs y')
    plt.xlabel('x1')
    plt.ylabel('y')
    plt.legend()
    plt.title('Random 2D Linear Regression Data')
    plt.show()

New_Linear_Regression = OPRO4LinearRegressionData(2, 30)
New_Linear_Regression.mat_for_data()

class Network(nn.Module):
  def __init__(self, input_dim, hidden_size=100, depth=1, init_params=None):
    super(Network, self).__init__()

    self.activate = nn.ReLU()
    self.layer_list = nn.ModuleList()
    self.layer_list.append(nn.Linear(input_dim, hidden_size))
    for i in range(depth-1):
      self.layer_list.append(nn.Linear(hidden_size, hidden_size))
    self.layer_list.append(nn.Linear(hidden_size, 1))

    if init_params is None:
      for i in range(len(self.layer_list)):
        torch.nn.init.normal_(self.layer_list[i].weight, mean=0, std=1.0)
        torch.nn.init.normal_(self.layer_list[i].bias, mean=0, std=1.0)
    else:
      for i in range(len(self.layer_list)):
        self.layer_list[i].weight.data = init_params[i*2]
        self.layer_list[i].bias.data = init_params[i*2+1]

  def forward(self, x):
    y = x
    for i in range(len(self.layer_list)-1):
      y = self.activate(self.layer_list[i](y))
    y = self.layer_list[-1](y)
    y = y.squeeze(-1)
    return y

class Template_Padding:
  def __init__(self, Linear_Regression):
    self.exemplars_template = """
    input:
    w={}, b={}
    value:
    {}
    """
    self.Linear_Regression = Linear_Regression
    np.random.seed(1001)

    w_samples = np.random.uniform(10, 20, 5)
    b_samples = np.random.uniform(10, 20, 5)
    samples_evaluator = []
    for i in range(5):
      samples_evaluator.append(self.Linear_Regression.evaluator(w_samples[i], b_samples[i]).item())
    self.samples = list(zip(w_samples, b_samples, samples_evaluator))
    self.samples.sort(key=lambda x: x[2], reverse=True)

    self.input_exemplars = ""
    self.Generate_Subset()
    self.subset_list = []
    self.selected_samples = None
    self.mask = None

  def Add_Padding(self):
    self.selected_samples.sort(key=lambda x: x[2], reverse=True)

    selected_samples = self.selected_samples

    self.input_exemplars = ""
    for sample in selected_samples:
        self.input_exemplars += self.exemplars_template.format(sample[0], sample[1], sample[2])

  def Padding(self, z3):
    input_exemplars = ""
    for sample in z3:
      input_exemplars += self.exemplars_template.format(sample[0], sample[1], sample[2])
    return input_exemplars

  def Add_Sample(self, generated_w, generated_b):
    self.samples.append((generated_w, generated_b, self.Linear_Regression.evaluator(generated_w, generated_b).item()))

  def Add_All(self, Meta_Prompt):
    return Meta_Prompt.format(self.input_exemplars) + "\nPlease return only the revised version of the following content without adding any additional information or explanation."

  def generate_01_low_discrepancy_sequences(self, num_sequences, dim, num_ones):
    sampler = scipy.stats.qmc.Sobol(d=dim-1, scramble=True)
    samples = sampler.random(n=num_sequences)

    binary_sequences = []
    for sample in samples:
      threshold_indices = np.argsort(sample)[:num_ones-1]
      binary_sequence = np.zeros(dim)
      binary_sequence[threshold_indices] = 1
      binary_sequence[-1] = 1
      binary_sequences.append(binary_sequence)

    additional_mask = np.zeros(dim)
    additional_mask[-20:] = 1
    binary_sequences.append(additional_mask)

    return np.array(binary_sequences)

  def get_subset(self, samples):
    num_to_select = min(len(samples), 30)
    best_samples = samples[-num_to_select:]
    binary_sequences = self.generate_01_low_discrepancy_sequences(256, num_to_select, 20)

    full_length = len(samples)
    padded_binary_sequences = []
    for binary_seq in binary_sequences:
      padded_binary_sequence = np.zeros(full_length, dtype=int)
      padded_binary_sequence[-num_to_select:] = binary_seq
      padded_binary_sequences.append(padded_binary_sequence)

    self.mask = np.array(padded_binary_sequences)

    all_selected_samples = []
    for binary_seq in self.mask:
      stack = []
      for index, value in enumerate(binary_seq):
        if value == 1:
          stack.append(samples[index])
      all_selected_samples.append(stack)

    self.subset_list = all_selected_samples
    return

  def Generate_Subset(self):
    self.samples.sort(key=lambda x: x[2], reverse=True)
    samples_copy = copy.deepcopy(self.samples)
    sample_set = list(set(samples_copy))
    sample_set.sort(key=lambda x: x[2], reverse=True)
    print("The Best Performance, w=%f, b=%f, value=%f"%(sample_set[-1][0], sample_set[-1][1], sample_set[-1][2]))
    if len(sample_set) <= 20:
      self.subset_list = [sample_set]
      self.selected_samples = sample_set
      self.Add_Padding()
    else:
      self.get_subset(sample_set)

class RequestLLM:
  def __init__(self, New_Linear_Regression, model='gpt-3.5-turbo', steps=100, epoch=7):
    self.API_key = "[API_KEY]"
    self.model = model
    self.steps = steps
    self.epoch = epoch
    self.Template = Template_Padding(New_Linear_Regression)
    self.latest_batch_score = 0
    self.stable_score = 0

  async def fetch_response(self, session, Meta_Prompt, temperature):
    if temperature==0:
      print(self.Template.Add_All(Meta_Prompt))
    url = "https://api.openai.com/v1/chat/completions"
    payload = {
      "model": self.model,
      "messages": [{"role": "user", "content": self.Template.Add_All(Meta_Prompt)}],
      "temperature": temperature,
    }
    headers = {
      "Authorization": f"Bearer {self.API_key}",
      "Content-Type": "application/json",
    }

    for retry in range(5):
      try:
        async with session.post(url, json=payload, headers=headers) as response:
          if response.status == 429:
            wait_time = 2 ** retry
            print(f"Rate limit exceeded. Retrying in {wait_time} seconds...")
            await asyncio.sleep(wait_time)
            continue
          if response.status != 200:
            print(f"Error: {response.status}, {await response.text()}")
            return None
          return await response.json()
      except Exception as e:
        print(f"Request error: {e}")
        await asyncio.sleep(5)
    print("Max retries exceeded.")
    return None

  async def process_single_epoch(self, Meta_Prompt, session, temperature=1):
    while True:
      response = await self.fetch_response(session, Meta_Prompt, temperature)
      if not response or "choices" not in response or "message" not in response["choices"][0]:
        print("Invalid or empty response from API. Retrying...")
        continue

      content = response["choices"][0]["message"]["content"].strip()
      print(f"Generated content (Temp={temperature}): {content}")

      pattern = r"w=([+-]?\d+(\.\d+)?),?\s*b=([+-]?\d+(\.\d+)?)|\[\s*([+-]?\d+(\.\d+)?),\s*([+-]?\d+(\.\d+)?)\s*\]"
      matches = re.findall(pattern, content)

      if matches:
        last_match = matches[-1]
        if last_match[0] and last_match[2]:
          w = float(last_match[0])
          b = float(last_match[2])
        elif last_match[4] and last_match[6]:
          w = float(last_match[4])
          b = float(last_match[6])
        return w, b

      print("No valid match found in response. Retrying...")


  async def Prompt_LLM_Epoch(self, Meta_Prompt, temperature=1):
    async with ClientSession() as session:
      tasks = [self.process_single_epoch(Meta_Prompt, session, temperature) for _ in range(self.epoch)]
      results = await asyncio.gather(*tasks)
      request_success = 0

      for result in results:
        if result:
          w, b = result
          self.Template.Add_Sample(w, b)
          request_success += 1
      print(f"Successfully processed {request_success} requests.")

  async def Prompt_LLM_Zero_Temp(self, Meta_Prompt):
    async with ClientSession() as session:
      while True:
        result = await self.process_single_epoch(Meta_Prompt, session, temperature=0)
        if result:
          w, b = result
          self.stable_score = self.Template.Linear_Regression.evaluator(w, b)
          self.Template.Add_Sample(w, b)
          print(f"The stable score is: {self.stable_score}")
          return

  async def Prompt_LLM_Step(self, Meta_Prompt):
    await self.Prompt_LLM_Zero_Temp(Meta_Prompt)
    for i in range(self.steps):
      await self.Prompt_LLM_Epoch(Meta_Prompt, temperature=1)
      print("Step %d: The latest Meta-Prompt score of Prompts is %f" % (i + 1, self.stable_score))

class PromptEmbedding:
  def __init__(self, z1_z2_prompt_data, New_Linear_Regression, Optimizer_Model='gpt-3.5-turbo', Embedding_model='text-embedding-3-large'):
    self.API_key = "[API_KEY]"
    self.z1_z2_prompt = z1_z2_prompt_data
    self.Embedding_model = Embedding_model
    self.RequestLLM = RequestLLM(New_Linear_Regression, model='gpt-3.5-turbo', steps=1, epoch=7)
    self.Prompt_Embedding_List = []

  def normalize_l2(self, x):
    x = np.array(x)
    norm = np.linalg.norm(x)
    if norm == 0:
      return x
    return x / norm

  async def async_get_single_embedding(self, text, session, target_dim=3072):
    headers = {
      "Authorization": f"Bearer {self.API_key}",
      "Content-Type": "application/json",
    }
    url = "https://api.openai.com/v1/embeddings"
    payload = {
      "input": text,
      "model": self.Embedding_model
    }
    while True:
      try:
        await asyncio.sleep(random.uniform(0.5, 1.5))
        async with session.post(url, json=payload, headers=headers) as response:
          if response.status == 429:
            print("Rate limit exceeded. Waiting before retrying...")
            await asyncio.sleep(10)
            continue
          response_json = await response.json()
          embedding = response_json['data'][0]['embedding']
          truncated_embedding = embedding[:target_dim]
          normalized_embedding = list(self.normalize_l2(truncated_embedding))
          return normalized_embedding
      except Exception as e:
        print(f"An error occurred: {e}")
        print("Retrying...")
        await asyncio.sleep(5)

  async def get_batch_embedding_z3_async(self, Z1_Z2_Embedding):
    openai.api_key = self.API_key
    self.Prompt_Embedding_List = []
    tasks = []
    async with ClientSession() as session:
      for i in tqdm(range(0, len(self.RequestLLM.Template.subset_list), 128), desc="Batch Embedding ... "):
        batch = self.RequestLLM.Template.subset_list[i:i + 128]
        for prompt in batch:
          tasks.append(self.async_get_single_embedding(self.RequestLLM.Template.Padding(prompt), session))
        batch_results = await asyncio.gather(*tasks)
        self.Prompt_Embedding_List.extend(batch_results)
        tasks = []
        if i + 128 < len(self.RequestLLM.Template.subset_list):
          print("Batch completed, sleeping for 15 seconds...")
          await asyncio.sleep(15)

  def Get_Batch_Embedding_z3(self, Z1_Z2_Embedding):
    asyncio.run(self.get_batch_embedding_z3_async(Z1_Z2_Embedding))

  def Get_Single_Embedding(self, text, target_dim=3072):
    openai.api_key = self.API_key
    text = text.replace("\n", " ")
    while True:
      try:
        time.sleep(random.uniform(0.5, 1.5))
        response = openai.Embedding.create(
          input=[text],
          model=self.Embedding_model
        )
        embedding = response['data'][0]['embedding']
        truncated_embedding = embedding[:target_dim]
        normalized_embedding = list(self.normalize_l2(truncated_embedding))
        return normalized_embedding
      except openai.error.RateLimitError:
        print("Rate limit exceeded. Waiting before retrying...")
        time.sleep(10)
      except Exception as e:
        print(f"An error occurred: {e}")
        print("Retrying...")
        time.sleep(5)

  def PaddingforPrompt(self):
    tqdm.pandas()
    self.z1_z2_prompt['Completed_Prompt'] = self.z1_z2_prompt['Merged_Content'].progress_apply(lambda x: self.RequestLLM.Template.Add_All(x))

  def Get_Batch_Embedding(self, column='PromptEmbedding'):
    openai.api_key = self.API_key
    self.PaddingforPrompt()
    exemplars_embedding = self.Get_Single_Embedding(self.RequestLLM.Template.input_exemplars)
    tqdm.pandas()
    self.z1_z2_prompt[column] = self.z1_z2_prompt['Z1-Z2-Embedding'].progress_apply(
      lambda x: x
    )

  async def forward(self, current_point, current_z1, current_z2, z1_z2_Embedding):
    self.Get_Batch_Embedding()
    exemplars_embedding = self.Get_Single_Embedding(self.RequestLLM.Template.input_exemplars)
    await self.RequestLLM.Prompt_LLM_Step(current_point)
    current_point_score = self.RequestLLM.stable_score
    current_point_embedding = z1_z2_Embedding
    concatenated_embedding = current_point_embedding + exemplars_embedding
    return [concatenated_embedding, current_point_score]

class AdversarialBandit: 
  def __init__(self, z1_z2_prompt_data, New_Linear_Regression, input_dim=6144, steps=50, NumberofK=10201):
    self.z1_z2_prompt_data = z1_z2_prompt_data
    self.input_dim = input_dim
    self.steps = steps
    self.NumberofK = NumberofK
    self.device = torch.device("cuda")

    self.model = Network(6144, 1536, 1).to('cuda').double()
    self.model_z3 = Network(3072, 512, 1).to('cuda').float()

    self.lamdba = 0.1
    self.nu = 1

    tkwargs = {
      "dtype": torch.float32,
      "device": torch.device("cuda")
    }
    self.func = Network(self.input_dim).to(**tkwargs)
    self.total_param = sum(p.numel() for p in self.model.parameters() if p.requires_grad)

    self.best_result_list = []

    self.lr = 1e-3
    self.epochs = 5000
    self.optimizer_fn = torch.optim.Adam
    self.optimizer_fn_z3 = torch.optim.Adam
    self.optimizer = self.optimizer_fn(self.model.parameters(), lr=self.lr, weight_decay=1)
    self.optimizer_z3 = self.optimizer_fn_z3(self.model_z3.parameters(), lr=self.lr, weight_decay=1)
    self.loss_fn = nn.MSELoss()
    self.loss_fn_z3 = nn.MSELoss()

    self.init_model_weight = copy.deepcopy(self.model.state_dict())
    self.init_model_weight_z3 = copy.deepcopy(self.model_z3.state_dict())

    self.eta = 100
    self.prompt_score_trace = []
    self.embedding_score_trace = []

    self.eta_z3 = 25
    self.NN_Para = []
    self.historical_score = []


    self.current_point = """Now you will help me minimize a function with two input variables, w and b. I have some (w, b) pairs and the function values at those points. The pairs are arranged in descending order based on their function values, where smaller function values indicate better results. Therefore, although the pairs are listed from the highest function value to the lowest, the pair with the smallest function value is considered the most optimal.

{}

Give me a new (w, b) pair that is different from all pairs above, and has a function value lower than any of the above. Do not write code. The output must end with a pair [w, b], where w and b are numerical values.
"""

    self.current_z1 = """Now you will help me minimize a function with two input variables, w and b. I have some (w, b) pairs and the function values at those points. The pairs are arranged in descending order based on their function values, where smaller function values indicate better results. Therefore, although the pairs are listed from the highest function value to the lowest, the pair with the smallest function value is considered the most optimal."""
    self.current_z2 = """Give me a new (w, b) pair that is different from all pairs above, and has a function value lower than any of the above. Do not write code. The output must end with a pair [w, b], where w and b are numerical values."""

    self.prompt_embedding = PromptEmbedding(z1_z2_prompt_data, New_Linear_Regression, Optimizer_Model='gpt-3.5-turbo', Embedding_model='text-embedding-3-large')

    self.current_z1_z2_embedding = self.prompt_embedding.Get_Single_Embedding(self.current_z1) + self.prompt_embedding.Get_Single_Embedding(self.current_z2)

    self.c = None
    self.c_z3 = None

    self.scheduler = None
    self.scheduler_z3 = None

    self.Stop_EXP3 = 0


  def restart_model_z3(self):
    self.model_z3.load_state_dict(copy.deepcopy(self.init_model_weight_z3))
    self.optimizer_z3 = self.optimizer_fn_z3(self.model_z3.parameters(), lr=self.lr, weight_decay=1/self.embedding_score_trace.__len__())
    self.scheduler_z3 = torch.optim.lr_scheduler.LambdaLR(self.optimizer_z3, lr_lambda=lambda epoch: 0.1 if epoch > 3000 else 1.0)

  def restart_model(self):
    self.model.load_state_dict(copy.deepcopy(self.init_model_weight))
    self.optimizer = self.optimizer_fn(self.model.parameters(), lr=self.lr, weight_decay=1/self.embedding_score_trace.__len__())
    self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda epoch: 0.1 if epoch > 3000 else 1.0)

  def train_model_z3(self):
    self.restart_model_z3()
    self.model_z3.to('cuda')

    embeddings = [item[0][-3072:] for item in self.embedding_score_trace]
    scores = [item[1] for item in self.embedding_score_trace]

    embeddings_tensor = torch.stack([torch.tensor(e, dtype=torch.float32) for e in embeddings]).to('cuda')
    scores_tensor = torch.tensor(scores, dtype=torch.float32).to('cuda')

    dataset = torch.utils.data.TensorDataset(embeddings_tensor, scores_tensor)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)

    self.model_z3.train()
    for epoch in tqdm(range(self.epochs), desc="Training Model Z3"):
      for batch_embeddings, batch_scores in dataloader:
        self.optimizer_z3.zero_grad()

        predictions = self.model_z3(batch_embeddings)

        loss = self.loss_fn_z3(predictions, batch_scores)
        loss.backward()

        self.optimizer_z3.step()

      self.scheduler_z3.step()

      if epoch % 100 == 0:
        print(f"Epoch {epoch}/{self.epochs}, Loss: {loss.item()}")

    self.NN_Para.append(copy.deepcopy(self.model_z3.state_dict()))

    print("Training Model Z3 complete.")

  def train_model(self):
    self.restart_model()
    self.model.to('cuda')

    embeddings = [item[0][:6144] for item in self.embedding_score_trace]
    scores = [item[1] for item in self.embedding_score_trace]

    embeddings_tensor = torch.stack([torch.tensor(e, dtype=torch.float64) for e in embeddings]).to('cuda')
    scores_tensor = torch.tensor(scores, dtype=torch.float64).to('cuda')

    dataset = torch.utils.data.TensorDataset(embeddings_tensor, scores_tensor)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)

    self.model.train()
    for epoch in tqdm(range(self.epochs), desc="Training Model"):
      for batch_embeddings, batch_scores in dataloader:
        self.optimizer.zero_grad()

        predictions = self.model(batch_embeddings)

        loss = self.loss_fn(predictions, batch_scores)
        loss.backward()

        self.optimizer.step()

      self.scheduler.step()

      if epoch % 100 == 0:
        print(f"Epoch {epoch}/{self.epochs}, Loss: {loss.item()}")

    print("Training complete.")


  def select_z3(self, iteration):
    prompt_embeddings = torch.stack(
      [torch.tensor(embedding, dtype=torch.float32) for embedding in self.prompt_embedding.Prompt_Embedding_List]
    ).to('cuda')

    num_models = len(self.NN_Para)
    num_embeddings = prompt_embeddings.size(0)

    self.historical_score = [torch.zeros(num_embeddings, dtype=torch.float32) for _ in range(num_models)]

    for model_idx in tqdm(range(num_models), desc="Evaluating models Z3"):
      self.model_z3.load_state_dict(self.NN_Para[model_idx])
      self.model_z3.eval()

      with torch.no_grad():
        scores = self.model_z3(prompt_embeddings).squeeze(-1)

      if model_idx == 0:
        max_score = scores.max().item()
        self.c_z3 = max_score

      shifted_scores = (-scores + self.c_z3) / self.c_z3

      self.historical_score[model_idx] = shifted_scores.cpu()

    total_scores = torch.sum(torch.stack(self.historical_score), dim=0)
    print("The predicted acumulative scores: ", total_scores)
    max_log_score = torch.max(self.eta_z3 * total_scores)
    safe_log_exp_scores = self.eta_z3 * total_scores - max_log_score
    exp_scores = torch.exp(safe_log_exp_scores)

    softmax_scores = exp_scores / torch.sum(exp_scores)

    print("The number of Embedding Vectors and models are: %d and %d" % (num_embeddings, num_models))

    top_scores, top_indices = torch.topk(softmax_scores, 5)
    print("Top 5 scores: ", top_scores)
    print("Indices of top 5 scores: ", top_indices)

    if not torch.isnan(softmax_scores).any() and not torch.isinf(softmax_scores).any() and torch.sum(softmax_scores) > 0:
      best_index = torch.multinomial(softmax_scores, 1).item()
      print("Selected index: ", best_index)
      print("Selected Mask: ", self.prompt_embedding.RequestLLM.Template.mask[best_index])

    else:
      best_index = top_indices[0]
      print("Selected index: ", best_index)
      print("Selected Mask: ", self.prompt_embedding.RequestLLM.Template.mask[best_index])

    return best_index



  def select(self, iteration):
    embeddings = torch.stack(
      [torch.tensor(e, dtype=torch.float64) for e in self.prompt_embedding.z1_z2_prompt['PromptEmbedding']]
    ).to('cuda')

    with torch.no_grad():
      scores = self.model(embeddings).squeeze(-1)

    if iteration == 0:
      self.c = scores.max().item()

    print(f"Shift value (self.c): {self.c}")

    shifted_scores = (-scores + self.c) / self.c
    self.z1_z2_prompt_data['Score'] = shifted_scores.cpu().tolist()

    self.prompt_score_trace.append(self.z1_z2_prompt_data['Score'])
    score_tensor = torch.tensor(self.prompt_score_trace, dtype=torch.float64)
    exp_scores = torch.exp(self.eta * torch.sum(score_tensor, dim=0))
    total_score = torch.sum(exp_scores)
    normalized_exp_scores = exp_scores / total_score

    top_scores, top_indices = torch.topk(normalized_exp_scores, 10)
    print("Top 10 probabilities: ", top_scores)
    print("Indices of Top 10 probabilities: ", top_indices)

    if not torch.isnan(normalized_exp_scores).any() and not torch.isinf(normalized_exp_scores).any() and torch.sum(normalized_exp_scores) > 0:
      best_index = torch.multinomial(normalized_exp_scores, 1).item()
      self.best_index = best_index
    else:
      print("Encountered NaN or Inf in normalized_exp_scores. Stopping EXP3.")
      self.Stop_EXP3 = 1
      return

    print("The index of selected prompt is: %d"%(self.best_index))
    print("The selected prompt's cumulative score: ", torch.sum(score_tensor, dim=0)[self.best_index])
    best_prompt = self.z1_z2_prompt_data.iloc[self.best_index]['Merged_Content']
    self.current_point = best_prompt
    self.current_z1 = self.z1_z2_prompt_data.iloc[self.best_index]['Z1']
    self.current_z2 = self.z1_z2_prompt_data.iloc[self.best_index]['Z2']
    self.current_z1_z2_Embedding = self.z1_z2_prompt_data.iloc[self.best_index]['Z1-Z2-Embedding']



  def optimization(self):
    for i in tqdm(range(self.steps), desc="Optimizing the Meta-Prompt " + str(self.steps) + " step"):
      forward_result = asyncio.run(self.prompt_embedding.forward(self.current_point, self.current_z1, self.current_z2, self.current_z1_z2_Embedding))
      self.embedding_score_trace.append(forward_result)
      if not self.Stop_EXP3:
        self.train_model()
        self.select(i)
      self.train_model_z3()
      self.prompt_embedding.RequestLLM.Template.Generate_Subset()
      if len(self.prompt_embedding.RequestLLM.Template.subset_list) != 1:
        self.prompt_embedding.Get_Batch_Embedding_z3(self.current_z1_z2_Embedding)
        best_index = self.select_z3(i)
        self.prompt_embedding.RequestLLM.Template.selected_samples = self.prompt_embedding.RequestLLM.Template.subset_list[best_index]
        self.prompt_embedding.RequestLLM.Template.Add_Padding()
        print(self.prompt_embedding.RequestLLM.Template.input_exemplars)

      self.best_result_list.append(self.prompt_embedding.RequestLLM.Template.samples[-1][2])
      print(self.best_result_list)

converge_curve = []

for i in range(5):
  new_meta_prompt_data = copy.deepcopy(Meta_Prompt_data)
  new_optimization = AdversarialBandit(new_meta_prompt_data, New_Linear_Regression)
  new_optimization.optimization()
  converge_curve.append(copy.deepcopy(new_optimization.best_result_list))
  print(converge_curve)