import numpy as np
import pandas as pd
import re
from tqdm.auto import tqdm
import time
import copy
import ast
import random
from scipy.stats import bernoulli

import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset

import openai

Meta_Prompt_data = pd.read_csv("PATH/DATA_FILE.csv")
tqdm.pandas()
Meta_Prompt_data['Z1-Z2-Embedding'] = Meta_Prompt_data['Z1-Z2-Embedding'].progress_apply(ast.literal_eval)
Meta_Prompt_data.head()

def set_random_seed(seed=0):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  if torch.cuda.is_available():
      torch.cuda.manual_seed(seed)

  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False

class Multi_Armed_Bandit:
  def __init__(self, delta):
    self.delta = delta
    self.optima = 0.5 + self.delta/2
    self.suboptima = self.optima - self.delta
    self.button_color = ["blue", "green", "red", "yellow"]
    self.bandit = self.bernoulli()
    self.sample_list = []

  def bernoulli(self):
    return {color: bernoulli(self.optima if i == 0 else self.suboptima) for i, color in enumerate(self.button_color)}

  def sample(self, color):
    sample = self.bandit[color].rvs()
    self.sample_list.append((color, sample))
    return sample

  def count_and_mean(self):
    color_stats = {color: [0, 0] for color in self.button_color}
    for color, reward in self.sample_list:
        color_stats[color][0] += 1
        color_stats[color][1] += reward

    for color in color_stats:
        if color_stats[color][0] > 0:
            color_stats[color][1] /= color_stats[color][0]

    return color_stats

class Network(nn.Module):
  def __init__(self, input_dim, hidden_size=100, depth=1, init_params=None):
    super(Network, self).__init__()

    self.activate = nn.ReLU()
    self.layer_list = nn.ModuleList()
    self.layer_list.append(nn.Linear(input_dim, hidden_size))
    for i in range(depth-1):
      self.layer_list.append(nn.Linear(hidden_size, hidden_size))
    self.layer_list.append(nn.Linear(hidden_size, 1))

    if init_params is None:
      for i in range(len(self.layer_list)):
        torch.nn.init.normal_(self.layer_list[i].weight, mean=0, std=1.0)
        torch.nn.init.normal_(self.layer_list[i].bias, mean=0, std=1.0)
    else:
      for i in range(len(self.layer_list)):
        self.layer_list[i].weight.data = init_params[i*2]
        self.layer_list[i].bias.data = init_params[i*2+1]

  def forward(self, x):
    y = x
    for i in range(len(self.layer_list)-1):
      y = self.activate(self.layer_list[i](y))
    y = self.layer_list[-1](y)
    y = y.squeeze(-1)
    return y

class Template_Padding:
  def __init__(self, delta):
    self.Multi_Armed_Bandit = Multi_Armed_Bandit(delta)
    self.SYSTEM_template = """{}You must provide your final answer within the tags <Answer>DIST</Answer> where DIST is the distribution in the format specified above.\n"""
    self.SYSTEM = ""
    self.exemplar_template_not_none = """{} button: pressed {} times with average reward {}\n"""
    self.exemplar_template_none = """{} button: pressed {} times\n"""
    self.input_exemplars = None
    self.USER = """So far you have played {} times with your past choices and rewards summarized as follows:

{}
Which button will you choose next? Remember, YOU MUST provide your final answer within the tags <Answer>DIST</Answer> where DIST is formatted like "blue:a,green:b,red:c,yellow:d"."""
    self.Input_USER = ""
    self.padding_exemplar()

  def padding_exemplar(self, step=0):
    summary_dict = self.Multi_Armed_Bandit.count_and_mean()
    input_exemplars = ""
    for i in range(4):
      color = self.Multi_Armed_Bandit.button_color[i]
      if summary_dict[color][0] == 0:
        input_exemplars += self.exemplar_template_none.format(color, 0)
      else:
        input_exemplars += self.exemplar_template_not_none.format(color, summary_dict[color][0], summary_dict[color][1])
    self.input_exemplars = input_exemplars
    self.Input_USER = self.USER.format(step, self.input_exemplars)

  def padding_system(self, z1_z2):
    self.SYSTEM = self.SYSTEM_template.format(z1_z2)

class RequestLLM:
  def __init__(self, delta, model='gpt-4-turbo', steps=100):
    self.API_key = '[API_KEY]'
    self.model = model
    self.steps = steps
    self.Template = Template_Padding(delta)
    self.stable_score = 0
    self.Suffix_Failure_Frequency = []
    self.MinFrac_t = [] 
    self.Time_Averaged_Reward = []
    self.Time_Mu = []
    self.Time_Averaged_Mu = []
    self.cumulative_reward_sum = 0
    self.cumulative_mu_sum = 0
    self.fract_best_count = 0
    self.frac_best = []
    self.regret = []


  def Prompt_LLM_Epoch(self, step):
    openai.api_key = self.API_key
    valid_distribution = False
    print("[SYSTEM]" + self.Template.SYSTEM)
    print("[USER]" + self.Template.Input_USER + "\n")

    while not valid_distribution:
      response = openai.ChatCompletion.create(
        model=self.model,
        messages=[
          {"role": "system", "content": self.Template.SYSTEM},
          {"role": "user", "content": self.Template.Input_USER}
        ],
        temperature=0
      )

      response_content = response.choices[0].message['content'].strip()
      print(response_content)

      matches = re.findall(r'\b(blue|green|red|yellow):([0-9]+(?:\.[0-9]*)?)\b', response_content)

      colors_order = ['blue', 'green', 'red', 'yellow']
      color_prob_dict = {color: 0.0 for color in colors_order}

      for color, probability in matches:
        color_prob_dict[color] = float(probability)

      colors = list(color_prob_dict.keys())
      probabilities = list(color_prob_dict.values())

      total_prob_sum = sum(probabilities)
      if total_prob_sum > 0:
        probabilities = [p / total_prob_sum for p in probabilities]
        color_prob_dict = {color: prob / total_prob_sum for color, prob in color_prob_dict.items()}

      if np.isclose(sum(probabilities), 1.0) and len(probabilities) == 4:
        valid_distribution = True

        index = np.random.choice(len(probabilities), p=probabilities)

        color_stats = self.Template.Multi_Armed_Bandit.count_and_mean()
        expected_reward = sum(color_prob_dict[color] * color_stats[color][1] for color in self.Template.Multi_Armed_Bandit.button_color)

        selected_color = colors[index]
        direct_reward = self.Template.Multi_Armed_Bandit.sample(selected_color)
        print(f"The selected button : {selected_color}, reward: {direct_reward}")

        self.Time_Mu.append(self.Template.Multi_Armed_Bandit.optima if selected_color == "blue" else self.Template.Multi_Armed_Bandit.suboptima)

        self.Template.padding_exemplar(step + 1)
        self.cumulative_reward_sum += direct_reward
        self.cumulative_mu_sum += self.Template.Multi_Armed_Bandit.optima if selected_color == "blue" else self.Template.Multi_Armed_Bandit.suboptima
        self.fract_best_count += 1 if selected_color == "blue" else 0
        self.regret.append(0 if selected_color == "blue" else 0.5)

        time_averaged_reward = self.cumulative_reward_sum / (step + 1)
        time_averaged_mu = self.cumulative_mu_sum / (step + 1)
        print(f"Averaged reward : {time_averaged_reward}")
        print(f"Averaged mu : {time_averaged_mu}")
        accumulative_regret = sum(np.array(self.regret))
        print(f"Accumulative Regret: {accumulative_regret}")
        self.Time_Averaged_Reward.append(time_averaged_reward)
        self.Time_Averaged_Mu.append(time_averaged_mu)
        self.frac_best.append(self.fract_best_count / (step + 1))

        print(f"\nExpected reward based on distribution: {expected_reward}\n")

        return expected_reward
      else:
        print("Invalid distribution. Please try again.")

  def calculate_median_reward(self):
    rewards = [reward for _, reward in self.Template.Multi_Armed_Bandit.sample_list]
    print(f"MedianReward : {np.median(rewards)}")
    return np.median(rewards)

  def calculate_suff_fail_freq(self, optimal_button):
    failures = sum(1 for i, (color, _) in enumerate(self.Template.Multi_Armed_Bandit.sample_list[:self.steps//2]) if color != optimal_button)
    print(f"\nSuffFailFreq(T/2) : {failures / (self.steps / 2)}")
    self.Suffix_Failure_Frequency.append(failures / (self.steps / 2))

  def calculate_k_min_frac(self):
    color_counts = {color: 0 for color in self.Template.Multi_Armed_Bandit.button_color}
    for color, _ in self.Template.Multi_Armed_Bandit.sample_list:
        color_counts[color] += 1
    min_count = min(color_counts.values())
    min_frac = min_count / len(self.Template.Multi_Armed_Bandit.sample_list)
    print(f"K*MinFrac : {len(self.Template.Multi_Armed_Bandit.button_color) * min_frac}")
    self.MinFrac_t.append(len(self.Template.Multi_Armed_Bandit.button_color) * min_frac)

  def calculate_greedy_frac(self):
    color_rewards = {color: 0 for color in self.Template.Multi_Armed_Bandit.button_color}
    color_counts = {color: 0 for color in self.Template.Multi_Armed_Bandit.button_color}

    for color, reward in self.Template.Multi_Armed_Bandit.sample_list:
        color_rewards[color] += reward
        color_counts[color] += 1

    avg_rewards = {color: (color_rewards[color] / color_counts[color] if color_counts[color] > 0 else 0)
                    for color in self.Template.Multi_Armed_Bandit.button_color}
    best_color = max(avg_rewards, key=avg_rewards.get)

    greedy_count = sum(1 for color, _ in self.Template.Multi_Armed_Bandit.sample_list if color == best_color)
    print(f"GreedyFrac : {greedy_count / len(self.Template.Multi_Armed_Bandit.sample_list)}")
    return greedy_count / len(self.Template.Multi_Armed_Bandit.sample_list)

class PromptEmbedding:
  def __init__(self, z1_z2_prompt_data, delta, Optimizer_Model='gpt-4-turbo', Embedding_model='text-embedding-3-large'):
    self.API_key = '[API_KEY]'
    self.z1_z2_prompt = z1_z2_prompt_data
    self.Embedding_model = Embedding_model
    self.RequestLLM = RequestLLM(delta, model='gpt-4-turbo', steps=1)
    self.Prompt_Embedding_List = []


  def normalize_l2(self, x):
    x = np.array(x)
    norm = np.linalg.norm(x)
    if norm == 0:
      return x
    return x / norm

  def Get_Single_Embedding(self, text, target_dim=3072):
    openai.api_key = self.API_key
    text = text.replace("\n", " ")

    while True:
      try:
        time.sleep(random.uniform(0.5, 1.5))

        response = openai.Embedding.create(
          input=[text],
          model=self.Embedding_model
        )
        embedding = response['data'][0]['embedding']

        truncated_embedding = embedding[:target_dim]

        normalized_embedding = list(self.normalize_l2(truncated_embedding))

        return normalized_embedding

      except openai.error.RateLimitError:
        print("Rate limit exceeded. Waiting before retrying...")
        time.sleep(10)

      except Exception as e:
        print(f"An error occurred: {e}")
        print("Retrying...")
        time.sleep(5)

  def Get_Batch_Embedding(self, exemplars_embedding, column='PromptEmbedding'):
    openai.api_key = self.API_key
    tqdm.pandas()
    # self.z1_z2_prompt[column] = self.z1_z2_prompt['Z1-Z2-Embedding'].progress_apply(
    #   lambda x: x + exemplars_embedding
    # )
    self.z1_z2_prompt[column] = self.z1_z2_prompt['Z1-Z2-Embedding'].progress_apply(
      lambda x: x
    )

  def forward(self, i, z1_z2, z1, z2, z1_z2_embedding):
    exemplars_embedding = self.Get_Single_Embedding(self.RequestLLM.Template.Input_USER)
    self.Get_Batch_Embedding(exemplars_embedding)
    self.RequestLLM.Template.padding_system(z1_z2)
    reward = self.RequestLLM.Prompt_LLM_Epoch(i)
    self.RequestLLM.calculate_suff_fail_freq("blue")
    self.RequestLLM.calculate_k_min_frac()
    # completed_embedding = z1_z2_embedding + exemplars_embedding
    completed_embedding = z1_z2_embedding
    self.RequestLLM.Template.padding_exemplar(i + 1)
    return [completed_embedding, reward]

class AdversarialBandit:
  def __init__(self, z1_z2_prompt_data, delta, input_dim=6144, steps=100, NumberofK=10201):
    self.z1_z2_prompt_data = z1_z2_prompt_data
    self.input_dim = input_dim
    self.steps = steps
    self.NumberofK = NumberofK
    self.device = torch.device("cuda")

    self.model = Network(6144, 1536, 1).to('cuda').double()


    self.best_result_list = []

    self.lr = 1e-2
    self.epochs = 6000
    self.optimizer_fn = torch.optim.Adam
    self.optimizer = self.optimizer_fn(self.model.parameters(), lr=self.lr, weight_decay=1)
    self.loss_fn = nn.MSELoss()

    self.init_model_weight = copy.deepcopy(self.model.state_dict())

    self.eta = 10
    self.prompt_score_trace = []
    self.embedding_score_trace = []



    self.z1_z2 = """You are a bandit algorithm in a room with 4 buttons labeled blue, green, red, yellow. Each button is associated with a Bernoulli distribution with a fixed but unknown mean; the means for the buttons could be different. For each button, when you press it, you will get a reward that is sampled from the button's associated distribution. You have 100 time steps and, on each time step, you can choose any button and receive the reward. Your goal is to maximize the total reward over the 100 time steps.
At each time step, I will show you a summary of your past choices and rewards. Then you must make the next choice. You may output a distribution over the 4 buttons formatted EXACTLY like "blue:a,green:b,red:c,yellow:d". Let's think step by step to make sure we make a good choice.\n"""

    self.z1 = """You are a bandit algorithm in a room with 4 buttons labeled blue, green, red, yellow. Each button is associated with a Bernoulli distribution with a fixed but unknown mean; the means for the buttons could be different. For each button, when you press it, you will get a reward that is sampled from the button's associated distribution. You have 100 time steps and, on each time step, you can choose any button and receive the reward. Your goal is to maximize the total reward over the 100 time steps."""
    self.z2 = """At each time step, I will show you a summary of your past choices and rewards. Then you must make the next choice. You may output a distribution over the 4 buttons formatted EXACTLY like "blue:a,green:b,red:c,yellow:d". Let's think step by step to make sure we make a good choice."""

    self.prompt_embedding = PromptEmbedding(z1_z2_prompt_data, delta, Optimizer_Model='gpt-4-turbo', Embedding_model='text-embedding-3-large')

    self.z1_z2_Embedding = self.prompt_embedding.Get_Single_Embedding(self.z1) + self.prompt_embedding.Get_Single_Embedding(self.z2)

    self.scheduler = None 

    self.Stop_EXP3 = 0


  def restart_model(self):
    self.model.load_state_dict(copy.deepcopy(self.init_model_weight))
    self.optimizer = self.optimizer_fn(self.model.parameters(), lr=self.lr, weight_decay=1/self.embedding_score_trace.__len__())
    self.scheduler = torch.optim.lr_scheduler.LambdaLR(
      self.optimizer,
      lr_lambda=lambda epoch: 0.1 ** (epoch // 2000)
    )


  def train_model(self, batch_size=64):
    self.restart_model()
    self.model.to('cuda')

    embeddings = [item[0] for item in self.embedding_score_trace]
    scores = [item[1] for item in self.embedding_score_trace]

    embeddings_tensor = torch.stack([torch.tensor(e, dtype=torch.float64) for e in embeddings])
    scores_tensor = torch.tensor(scores, dtype=torch.float64)

    dataset = TensorDataset(embeddings_tensor, scores_tensor)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

    self.model.train()
    for epoch in tqdm(range(self.epochs), desc="Training Model"):
      epoch_loss = 0
      for batch_embeddings, batch_scores in dataloader:
        batch_embeddings = batch_embeddings.to('cuda')
        batch_scores = batch_scores.to('cuda')

        self.optimizer.zero_grad()

        predictions = self.model(batch_embeddings)
        loss = self.loss_fn(predictions, batch_scores)

        loss.backward()
        self.optimizer.step()

        epoch_loss += loss.item()

      self.scheduler.step()

      if epoch % 100 == 0:
        print(f"Epoch {epoch}/{self.epochs}, Loss: {epoch_loss / len(dataloader)}")

    print("Training complete.")


  def select(self, iteration):
    embeddings = torch.stack(
      [torch.tensor(e, dtype=torch.float64) for e in self.prompt_embedding.z1_z2_prompt['PromptEmbedding']]
    ).to('cuda')

    with torch.no_grad():
      scores = self.model(embeddings).squeeze(-1)


    print(scores)
    self.z1_z2_prompt_data['Score'] = scores.cpu().tolist()

    self.prompt_score_trace.append(self.z1_z2_prompt_data['Score'])
    score_tensor = torch.tensor(self.prompt_score_trace, dtype=torch.float64)

    log_exp_scores = self.eta * torch.sum(score_tensor, dim=0)
    max_log_score = torch.max(log_exp_scores)
    safe_log_exp_scores = log_exp_scores - max_log_score 
    exp_scores = torch.exp(safe_log_exp_scores)

    total_score = torch.sum(exp_scores)
    normalized_exp_scores = exp_scores / total_score

    top_scores, top_indices = torch.topk(normalized_exp_scores, 10)
    print("Top 10 probabilities: ", top_scores)
    print("Indices of Top 10 probabilities: ", top_indices)

    if not torch.isnan(normalized_exp_scores).any() and not torch.isinf(normalized_exp_scores).any() and torch.sum(normalized_exp_scores) > 0:
      best_index = torch.multinomial(normalized_exp_scores, 1).item()
      self.best_index = best_index
    else:
      print(f"The final index is {self.best_index}")
      print("Encountered NaN or Inf in normalized_exp_scores. Stopping EXP3.")
      self.Stop_EXP3 = 1
      return

    print(f"Selected index: {self.best_index}")
    best_prompt = self.z1_z2_prompt_data.iloc[self.best_index]['Merged_Content']
    self.z1_z2 = best_prompt
    self.z1 = self.z1_z2_prompt_data.iloc[self.best_index]['Z1']
    self.z2 = self.z1_z2_prompt_data.iloc[self.best_index]['Z2']
    self.z1_z2_Embedding = self.z1_z2_prompt_data.iloc[self.best_index]['Z1-Z2-Embedding']

  def optimization(self):
    for i in tqdm(range(self.steps), desc="Optimizing the Meta-Prompt " + str(self.steps) + " step"):
      forward_results = self.prompt_embedding.forward(i, self.z1_z2, self.z1, self.z2, self.z1_z2_Embedding)
      self.embedding_score_trace.append(forward_results)
      if not self.Stop_EXP3:
        self.train_model()
        self.select(i)
    print(self.prompt_embedding.RequestLLM.regret)

Suffix_Failure_Frequency = []
MinFrac_t = []
Time_Averaged_Reward = []
Time_Mu = []
Time_Average_Mu = []

Time_Averaged_Reward = []
Regret = []

for i in range(6):
  if i < 3:
    set_random_seed(999)
  else:
    set_random_seed(1000)
  z1_z2_prompt_data = copy.deepcopy(Meta_Prompt_data)
  new_bandit = AdversarialBandit(z1_z2_prompt_data, 0.5)
  new_bandit.optimization()
  Time_Averaged_Reward.append(copy.deepcopy(new_bandit.prompt_embedding.RequestLLM.Time_Averaged_Reward))
  Regret.append(new_bandit.prompt_embedding.RequestLLM.regret)
  print(Time_Averaged_Reward)
  print(Regret)