import numpy as np
import matplotlib.pyplot as plt
import re
from tqdm.auto import tqdm
import copy

import torch
import openai

class OPRO4LinearRegressionData:
  def __init__(self, w, b):

    np.random.seed(3407)
    x = np.linspace(-1, 1, 50).reshape(-1, 1)
    noise = np.random.randn(50, 1)
    y = x * w + b + noise
    data=np.hstack((x, y))

    self.true_parameter = torch.tensor([w, b])
    self.whole_data = torch.tensor(data)
    self.data = torch.stack([self.whole_data[:, 0], torch.ones_like(self.whole_data[:, 0])], dim=1)
    self.label = self.whole_data[:, 1]

  def evaluator(self, generated_w, generated_b):
    inference_parameter = torch.tensor([generated_w, generated_b], dtype=torch.double)
    return torch.mean((torch.matmul(self.data, inference_parameter) - self.label)**2)

  def mat_for_data(self):
    plt.scatter(self.whole_data[:, 0], self.whole_data[:, 1], color='blue', label='x1 vs y')
    plt.xlabel('x1')
    plt.ylabel('y')
    plt.legend()
    plt.title('Random 2D Linear Regression Data')
    plt.show()

New_Linear_Regression = OPRO4LinearRegressionData(36, -1)
New_Linear_Regression.mat_for_data()

class Template_Padding:
  def __init__(self, Linear_Regression):
    self.exemplars_template = """
    input:
    w={}, b={}
    value:
    {}
    """
    self.Linear_Regression = Linear_Regression
    np.random.seed(1001)

    w_samples = np.random.uniform(10, 20, 5)
    b_samples = np.random.uniform(10, 20, 5)
    samples_evaluator = []
    for i in range(5):
      samples_evaluator.append(self.Linear_Regression.evaluator(w_samples[i], b_samples[i]).item())
    self.samples = list(zip(w_samples, b_samples, samples_evaluator))
    self.samples.sort(key=lambda x: x[2], reverse=True)

    self.input_exemplars = ""
    self.Add_Padding()

  def Add_Padding(self):
    self.samples.sort(key=lambda x: x[2], reverse=True)

    samples_copy = copy.deepcopy(self.samples)

    sample_set = list(set(samples_copy))
    sample_set.sort(key=lambda x: x[2], reverse=True)

    selected_samples = sample_set[-20:]

    self.input_exemplars = ""
    for sample in selected_samples:
        self.input_exemplars += self.exemplars_template.format(sample[0], sample[1], sample[2])

  def Add_Sample(self, generated_w, generated_b):
    self.samples.append((generated_w, generated_b, self.Linear_Regression.evaluator(generated_w, generated_b).item()))

  def Add_All(self, Meta_Prompt):
    return Meta_Prompt.format(self.input_exemplars) + "\nPlease return only the revised version of the following content without adding any additional information or explanation."

Meta_Prompt_data = None

class RequestLLM:
  def __init__(self, meta_prompt, New_Linear_Regression, model='GPT-3.5-turbo', steps=100, epoch=4):
    self.API_key = '[API_KEY]'
    self.model = model
    self.meta_prompt = meta_prompt
    self.steps = steps
    self.epoch = epoch
    self.Template = Template_Padding(New_Linear_Regression)

  def Prompt_LLM_Epoch(self, Meta_Prompt):
      print(self.Template.Add_All(Meta_Prompt))
      for i in tqdm(range(self.epoch), desc="Prompting LLM " + str(self.epoch) + " epochs: " + self.model):
          response = openai.ChatCompletion.create(
              model=self.model,
              messages=[
                  {"role": "user", "content": self.Template.Add_All(Meta_Prompt)}],
              temperature=1
          )

          content = response.choices[0].message['content'].strip()
          print(content)

          w_values = re.findall(r"w=([+-]?\d+(\.\d+)?)", content)
          b_values = re.findall(r"b=([+-]?\d+(\.\d+)?)", content)

          if w_values and b_values:
              w = float(w_values[-1][0])
              b = float(b_values[-1][0])

              self.Template.Add_Sample(w, b)

  def Prompt_LLM_Epoch_Temp_0(self, Meta_Prompt):
      response = openai.ChatCompletion.create(
          model=self.model,
          messages=[{"role": "user", "content": self.Template.Add_All(Meta_Prompt)}],
          temperature=0
      )

      content = response.choices[0].message['content'].strip()
      print(content)

      w_values = re.findall(r"w=([+-]?\d+(\.\d+)?)", content)
      b_values = re.findall(r"b=([+-]?\d+(\.\d+)?)", content)

      if w_values and b_values:
          w = float(w_values[-1][0])
          b = float(b_values[-1][0])

          self.Template.Add_Sample(w, b)

  def Prompt_LLM_Step(self, Meta_Prompt):
    result_lst = []
    openai.api_key = self.API_key
    for i in tqdm(range(self.steps), desc="Prompting LLM " + str(self.steps) + " steps: " + self.model):
      self.Prompt_LLM_Epoch_Temp_0(Meta_Prompt)
      self.Prompt_LLM_Epoch(Meta_Prompt)
      self.Template.Add_Padding()
      print("Step %d: The Best Performance, w=%f, b=%f, value=%f"%(i + 1, self.Template.samples[-1][0], self.Template.samples[-1][1], self.Template.samples[-1][2]))
      result_lst.append(self.Template.samples[-1][2])
    print(result_lst)


new_RequestLLM = RequestLLM(Meta_Prompt_data, New_Linear_Regression, model='gpt-3.5-turbo', steps=50, epoch=7)
# prompt_data = """Now you will help me minimize a function with two input variables w, b. I have some (w, b) pairs and the function values at those points. The pairs are arranged in descending order based on their function values, where lower values are better.

# {}

# Give me a new (w, b) pair that is different from all pairs above, and has a function value lower than any of the above. Do not write code. The output must end with a pair [w, b], where w and b are numerical values.
# """

prompt_data = """Now you will help me minimize a function with two input variables, w and b. I have some (w, b) pairs and the function values at those points. The pairs are arranged in descending order based on their function values, where smaller function values indicate better results. Therefore, although the pairs are listed from the highest function value to the lowest, the pair with the smallest function value is considered the most optimal.

{}

Give me a new (w, b) pair that is different from all pairs above, and has a function value lower than any of the above. Do not write code. The output must end with a pair [w, b], where w and b are numerical values.
"""

new_RequestLLM.Prompt_LLM_Step(prompt_data)