#!/usr/bin/env python
# coding: utf-8

# In[ ]:
import openai
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import GPy
import re
import copy
import math
import concurrent.futures

# This is an auxiliary function to help clean up the response generated by the LLM
# Remove extra dots from the LLM's response
def remove_extra_dots(input_string):
    # Split the string at the first occurrence of "."
    first_part, *remaining_parts = input_string.split('.', 1)

    # If there's no "." in the string, return it as is
    if not remaining_parts:
        return input_string

    # Concatenate the first part with the first occurrence of "." and remove extra dots from the remaining parts
    result = first_part + '.' + remaining_parts[0].replace('.', '')

    return result

# Mark the button color corresponding to the explored data
def mark_color(X, xx, colorsheet):
    
    X_color = []
    
    for feature_i in X:
        for idx, xx_row in enumerate(xx):
            if np.array_equal(feature_i, xx_row):
                X_color.append(colorsheet[idx])
                break

    return X_color

# Convert a response containing a probability distribution into the index of the next arm
def dist_response_to_next_arm(chatgpt_output, K):
    
    match = re.search(r"<Answer>#(.*?)#</Answer>", chatgpt_output)
    # If the target format is not matched, randomly select the next arm
    if not match:
        print("Format not matched, returning random choice.")
        return np.random.choice(list(range(K)))
    
    distribution_str = match.group(1).strip()
    distribution = {}
    # Convert the probability distribution into a dictionary
    for item in distribution_str.split(","):
        try:
            button, prob = item.split(":")
            button = button.strip().lower()  # Ensure color is in lowercase
            prob = float(prob.strip())  # Convert to float
            if prob < 0:
                prob = 0  # If probability is negative, set it to 0
            distribution[button] = prob
        except ValueError as e:
            print(f"Error parsing item '{item}': {e}")
            continue  # Skip incorrectly formatted items

    # If the dictionary is empty, randomly select the next arm
    if not distribution:
        print("No valid distribution found, returning random choice.")
        return np.random.choice(list(range(K)))
        
    # If the sum of probabilities is not 1, normalize it. If all probabilities are 0, randomly select the next arm
    total_prob = sum(distribution.values())
    if total_prob != 1.0:
        if total_prob != 0:
            distribution = {key: value / total_prob for key, value in distribution.items()}
            print("Probabilities do not sum to 1")
        else:
            print("Probabilities are all 0, returning random choice.")
            return np.random.choice(list(range(K)))
    
    # Sample to determine the next arm's color
    buttons = list(distribution.keys())
    probabilities = list(distribution.values())
    sampled_button = np.random.choice(buttons, p=probabilities)
    
    # Mapping from color to arm index
    color_to_index = {
        'blue': 0, 'green': 1, 'red': 2, 'yellow': 3, 'purple': 4, 'orange': 5, 'cyan': 6, 'magenta': 7, 'lime': 8, 'pink': 9,
        'teal': 10, 'lavender': 11, 'brown': 12, 'beige': 13, 'maroon': 14, 'mint': 15, 'olive': 16, 'coral': 17, 'navy': 18, 'grey': 19,
        'black': 20, 'white': 21, 'gold': 22, 'silver': 23, 'aqua': 24, 'fuchsia': 25, 'indigo': 26, 'peach': 27, 'salmon': 28, 'plum': 29,
        'orchid': 30, 'crimson': 31, 'turquoise': 32, 'ivory': 33, 'khaki': 34, 'violet': 35, 'azure': 36, 'amber': 37, 'emerald': 38, 'ruby': 39,
        'sapphire': 40, 'bronze': 41, 'copper': 42, 'rose': 43, 'periwinkle': 44
    }
    
    return color_to_index[sampled_button]

# After iteration ends, calculate the cumulative reward
def cumulate_reward(reward):
    reward_copy = copy.deepcopy(reward)
    
    for row in reward_copy:
        for i in range(1, len(row)):
            row[i] += row[i - 1]
    
    return reward_copy

# Prepare the prompt to be passed to the LLM

# Experimental group: prompt = task description + data, 
# where data is sorted by distance to the point to be predicted, with closer points appearing later
def make_prompt_experiment(X, Y, x_test):
    prompt = "Help me predict the function value at the last input. \
    Each function value is associated with a Normal distribution with a fixed but unknown mean. \
    Your response should only contain the function value in the format of #function value#.\n"

    # Compute the distance between each sample in X and x_test (Euclidean distance)
    # If X is multi-dimensional, calculate the distance for each sample
    if len(X.shape) > 1:  # Multi-dimensional data (each sample has multiple features)
        X_ = np.linalg.norm(X - x_test, axis=1)  # Compute Euclidean distance row-wise
    else:  # One-dimensional data (only one feature per sample)
        X_ = np.abs(X - x_test)  # Directly compute the absolute difference
    sorted_indices_descending = sorted(range(len(X_)), key=lambda i: X_[i], reverse=True)  # Sort distances in descending order

    s = ""  # s stores the sorted data
    for i in sorted_indices_descending:
        s += f"input: {X[i]}, output: {Y[i]}\n"
    s += f"input: " + str(x_test) + ", output: "
    prompt += s

    return prompt


# Baseline uses the highest MedianReward prompt combination (BNRCD) from GPT-3.5 button scenario easy task
# Baseline: Add features in the framing section
def make_prompt_framingfeature(X_color, Y, num_iter, xx, itr, N_init, K, colorsheet):
    
    prompt = ""
    
    # Button labels corresponding to color order
    framing = f"You are in a room with {K} buttons labeled "
    for i in range(K):
        framing += f"{colorsheet[i]}, "
    framing += ".\n"
    
    for i in range(K):
        framing += f"Feature of {colorsheet[i]} button: {xx[i]}\n"
    
    framing += "Each button is associated with a Normal distribution with a fixed but unknown mean; \
    the means for the buttons could be different and are associated with features of buttons. \
    For each button, when you press it, you will get a reward that is sampled from the button's associated distribution.\n"
    
    framing += f"You have {num_iter} time steps and, on each time step, you can choose any button and receive the reward. "
    framing += f"Your goal is to maximize the total reward over the {num_iter} time steps. "
    
    history = f"So far you have played {itr + N_init} times with the following choices and rewards:\n"
    for i in range(itr + N_init):
        history += f"{X_color[i]} button, reward {Y[i]}\n"
    
    action = f"You MUST output a distribution over the {K} buttons as probabilities, formatted EXACTLY like this example: #"
    for i in range(K-1):
        action += f"{colorsheet[i]}:p{i+1},"
    action += f"{colorsheet[K-1]}:p{K}#. Each probability value("
    for i in range(K-1):
        action += f"p{i+1},"
    action += f"p{K}) MUST be a number between 0 and 1, and the total of all probabilities MUST equal 1.\n"
    
    final_prompt = "Let's think step by step to make sure we make a good choice. Which button will you choose next? \
    YOU MUST provide your final answer within the tags <Answer>DIST</Answer> where DIST is #"
    for i in range(K-1):
        final_prompt += f"{colorsheet[i]}:p{i+1},"
    final_prompt += f"{colorsheet[K-1]}:p{K}#."
    
    prompt += framing
    prompt += history
    prompt += action
    prompt += final_prompt
    
    return prompt
    

# Baseline: Add features in the history section
def make_prompt_historyfeature(X_color, Y, num_iter, xx, itr, N_init, K, colorsheet):
    
    prompt = ""
    
    framing = f"You are in a room with {K} buttons labeled "
    for i in range(K):
        framing += f"{colorsheet[i]}, "
    framing += ".\n"
    
    framing += "Each button is associated with a Normal distribution with a fixed but unknown mean; \
    the means for the buttons could be different and are associated with features of buttons. \
    For each button, when you press it, you will get a reward that is sampled from the button's associated distribution.\n"
    
    framing += f"You have {num_iter} time steps and, on each time step, you can choose any button and receive the reward. "
    framing += f"Your goal is to maximize the total reward over the {num_iter} time steps. "
    
    history = ""
    for i in range(K):
        history += f"Feature of {colorsheet[i]} button: {xx[i]}\n"
    
    history += f"So far you have played {itr + N_init} times with the following choices and rewards:\n"
    for i in range(itr + N_init):
        history += f"{X_color[i]} button, reward {Y[i]}\n"
        
    action = f"You MUST output a distribution over the {K} buttons as probabilities, formatted EXACTLY like this example: #"
    for i in range(K-1):
        action += f"{colorsheet[i]}:p{i+1},"
    action += f"{colorsheet[K-1]}:p{K}#. Each probability value("
    for i in range(K-1):
        action += f"p{i+1},"
    action += f"p{K}) MUST be a number between 0 and 1, and the total of all probabilities MUST equal 1.\n"
    
    final_prompt = "Let's think step by step to make sure we make a good choice. Which button will you choose next? \
    YOU MUST provide your final answer within the tags <Answer>DIST</Answer> where DIST is #"
    for i in range(K-1):
        final_prompt += f"{colorsheet[i]}:p{i+1},"
    final_prompt += f"{colorsheet[K-1]}:p{K}#."
    
    prompt += framing
    prompt += history
    prompt += action
    prompt += final_prompt
    
    return prompt


# Baseline: No features added
def make_prompt_nofeature(X_color, Y, num_iter, itr, N_init, K, colorsheet):
    
    prompt = ""
    
    framing = f"You are in a room with {K} buttons labeled "
    for i in range(K):
        framing += f"{colorsheet[i]}, "
    framing += ".\n"
    
    framing += "Each button is associated with a Normal distribution with a fixed but unknown mean; \
    the means for the buttons could be different and are associated with features of buttons. \
    For each button, when you press it, you will get a reward that is sampled from the button's associated distribution.\n"
    
    framing += f"You have {num_iter} time steps and, on each time step, you can choose any button and receive the reward. "
    framing += f"Your goal is to maximize the total reward over the {num_iter} time steps. "
    
    history = f"So far you have played {itr + N_init} times with the following choices and rewards:\n"
    for i in range(itr + N_init):
        history += f"{X_color[i]} button, reward {Y[i]}\n"
        
    action = f"You MUST output a distribution over the {K} buttons as probabilities, formatted EXACTLY like this example: #"
    for i in range(K-1):
        action += f"{colorsheet[i]}:p{i+1},"
    action += f"{colorsheet[K-1]}:p{K}#. Each probability value("
    for i in range(K-1):
        action += f"p{i+1},"
    action += f"p{K}) MUST be a number between 0 and 1, and the total of all probabilities MUST equal 1.\n"
    
    final_prompt = "Let's think step by step to make sure we make a good choice. Which button will you choose next? \
    YOU MUST provide your final answer within the tags <Answer>DIST</Answer> where DIST is #"
    for i in range(K-1):
        final_prompt += f"{colorsheet[i]}:p{i+1},"
    final_prompt += f"{colorsheet[K-1]}:p{K}#."
    
    prompt += framing
    prompt += history
    prompt += action
    prompt += final_prompt
    
    return prompt

