import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import re
from tqdm.auto import tqdm
import time
import copy
import ast
import random

import gurobipy as gp
from gurobipy import GRB

import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset

import openai

import asyncio
from aiohttp import ClientSession

import nest_asyncio
nest_asyncio.apply()

Meta_Prompt_data = pd.read_csv("PATH/DATA_FILE.csv")
tqdm.pandas()
Meta_Prompt_data['Z1-Z2-Embedding'] = Meta_Prompt_data['Z1-Z2-Embedding'].progress_apply(ast.literal_eval)
Meta_Prompt_data.head()

class OPRO4TSPData:
    def __init__(self, nodes_number, seed=3407):
      np.random.seed(seed)
      self.nodes = np.random.randint(-100, 101, size=(nodes_number, 2))
      self.nodes_number = nodes_number
      self.adj_matrix = self.compute_euclidean_distance_matrix()

      self.optimal_tour, self.optimal_distance = self.solve_tsp()

    def compute_euclidean_distance_matrix(self):
      num_locations = len(self.nodes)
      distances = np.zeros((num_locations, num_locations))
      for i in range(num_locations):
        for j in range(num_locations):
          distances[i, j] = np.linalg.norm(self.nodes[i] - self.nodes[j])
      return distances

    def evaluator(self, trace):
      total_distance = 0
      for i in range(len(trace) - 1):
        total_distance += self.adj_matrix[trace[i], trace[i + 1]]
      total_distance += self.adj_matrix[trace[-1], trace[0]]
      return int(total_distance)

    def solve_tsp(self):
      dist_matrix = self.adj_matrix
      n = self.nodes_number

      m = gp.Model()

      x = m.addVars(n, n, vtype=GRB.BINARY, name="x")

      m.setObjective(gp.quicksum(dist_matrix[i, j] * x[i, j] for i in range(n) for j in range(n)), GRB.MINIMIZE)

      m.addConstrs(gp.quicksum(x[i, j] for j in range(n) if i != j) == 1 for i in range(n))
      m.addConstrs(gp.quicksum(x[i, j] for i in range(n) if i != j) == 1 for j in range(n))

      m.addConstrs(x[i, i] == 0 for i in range(n))

      def subtour(eliminated_edges):
        unvisited = list(range(n))
        cycle = []
        while unvisited:
          thiscycle = []
          neighbors = [unvisited[0]]
          while neighbors:
            current = neighbors.pop()
            if current in unvisited:
              thiscycle.append(current)
              unvisited.remove(current)
              neighbors += [j for j in range(n) if eliminated_edges[current, j] > 0.5]
          if len(thiscycle) < len(cycle) or not cycle:
            cycle = thiscycle
        return cycle

      def add_subtour_constraints(m, solution):
        tour = subtour(solution)
        if len(tour) < n:
          m.addConstr(gp.quicksum(x[i, j] for i in tour for j in tour if i != j) <= len(tour) - 1)

      solution = None
      while True:
        m.optimize()
        solution = m.getAttr('x', x)
        tour = subtour(solution)
        if len(tour) == n:
          break
        add_subtour_constraints(m, solution)

      optimal_tour = subtour(solution)
      optimal_distance = self.evaluator(optimal_tour)
      print(f"Optimal tour: {optimal_tour}")
      print(f"Optimal distance: {optimal_distance}")

      return optimal_tour, optimal_distance

    def visualize_tsp(self, tour=None):
      nodes = self.nodes
      plt.figure(figsize=(8, 6))
      plt.scatter(nodes[:, 0], nodes[:, 1], color='blue', label='Nodes')
      for i, (x, y) in enumerate(nodes):
        plt.text(x + 2, y + 2, str(i), color="red", fontsize=12)

      if tour is not None:
        for i in range(len(tour) - 1):
          plt.plot([nodes[tour[i]][0], nodes[tour[i + 1]][0]], [nodes[tour[i]][1], nodes[tour[i + 1]][1]], 'r-')
        plt.plot([nodes[tour[-1]][0], nodes[tour[0]][0]], [nodes[tour[-1]][1], nodes[tour[0]][1]], 'r-')
        plt.title('TSP Solution with Optimal Tour')
      else:
        plt.title('TSP Nodes')

      plt.xlabel('X')
      plt.ylabel('Y')
      plt.legend()
      plt.grid(True)
      plt.show()

class Template_Padding:
  def __init__(self, nodes_number):
    self.exemplars_template = """
    <trace>{}</trace>
    length:
    {}
    """
    self.nodes_number = nodes_number
    self.TSP = OPRO4TSPData(nodes_number=self.nodes_number)
    np.random.seed(1001)

    random_solutions = [np.random.permutation(self.nodes_number).tolist() for _ in range(5)]

    for i in range(len(random_solutions)):
      zero_index = random_solutions[i].index(0)
      random_solutions[i] = random_solutions[i][zero_index:] + random_solutions[i][:zero_index]

    samples_evaluator = []
    for i in range(5):
      samples_evaluator.append(self.TSP.evaluator(random_solutions[i]))
    self.samples = list(zip(random_solutions, samples_evaluator))
    self.samples.sort(key=lambda x: x[1], reverse=True)

    self.input_exemplars = ""
    self.input_TSP = ""
    self.Add_Padding()
    self.Convert_TSP_Points()

  def Convert_TSP_Points(self):
    tsp_points = self.TSP.nodes
    formatted_points = ", ".join([f"({i}): ({x}, {y})" for i, (x, y) in enumerate(tsp_points)])
    self.input_TSP = formatted_points

  def Add_Padding(self):
    self.samples.sort(key=lambda x: x[1], reverse=True)

    samples_copy = copy.deepcopy(self.samples)

    sample_set = list(set((tuple(sample[0]), sample[1]) for sample in samples_copy))
    sample_set.sort(key=lambda x: x[1], reverse=True)

    selected_samples = sample_set[-20:]

    self.input_exemplars = ""
    for sample in selected_samples:
      formatted_solution = ','.join(map(str, list(sample[0])))
      self.input_exemplars += self.exemplars_template.format(formatted_solution, sample[1])

  def Add_Sample(self, generated_trace):
    self.samples.append((generated_trace, self.TSP.evaluator(generated_trace)))

  def Add_All(self, Meta_Prompt):
    return Meta_Prompt.format(self.input_TSP, self.input_exemplars)


class RequestLLM:
  def __init__(self, nodes_number, model='gpt-3.5-turbo', steps=100, epoch=8):
    self.API_key = '[API_KEY]'
    self.model = model
    self.steps = steps
    self.epoch = epoch
    self.nodes_number = nodes_number
    self.Template = Template_Padding(self.nodes_number)
    self.latest_batch_score = 0
    self.stable_score = 0
    self.count_time = 0

  async def fetch_response(self, session, url, payload):
    headers = {
      "Authorization": f"Bearer {self.API_key}",
      "Content-Type": "application/json",
    }

    for retry_count in range(5):
      try:
        async with session.post(url, json=payload, headers=headers) as response:
          if response.status == 429:
            wait_time = 10 * (2 ** retry_count)
            print(f"Rate limit exceeded. Retrying in {wait_time} seconds...")
            await asyncio.sleep(wait_time)
            continue
          if response.status != 200:
            error_text = await response.text()
            print(f"Error response: {error_text}")
            return None
          return await response.json()
      except Exception as e:
        print(f"Error occurred: {e}")
        await asyncio.sleep(30)

    raise RuntimeError("Exceeded maximum retry attempts.")

  async def generate_trace(self, Meta_Prompt, session, temperature=1):
    url = "https://api.openai.com/v1/chat/completions"
    payload = {
      "model": self.model,
      "messages": [{"role": "user", "content": self.Template.Add_All(Meta_Prompt)}],
      "temperature": temperature,
    }

    while True:
      response = await self.fetch_response(session, url, payload)
      if not response or "choices" not in response or "message" not in response["choices"][0]:
        print("Invalid or empty response from API. Retrying...")
        continue

      content = response["choices"][0]["message"]["content"].strip()
      print(f"Generated content (Temp={temperature}): {content}")

      trace_match = re.search(r"<trace>(.*?)</trace>", content)
      if not trace_match:
        print("Invalid trace format: Missing <trace> tags. Retrying...")
        continue

      trace_str = trace_match.group(1).strip()
      if not re.match(r"^\d+(,\s*\d+)*$", trace_str):
        print("Invalid trace format: Incorrect number sequence. Retrying...")
        continue

      trace_list = list(map(int, re.split(r",\s*", trace_str)))

      if len(trace_list) > self.nodes_number:
        trace_list = trace_list[:self.nodes_number]
        print(f"Trimmed trace to {self.nodes_number} elements: {trace_list}")

      if (
        len(trace_list) == self.nodes_number
        and len(set(trace_list)) == self.nodes_number
        and all(x in range(self.nodes_number) for x in trace_list)
      ):
        while trace_list[0] != 0:
            trace_list = trace_list[1:] + trace_list[:1]
        print(f"Valid trace: {trace_list}")
        return trace_list

      print("Invalid trace. Retrying...")

  async def Prompt_LLM_Epoch(self, Meta_Prompt, temperature=1):
    async with ClientSession() as session:
      tasks = [
        self.generate_trace(Meta_Prompt, session, temperature=temperature) for _ in range(self.epoch)
      ]
      results = await asyncio.gather(*tasks)
      for trace_list in results:
        if trace_list:
          self.Template.Add_Sample(trace_list)

  async def Prompt_LLM_Epoch_Temp_0(self, Meta_Prompt):
    print(self.Template.Add_All(Meta_Prompt))
    async with ClientSession() as session:
      while True:
        trace_list = await self.generate_trace(Meta_Prompt, session, temperature=0)
        if trace_list:
          self.stable_score = self.Template.TSP.evaluator(trace_list)
          self.Template.Add_Sample(trace_list)
          print(f"Temp=0 Trace: {trace_list}, Score: {self.stable_score}")
          return trace_list

  async def Prompt_LLM_Step(self, Meta_Prompt):
    self.count_time += 1
    if not (self.count_time % 5):
      print("Wait for 10 seconds ...")
      time.sleep(10)
    await self.Prompt_LLM_Epoch_Temp_0(Meta_Prompt)
    for i in range(self.steps):
      await self.Prompt_LLM_Epoch(Meta_Prompt, temperature=1)
      self.Template.Add_Padding()

      best_trace = self.Template.samples[-1][0]
      best_value = self.Template.samples[-1][1]

      best_trace_str = ','.join(map(str, best_trace))
      print(f"Step {i + 1}: The Score of the Meta-Prompt: {self.stable_score:d}")
      print(f"Step {i + 1}: The Best Performance, trace: {best_trace_str}, value={best_value:d}")


class Network(nn.Module):
  def __init__(self, input_dim, hidden_size=100, depth=1, init_params=None):
    super(Network, self).__init__()

    self.activate = nn.ReLU()
    self.layer_list = nn.ModuleList()
    self.layer_list.append(nn.Linear(input_dim, hidden_size))
    for i in range(depth-1):
      self.layer_list.append(nn.Linear(hidden_size, hidden_size))
    self.layer_list.append(nn.Linear(hidden_size, 1))

    if init_params is None:
      for i in range(len(self.layer_list)):
        torch.nn.init.normal_(self.layer_list[i].weight, mean=0, std=1.0)
        torch.nn.init.normal_(self.layer_list[i].bias, mean=0, std=1.0)
    else:
      for i in range(len(self.layer_list)):
        self.layer_list[i].weight.data = init_params[i*2]
        self.layer_list[i].bias.data = init_params[i*2+1]

  def forward(self, x):
    y = x
    for i in range(len(self.layer_list)-1):
      y = self.activate(self.layer_list[i](y))
    y = self.layer_list[-1](y)
    y = y.squeeze(-1)
    return y

class PromptEmbedding:
  def __init__(self, z1_z2_prompt_data, nodes_number, Optimizer_Model='gpt-3.5-turbo', Embedding_model='text-embedding-3-large'):
    self.API_key = '[API_KEY]'
    self.z1_z2_prompt = z1_z2_prompt_data
    self.Embedding_model = Embedding_model
    self.RequestLLM = RequestLLM(nodes_number, model='gpt-3.5-turbo', steps=1, epoch=7)
    # self.RequestLLM = RequestLLM(nodes_number, model='gpt-4-turbo', steps=1, epoch=7)

  def normalize_l2(self, x):
    x = np.array(x)
    norm = np.linalg.norm(x)
    if norm == 0:
      return x
    return x / norm

  def Get_Single_Embedding(self, text, target_dim=3072):
    openai.api_key = self.API_key
    text = text.replace("\n", " ")

    while True:
      try:
        time.sleep(random.uniform(0.5, 1.5))

        response = openai.Embedding.create(
          input=[text],
          model=self.Embedding_model
        )
        embedding = response['data'][0]['embedding']

        truncated_embedding = embedding[:target_dim]

        normalized_embedding = list(self.normalize_l2(truncated_embedding))

        return normalized_embedding

      except openai.error.RateLimitError:
        print("Rate limit exceeded. Waiting before retrying...")
        time.sleep(10)

      except Exception as e:
        print(f"An error occurred: {e}")
        print("Retrying...")
        time.sleep(5)

  def PaddingforPrompt(self):
    tqdm.pandas()
    self.z1_z2_prompt['Completed_Prompt'] = self.z1_z2_prompt['Merged_Content'].progress_apply(lambda x: self.RequestLLM.Template.Add_All(x))

  def Get_Batch_Embedding(self, column='PromptEmbedding'):
    openai.api_key = self.API_key
    self.PaddingforPrompt()
    exemplars_embedding = self.Get_Single_Embedding(self.RequestLLM.Template.input_exemplars)
    tqdm.pandas()
    # self.z1_z2_prompt[column] = self.z1_z2_prompt['Z1-Z2-Embedding'].progress_apply(
    #   lambda x: x + exemplars_embedding
    # )
    self.z1_z2_prompt[column] = self.z1_z2_prompt['Z1-Z2-Embedding'].progress_apply(
      lambda x: x
    )

  def forward(self, current_point, current_z1, current_z2, z1_z2_Embedding):
    self.Get_Batch_Embedding()
    # exemplars_embedding = self.Get_Single_Embedding(self.RequestLLM.Template.input_exemplars)
    asyncio.run(self.RequestLLM.Prompt_LLM_Step(current_point))
    current_point_score = self.RequestLLM.stable_score
    current_point_embedding = z1_z2_Embedding
    # concatenated_embedding = current_point_embedding + exemplars_embedding
    concatenated_embedding = current_point_embedding
    return [concatenated_embedding, current_point_score]

class AdversarialBandit:
  def __init__(self, z1_z2_prompt_data, nodes_number, input_dim=6144, steps=300, NumberofK=10201):
    self.z1_z2_prompt_data = z1_z2_prompt_data
    self.input_dim = input_dim
    self.steps = steps
    self.NumberofK = NumberofK
    self.device = torch.device("cuda")

    self.model = Network(6144, 1536, 1).to('cuda').float()

    self.lamdba = 0.1
    self.nu = 1

    tkwargs = {
      "dtype": torch.float32,
      "device": torch.device("cuda")
    }
    self.func = Network(self.input_dim).to(**tkwargs)
    self.total_param = sum(p.numel() for p in self.model.parameters() if p.requires_grad)

    self.best_result_list = []

    self.lr = 1e-3
    self.epochs = 5000
    self.optimizer_fn = torch.optim.Adam
    self.optimizer = self.optimizer_fn(self.model.parameters(), lr=self.lr, weight_decay=1)
    self.loss_fn = nn.MSELoss()

    self.init_model_weight = copy.deepcopy(self.model.state_dict())

    self.eta = 100
    self.prompt_score_trace = []
    self.embedding_score_trace = []

    self.current_point = """You are given a list of points with coordinates below:{}

Below are some previous traces and their lengths. The traces are arranged in descending order based on their lengths, where smaller lengths indicate better solutions. Therefore, the traces are listed from the largest length to the smallest, the trace with the smallest length is considered the most optimal.

{}

Give me a new trace that is different from all traces above, and has a length lower than any of the above. The trace should traverse all points exactly once. The trace should start with <trace> and end with </trace>.
    """
    self.current_z1 = """You are given a list of points with coordinates below:{}

Below are some previous traces and their lengths. The traces are arranged in descending order based on their lengths, where smaller lengths indicate better solutions. Therefore, the traces are listed from the largest length to the smallest, the trace with the smallest length is considered the most optimal."""
    self.current_z2 = """Give me a new trace that is different from all traces above, and has a length lower than any of the above. The trace should traverse all points exactly once. The trace should start with <trace> and end with </trace>."""

    self.prompt_embedding = PromptEmbedding(z1_z2_prompt_data, nodes_number, Optimizer_Model='gpt-3.5-turbo', Embedding_model='text-embedding-3-large')

    self.z1_z2_Embedding = self.prompt_embedding.Get_Single_Embedding(self.current_z1) + self.prompt_embedding.Get_Single_Embedding(self.current_z2)

    self.c = None
    self.scheduler = None
    self.Stop_EXP3 = 0
    self.index_history = []

  def restart_model(self):
    self.model.load_state_dict(copy.deepcopy(self.init_model_weight))
    self.optimizer = self.optimizer_fn(self.model.parameters(), lr=self.lr, weight_decay=1 / max(1, len(self.embedding_score_trace)))
    self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda epoch: 0.1 if epoch > 3000 else 1.0)

  def train_model(self, batch_size=64):
    self.restart_model()
    self.model.to('cuda')

    embeddings = [item[0] for item in self.embedding_score_trace]
    scores = [item[1] for item in self.embedding_score_trace]

    embeddings_tensor = torch.stack([torch.tensor(e, dtype=torch.float32) for e in embeddings])
    scores_tensor = torch.tensor(scores, dtype=torch.float32)

    dataset = TensorDataset(embeddings_tensor, scores_tensor)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

    self.model.train()
    for epoch in tqdm(range(self.epochs), desc="Training Model"):
      epoch_loss = 0
      for batch_embeddings, batch_scores in dataloader:
        batch_embeddings = batch_embeddings.to('cuda')
        batch_scores = batch_scores.to('cuda')

        self.optimizer.zero_grad()

        predictions = self.model(batch_embeddings)
        loss = self.loss_fn(predictions, batch_scores)

        loss.backward()
        self.optimizer.step()

        epoch_loss += loss.item()

      self.scheduler.step()

      if epoch % 100 == 0:
        print(f"Epoch {epoch}/{self.epochs}, Loss: {epoch_loss / len(dataloader)}")

    print("Training complete.")

  def select(self, iteration):
    embeddings = torch.stack(
      [torch.tensor(e, dtype=torch.float32) for e in self.prompt_embedding.z1_z2_prompt['PromptEmbedding']]
    ).to('cuda')

    with torch.no_grad():
      scores = self.model(embeddings).squeeze(-1)

    if iteration == 0:
      self.c = scores.max().item()

    print(f"Shift value (self.c): {self.c}")

    shifted_scores = (-scores + self.c) / self.c
    self.z1_z2_prompt_data['Score'] = shifted_scores.cpu().tolist()

    self.prompt_score_trace.append(self.z1_z2_prompt_data['Score'])
    score_tensor = torch.tensor(self.prompt_score_trace, dtype=torch.float32)

    log_exp_scores = self.eta * torch.sum(score_tensor, dim=0)
    max_log_score = torch.max(log_exp_scores)
    safe_log_exp_scores = log_exp_scores - max_log_score
    exp_scores = torch.exp(safe_log_exp_scores)

    total_score = torch.sum(exp_scores)
    normalized_exp_scores = exp_scores / total_score

    top_scores, top_indices = torch.topk(normalized_exp_scores, 10)
    print("Top 10 probabilities: ", top_scores)
    print("Indices of Top 10 probabilities: ", top_indices)

    if not torch.isnan(normalized_exp_scores).any() and not torch.isinf(normalized_exp_scores).any() and torch.sum(normalized_exp_scores) > 0:
      best_index = torch.multinomial(normalized_exp_scores, 1).item()
      self.best_index = best_index
    else:
      print(f"The final index is {self.best_index}")
      print("Encountered NaN or Inf in normalized_exp_scores. Stopping EXP3.")
      self.Stop_EXP3 = 1
      return

    print(f"Selected index: {self.best_index}")
    best_prompt = self.z1_z2_prompt_data.iloc[self.best_index]['Merged_Content']
    self.current_point = best_prompt
    self.current_z1 = self.z1_z2_prompt_data.iloc[self.best_index]['Z1']
    self.current_z2 = self.z1_z2_prompt_data.iloc[self.best_index]['Z2']
    self.z1_z2_Embedding = self.z1_z2_prompt_data.iloc[self.best_index]['Z1-Z2-Embedding']

    print("The shifted scores: ")
    print(shifted_scores)

  def calculate_gradient(self, x):
    input_tensor = torch.tensor(x, dtype=torch.float32, requires_grad=True).to('cuda')

    output = self.model(input_tensor)

    sum_output = torch.sum(output)

    self.model.zero_grad()
    sum_output.backward()

    gradients = torch.cat([p.grad.flatten() for p in self.model.parameters()])

    return gradients.detach().cpu()


  def NeuralUCBSelect(self):
    tqdm.pandas()
    if not hasattr(self, 'U'):
      self.U = self.lamdba * torch.ones(self.total_param).to('cuda')

    def calculate_ucb(row):
      score = -self.model(torch.tensor(row['PromptEmbedding'], dtype=torch.float32).to('cuda')).item()

      gradient = self.calculate_gradient(row['PromptEmbedding']).to('cuda')

      sigma = torch.sqrt(torch.sum(self.lamdba * gradient ** 2 / self.U))
      ucb_score = score + self.nu * sigma.item()

      return ucb_score

    self.z1_z2_prompt_data['UCB_Score'] = self.z1_z2_prompt_data.progress_apply(calculate_ucb, axis=1)

    best_index = torch.argmax(torch.tensor(self.z1_z2_prompt_data['UCB_Score'].tolist())).item()
    best_prompt = self.z1_z2_prompt_data.iloc[best_index]['Merged_Content']
    self.current_point = best_prompt
    self.current_z1 = self.z1_z2_prompt_data.iloc[best_index]['Z1']
    self.current_z2 = self.z1_z2_prompt_data.iloc[best_index]['Z2']
    self.z1_z2_embedding = self.z1_z2_prompt_data.iloc[best_index]['Z1-Z2-Embedding']

    print(f"Selected prompt index: {best_index}, with UCB score: {self.z1_z2_prompt_data.iloc[best_index]['UCB_Score']}")

    selected_gradient = self.calculate_gradient(self.z1_z2_prompt_data.iloc[best_index]['PromptEmbedding']).to('cuda')
    outer_product_diag = selected_gradient ** 2
    self.U += outer_product_diag

  def add_sample(self, forward_result):
    self.embedding_score_trace.append(forward_result)

    unique_samples = {}
    for embedding, score in self.embedding_score_trace:
      embedding_key = tuple(embedding)
      unique_samples[embedding_key] = (embedding, score)

    self.embedding_score_trace = list(unique_samples.values())


  def optimization(self):
    for i in tqdm(range(self.steps), desc="Optimizing the Meta-Prompt " + str(self.steps) + " step"):
      forward_result = self.prompt_embedding.forward(self.current_point, self.current_z1, self.current_z2, self.z1_z2_Embedding)
      self.embedding_score_trace.append(forward_result)
      print(len(self.embedding_score_trace))
      if not self.Stop_EXP3:
        self.train_model()
        self.select(i)
        # self.NeuralUCBSelect() # Neural UCB

      self.best_result_list.append(self.prompt_embedding.RequestLLM.Template.samples[-1][1])
      print(self.best_result_list)

Converge_curve = []

for i in range(3):
  new_Meta_Prompt_data = copy.deepcopy(Meta_Prompt_data)
  new_optimization = AdversarialBandit(new_Meta_Prompt_data, 20)
  new_optimization.optimization()
  Converge_curve.append(copy.deepcopy(new_optimization.best_result_list))
  print(Converge_curve)