{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f333639",
   "metadata": {},
   "outputs": [],
   "source": [
    "import openai\n",
    "from openai import OpenAI\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "# Set your API key\n",
    "openai.api_key = 'API KEY'\n",
    "from tqdm import tqdm\n",
    "import GPy\n",
    "import re\n",
    "import copy\n",
    "import math\n",
    "import concurrent.futures\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38f36c38",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define a function to call the deepseek API \n",
    "def get_deepseek_response(prompt, itr, max_tokens=2000):\n",
    "    # Initialize the OpenAI client\n",
    "    client = openai.OpenAI(api_key=\"API KEY\", base_url=\"https://api.deepseek.com\")\n",
    "\n",
    "    # Make the API request to DeepSeek API\n",
    "    response = client.chat.completions.create(\n",
    "        model=\"deepseek-chat\",\n",
    "        messages=[{\"role\": \"user\", \"content\": prompt}],\n",
    "        max_tokens=max_tokens,\n",
    "        temperature=0.7,\n",
    "        stream=False\n",
    "    )\n",
    "\n",
    "    # Attempt to parse the JSON response\n",
    "    return response\n",
    "\n",
    "# Define a function to call the ChatGPT API \n",
    "def get_chatgpt_response_variable_t(prompt, itr, model=\"gpt-3.5-turbo\", max_tokens=2000):\n",
    "    # response = openai.ChatCompletion.create(\n",
    "    # response = openai.Completion.create(\n",
    "    response = openai.chat.completions.create(\n",
    "        model=model,\n",
    "        messages=[{\"role\": \"user\", \"content\": prompt}],\n",
    "        max_tokens=max_tokens,\n",
    "        temperature = 0.7\n",
    "        #temperature = 1 - min(0.12*math.sqrt(itr), 1)\n",
    "        #temperature = 0.8 - min(0.15*math.sqrt(itr), 0.8)\n",
    "    )\n",
    "    # return response['choices'][0]['message']['content']\n",
    "    return response"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9b97abd",
   "metadata": {},
   "outputs": [],
   "source": [
    "dt = pd.read_csv('OneShotWikiLinks.csv')\n",
    "all_arms = dt['text'].unique().tolist()\n",
    "\n",
    "def generate_reward(textrow_idx, arm):\n",
    "    #print('textrow_idx:',textrow_idx)\n",
    "    #print('arm:',arm)\n",
    "    if dt.iloc[textrow_idx, 1] == arm:\n",
    "        return 1\n",
    "    else:\n",
    "        return 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7a74e8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Remove extra dots from the LLM response\n",
    "def remove_extra_dots(input_string):\n",
    "    # Split the string at the first occurrence of \".\"\n",
    "    first_part, *remaining_parts = input_string.split('.', 1)\n",
    "\n",
    "    # If there's no \".\" in the string, return it as is\n",
    "    if not remaining_parts:\n",
    "        return input_string\n",
    "\n",
    "    # Concatenate the first part with the first occurrence of \".\" and remove extra \".\" from the remaining parts\n",
    "    result = first_part + '.' + remaining_parts[0].replace('.', '')\n",
    "\n",
    "    return result\n",
    "\n",
    "# Compute cumulative reward after iteration ends\n",
    "def cumulate_reward(reward):\n",
    "    reward_copy = copy.deepcopy(reward)\n",
    "    \n",
    "    for row in reward_copy:\n",
    "        for i in range(1, len(row)):\n",
    "            row[i] += row[i - 1]\n",
    "    \n",
    "    return reward_copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f74ca60f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get a valid message, ensuring the return value is between 0 and 1 (inclusive). Try up to max_retries times.\n",
    "def get_valid_msg(prompt, itr, model=\"gpt-3.5-turbo\", max_retries=10):\n",
    "\n",
    "    retries = 0\n",
    "    while retries < max_retries:\n",
    "        # Get the model's response\n",
    "        # response = get_chatgpt_response(prompt, model=model)\n",
    "        response = get_chatgpt_response_variable_t(prompt, itr, model=model)  # Variable temperature\n",
    "        msg = response.choices[0].message.content\n",
    "        # Clean msg, removing non-numeric characters and extra dots\n",
    "        msg = ''.join(filter(lambda x: x.isdigit() or x == '.', msg))\n",
    "        msg = remove_extra_dots(msg)\n",
    "        # Check if msg is within the valid range (0 <= msg <= 1)\n",
    "        try:\n",
    "            msg_float = float(msg)  # Attempt to convert msg to a float\n",
    "            if 0 <= msg_float <= 1:  # If msg is within the valid range\n",
    "                return msg_float  # Return the valid message\n",
    "        except ValueError:\n",
    "            pass  # If conversion to float fails, continue retrying\n",
    "        \n",
    "        retries += 1  # Increase retry count\n",
    "        print(f\"Attempt {retries}, received invalid result: {msg}, retrying...\")\n",
    "\n",
    "    # If the maximum retry count is exceeded and still invalid, return None or another appropriate value\n",
    "    print(f\"Exceeded maximum retry attempts, returning invalid result\")\n",
    "    return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71621c07",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get a valid message, ensuring the return value is between 0 and 1 (inclusive). Try up to max_retries times.\n",
    "def get_valid_msg_deepseek(prompt, itr, model=\"gpt-3.5-turbo\", max_retries=10):\n",
    "\n",
    "    retries = 0\n",
    "    while retries < max_retries:\n",
    "        # Get the model's response\n",
    "        response = get_deepseek_response(prompt, itr)\n",
    "        msg = response.choices[0].message.content\n",
    "        # Clean msg, removing non-numeric characters and extra dots\n",
    "        msg = ''.join(filter(lambda x: x.isdigit() or x == '.', msg))\n",
    "        msg = remove_extra_dots(msg)\n",
    "        # Check if msg is within the valid range (0 <= msg <= 1)\n",
    "        try:\n",
    "            msg_float = float(msg)  # Attempt to convert msg to a float\n",
    "            if 0 <= msg_float <= 1:  # If msg is within the valid range\n",
    "                return msg_float  # Return the valid message\n",
    "        except ValueError:\n",
    "            pass  # If conversion to float fails, continue retrying\n",
    "        \n",
    "        retries += 1  # Increase retry count\n",
    "        print(f\"Attempt {retries}, received invalid result: {msg}, retrying...\")\n",
    "\n",
    "    # If the maximum retry count is exceeded and still invalid, return None or another appropriate value\n",
    "    print(f\"Exceeded maximum retry attempts, returning invalid result\")\n",
    "    return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c252f1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_prompt_experiment(X, Y, x_test, context):\n",
    "    prompt = \"\"\"\n",
    "    **Task Description**\n",
    "    At the TEST DATA, Please assign a reward indicating how well the Incomplete Text aligns with the Previous Text and Next Text.\n",
    "\n",
    "    **reward**:\n",
    "    - 0 indicates poor alignment.\n",
    "    - 1 indicates perfect alignment.\n",
    "    - A reward closer to 1 should only be assigned when the Incomplete Text is perfectly aligned with the surrounding texts.\n",
    "\n",
    "    **The Incomplete Text can be one of the following words**:\n",
    "    ['Microsoft Windows', 'Telugu', 'XML', 'Moscow', 'help', 'MTV', 'Halloween', 'Ottoman Empire', 'Soviet', 'Bangladesh'].\n",
    "\n",
    "    The reward value MUST be a number between 0 and 1. Your response MUST be the reward value only, formatted as #reward value#.\n",
    "    \n",
    "    Below are previous examples:\n",
    "    \"\"\"\n",
    "    \n",
    "    # Add historical data (X and Y) to the prompt\n",
    "    s = \"\"  # s stores the historical data section\n",
    "    for i in range(len(X)):\n",
    "        prev_text = X[i][0]\n",
    "        incomplete_text = X[i][1]\n",
    "        next_text = X[i][2]\n",
    "        reward = Y[i]\n",
    "        \n",
    "        s += f\"**Previous Text**: {prev_text}\\n**Next Text**: {next_text}\\n**Incomplete Text**: {incomplete_text}\\n**Reward**: {reward}\\n\\n\"\n",
    "    \n",
    "    \n",
    "    # Add the current test sample (context and x_test)\n",
    "    prev_text_test = context[0]\n",
    "    next_text_test = context[1]\n",
    "    \n",
    "    s += \"###TEST DATA:\\n\"\n",
    "    s += \"This is the TEST DATA for which the reward needs to be assigned:\\n\"\n",
    "    s += f\"**Previous Text**: {prev_text_test}\\n**Next Text**: {next_text_test}\\n**Incomplete Text**: {x_test}\\n**Reward**: \"\n",
    "    \n",
    "    prompt += s\n",
    "    return prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2e410bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_prompt_baseline(X, Y, context):\n",
    "    prompt = \"\"\"\n",
    "    The task is to choose the most suitable word to complete the Incomplete Text from the following list of options in order to earn the most reward:\n",
    "    ['Microsoft Windows', 'Telugu', 'XML', 'Moscow', 'help', 'MTV', 'Halloween', 'Ottoman Empire', 'Soviet', 'Bangladesh'].\n",
    "    Your response MUST only contain one word from the list.\n",
    "\n",
    "    Reward indicates how well the Incomplete Text aligns with the Previous Text and Next Text.\n",
    "    - 0 indicates poor alignment.\n",
    "    - 1 indicates perfect alignment.\n",
    "\n",
    "    Below is the historical data:\n",
    "    \"\"\"\n",
    "    \n",
    "    # Add historical data section\n",
    "    s = \"\"  # 's' stores the historical data section\n",
    "    \n",
    "    for i in range(len(X)):\n",
    "        # Each historical data consists of previous, incomplete, and next texts, along with the reward\n",
    "        prev_text = X[i][0]\n",
    "        incomplete_text = X[i][1]\n",
    "        next_text = X[i][2]\n",
    "        reward = Y[i]\n",
    "        \n",
    "        s += f\"**Previous Text**: {prev_text}\\n**Next Text**: {next_text}\\n**Incomplete Text**: {incomplete_text}\\n**Reward**: {reward}\\n\\n\"\n",
    "    \n",
    "    s += \"Below is the incomplete text for which you need to complete:\\n\"\n",
    "    \n",
    "    # Add the current test sample (context)\n",
    "    prev_text_test = context[0]\n",
    "    next_text_test = context[1]\n",
    "    \n",
    "    s += f\"**Previous Text**: {prev_text_test}\\n**Next Text**: {next_text_test}\\n**Incomplete Text**:  \"\n",
    "    \n",
    "    prompt += s\n",
    "    return prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45d2dda4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def optimize_experiment(X, Y, K, num_iter, random_seed):\n",
    "    \n",
    "    rewards_all = []\n",
    "    \n",
    "    for itr in np.arange(num_iter):\n",
    "        random.seed(random_seed * float(itr + 1))\n",
    "        textrow_idx = random.randint(0, 19999)\n",
    "        # print(\"textrow_idx: \", textrow_idx)\n",
    "        reward_predicted = []\n",
    "        for j in range(K):  # This for loop updates the acquisition function\n",
    "            \n",
    "            context = [dt.iloc[textrow_idx, 0], dt.iloc[textrow_idx, 2]] \n",
    "            x_test = all_arms[j]\n",
    "            \n",
    "            prompt = make_prompt_experiment(X, Y, x_test, context)\n",
    "            \n",
    "            # msg = get_valid_msg(prompt, itr, model=gpt_model)  # Get ChatGPT response\n",
    "            msg = get_valid_msg_deepseek(prompt, itr)\n",
    "            reward_predicted.append(np.asarray(msg, dtype=float))\n",
    "        \n",
    "        x_t_idx = np.argmax(reward_predicted)  # Select the next arm\n",
    "        x_t = [dt.iloc[textrow_idx, 0], all_arms[x_t_idx], dt.iloc[textrow_idx, 2]]  # Corresponding features\n",
    "        y_t = generate_reward(textrow_idx, all_arms[x_t_idx])\n",
    "    \n",
    "        rewards_all.append(y_t)  # Store the true reward of the next sampling point\n",
    "        # print(\"reward_predicted: \", reward_predicted)\n",
    "        print(\"iter: \", itr)\n",
    "        # print(\"x_t: \", x_t)\n",
    "        # print(\"y_t: \", y_t)\n",
    "        X = np.append(X, np.expand_dims(x_t, axis=0), axis=0)\n",
    "        Y = np.append(Y, y_t)    \n",
    "    \n",
    "    return rewards_all\n",
    "\n",
    "def optimize_baseline(X, Y, K, num_iter, random_seed):\n",
    "    \n",
    "    rewards_all = []\n",
    "    \n",
    "    for itr in np.arange(num_iter):\n",
    "        random.seed(random_seed * float(itr + 1))\n",
    "        textrow_idx = random.randint(0, 19999)\n",
    "        # print(\"textrow_idx: \", textrow_idx)\n",
    "            \n",
    "        context = [dt.iloc[textrow_idx, 0], dt.iloc[textrow_idx, 2]] \n",
    "            \n",
    "        prompt = make_prompt_baseline(X, Y, context)\n",
    "            \n",
    "        # response = get_chatgpt_response_variable_t(prompt, itr, model=gpt_model)  # Variable temperature\n",
    "        response = get_deepseek_response(prompt, itr)\n",
    "        \n",
    "        msg = response.choices[0].message.content.strip()  # Get the single word and strip any extra spaces or newline characters\n",
    "        msg = remove_extra_dots(msg)\n",
    "        # print('msg', msg)\n",
    "        \n",
    "        # print('reward_predicted:', reward_predicted)\n",
    "        x_t = [dt.iloc[textrow_idx, 0], msg, dt.iloc[textrow_idx, 2]]\n",
    "        y_t = generate_reward(textrow_idx, msg)\n",
    "    \n",
    "        rewards_all.append(y_t)  # Store the true reward of the next sampling point\n",
    "        # print(\"reward_predicted: \", reward_predicted)\n",
    "        print(\"iter: \", itr)\n",
    "        # print(\"x_t: \", x_t)\n",
    "        # print(\"y_t: \", y_t)\n",
    "        X = np.append(X, np.expand_dims(x_t, axis=0), axis=0)\n",
    "        Y = np.append(Y, y_t)    \n",
    "    \n",
    "    return rewards_all\n",
    "\n",
    "def optimize_randomsearch(num_iter, K):\n",
    "\n",
    "    rewards_all = []\n",
    "    random.seed(None)\n",
    "    \n",
    "    for itr in np.arange(num_iter):\n",
    "        \n",
    "        lst = [0] * (K - 1) + [1]\n",
    "        random.shuffle(lst)\n",
    "        reward_next = lst[0]\n",
    "        rewards_all.append(reward_next)\n",
    "    \n",
    "        print(\"iter: \", itr)\n",
    "        \n",
    "    return rewards_all"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca10df8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def LLM_for_bandits_text(random_seed, K, N_init, num_iter):\n",
    "    \n",
    "    rewards = []\n",
    "    rewards_2 = []\n",
    "\n",
    "    all_x_init = []\n",
    "    all_y_init = []\n",
    "\n",
    "    X, Y = [], []  # X is an n*3 array, with three columns storing preceding text, fill-in-the-blank text, and following text. \n",
    "                   # Y stores rewards, where a correct fill-in gets a reward of 1, and an incorrect fill-in gets a reward of 0.\n",
    "    \n",
    "    np.random.seed(random_seed)\n",
    "\n",
    "    # We first randomly select some arms\n",
    "    for i in range(N_init): \n",
    "        random.seed(random_seed * (i + 1))\n",
    "        textrow_idx = random.randint(0, 19999)\n",
    "        arm = random.choice(all_arms)\n",
    "        init = [dt.iloc[textrow_idx, 0], arm, dt.iloc[textrow_idx, 2]]\n",
    "        X.append(init)  # Add to X\n",
    "        y_init = generate_reward(textrow_idx, arm)\n",
    "        Y.append(y_init) \n",
    "        # print(\"x_init: \", init)\n",
    "        # print(\"y_init: \", y_init)\n",
    "        all_x_init.append(init)\n",
    "        all_y_init.append(y_init)\n",
    "        rewards.append(y_init)\n",
    "    \n",
    "    # Iteration\n",
    "    rewards_2 = optimize_randomsearch(num_iter, K)\n",
    "    # rewards_2 = optimize_experiment(X, Y, K, num_iter, random_seed)\n",
    "    # rewards_2 = optimize_baseline(X, Y, K, num_iter, random_seed)\n",
    "    \n",
    "    rewards = np.array(rewards)\n",
    "    rewards_2 = np.array(rewards_2)\n",
    "    rewards = np.concatenate([rewards, rewards_2])\n",
    "    \n",
    "    return rewards, all_x_init, all_y_init"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d5bcf64",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Parameters: random seed, number of arms, initial sample size, number of samples in iteration\n",
    "# rewards, all_x_init, all_y_init = LLM_for_bandits_oracle()\n",
    "\n",
    "# Use ThreadPoolExecutor to execute threads\n",
    "gpt_model = \"gpt-3.5-turbo\"\n",
    "with concurrent.futures.ThreadPoolExecutor() as executor:\n",
    "    future1 = executor.submit(LLM_for_bandits_text, 41, 10, 2, 28)\n",
    "    future2 = executor.submit(LLM_for_bandits_text, 42, 10, 2, 28)\n",
    "    future3 = executor.submit(LLM_for_bandits_text, 43, 10, 2, 28)\n",
    "    future4 = executor.submit(LLM_for_bandits_text, 44, 10, 2, 28)\n",
    "    future5 = executor.submit(LLM_for_bandits_text, 45, 10, 2, 28)\n",
    "    future6 = executor.submit(LLM_for_bandits_text, 46, 10, 2, 28)\n",
    "    future7 = executor.submit(LLM_for_bandits_text, 47, 10, 2, 28)\n",
    "    future8 = executor.submit(LLM_for_bandits_text, 48, 10, 2, 28)\n",
    "    future9 = executor.submit(LLM_for_bandits_text, 49, 10, 2, 28)\n",
    "    future0 = executor.submit(LLM_for_bandits_text, 50, 10, 2, 28)\n",
    "\n",
    "    # Retrieve the return values from threads\n",
    "    result1 = future1.result()\n",
    "    result2 = future2.result()\n",
    "    result3 = future3.result()\n",
    "    result4 = future4.result()\n",
    "    result5 = future5.result()\n",
    "    result6 = future6.result()\n",
    "    result7 = future7.result()\n",
    "    result8 = future8.result()\n",
    "    result9 = future9.result()\n",
    "    result0 = future0.result()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
