import numpy as np
import re
from tqdm.auto import tqdm
import copy
import random
from scipy.stats import bernoulli

import torch

import openai

def set_random_seed(seed=0):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)

  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False

class Multi_Armed_Bandit:
  def __init__(self, delta):
    self.delta = delta
    self.optima = 0.5 + self.delta/2
    self.suboptima = self.optima - self.delta
    self.button_color = ["blue", "green", "red", "yellow", "purple"]
    self.bandit = self.bernoulli()
    self.sample_list = []

  def bernoulli(self):
    return {color: bernoulli(self.optima if i == 0 else self.suboptima) for i, color in enumerate(self.button_color)}

  def sample(self, color):
    sample = self.bandit[color].rvs()
    self.sample_list.append((color, sample))
    return sample

  def count_and_mean(self):
    color_stats = {color: [0, 0] for color in self.button_color}
    for color, reward in self.sample_list:
        color_stats[color][0] += 1
        color_stats[color][1] += reward

    for color in color_stats:
        if color_stats[color][0] > 0:
            color_stats[color][1] /= color_stats[color][0]

    return color_stats

class Template_Padding:
  def __init__(self, delta):
    self.Multi_Armed_Bandit = Multi_Armed_Bandit(delta)
    self.SYSTEM = """You are a bandit algorithm in a room with 5 buttons labeled blue, green, red, yellow, purple. Each button is associated with a Bernoulli distribution with a fixed but unknown mean; the means for the buttons could be different. For each button, when you press it, you will get a reward that is sampled from the button's associated distribution. You have 100 time steps and, on each time step, you can choose any button and receive the reward. Your goal is to maximize the total reward over the 100 time steps.
At each time step, I will show you a summary of your past choices and rewards. Then you must make the next choice. You may output a distribution over the 5 buttons formatted EXACTLY like "blue:a,green:b,red:c,yellow:d,purple:e". Let's think step by step to make sure we make a good choice.
You must provide your final answer within the tags <Answer>DIST</Answer> where DIST is the distribution in the format specified above.\n"""
    self.exemplar_template_not_none = """{} button: pressed {} times with average reward {}\n"""
    self.exemplar_template_none = """{} button: pressed {} times\n"""
    self.input_exemplars = None
    self.USER = """So far you have played {} times with your past choices and rewards summarized as follows:

{}
Which button will you choose next? Remember, YOU MUST provide your final answer within the tags <Answer>DIST</Answer> where DIST is formatted like "blue:a,green:b,red:c,yellow:d,purple:e"."""
    self.Input_USER = ""
    self.padding_exemplar()

  def padding_exemplar(self, step=0):
    summary_dict = self.Multi_Armed_Bandit.count_and_mean()
    input_exemplars = ""
    for i in range(5):
      color = self.Multi_Armed_Bandit.button_color[i]
      if summary_dict[color][0] == 0:
        input_exemplars += self.exemplar_template_none.format(color, 0)
      else:
        input_exemplars += self.exemplar_template_not_none.format(color, summary_dict[color][0], summary_dict[color][1])
    self.input_exemplars = input_exemplars
    self.Input_USER = self.USER.format(step, self.input_exemplars)

class RequestLLM:
  def __init__(self, delta, model='gpt-4-turbo', steps=100):
    self.API_key = '[API_KEY]'
    self.model = model
    self.steps = steps
    self.Template = Template_Padding(delta)
    self.stable_score = 0
    self.Suffix_Failure_Frequency = []
    self.MinFrac_t = []
    self.Time_Averaged_Reward = []
    self.Time_Mu = []
    self.Time_Averaged_Mu = []
    self.cumulative_reward_sum = 0
    self.cumulative_mu_sum = 0
    self.fract_best_count = 0
    self.frac_best = []
    self.regret = []

  def Prompt_LLM_Epoch(self, step):
    openai.api_key = self.API_key
    valid_distribution = False
    print("[SYSTEM]" + self.Template.SYSTEM)
    print("[USER]" + self.Template.Input_USER + "\n")

    while not valid_distribution:
      response = openai.ChatCompletion.create(
        model=self.model,
        messages=[
          {"role": "system", "content": self.Template.SYSTEM},
          {"role": "user", "content": self.Template.Input_USER}
        ],
        temperature=0
      )

      response_content = response.choices[0].message['content'].strip()
      print(response_content)

      matches = re.findall(r'\b(blue|green|red|yellow|purple):([0-9]+(?:\.[0-9]*)?)\b', response_content)

      colors_order = ['blue', 'green', 'red', 'yellow', 'purple']
      color_prob_dict = {color: 0.0 for color in colors_order}

      for color, probability in matches:
        color_prob_dict[color] = float(probability)

      colors = list(color_prob_dict.keys())
      probabilities = list(color_prob_dict.values())

      total_prob_sum = sum(probabilities)
      if total_prob_sum > 0:
        probabilities = [p / total_prob_sum for p in probabilities]
        color_prob_dict = {color: prob / total_prob_sum for color, prob in color_prob_dict.items()}

      if np.isclose(sum(probabilities), 1.0) and len(probabilities) == 5:
        valid_distribution = True

        index = np.random.choice(len(probabilities), p=probabilities)

        selected_color = colors[index]
        direct_reward = self.Template.Multi_Armed_Bandit.sample(selected_color)
        print(f"The selected button : {selected_color}, reward: {direct_reward}")

        self.Time_Mu.append(self.Template.Multi_Armed_Bandit.optima if selected_color == "blue" else self.Template.Multi_Armed_Bandit.suboptima)

        self.Template.padding_exemplar(step + 1)
        self.cumulative_reward_sum += direct_reward
        self.cumulative_mu_sum += self.Template.Multi_Armed_Bandit.optima if selected_color == "blue" else self.Template.Multi_Armed_Bandit.suboptima
        self.fract_best_count += 1 if selected_color == "blue" else 0
        self.regret.append(0 if selected_color == "blue" else 0.2)

        time_averaged_reward = self.cumulative_reward_sum / (step + 1)
        time_averaged_mu = self.cumulative_mu_sum / (step + 1)
        print(f"Averaged reward : {time_averaged_reward}")
        print(f"Averaged mu : {time_averaged_mu}")
        self.Time_Averaged_Reward.append(time_averaged_reward)
        self.Time_Averaged_Mu.append(time_averaged_mu)
        self.frac_best.append(self.fract_best_count / (step + 1))
      else:
        print("No valid <Answer> tag found.")

  def calculate_median_reward(self):
    rewards = [reward for _, reward in self.Template.Multi_Armed_Bandit.sample_list]
    print(f"MedianReward : {np.median(rewards)}")
    return np.median(rewards)

  def calculate_suff_fail_freq(self, optimal_button):
    failures = sum(1 for i, (color, _) in enumerate(self.Template.Multi_Armed_Bandit.sample_list[:self.steps//2]) if color != optimal_button)
    print(f"SuffFailFreq(T/2) : {failures / (self.steps / 2)}")
    self.Suffix_Failure_Frequency.append(failures / (self.steps / 2))

  def calculate_k_min_frac(self):
    color_counts = {color: 0 for color in self.Template.Multi_Armed_Bandit.button_color}
    for color, _ in self.Template.Multi_Armed_Bandit.sample_list:
        color_counts[color] += 1
    min_count = min(color_counts.values())
    min_frac = min_count / len(self.Template.Multi_Armed_Bandit.sample_list)
    print(f"K*MinFrac : {len(self.Template.Multi_Armed_Bandit.button_color) * min_frac}")
    self.MinFrac_t.append(len(self.Template.Multi_Armed_Bandit.button_color) * min_frac)

  def calculate_greedy_frac(self):
    color_rewards = {color: 0 for color in self.Template.Multi_Armed_Bandit.button_color}
    color_counts = {color: 0 for color in self.Template.Multi_Armed_Bandit.button_color}

    for color, reward in self.Template.Multi_Armed_Bandit.sample_list:
        color_rewards[color] += reward
        color_counts[color] += 1

    avg_rewards = {color: (color_rewards[color] / color_counts[color] if color_counts[color] > 0 else 0)
                    for color in self.Template.Multi_Armed_Bandit.button_color}
    best_color = max(avg_rewards, key=avg_rewards.get)

    greedy_count = sum(1 for color, _ in self.Template.Multi_Armed_Bandit.sample_list if color == best_color)
    print(f"GreedyFrac : {greedy_count / len(self.Template.Multi_Armed_Bandit.sample_list)}")
    return greedy_count / len(self.Template.Multi_Armed_Bandit.sample_list)



  def optimization(self):
    self.Suffix_Failure_Frequency = []
    self.MinFrac_t = []
    self.Time_Averaged_Reward = []
    self.Time_Mu = []
    self.cumulative_reward_sum = 0
    for i in tqdm(range(self.steps), desc="Prompting LLM " + str(self.steps) + " step"):
      self.Prompt_LLM_Epoch(i)
      self.calculate_suff_fail_freq("blue")
      self.calculate_k_min_frac()
    self.calculate_greedy_frac()
    self.calculate_median_reward()
    print(self.regret)

Suffix_Failure_Frequency = []
MinFrac_t = []
Time_Averaged_Reward = []
Time_Mu = []
Time_Average_Mu = []
Frac_Best = []

Time_Averaged_Reward = []
Regret = []

for i in range(6):
  if i < 3:
    set_random_seed(999)
  else:
    set_random_seed(1000)
  new = RequestLLM(0.2)
  new.optimization()
  Time_Averaged_Reward.append(copy.deepcopy(new.Time_Averaged_Reward))
  Regret.append(new.regret)
  print(Time_Averaged_Reward)
  print(Regret)