import numpy as np
import matplotlib.pyplot as plt
import re
from tqdm.auto import tqdm
import copy

import gurobipy as gp
from gurobipy import GRB

import asyncio
from aiohttp import ClientSession

import nest_asyncio
nest_asyncio.apply()


class OPRO4TSPData:
    def __init__(self, nodes_number, seed=3407):
      np.random.seed(seed)
      self.nodes = np.random.randint(-100, 101, size=(nodes_number, 2))
      self.nodes_number = nodes_number
      self.adj_matrix = self.compute_euclidean_distance_matrix()

      if not nodes_number == 50:
        self.optimal_tour, self.optimal_distance = self.solve_tsp()

    def compute_euclidean_distance_matrix(self):
      num_locations = len(self.nodes)
      distances = np.zeros((num_locations, num_locations))
      for i in range(num_locations):
        for j in range(num_locations):
          distances[i, j] = np.linalg.norm(self.nodes[i] - self.nodes[j])
      return distances

    def evaluator(self, trace):
      total_distance = 0
      for i in range(len(trace) - 1):
        total_distance += self.adj_matrix[trace[i], trace[i + 1]]
      total_distance += self.adj_matrix[trace[-1], trace[0]]
      return int(total_distance)

    def solve_tsp(self):
      dist_matrix = self.adj_matrix
      n = self.nodes_number

      m = gp.Model()

      x = m.addVars(n, n, vtype=GRB.BINARY, name="x")

      m.setObjective(gp.quicksum(dist_matrix[i, j] * x[i, j] for i in range(n) for j in range(n)), GRB.MINIMIZE)

      m.addConstrs(gp.quicksum(x[i, j] for j in range(n) if i != j) == 1 for i in range(n))
      m.addConstrs(gp.quicksum(x[i, j] for i in range(n) if i != j) == 1 for j in range(n))

      m.addConstrs(x[i, i] == 0 for i in range(n))

      def subtour(eliminated_edges):
        unvisited = list(range(n))
        cycle = []
        while unvisited:
          thiscycle = []
          neighbors = [unvisited[0]]
          while neighbors:
            current = neighbors.pop()
            if current in unvisited:
              thiscycle.append(current)
              unvisited.remove(current)
              neighbors += [j for j in range(n) if eliminated_edges[current, j] > 0.5]
          if len(thiscycle) < len(cycle) or not cycle:
            cycle = thiscycle
        return cycle

      def add_subtour_constraints(m, solution):
        tour = subtour(solution)
        if len(tour) < n:
          m.addConstr(gp.quicksum(x[i, j] for i in tour for j in tour if i != j) <= len(tour) - 1)

      solution = None
      while True:
        m.optimize()
        solution = m.getAttr('x', x)
        tour = subtour(solution)
        if len(tour) == n:
          break
        add_subtour_constraints(m, solution)

      optimal_tour = subtour(solution)
      optimal_distance = self.evaluator(optimal_tour)
      print(f"Optimal tour: {optimal_tour}")
      print(f"Optimal distance: {optimal_distance}")

      return optimal_tour, optimal_distance

    def visualize_tsp(self, tour=None):
      nodes = self.nodes
      plt.figure(figsize=(8, 6))
      plt.scatter(nodes[:, 0], nodes[:, 1], color='blue', label='Nodes')
      for i, (x, y) in enumerate(nodes):
        plt.text(x + 2, y + 2, str(i), color="red", fontsize=12)

      if tour is not None:
        for i in range(len(tour) - 1):
          plt.plot([nodes[tour[i]][0], nodes[tour[i + 1]][0]], [nodes[tour[i]][1], nodes[tour[i + 1]][1]], 'r-')
        plt.plot([nodes[tour[-1]][0], nodes[tour[0]][0]], [nodes[tour[-1]][1], nodes[tour[0]][1]], 'r-')
        plt.title('TSP Solution with Optimal Tour')
      else:
        plt.title('TSP Nodes')

      plt.xlabel('X')
      plt.ylabel('Y')
      plt.legend()
      plt.grid(True)
      plt.show()

class Template_Padding:
  def __init__(self, nodes_number):
    self.exemplars_template = """
    <trace>{}</trace>
    length:
    {}
    """
    self.nodes_number = nodes_number
    self.TSP = OPRO4TSPData(nodes_number=self.nodes_number)
    np.random.seed(1001)

    random_solutions = [np.random.permutation(self.nodes_number).tolist() for _ in range(5)]

    for i in range(len(random_solutions)):
      zero_index = random_solutions[i].index(0)
      random_solutions[i] = random_solutions[i][zero_index:] + random_solutions[i][:zero_index]

    samples_evaluator = []
    for i in range(5):
      samples_evaluator.append(self.TSP.evaluator(random_solutions[i]))
    self.samples = list(zip(random_solutions, samples_evaluator))
    self.samples.sort(key=lambda x: x[1], reverse=True)

    self.input_exemplars = ""
    self.input_TSP = ""
    self.Add_Padding()
    self.Convert_TSP_Points()

  def Convert_TSP_Points(self):
    tsp_points = self.TSP.nodes
    formatted_points = ", ".join([f"({i}): ({x}, {y})" for i, (x, y) in enumerate(tsp_points)])
    self.input_TSP = formatted_points

  def Add_Padding(self):
    self.samples.sort(key=lambda x: x[1], reverse=True)

    samples_copy = copy.deepcopy(self.samples)

    sample_set = list(set((tuple(sample[0]), sample[1]) for sample in samples_copy))
    sample_set.sort(key=lambda x: x[1], reverse=True)

    selected_samples = sample_set[-20:]

    self.input_exemplars = ""
    for sample in selected_samples:
      formatted_solution = ','.join(map(str, list(sample[0])))
      self.input_exemplars += self.exemplars_template.format(formatted_solution, sample[1])

  def Add_Sample(self, generated_trace):
    self.samples.append((generated_trace, self.TSP.evaluator(generated_trace)))

  def Add_All(self, Meta_Prompt):
    return Meta_Prompt.format(self.input_TSP, self.input_exemplars)

class RequestLLM:
    def __init__(self, meta_prompt, nodes_number, model='gpt-3.5-turbo', steps=300, epoch=8):
        self.API_key = '[API_KEY]'
        self.model = model
        self.meta_prompt = meta_prompt
        self.steps = steps
        self.epoch = epoch
        self.nodes_number = nodes_number
        self.Template = Template_Padding(self.nodes_number)
        self.best_result_list = []

    async def fetch_response(self, session, url, payload):
        headers = {
            "Authorization": f"Bearer {self.API_key}",
            "Content-Type": "application/json",
        }
        for retry_count in range(5):
            try:
                async with session.post(url, json=payload, headers=headers) as response:
                    if response.status == 429:
                        wait_time = 2 ** retry_count
                        print(f"Rate limit exceeded. Retrying in {wait_time} seconds...")
                        await asyncio.sleep(wait_time)
                        continue
                    if response.status != 200:
                        error_text = await response.text()
                        print(f"Error response: {error_text}")
                        return None
                    return await response.json()
            except Exception as e:
                print(f"Error occurred: {e}. Retrying in 10 seconds...")
                await asyncio.sleep(10)
        raise RuntimeError("Exceeded maximum retry attempts.")

    async def generate_trace(self, Meta_Prompt, session, temperature=1):
        if temperature == 0:
          print(self.Template.Add_All(Meta_Prompt))
        url = "https://api.openai.com/v1/chat/completions"
        payload = {
            "model": self.model,
            "messages": [{"role": "user", "content": self.Template.Add_All(Meta_Prompt)}],
            "temperature": temperature,
        }
        while True:
            response = await self.fetch_response(session, url, payload)
            if not response or "choices" not in response or "message" not in response["choices"][0]:
                print("Invalid or empty response. Retrying...")
                continue

            content = response["choices"][0]["message"]["content"].strip()
            print(f"Generated content (Temp={temperature}): {content}")

            trace_match = re.search(r"<trace>(.*?)</trace>", content)
            if not trace_match:
                print("Invalid trace format: Missing <trace> tags. Retrying...")
                continue

            trace_str = trace_match.group(1).strip()
            if not re.match(r"^\d+(,\s*\d+)*$", trace_str):
                print("Invalid trace format: Incorrect number sequence. Retrying...")
                continue

            trace_list = list(map(int, re.split(r",\s*", trace_str)))
            if len(trace_list) == self.nodes_number and len(set(trace_list)) == self.nodes_number and all(x in range(self.nodes_number) for x in trace_list):
                while trace_list[0] != 0:
                    trace_list = trace_list[1:] + trace_list[:1]
                return trace_list

    async def Prompt_LLM_Epoch(self, Meta_Prompt):
        async with ClientSession() as session:
            tasks = [
                self.generate_trace(Meta_Prompt, session, temperature=1)
                for _ in range(self.epoch)
            ]
            results = await asyncio.gather(*tasks)
            for trace in results:
                if trace:
                    self.Template.Add_Sample(trace)

    async def Prompt_LLM_Epoch_Temp_0(self, Meta_Prompt):
        async with ClientSession() as session:
            while True:
                trace = await self.generate_trace(Meta_Prompt, session, temperature=0)
                if trace:
                    self.Template.Add_Sample(trace)
                    return trace

    async def Prompt_LLM_Step(self, Meta_Prompt):
        for i in tqdm(range(self.steps), desc=f"Prompting LLM ({self.steps} steps): {self.model}"):
            await self.Prompt_LLM_Epoch_Temp_0(Meta_Prompt)
            await self.Prompt_LLM_Epoch(Meta_Prompt)
            self.Template.Add_Padding()

            best_trace = self.Template.samples[-1][0]
            best_value = self.Template.samples[-1][1]
            best_trace_str = ','.join(map(str, best_trace))

            print(f"Step {i + 1}: The Best Performance, trace: {best_trace_str}, value={best_value:d}")
            self.best_result_list.append(best_value)
            print(self.best_result_list)

for i in range(3):
#   prompt_data = """You are given a list of points with coordinates below:{}

# Below are some previous traces and their lengths. The traces are arranged in descending order based on their lengths, where lower values are better.

# {}

# Give me a new trace that is different from all traces above, and has a length lower than any of the above. The trace should traverse all points exactly once. The trace should start with <trace> and end with </trace>.
# """

  prompt_data = """You are given a list of points with coordinates below:{}

Below are some previous traces and their lengths. The traces are arranged in descending order based on their lengths, where smaller lengths indicate better solutions. Therefore, the traces are listed from the largest length to the smallest, the trace with the smallest length is considered the most optimal.

{}

Give me a new trace that is different from all traces above, and has a length lower than any of the above. The trace should traverse all points exactly once. The trace should start with <trace> and end with </trace>.
"""
  Meta_Prompt_data = None
  new_RequestLLM = RequestLLM(Meta_Prompt_data, 20, model='gpt-3.5-turbo', steps=300, epoch=7)
#   new_RequestLLM = RequestLLM(Meta_Prompt_data, 20, model='gpt-4-turbo', steps=100, epoch=7)
  asyncio.run(new_RequestLLM.Prompt_LLM_Step(prompt_data))