import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import re
from tqdm.auto import tqdm
import time
import copy
import ast
import random
import scipy.stats.qmc

import gurobipy as gp
from gurobipy import GRB

import torch
from torch import nn
import openai
import asyncio
from aiohttp import ClientSession

import nest_asyncio
nest_asyncio.apply()

Meta_Prompt_data = pd.read_csv("PATH/DATA_FILE.csv")
tqdm.pandas()
Meta_Prompt_data['Z1-Z2-Embedding'] = Meta_Prompt_data['Z1-Z2-Embedding'].progress_apply(ast.literal_eval)
Meta_Prompt_data.head()

class OPRO4TSPData:
    def __init__(self, nodes_number, seed=3407):
      np.random.seed(seed)
      self.nodes = np.random.randint(-100, 101, size=(nodes_number, 2))
      self.nodes_number = nodes_number
      self.adj_matrix = self.compute_euclidean_distance_matrix()

      self.optimal_tour, self.optimal_distance = self.solve_tsp()

    def compute_euclidean_distance_matrix(self):
      num_locations = len(self.nodes)
      distances = np.zeros((num_locations, num_locations))
      for i in range(num_locations):
        for j in range(num_locations):
          distances[i, j] = np.linalg.norm(self.nodes[i] - self.nodes[j])
      return distances

    def evaluator(self, trace):
      total_distance = 0
      for i in range(len(trace) - 1):
        total_distance += self.adj_matrix[trace[i], trace[i + 1]]
      total_distance += self.adj_matrix[trace[-1], trace[0]]
      return int(total_distance)

    def solve_tsp(self):
      dist_matrix = self.adj_matrix
      n = self.nodes_number

      m = gp.Model()

      x = m.addVars(n, n, vtype=GRB.BINARY, name="x")

      m.setObjective(gp.quicksum(dist_matrix[i, j] * x[i, j] for i in range(n) for j in range(n)), GRB.MINIMIZE)

      m.addConstrs(gp.quicksum(x[i, j] for j in range(n) if i != j) == 1 for i in range(n))
      m.addConstrs(gp.quicksum(x[i, j] for i in range(n) if i != j) == 1 for j in range(n))

      m.addConstrs(x[i, i] == 0 for i in range(n))

      def subtour(eliminated_edges):
        unvisited = list(range(n))
        cycle = []
        while unvisited:
          thiscycle = []
          neighbors = [unvisited[0]]
          while neighbors:
            current = neighbors.pop()
            if current in unvisited:
              thiscycle.append(current)
              unvisited.remove(current)
              neighbors += [j for j in range(n) if eliminated_edges[current, j] > 0.5]
          if len(thiscycle) < len(cycle) or not cycle:
            cycle = thiscycle
        return cycle

      def add_subtour_constraints(m, solution):
        tour = subtour(solution)
        if len(tour) < n:
          m.addConstr(gp.quicksum(x[i, j] for i in tour for j in tour if i != j) <= len(tour) - 1)

      solution = None
      while True:
        m.optimize()
        solution = m.getAttr('x', x)
        tour = subtour(solution)
        if len(tour) == n:
          break
        add_subtour_constraints(m, solution)

      optimal_tour = subtour(solution)
      optimal_distance = self.evaluator(optimal_tour)
      print(f"Optimal tour: {optimal_tour}")
      print(f"Optimal distance: {optimal_distance}")

      return optimal_tour, optimal_distance

    def visualize_tsp(self, tour=None):
      nodes = self.nodes
      plt.figure(figsize=(8, 6))
      plt.scatter(nodes[:, 0], nodes[:, 1], color='blue', label='Nodes')
      for i, (x, y) in enumerate(nodes):
        plt.text(x + 2, y + 2, str(i), color="red", fontsize=12)

      if tour is not None:
        for i in range(len(tour) - 1):
          plt.plot([nodes[tour[i]][0], nodes[tour[i + 1]][0]], [nodes[tour[i]][1], nodes[tour[i + 1]][1]], 'r-')
        plt.plot([nodes[tour[-1]][0], nodes[tour[0]][0]], [nodes[tour[-1]][1], nodes[tour[0]][1]], 'r-')
        plt.title('TSP Solution with Optimal Tour')
      else:
        plt.title('TSP Nodes')

      plt.xlabel('X')
      plt.ylabel('Y')
      plt.legend()
      plt.grid(True)
      plt.show()

class Network(nn.Module):
  def __init__(self, input_dim, hidden_size=100, depth=1, init_params=None):
    super(Network, self).__init__()

    self.activate = nn.ReLU()
    self.layer_list = nn.ModuleList()
    self.layer_list.append(nn.Linear(input_dim, hidden_size))
    for i in range(depth-1):
      self.layer_list.append(nn.Linear(hidden_size, hidden_size))
    self.layer_list.append(nn.Linear(hidden_size, 1))

    if init_params is None:
      for i in range(len(self.layer_list)):
        torch.nn.init.normal_(self.layer_list[i].weight, mean=0, std=1.0)
        torch.nn.init.normal_(self.layer_list[i].bias, mean=0, std=1.0)
    else:
      for i in range(len(self.layer_list)):
        self.layer_list[i].weight.data = init_params[i*2]
        self.layer_list[i].bias.data = init_params[i*2+1]

  def forward(self, x):
    y = x
    for i in range(len(self.layer_list)-1):
      y = self.activate(self.layer_list[i](y))
    y = self.layer_list[-1](y)
    y = y.squeeze(-1)
    return y

class Template_Padding:
  def __init__(self, nodes_number):
    self.exemplars_template = """
    <trace>{}</trace>
    length:
    {}
    """
    self.nodes_number = nodes_number
    self.TSP = OPRO4TSPData(nodes_number=self.nodes_number)
    np.random.seed(1001)

    random_solutions = [np.random.permutation(self.nodes_number).tolist() for _ in range(5)]

    for i in range(len(random_solutions)):
      zero_index = random_solutions[i].index(0)
      random_solutions[i] = random_solutions[i][zero_index:] + random_solutions[i][:zero_index]

    samples_evaluator = []
    for i in range(5):
      samples_evaluator.append(self.TSP.evaluator(random_solutions[i]))
    self.samples = list(zip(random_solutions, samples_evaluator))
    self.samples.sort(key=lambda x: x[1], reverse=True)

    self.input_TSP = ""
    self.Convert_TSP_Points()

    self.input_exemplars = ""
    self.Generate_Subset()
    self.subset_list = []
    self.selected_samples = None
    self.mask = None

  def Convert_TSP_Points(self):
    tsp_points = self.TSP.nodes
    formatted_points = ", ".join([f"({i}): ({x}, {y})" for i, (x, y) in enumerate(tsp_points)])
    self.input_TSP = formatted_points

  def Add_Padding(self):
    self.selected_samples.sort(key=lambda x: x[1], reverse=True)
    selected_samples = self.selected_samples

    self.input_exemplars = ""
    for sample in selected_samples:
      formatted_solution = ','.join(map(str, list(sample[0])))
      self.input_exemplars += self.exemplars_template.format(formatted_solution, sample[1])

  def Padding(self, z3):
    input_exemplars = ""
    for sample in z3:
      formatted_solution = ','.join(map(str, list(sample[0])))
      input_exemplars += self.exemplars_template.format(formatted_solution, sample[1])
    return input_exemplars

  def Add_Sample(self, generated_trace):
    self.samples.append((generated_trace, self.TSP.evaluator(generated_trace)))

  def Add_All(self, Meta_Prompt):
    return Meta_Prompt.format(self.input_TSP, self.input_exemplars)

  def generate_01_low_discrepancy_sequences(self, num_sequences, dim, num_ones):
    sampler = scipy.stats.qmc.Sobol(d=dim-1, scramble=True)
    samples = sampler.random(n=num_sequences)

    binary_sequences = []
    for sample in samples:
      threshold_indices = np.argsort(sample)[:num_ones-1]
      binary_sequence = np.zeros(dim)
      binary_sequence[threshold_indices] = 1
      binary_sequence[-1] = 1
      binary_sequences.append(binary_sequence)

    additional_mask = np.zeros(dim)
    additional_mask[-20:] = 1
    binary_sequences.append(additional_mask)

    return np.array(binary_sequences)

  def get_subset(self, samples):
    num_to_select = min(len(samples), 30)
    best_samples = samples[-num_to_select:]

    binary_sequences = self.generate_01_low_discrepancy_sequences(256, num_to_select, 20)

    full_length = len(samples)
    padded_binary_sequences = []
    for binary_seq in binary_sequences:
      padded_binary_sequence = np.zeros(full_length, dtype=int)
      padded_binary_sequence[-num_to_select:] = binary_seq
      padded_binary_sequences.append(padded_binary_sequence)

    self.mask = np.array(padded_binary_sequences)

    all_selected_samples = []
    for binary_seq in self.mask:
      stack = []
      for index, value in enumerate(binary_seq):
        if value == 1:
          stack.append(samples[index])
      all_selected_samples.append(stack)

    self.subset_list = all_selected_samples
    return

  def Generate_Subset(self):
    self.samples.sort(key=lambda x: x[1], reverse=True)
    samples_copy = copy.deepcopy(self.samples)

    sample_set = list(set((tuple(sample[0]), sample[1]) for sample in samples_copy))
    sample_set.sort(key=lambda x: x[1], reverse=True)

    if len(sample_set) <= 20:
      self.subset_list = [sample_set]
      self.selected_samples = sample_set
      self.Add_Padding()
    else:
      self.get_subset(sample_set)

class RequestLLM:
  def __init__(self, nodes_number, model='gpt-3.5-turbo', steps=100, epoch=8):
    self.API_key = "[API_KEY]"
    self.model = model
    self.steps = steps
    self.epoch = epoch
    self.nodes_number = nodes_number
    self.Template = Template_Padding(self.nodes_number)
    self.latest_batch_score = 0
    self.stable_score = 0
    self.count_time = 0

  async def fetch_response(self, session, url, payload):
    headers = {
      "Authorization": f"Bearer {self.API_key}",
      "Content-Type": "application/json",
    }

    for retry_count in range(5):
      try:
        async with session.post(url, json=payload, headers=headers) as response:
          if response.status == 429:
            wait_time = 10 * (2 ** retry_count)
            print(f"Rate limit exceeded. Retrying in {wait_time} seconds...")
            await asyncio.sleep(wait_time)
            continue
          if response.status != 200:
            error_text = await response.text()
            print(f"Error response: {error_text}")
            return None
          return await response.json()
      except Exception as e:
        print(f"Error occurred: {e}")
        await asyncio.sleep(30)

    raise RuntimeError("Exceeded maximum retry attempts.")

  async def generate_trace(self, Meta_Prompt, session, temperature=1):
    url = "https://api.openai.com/v1/chat/completions"
    payload = {
      "model": self.model,
      "messages": [{"role": "user", "content": self.Template.Add_All(Meta_Prompt)}],
      "temperature": temperature,
    }

    while True:
      response = await self.fetch_response(session, url, payload)
      if not response or "choices" not in response or "message" not in response["choices"][0]:
        print("Invalid or empty response from API. Retrying...")
        continue

      content = response["choices"][0]["message"]["content"].strip()
      print(f"Generated content (Temp={temperature}): {content}")

      trace_match = re.search(r"<trace>(.*?)</trace>", content)
      if not trace_match:
        print("Invalid trace format: Missing <trace> tags. Retrying...")
        continue

      trace_str = trace_match.group(1).strip()
      if not re.match(r"^\d+(,\s*\d+)*$", trace_str):
        print("Invalid trace format: Incorrect number sequence. Retrying...")
        continue

      trace_list = list(map(int, re.split(r",\s*", trace_str)))

      if len(trace_list) > self.nodes_number:
        trace_list = trace_list[:self.nodes_number]
        print(f"Trimmed trace to {self.nodes_number} elements: {trace_list}")

      if (
        len(trace_list) == self.nodes_number
        and len(set(trace_list)) == self.nodes_number
        and all(x in range(self.nodes_number) for x in trace_list)
      ):
        while trace_list[0] != 0:
            trace_list = trace_list[1:] + trace_list[:1]
        print(f"Valid trace: {trace_list}")
        return trace_list

      print("Invalid trace. Retrying...")

  async def Prompt_LLM_Epoch(self, Meta_Prompt, temperature=1):
    async with ClientSession() as session:
      tasks = [
        self.generate_trace(Meta_Prompt, session, temperature=temperature) for _ in range(self.epoch)
      ]
      results = await asyncio.gather(*tasks)
      for trace_list in results:
        if trace_list:
          self.Template.Add_Sample(trace_list)

  async def Prompt_LLM_Epoch_Temp_0(self, Meta_Prompt):
    print(self.Template.Add_All(Meta_Prompt))
    async with ClientSession() as session:
      while True:
        trace_list = await self.generate_trace(Meta_Prompt, session, temperature=0)
        if trace_list:
          self.stable_score = self.Template.TSP.evaluator(trace_list)
          self.Template.Add_Sample(trace_list)
          print(f"Temp=0 Trace: {trace_list}, Score: {self.stable_score}")
          return trace_list

  async def Prompt_LLM_Step(self, Meta_Prompt):
    self.count_time += 1
    if not (self.count_time % 5):
      print("Wait for 10 seconds ...")
      time.sleep(10)
    await self.Prompt_LLM_Epoch_Temp_0(Meta_Prompt)
    for i in range(self.steps):
      await self.Prompt_LLM_Epoch(Meta_Prompt, temperature=1)

class PromptEmbedding:
  def __init__(self, z1_z2_prompt_data, nodes_number, Optimizer_Model='gpt-3.5-turbo', Embedding_model='text-embedding-3-large'):
    self.API_key = "[API_KEY]"
    self.z1_z2_prompt = z1_z2_prompt_data
    self.Embedding_model = Embedding_model
    self.RequestLLM = RequestLLM(nodes_number, model='gpt-3.5-turbo', steps=1, epoch=7)
    self.Prompt_Embedding_List = []

  def normalize_l2(self, x):
    x = np.array(x)
    norm = np.linalg.norm(x)
    if norm == 0:
      return x
    return x / norm

  async def async_get_single_embedding(self, text, session, target_dim=3072):
    headers = {
      "Authorization": f"Bearer {self.API_key}",
      "Content-Type": "application/json",
    }
    url = "https://api.openai.com/v1/embeddings"
    payload = {
      "input": text,
      "model": self.Embedding_model
    }
    while True:
      try:
        await asyncio.sleep(random.uniform(0.5, 1.5))
        async with session.post(url, json=payload, headers=headers) as response:
          if response.status == 429:
            print("Rate limit exceeded. Waiting before retrying...")
            await asyncio.sleep(10)
            continue
          response_json = await response.json()
          embedding = response_json['data'][0]['embedding']
          truncated_embedding = embedding[:target_dim]
          normalized_embedding = list(self.normalize_l2(truncated_embedding))
          return normalized_embedding
      except Exception as e:
        print(f"An error occurred: {e}")
        print("Retrying...")
        await asyncio.sleep(30)

  async def get_batch_embedding_async(self):
    openai.api_key = self.API_key
    self.Prompt_Embedding_List = []
    tasks = []
    async with ClientSession() as session:
      for i in tqdm(range(0, len(self.RequestLLM.Template.subset_list), 128), desc="Batch Embedding ... "):
        batch = self.RequestLLM.Template.subset_list[i:i + 128]
        for prompt in batch:
          tasks.append(self.async_get_single_embedding(self.RequestLLM.Template.Padding(prompt), session))
        batch_results = await asyncio.gather(*tasks)
        self.Prompt_Embedding_List.extend(batch_results)
        tasks = []
        if i + 128 < len(self.RequestLLM.Template.subset_list):
          print("Batch completed, sleeping for 15 seconds...")
          await asyncio.sleep(15)

  def Get_Batch_Embedding_z3(self):
    asyncio.run(self.get_batch_embedding_async())

  def Get_Single_Embedding(self, text, target_dim=3072):
    openai.api_key = self.API_key
    text = text.replace("\n", " ")
    while True:
      try:
        time.sleep(random.uniform(0.5, 1.5))
        response = openai.Embedding.create(
          input=[text],
          model=self.Embedding_model
        )
        embedding = response['data'][0]['embedding']
        truncated_embedding = embedding[:target_dim]
        normalized_embedding = list(self.normalize_l2(truncated_embedding))
        return normalized_embedding
      except openai.error.RateLimitError:
        print("Rate limit exceeded. Waiting before retrying...")
        time.sleep(10)
      except Exception as e:
        print(f"An error occurred: {e}")
        print("Retrying...")
        time.sleep(30)

  def Get_Batch_Embedding(self, column='PromptEmbedding'):
    openai.api_key = self.API_key
    # exemplars_embedding = self.Get_Single_Embedding(self.RequestLLM.Template.input_exemplars)
    tqdm.pandas()
    self.z1_z2_prompt[column] = self.z1_z2_prompt['Z1-Z2-Embedding'].progress_apply(
      lambda x: x
    )


  def forward(self, current_point, current_z1, current_z2, z1_z2_embedding):
    self.Get_Batch_Embedding()
    exemplars_embedding = self.Get_Single_Embedding(self.RequestLLM.Template.input_exemplars)
    asyncio.run(self.RequestLLM.Prompt_LLM_Step(current_point))
    current_point_score = self.RequestLLM.stable_score
    current_point_embedding = z1_z2_embedding
    concatenated_embedding = current_point_embedding + exemplars_embedding
    return [concatenated_embedding, current_point_score]

class AdversarialBandit:
  def __init__(self, z1_z2_prompt_data, nodes_number, input_dim=6144, steps=300, NumberofK=10201):
    self.z1_z2_prompt_data = z1_z2_prompt_data
    self.input_dim = input_dim
    self.steps = steps
    self.NumberofK = NumberofK
    self.device = torch.device("cuda")

    self.model = Network(6144, 1536, 1).to('cuda').float()
    self.model_z3 = Network(3072, 512, 1).to('cuda').float()

    self.lamdba = 0.1
    self.nu = 1

    tkwargs = {
      "dtype": torch.float64,
      "device": torch.device("cuda")
    }
    self.func = Network(self.input_dim).to(**tkwargs)
    self.total_param = sum(p.numel() for p in self.model.parameters() if p.requires_grad)

    self.best_result_list = []

    self.lr = 1e-3
    self.epochs = 5000
    self.optimizer_fn = torch.optim.Adam
    self.optimizer_fn_z3 = torch.optim.Adam
    self.optimizer = self.optimizer_fn(self.model.parameters(), lr=self.lr, weight_decay=1)
    self.optimizer_z3 = self.optimizer_fn_z3(self.model_z3.parameters(), lr=self.lr, weight_decay=1)
    self.loss_fn = nn.MSELoss()
    self.loss_fn_z3 = nn.MSELoss()

    self.init_model_weight = copy.deepcopy(self.model.state_dict())
    self.init_model_weight_z3 = copy.deepcopy(self.model_z3.state_dict())


    self.eta = 100
    self.prompt_score_trace = []
    self.embedding_score_trace = []

    self.eta_z3 = 25
    self.NN_Para = []
    self.historical_score = []
    self.current_point = """You are given a list of points with coordinates below:{}

Below are some previous traces and their lengths. The traces are arranged in descending order based on their lengths, where smaller lengths indicate better solutions. Therefore, the traces are listed from the largest length to the smallest, the trace with the smallest length is considered the most optimal.

{}

Give me a new trace that is different from all traces above, and has a length lower than any of the above. The trace should traverse all points exactly once. The trace should start with <trace> and end with </trace>.
    """
    self.current_z1 = """You are given a list of points with coordinates below:{}

Below are some previous traces and their lengths. The traces are arranged in descending order based on their lengths, where smaller lengths indicate better solutions. Therefore, the traces are listed from the largest length to the smallest, the trace with the smallest length is considered the most optimal."""
    self.current_z2 = """Give me a new trace that is different from all traces above, and has a length lower than any of the above. The trace should traverse all points exactly once. The trace should start with <trace> and end with </trace>."""

    self.prompt_embedding = PromptEmbedding(z1_z2_prompt_data, nodes_number, Optimizer_Model='gpt-3.5-turbo', Embedding_model='text-embedding-3-large')

    self.current_z1_z2_Embedding = self.prompt_embedding.Get_Single_Embedding(self.current_z1) + self.prompt_embedding.Get_Single_Embedding(self.current_z2)
    self.c = None
    self.c_z3 = None

    self.scheduler = None
    self.scheduler_z3 = None

    self.Stop_EXP3 = 0
    self.index_history = []

  def restart_model_z3(self):
    self.model_z3.load_state_dict(copy.deepcopy(self.init_model_weight_z3))
    self.optimizer_z3 = self.optimizer_fn_z3(self.model_z3.parameters(), lr=self.lr, weight_decay=1/self.embedding_score_trace.__len__())
    self.scheduler_z3 = torch.optim.lr_scheduler.LambdaLR(self.optimizer_z3, lr_lambda=lambda epoch: 0.1 if epoch > 3000 else 1.0)


  def restart_model(self):
    self.model.load_state_dict(copy.deepcopy(self.init_model_weight))
    self.optimizer = self.optimizer_fn(self.model.parameters(), lr=self.lr, weight_decay=1/self.embedding_score_trace.__len__())
    self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda epoch: 0.1 if epoch > 3000 else 1.0)

  def train_model_z3(self):

    self.restart_model_z3()
    self.model_z3.to('cuda')

    embeddings = [item[0][-3072:] for item in self.embedding_score_trace]
    scores = [item[1] for item in self.embedding_score_trace]

    embeddings_tensor = torch.stack([torch.tensor(e, dtype=torch.float32) for e in embeddings]).to('cuda')
    scores_tensor = torch.tensor(scores, dtype=torch.float32).to('cuda')

    self.model_z3.train()
    for epoch in tqdm(range(self.epochs), desc="Training Model Z3"):
      self.optimizer_z3.zero_grad()

      predictions = self.model_z3(embeddings_tensor)

      loss = self.loss_fn_z3(predictions, scores_tensor)

      loss.backward()

      self.optimizer_z3.step()
      self.scheduler_z3.step()

      if epoch % 100 == 0:
          print(f"Epoch {epoch}/{self.epochs}, Loss: {loss.item()}")

    self.NN_Para.append(copy.deepcopy(self.model_z3.state_dict()))

    print("Training Model Z3 complete.")


  def train_model(self):

    self.restart_model()
    self.model.to('cuda')

    embeddings = [item[0][:6144] for item in self.embedding_score_trace]
    scores = [item[1] for item in self.embedding_score_trace]

    embeddings_tensor = torch.stack([torch.tensor(e, dtype=torch.float32) for e in embeddings]).to('cuda')
    scores_tensor = torch.tensor(scores, dtype=torch.float32).to('cuda')

    self.model.train()
    for epoch in tqdm(range(self.epochs), desc="Training Model"):
      self.optimizer.zero_grad()

      predictions = self.model(embeddings_tensor)

      loss = self.loss_fn(predictions, scores_tensor)

      loss.backward()

      self.optimizer.step()
      self.scheduler.step()

      if epoch % 100 == 0:
          print(f"Epoch {epoch}/{self.epochs}, Loss: {loss.item()}")

    print("Training complete.")



  def select_z3(self, iteration):
    prompt_embeddings = self.prompt_embedding.Prompt_Embedding_List
    num_embeddings = len(prompt_embeddings)
    num_models = len(self.NN_Para)

    self.historical_score = [torch.zeros(num_embeddings, dtype=torch.float32) for _ in range(num_models)]

    for model_idx in tqdm(range(num_models), desc="Evaluating models Z3"):
      self.model_z3.load_state_dict(self.NN_Para[model_idx])
      self.model_z3.eval()
      embeddings_tensor = torch.stack(
          [torch.tensor(e, dtype=torch.float32) for e in prompt_embeddings]
      ).to('cuda')

      with torch.no_grad():
          scores = self.model_z3(embeddings_tensor).squeeze(-1).cpu()

      if model_idx == 0:
          self.c_z3 = scores.max().item()

      shifted_scores = (-scores + self.c_z3) / self.c_z3

      self.historical_score[model_idx] = shifted_scores

    total_scores = torch.sum(torch.stack(self.historical_score), dim=0)
    print("The predicted acumulative scores: ", total_scores)

    max_log_score = torch.max(self.eta_z3 * total_scores)
    safe_log_exp_scores = self.eta_z3 * total_scores - max_log_score
    exp_scores = torch.exp(safe_log_exp_scores)

    softmax_scores = exp_scores / torch.sum(exp_scores)

    print("The number of Embeeding Vectors and models are: %d and %d" % (num_embeddings, num_models))

    top_scores, top_indices = torch.topk(softmax_scores, 5)
    print("Top 5 scores: ", top_scores)
    print("Indices of top 5 scores: ", top_indices)

    best_index = torch.multinomial(softmax_scores, 1).item()
    print("Selected index: ", best_index)
    print("Selected Mask: ", self.prompt_embedding.RequestLLM.Template.mask[best_index])

    return best_index


  def select(self, iteration):
    if iteration == 0:
      embeddings_tensor = torch.stack(
        [torch.tensor(e, dtype=torch.float32) for e in self.prompt_embedding.z1_z2_prompt['PromptEmbedding']]
      ).to('cuda')
      with torch.no_grad():
        initial_scores = self.model(embeddings_tensor).squeeze(-1).cpu()
      self.c = initial_scores.max().item()

    print(f"Shift value (self.c): {self.c}")

    embeddings_tensor = torch.stack(
      [torch.tensor(e, dtype=torch.float32) for e in self.prompt_embedding.z1_z2_prompt['PromptEmbedding']]
    ).to('cuda')
    with torch.no_grad():
      scores = self.model(embeddings_tensor).squeeze(-1).cpu()

    shifted_scores = (-scores + self.c) / self.c
    self.z1_z2_prompt_data['Score'] = shifted_scores.tolist()
    print(self.z1_z2_prompt_data['Score'])

    self.prompt_score_trace.append(self.z1_z2_prompt_data['Score'])
    score_tensor = torch.tensor(self.prompt_score_trace, dtype=torch.float32)

    cumulative_scores = torch.sum(score_tensor, dim=0)
    print("The cumulative scores: ", cumulative_scores)

    max_log_score = torch.max(self.eta * cumulative_scores)
    safe_log_exp_scores = self.eta * cumulative_scores - max_log_score
    exp_scores = torch.exp(safe_log_exp_scores)
    total_score = torch.sum(exp_scores)
    normalized_exp_scores = exp_scores / total_score

    top_scores, top_indices = torch.topk(normalized_exp_scores, 5)
    print("Top 5 scores: ", top_scores)
    print("Indices of top 5 scores: ", top_indices)

    if not torch.isnan(normalized_exp_scores).any() and not torch.isinf(normalized_exp_scores).any():
      best_index = torch.multinomial(normalized_exp_scores, 1).item()
      self.best_index = best_index
    else:
      print("Encountered NaN or Inf in normalized_exp_scores. Stopping EXP3.")
      self.Stop_EXP3 = 1
      return

    print("The index of selected prompt is: %d" % self.best_index)
    print("The selected prompt's cumulative score: ", cumulative_scores[self.best_index])

    best_prompt = self.z1_z2_prompt_data.iloc[self.best_index]['Merged_Content']
    self.current_point = best_prompt
    self.current_z1 = self.z1_z2_prompt_data.iloc[self.best_index]['Z1']
    self.current_z2 = self.z1_z2_prompt_data.iloc[self.best_index]['Z2']
    self.current_z1_z2_Embedding = self.z1_z2_prompt_data.iloc[self.best_index]['Z1-Z2-Embedding']


  def show_best(self):
    best_trace = self.prompt_embedding.RequestLLM.Template.samples[-1][0]
    best_value = self.prompt_embedding.RequestLLM.Template.samples[-1][1]

    best_trace_str = ','.join(map(str, best_trace))
    print(f"The Score of the Meta-Prompt: {self.prompt_embedding.RequestLLM.stable_score:d}")
    print(f"The Best Performance, trace: {best_trace_str}, value={best_value:d}")

  def optimization(self):
    for i in tqdm(range(self.steps), desc="Optimizing the Meta-Prompt " + str(self.steps) + " step"):
      forward_result = self.prompt_embedding.forward(self.current_point, self.current_z1, self.current_z2, self.current_z1_z2_Embedding)
      self.embedding_score_trace.append(forward_result)
      if not self.Stop_EXP3:
        self.train_model()
        self.select(i)
      self.train_model_z3()
      self.prompt_embedding.RequestLLM.Template.Generate_Subset()
      if len(self.prompt_embedding.RequestLLM.Template.subset_list) != 1:
        self.prompt_embedding.Get_Batch_Embedding_z3()
        best_index = self.select_z3(i)
        self.prompt_embedding.RequestLLM.Template.selected_samples = self.prompt_embedding.RequestLLM.Template.subset_list[best_index]
        self.prompt_embedding.RequestLLM.Template.Add_Padding()
        print(self.prompt_embedding.RequestLLM.Template.input_exemplars)

      self.show_best()
      self.best_result_list.append(self.prompt_embedding.RequestLLM.Template.samples[-1][1])
      print(self.best_result_list)


for i in range(3):
  new_Meta_Prompt_data = copy.deepcopy(Meta_Prompt_data)
  new_optimization = AdversarialBandit(new_Meta_Prompt_data, 20, steps=300)
  new_optimization.optimization()