import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import re
from tqdm.auto import tqdm
import time
import copy
import ast
import random

import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset

import openai
import asyncio
from aiohttp import ClientSession
from tqdm.asyncio import tqdm as async_tqdm
import nest_asyncio
nest_asyncio.apply()


Meta_Prompt_data = pd.read_csv("PATH/DATA_FILE.csv")
tqdm.pandas()
Meta_Prompt_data['Z1-Z2-Embedding'] = Meta_Prompt_data['Z1-Z2-Embedding'].progress_apply(ast.literal_eval)
Meta_Prompt_data.head()

class OPRO4LinearRegressionData:
  def __init__(self, w, b):

    np.random.seed(3407)
    x = np.linspace(-1, 1, 50).reshape(-1, 1)
    noise = np.random.randn(50, 1)
    # noise = np.random.randn(50, 1)/10 # (2, 30)
    y = x * w + b + noise
    data=np.hstack((x, y))

    self.true_parameter = torch.tensor([w, b])
    self.whole_data = torch.tensor(data)
    self.data = torch.stack([self.whole_data[:, 0], torch.ones_like(self.whole_data[:, 0])], dim=1)
    self.label = self.whole_data[:, 1]

  def evaluator(self, generated_w, generated_b):
    inference_parameter = torch.tensor([generated_w, generated_b], dtype=torch.float64)
    return torch.mean((torch.matmul(self.data, inference_parameter) - self.label)**2)

  def mat_for_data(self):
    plt.scatter(self.whole_data[:, 0], self.whole_data[:, 1], color='blue', label='x1 vs y')
    plt.xlabel('x1')
    plt.ylabel('y')
    plt.legend()
    plt.title('Random 2D Linear Regression Data')
    plt.show()

New_Linear_Regression = OPRO4LinearRegressionData(36, -1)
New_Linear_Regression.mat_for_data()

class Template_Padding:
  def __init__(self, Linear_Regression):
    self.exemplars_template = """
    input:
    w={}, b={}
    value:
    {}
    """
    self.Linear_Regression = Linear_Regression
    np.random.seed(1001)

    w_samples = np.random.uniform(10, 20, 5)
    b_samples = np.random.uniform(10, 20, 5)
    samples_evaluator = []
    for i in range(5):
      samples_evaluator.append(self.Linear_Regression.evaluator(w_samples[i], b_samples[i]).item())
    self.samples = list(zip(w_samples, b_samples, samples_evaluator))
    self.samples.sort(key=lambda x: x[2], reverse=True)

    self.input_exemplars = ""
    self.Add_Padding()

  def Add_Padding(self):
    self.samples.sort(key=lambda x: x[2], reverse=True)

    samples_copy = copy.deepcopy(self.samples)

    sample_set = list(set(samples_copy))
    sample_set.sort(key=lambda x: x[2], reverse=True)

    selected_samples = sample_set[-20:]

    self.input_exemplars = ""
    for sample in selected_samples:
        self.input_exemplars += self.exemplars_template.format(sample[0], sample[1], sample[2])

  def Add_Sample(self, generated_w, generated_b):
    self.samples.append((generated_w, generated_b, self.Linear_Regression.evaluator(generated_w, generated_b).item()))

  def Add_All(self, Meta_Prompt):
    return Meta_Prompt.format(self.input_exemplars) + "\nPlease return only the revised version of the following content without adding any additional information or explanation."


class RequestLLM:
  def __init__(self, New_Linear_Regression, model='gpt-3.5-turbo', steps=100, epoch=7):
    self.API_key = '[API_KEY]'
    self.model = model
    self.steps = steps
    self.epoch = epoch
    self.Template = Template_Padding(New_Linear_Regression)
    self.latest_batch_score = 0
    self.stable_score = 0

  async def fetch_response(self, session, Meta_Prompt, temperature):
    if temperature==0:
      print(self.Template.Add_All(Meta_Prompt))
    url = "https://api.openai.com/v1/chat/completions"
    payload = {
      "model": self.model,
      "messages": [{"role": "user", "content": self.Template.Add_All(Meta_Prompt)}],
      "temperature": temperature,
    }
    headers = {
      "Authorization": f"Bearer {self.API_key}",
      "Content-Type": "application/json",
    }

    for retry in range(5):
      try:
        async with session.post(url, json=payload, headers=headers) as response:
          if response.status == 429:
            wait_time = 2 ** retry
            print(f"Rate limit exceeded. Retrying in {wait_time} seconds...")
            await asyncio.sleep(wait_time)
            continue
          if response.status != 200:
            print(f"Error: {response.status}, {await response.text()}")
            return None
          return await response.json()
      except Exception as e:
        print(f"Request error: {e}")
        await asyncio.sleep(5)
    print("Max retries exceeded.")
    return None

  async def process_single_epoch(self, Meta_Prompt, session, temperature=1):
    while True:
      response = await self.fetch_response(session, Meta_Prompt, temperature)
      if not response or "choices" not in response or "message" not in response["choices"][0]:
        print("Invalid or empty response from API. Retrying...")
        continue

      content = response["choices"][0]["message"]["content"].strip()
      print(f"Generated content (Temp={temperature}): {content}")

      pattern = r"w=([+-]?\d+(\.\d+)?),?\s*b=([+-]?\d+(\.\d+)?)|\[\s*([+-]?\d+(\.\d+)?),\s*([+-]?\d+(\.\d+)?)\s*\]"
      matches = re.findall(pattern, content)

      if matches:
        last_match = matches[-1]
        if last_match[0] and last_match[2]:
          w = float(last_match[0])
          b = float(last_match[2])
        elif last_match[4] and last_match[6]:
          w = float(last_match[4])
          b = float(last_match[6])
        return w, b

      print("No valid match found in response. Retrying...")


  async def Prompt_LLM_Epoch(self, Meta_Prompt, temperature=1):
    async with ClientSession() as session:
      tasks = [self.process_single_epoch(Meta_Prompt, session, temperature) for _ in range(self.epoch)]
      results = await asyncio.gather(*tasks)
      request_success = 0

      for result in results:
        if result:
          w, b = result
          self.Template.Add_Sample(w, b)
          request_success += 1
      print(f"Successfully processed {request_success} requests.")
      self.__Get_Batch_Score__(request_success + 1)

  async def Prompt_LLM_Zero_Temp(self, Meta_Prompt):
    success = False
    async with ClientSession() as session:
      while not success:
        response = await self.process_single_epoch(Meta_Prompt, session, temperature=0)
        if response:
          w, b = response
          self.stable_score = self.Template.Linear_Regression.evaluator(w, b)
          self.Template.Add_Sample(w, b)
          print(f"The stable score is: {self.stable_score}")
          success = True
        else:
          print("Invalid response, retrying...")

  async def Prompt_LLM_Step(self, Meta_Prompt):
    await self.Prompt_LLM_Zero_Temp(Meta_Prompt)
    for i in async_tqdm(range(self.steps), desc="Prompting LLM Steps"):
      await self.Prompt_LLM_Epoch(Meta_Prompt, temperature=1)
      self.Template.Add_Padding()
      print(f"Step {i + 1}: The latest average score of Prompts is {self.latest_batch_score:.6f}")
      best_w, best_b, best_value = self.Template.samples[-1]
      print(f"Step {i + 1}: Best w={best_w:.6f}, b={best_b:.6f}, value={best_value:.6f}")

  def __Get_Batch_Score__(self, n):
    last_n_values = [sample[2] for sample in self.Template.samples[-n:]]
    print(f"The number of the new: {n}")
    self.latest_batch_score = sum(last_n_values) / len(last_n_values)

class Network(nn.Module):
  def __init__(self, input_dim, hidden_size=100, depth=1, init_params=None):
    super(Network, self).__init__()

    self.activate = nn.ReLU()
    self.layer_list = nn.ModuleList()
    self.layer_list.append(nn.Linear(input_dim, hidden_size))
    for i in range(depth-1):
      self.layer_list.append(nn.Linear(hidden_size, hidden_size))
    self.layer_list.append(nn.Linear(hidden_size, 1))

    if init_params is None:
      for i in range(len(self.layer_list)):
        torch.nn.init.normal_(self.layer_list[i].weight, mean=0, std=1.0)
        torch.nn.init.normal_(self.layer_list[i].bias, mean=0, std=1.0)
    else:
      for i in range(len(self.layer_list)):
        self.layer_list[i].weight.data = init_params[i*2]
        self.layer_list[i].bias.data = init_params[i*2+1]

  def forward(self, x):
    y = x
    for i in range(len(self.layer_list)-1):
      y = self.activate(self.layer_list[i](y))
    y = self.layer_list[-1]( y)
    y = y.squeeze(-1)
    return y

class PromptEmbedding:
  def __init__(self, z1_z2_prompt_data, New_Linear_Regression, Optimizer_Model='gpt-3.5-turbo', Embedding_model='text-embedding-3-large'):
    self.API_key = '[API_KEY]'
    self.z1_z2_prompt = z1_z2_prompt_data
    self.Embedding_model = Embedding_model
    self.RequestLLM = RequestLLM(New_Linear_Regression, model='gpt-3.5-turbo', steps=1, epoch=7)
    # self.RequestLLM = RequestLLM(New_Linear_Regression, model='gpt-4-turbo', steps=1, epoch=7)

  def normalize_l2(self, x):
    x = np.array(x)
    norm = np.linalg.norm(x)
    if norm == 0:
      return x
    return x / norm

  def Get_Single_Embedding(self, text, target_dim=3072):
    openai.api_key = self.API_key
    text = text.replace("\n", " ")

    while True:
      try:
        time.sleep(random.uniform(0.5, 1.5))

        response = openai.Embedding.create(
          input=[text],
          model=self.Embedding_model
        )
        embedding = response['data'][0]['embedding']

        truncated_embedding = embedding[:target_dim]

        normalized_embedding = list(self.normalize_l2(truncated_embedding))

        return normalized_embedding

      except openai.error.RateLimitError:
        print("Rate limit exceeded. Waiting before retrying...")
        time.sleep(10)

      except Exception as e:
        print(f"An error occurred: {e}")
        print("Retrying...")
        time.sleep(5)

  def PaddingforPrompt(self):
    tqdm.pandas()
    self.z1_z2_prompt['Completed_Prompt'] = self.z1_z2_prompt['Merged_Content'].progress_apply(lambda x: self.RequestLLM.Template.Add_All(x))

  def Get_Batch_Embedding(self, exemplars_embedding=None, column='PromptEmbedding'):
    openai.api_key = self.API_key
    self.PaddingforPrompt()
    tqdm.pandas()
    # self.z1_z2_prompt[column] = self.z1_z2_prompt['Z1-Z2-Embedding'].progress_apply(
    #   lambda x: x + exemplars_embedding
    # )
    self.z1_z2_prompt[column] = self.z1_z2_prompt['Z1-Z2-Embedding'].progress_apply(
      lambda x: x
    )

  def forward(self, current_point, current_z1, current_z2, z1_z2_embedding):
    # exemplars_embedding = self.Get_Single_Embedding(self.RequestLLM.Template.input_exemplars)
    self.Get_Batch_Embedding()
    asyncio.run(self.RequestLLM.Prompt_LLM_Step(current_point))
    current_point_score = self.RequestLLM.stable_score
    current_point_embedding = z1_z2_embedding
    # concatenated_embedding = current_point_embedding + exemplars_embedding
    concatenated_embedding = current_point_embedding
    return [concatenated_embedding, current_point_score]

class AdversarialBandit:
  def __init__(self, z1_z2_prompt_data, New_Linear_Regression, input_dim=6144, steps=50, NumberofK=10201):
    self.z1_z2_prompt_data = z1_z2_prompt_data
    self.input_dim = input_dim
    self.steps = steps
    self.NumberofK = NumberofK
    self.device = torch.device("cuda")

    self.model = Network(6144, 1536, 1).to('cuda').double()
    # self.model = Network(6144, 1536, 1).to('cuda').float()

    self.lamdba = 0.1
    self.nu = 1

    tkwargs = {
      "dtype": torch.float32,
      "device": torch.device("cuda")
    }
    self.func = Network(self.input_dim).to(**tkwargs)
    self.total_param = sum(p.numel() for p in self.model.parameters() if p.requires_grad)

    self.best_result_list = []

    self.lr = 1e-3
    self.epochs = 5000
    self.optimizer_fn = torch.optim.Adam
    self.optimizer = self.optimizer_fn(self.model.parameters(), lr=self.lr, weight_decay=1)
    self.loss_fn = nn.MSELoss()

    self.init_model_weight = copy.deepcopy(self.model.state_dict())


    self.eta = 100
    self.prompt_score_trace = []
    self.embedding_score_trace = []

    self.current_point = """Now you will help me minimize a function with two input variables, w and b. I have some (w, b) pairs and the function values at those points. The pairs are arranged in descending order based on their function values, where smaller function values indicate better results. Therefore, although the pairs are listed from the highest function value to the lowest, the pair with the smallest function value is considered the most optimal.

{}

Give me a new (w, b) pair that is different from all pairs above, and has a function value lower than any of the above. Do not write code. The output must end with a pair [w, b], where w and b are numerical values.
    """
    self.current_z1 = """Now you will help me minimize a function with two input variables, w and b. I have some (w, b) pairs and the function values at those points. The pairs are arranged in descending order based on their function values, where smaller function values indicate better results. Therefore, although the pairs are listed from the highest function value to the lowest, the pair with the smallest function value is considered the most optimal."""
    self.current_z2 = """Give me a new (w, b) pair that is different from all pairs above, and has a function value lower than any of the above. Do not write code. The output must end with a pair [w, b], where w and b are numerical values."""

    self.prompt_embedding = PromptEmbedding(z1_z2_prompt_data, New_Linear_Regression, Optimizer_Model='gpt-3.5-turbo', Embedding_model='text-embedding-3-large')

    self.z1_z2_embedding = self.prompt_embedding.Get_Single_Embedding(self.current_z1) + self.prompt_embedding.Get_Single_Embedding(self.current_z2)

    self.c = None

    self.scheduler = None

    self.Stop_EXP3 = 0
    self.index_history = []


  def restart_model(self):
    self.model.load_state_dict(copy.deepcopy(self.init_model_weight))
    self.optimizer = self.optimizer_fn(self.model.parameters(), lr=self.lr, weight_decay=1/self.embedding_score_trace.__len__())
    self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda epoch: 0.1 if epoch > 3000 else 1.0)


  def train_model(self, batch_size=64):
    self.restart_model()
    self.model.to('cuda')

    embeddings = [item[0] for item in self.embedding_score_trace]
    scores = [item[1] for item in self.embedding_score_trace]

    # embeddings_tensor = torch.stack([torch.tensor(e, dtype=torch.float32) for e in embeddings])
    # scores_tensor = torch.tensor(scores, dtype=torch.float32)
    embeddings_tensor = torch.stack([torch.tensor(e, dtype=torch.float64) for e in embeddings])
    scores_tensor = torch.tensor(scores, dtype=torch.float64)

    dataset = TensorDataset(embeddings_tensor, scores_tensor)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

    self.model.train()
    for epoch in tqdm(range(self.epochs), desc="Training Model"):
      epoch_loss = 0
      for batch_embeddings, batch_scores in dataloader:
        batch_embeddings = batch_embeddings.to('cuda')
        batch_scores = batch_scores.to('cuda')

        self.optimizer.zero_grad()

        predictions = self.model(batch_embeddings)
        loss = self.loss_fn(predictions, batch_scores)

        loss.backward()
        self.optimizer.step()

        epoch_loss += loss.item()

      self.scheduler.step()

      if epoch % 100 == 0:
        print(f"Epoch {epoch}/{self.epochs}, Loss: {epoch_loss / len(dataloader)}")

    print("Training complete.")


  def calculate_gradient(self, x):
    input_tensor = torch.tensor(x, dtype=torch.float32, requires_grad=True).to('cuda')

    output = self.model(input_tensor)

    sum_output = torch.sum(output)

    self.model.zero_grad()
    sum_output.backward()

    gradients = torch.cat([p.grad.flatten() for p in self.model.parameters()])

    return gradients.detach().cpu()


  def NeuralUCBSelect(self):
    tqdm.pandas()
    if not hasattr(self, 'U'):
      self.U = self.lamdba * torch.ones(self.total_param).to('cuda')

    def calculate_ucb(row):
      score = -self.model(torch.tensor(row['PromptEmbedding'], dtype=torch.float32).to('cuda')).item()

      gradient = self.calculate_gradient(row['PromptEmbedding']).to('cuda')

      sigma = torch.sqrt(torch.sum(self.lamdba * gradient ** 2 / self.U))
      ucb_score = score + self.nu * sigma.item()

      return ucb_score

    self.z1_z2_prompt_data['UCB_Score'] = self.z1_z2_prompt_data.progress_apply(calculate_ucb, axis=1)

    best_index = torch.argmax(torch.tensor(self.z1_z2_prompt_data['UCB_Score'].tolist())).item()
    best_prompt = self.z1_z2_prompt_data.iloc[best_index]['Merged_Content']
    self.current_point = best_prompt
    self.current_z1 = self.z1_z2_prompt_data.iloc[best_index]['Z1']
    self.current_z2 = self.z1_z2_prompt_data.iloc[best_index]['Z2']
    self.z1_z2_embedding = self.z1_z2_prompt_data.iloc[best_index]['Z1-Z2-Embedding']

    print(f"Selected prompt index: {best_index}, with UCB score: {self.z1_z2_prompt_data.iloc[best_index]['UCB_Score']}")

    selected_gradient = self.calculate_gradient(self.z1_z2_prompt_data.iloc[best_index]['PromptEmbedding']).to('cuda')
    outer_product_diag = selected_gradient ** 2
    self.U += outer_product_diag


  def select(self, iteration):
    if iteration == 0:
      initial_scores = self.prompt_embedding.z1_z2_prompt['PromptEmbedding'].swifter.apply(
        lambda x: self.model(torch.tensor(x, dtype=torch.float64).to('cuda')).item()
      )

      self.c = initial_scores.max()

    print(f"Shift value (self.c): {self.c}")

    self.z1_z2_prompt_data['Score'] = self.prompt_embedding.z1_z2_prompt['PromptEmbedding'].apply(
        lambda x: (-self.model(torch.tensor(x, dtype=torch.float64).to('cuda')).item() + self.c) / self.c
    )

    print(self.z1_z2_prompt_data['Score'].tolist())
    self.prompt_score_trace.append(self.z1_z2_prompt_data['Score'].tolist())
    score_tensor = torch.tensor(self.prompt_score_trace, dtype=torch.float64)

    log_exp_scores = self.eta * torch.sum(score_tensor, dim=0)
    max_log_score = torch.max(log_exp_scores)
    safe_log_exp_scores = log_exp_scores - max_log_score
    exp_scores = torch.exp(safe_log_exp_scores)

    total_score = torch.sum(exp_scores)
    normalized_exp_scores = exp_scores / total_score

    top_scores, top_indices = torch.topk(normalized_exp_scores, 5)
    print("Top 5 scores: ", top_scores)
    print("Indices of top 5 scores: ", top_indices)


    if not torch.isnan(normalized_exp_scores).any() and not torch.isinf(normalized_exp_scores).any():
      best_index = torch.multinomial(normalized_exp_scores, 1).item()
      self.best_index = best_index
    else:
      self.Stop_EXP3 = 1
      return

    print("The index of selected prompt is: %d"%(self.best_index))
    print("The selected prompt's cumulative score: ", torch.sum(score_tensor, dim=0)[self.best_index])
    best_prompt = self.z1_z2_prompt_data.iloc[self.best_index]['Merged_Content']
    self.current_point = best_prompt
    self.current_z1 = self.z1_z2_prompt_data.iloc[self.best_index]['Z1']
    self.current_z2 = self.z1_z2_prompt_data.iloc[self.best_index]['Z2']
    self.z1_z2_embedding = self.z1_z2_prompt_data.iloc[self.best_index]['Z1-Z2-Embedding']



  def optimization(self):
    for i in tqdm(range(self.steps), desc="Optimizing the Meta-Prompt " + str(self.steps) + " step"):
      forward_result = self.prompt_embedding.forward(self.current_point, self.current_z1, self.current_z2, self.z1_z2_embedding)
      self.embedding_score_trace.append(forward_result)
      if not self.Stop_EXP3:
        self.train_model()
        # self.NeuralUCBSelect() # Neural UCB
        self.select(i)

      self.best_result_list.append(self.prompt_embedding.RequestLLM.Template.samples[-1][2])
      print(self.best_result_list)

result_history = []


for i in range(5):
  prompt_data = copy.deepcopy(Meta_Prompt_data)
  new_optimization = AdversarialBandit(prompt_data, New_Linear_Regression)
  new_optimization.optimization()
  result_history.append(new_optimization.best_result_list)
  print(result_history)