{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0769a7dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import openai\n",
    "from openai import OpenAI\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "# Set your API key\n",
    "openai.api_key = 'API KEY'\n",
    "from tqdm import tqdm\n",
    "import GPy\n",
    "import re\n",
    "import copy\n",
    "import math\n",
    "import concurrent.futures\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d92820ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_deepseek_response(prompt, itr, max_tokens=2000):\n",
    "    try:\n",
    "        # Initialize the OpenAI client\n",
    "        client = openai.OpenAI(api_key=\"API KEY\", base_url=\"https://api.deepseek.com\")\n",
    "\n",
    "        # Make the API request to DeepSeek API\n",
    "        response = client.chat.completions.create(\n",
    "            model=\"deepseek-chat\",\n",
    "            messages=[{\"role\": \"user\", \"content\": prompt}],\n",
    "            max_tokens=max_tokens,\n",
    "            temperature=0.7,\n",
    "            stream=False\n",
    "        )\n",
    "\n",
    "        # Attempt to parse the JSON response\n",
    "        return response\n",
    "\n",
    "    except Exception as e:  # Catch all exceptions\n",
    "        print(f\"An error occurred\")\n",
    "        return -1\n",
    "    \n",
    "# Define a function to call the ChatGPT API \n",
    "def get_chatgpt_response_variable_t(prompt, itr, model=\"gpt-3.5-turbo\", max_tokens=2000):\n",
    "    # response = openai.ChatCompletion.create(\n",
    "    # response = openai.Completion.create(\n",
    "    response = openai.chat.completions.create(\n",
    "        model=model,\n",
    "        messages=[{\"role\": \"user\", \"content\": prompt}],\n",
    "        max_tokens=max_tokens,\n",
    "        temperature = 0.7\n",
    "        #temperature = 1 - min(0.12*math.sqrt(itr), 1)\n",
    "        #temperature = 0.8 - min(0.15*math.sqrt(itr), 0.8)\n",
    "    )\n",
    "    # return response['choices'][0]['message']['content']\n",
    "    return response"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad09c585",
   "metadata": {},
   "outputs": [],
   "source": [
    "dt = pd.read_csv('AmazonCat_10arm.csv')\n",
    "all_arms = dt['target_ind'].unique().tolist()\n",
    "\n",
    "dt['uid'] = dt['title']\n",
    "dt['title'] = dt['content']\n",
    "dt.drop(columns=['content'], inplace=True)\n",
    "\n",
    "dt.rename(columns={'title': 'content'}, inplace=True)\n",
    "dt.rename(columns={'uid': 'title'}, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75772c9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_reward(textrow_idx, arm):\n",
    "    #print('textrow_idx:',textrow_idx)\n",
    "    print('arm:',arm)\n",
    "    print('realarm',dt.iloc[textrow_idx, 1])\n",
    "    if dt.iloc[textrow_idx, 1] == arm:\n",
    "        return 1\n",
    "    else:\n",
    "        return 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5f78a6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Remove extra dots from the LLM response\n",
    "def remove_extra_dots(input_string):\n",
    "    # Split the string at the first occurrence of \".\"\n",
    "    first_part, *remaining_parts = input_string.split('.', 1)\n",
    "\n",
    "    # If there's no \".\" in the string, return it as is\n",
    "    if not remaining_parts:\n",
    "        return input_string\n",
    "\n",
    "    # Concatenate the first part with the first occurrence of \".\" and remove extra \".\" from the remaining parts\n",
    "    result = first_part + '.' + remaining_parts[0].replace('.', '')\n",
    "\n",
    "    return result\n",
    "\n",
    "# Compute cumulative reward after iteration ends\n",
    "def cumulate_reward(reward):\n",
    "    reward_copy = copy.deepcopy(reward)\n",
    "    \n",
    "    for row in reward_copy:\n",
    "        for i in range(1, len(row)):\n",
    "            row[i] += row[i - 1]\n",
    "    \n",
    "    return reward_copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd5c4b39",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get a valid message, ensuring the return value is between 0 and 1 (inclusive). Try up to max_retries times.\n",
    "def get_valid_msg(prompt, itr, model=\"gpt-3.5-turbo\", max_retries=10):\n",
    "\n",
    "    retries = 0\n",
    "    while retries < max_retries:\n",
    "        # Get the model's response\n",
    "        # response = get_chatgpt_response(prompt, model=model)\n",
    "        # response = get_chatgpt_response_variable_t(prompt, itr, model=model) # variable temperature\n",
    "        response = get_deepseek_response(prompt, itr)\n",
    "        if response == -1:\n",
    "            print(\"API call failed\")\n",
    "            break\n",
    "        else:\n",
    "            msg = response.choices[0].message.content\n",
    "            # Clean msg, removing non-numeric characters and extra dots\n",
    "            msg = ''.join(filter(lambda x: x.isdigit() or x == '.', msg))\n",
    "            msg = remove_extra_dots(msg)\n",
    "            # Check if msg is within the valid range (0 <= msg <= 1)\n",
    "            try:\n",
    "                msg_float = float(msg)  # Attempt to convert msg to a float\n",
    "                if 0 <= msg_float <= 1:  # If msg is within the valid range\n",
    "                    return msg_float  # Return the valid message\n",
    "            except ValueError:\n",
    "                pass  # If conversion to float fails, continue retrying\n",
    "        \n",
    "        retries += 1  # Increase retry count\n",
    "        print(f\"Attempt {retries}, received invalid result: {msg}, retrying...\")\n",
    "\n",
    "    print(f\"Returning invalid result\")\n",
    "    return None\n",
    "\n",
    "# Get a valid message, ensuring the return value is one of [2571, 1471, 7961, 12246, 5754, 342, 5456, 5960, 11235, 10688]. Try up to max_retries times.\n",
    "def get_valid_msg_baseline(prompt, itr, model=\"gpt-3.5-turbo\", max_retries=10):\n",
    "\n",
    "    retries = 0\n",
    "    valid_numbers = [2571, 1471, 7961, 12246, 5754, 342, 5456, 5960, 11235, 10688]\n",
    "    \n",
    "    while retries < max_retries:\n",
    "        # Get the model's response\n",
    "        # response = get_chatgpt_response(prompt, model=model)\n",
    "        # response = get_chatgpt_response_variable_t(prompt, itr, model=model) # variable temperature\n",
    "        response = get_deepseek_response(prompt, itr)\n",
    "        if response == -1:\n",
    "            print(\"API call failed\")\n",
    "            break\n",
    "        else:\n",
    "            msg = response.choices[0].message.content\n",
    "            # Clean msg, removing non-numeric characters and extra dots\n",
    "            msg = ''.join(filter(lambda x: x.isdigit() or x == '.', msg))\n",
    "            msg = remove_extra_dots(msg)\n",
    "            # Check if msg is in the valid number list\n",
    "            try:\n",
    "                msg_int = int(msg)\n",
    "                if msg_int in valid_numbers:  # Check if msg is in the valid number list\n",
    "                    return msg_int  # Return the valid message\n",
    "            except ValueError:\n",
    "                pass  # If conversion to int fails, continue retrying\n",
    "        \n",
    "        retries += 1  # Increase retry count\n",
    "        print(f\"Attempt {retries}, received invalid result: {msg}, retrying...\")\n",
    "\n",
    "    print(f\"Returning invalid result\")\n",
    "    return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee9cc9b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_prompt_experiment(X, Y, x_test, context):\n",
    "    prompt = \"\"\"\n",
    "    There are Titles and Contents of some items. \n",
    "    \n",
    "    Labels and items correspond one-to-one.\n",
    "    There are a total of 10 items.The Labels MUST be ONE of the following numbers: [2571, 1471, 7961, 12246, 5754, 342, 5456, 5960, 11235, 10688]\n",
    "    \n",
    "    The Reward is a number between 0 and 1 determined by whether the Label is correct or not.\n",
    "    \n",
    "    Help me predict the Reward at the last Title, Content and Label.\n",
    "    \n",
    "    Your response MUST be the predicted Reward only, formatted as #predicted Reward#.\n",
    "    \n",
    "    \"\"\"\n",
    "    \n",
    "    # Add historical data section\n",
    "    s = \"\"  # 's' stores the historical data section\n",
    "    \n",
    "    for i in range(len(X)):\n",
    "        # Each historical data consists of previous, incomplete, and next texts, along with the reward\n",
    "        Title = X[i][0]\n",
    "        Label = X[i][1]\n",
    "        Content = X[i][2]\n",
    "        reward = Y[i]\n",
    "        \n",
    "        s += f\"**Title**: {Title}\\n**Content**: {Content}\\n**Label**: {Label}\\n**Reward**: {reward}\\n\\n\"\n",
    "    \n",
    "    # Add the current test sample (context)\n",
    "    Title_test = context[0]\n",
    "    Content_test = context[1]\n",
    "    Label_test = x_test\n",
    "    \n",
    "    s += f\"**Title**: {Title_test}\\n**Content**: {Content_test}\\n**Label**: {Label_test} \\n**Reward**:\"\n",
    "    \n",
    "    prompt += s\n",
    "    return prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb6ff76a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_prompt_baseline(X, Y, context):\n",
    "    prompt = \"\"\"\n",
    "    There are Titles and Contents of some items. \n",
    "    \n",
    "    Labels and items correspond one-to-one.\n",
    "    There are a total of 10 items.The Labels MUST be ONE of the following numbers: [2571, 1471, 7961, 12246, 5754, 342, 5456, 5960, 11235, 10688]\n",
    "    \n",
    "    The Reward is a number between 0 and 1 determined by whether the Label is correct or not.\n",
    "    \n",
    "    Help me choose the correct Label at the last Title and Content. Your response MUST be the chosen Label only, formatted as #chosen Label#.\n",
    "    \n",
    "    \"\"\"\n",
    "    \n",
    "    # Add historical data section\n",
    "    s = \"\"  # 's' stores the historical data section\n",
    "    \n",
    "    for i in range(len(X)):\n",
    "        # Each historical data consists of previous, incomplete, and next texts, along with the reward\n",
    "        Title = X[i][0]\n",
    "        Label = X[i][1]\n",
    "        Content = X[i][2]\n",
    "        reward = Y[i]\n",
    "        \n",
    "        s += f\"**Title**: {Title}\\n**Content**: {Content}\\n**Label**: {Label}\\n**Reward**: {reward}\\n\\n\"\n",
    "    \n",
    "    # Add the current test sample (context)\n",
    "    Title_test = context[0]\n",
    "    Content_test = context[1]\n",
    "    \n",
    "    s += f\"**Title**: {Title_test}\\n**Content**: {Content_test}\\n**Label**:  \"\n",
    "    \n",
    "    prompt += s\n",
    "    return prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0bc128c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def optimize_experiment(X, Y, K, num_iter, random_seed):\n",
    "    \n",
    "    rewards_all = []\n",
    "    \n",
    "    for itr in np.arange(num_iter):\n",
    "        random.seed(random_seed * float(itr + 1))\n",
    "        textrow_idx = random.randint(0, 6783)\n",
    "        # print(\"textrow_idx: \", textrow_idx)\n",
    "        reward_predicted = []\n",
    "        for j in range(K):  # This for loop updates the acquisition function\n",
    "            \n",
    "            context = [dt.iloc[textrow_idx, 0], dt.iloc[textrow_idx, 2]] \n",
    "            x_test = all_arms[j]\n",
    "            \n",
    "            prompt = make_prompt_experiment(X, Y, x_test, context)\n",
    "            \n",
    "            msg = get_valid_msg(prompt, itr, model=gpt_model)  # Get ChatGPT response\n",
    "            reward_predicted.append(np.asarray(msg, dtype=float))\n",
    "        \n",
    "        x_t_idx = np.argmax(reward_predicted)  # Select the next arm\n",
    "        x_t = [dt.iloc[textrow_idx, 0], all_arms[x_t_idx], dt.iloc[textrow_idx, 2]]  # Corresponding features\n",
    "        y_t = generate_reward(textrow_idx, all_arms[x_t_idx])\n",
    "    \n",
    "        rewards_all.append(y_t)  # Store the true reward of the next sampling point\n",
    "        # print(\"reward_predicted: \", reward_predicted)\n",
    "        print(\"iter: \", itr)\n",
    "        # print(\"x_t: \", x_t)\n",
    "        # print(\"y_t: \", y_t)\n",
    "        X = np.append(X, np.expand_dims(x_t, axis=0), axis=0)\n",
    "        Y = np.append(Y, y_t)    \n",
    "    \n",
    "    return rewards_all\n",
    "\n",
    "def optimize_baseline(X, Y, K, num_iter, random_seed):\n",
    "    \n",
    "    rewards_all = []\n",
    "    \n",
    "    for itr in np.arange(num_iter):\n",
    "        random.seed(random_seed * float(itr + 1))\n",
    "        textrow_idx = random.randint(0, 6783)\n",
    "        # print(\"textrow_idx: \", textrow_idx)\n",
    "            \n",
    "        context = [dt.iloc[textrow_idx, 0], dt.iloc[textrow_idx, 2]] \n",
    "            \n",
    "        prompt = make_prompt_baseline(X, Y, context)\n",
    "            \n",
    "        msg = get_valid_msg_baseline(prompt, itr, model=gpt_model)  # Get ChatGPT response\n",
    "        # print('msg', msg)\n",
    "        \n",
    "        x_t = [dt.iloc[textrow_idx, 0], msg, dt.iloc[textrow_idx, 2]]\n",
    "        y_t = generate_reward(textrow_idx, msg)\n",
    "    \n",
    "        rewards_all.append(y_t)  # Store the true reward of the next sampling point\n",
    "        # print(\"reward_predicted: \", reward_predicted)\n",
    "        print(\"iter: \", itr)\n",
    "        # print(\"x_t: \", x_t)\n",
    "        # print(\"y_t: \", y_t)\n",
    "        X = np.append(X, np.expand_dims(x_t, axis=0), axis=0)\n",
    "        Y = np.append(Y, y_t)    \n",
    "    \n",
    "    return rewards_all\n",
    "\n",
    "def optimize_randomsearch(num_iter, K, random_seed):\n",
    "\n",
    "    rewards_all = []\n",
    "    \n",
    "    for itr in np.arange(num_iter):\n",
    "        \n",
    "        seed_value = random_seed * (itr + 100)\n",
    "        random.seed(int(seed_value))\n",
    "        textrow_idx = random.randint(0, 6783)\n",
    "        arm = random.choice(all_arms)\n",
    "        reward_next = generate_reward(textrow_idx, arm)\n",
    "        rewards_all.append(reward_next)\n",
    "        \n",
    "    return rewards_all"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2a11d8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def LLM_for_bandits_text2(random_seed, K, N_init, num_iter):\n",
    "    \n",
    "    rewards = []\n",
    "    rewards_2 = []\n",
    "\n",
    "    all_x_init = []\n",
    "    all_y_init = []\n",
    "\n",
    "    X, Y = [], []  # X is an n*3 array, with three columns storing title, label, and content. Y stores rewards.\n",
    "    \n",
    "    np.random.seed(random_seed)\n",
    "\n",
    "    # We first randomly select some arms\n",
    "    for i in range(N_init): \n",
    "        random.seed(random_seed * (i + 1))\n",
    "        textrow_idx = random.randint(0, 6783)\n",
    "        arm = random.choice(all_arms)\n",
    "        init = [dt.iloc[textrow_idx, 0], arm, dt.iloc[textrow_idx, 2]]\n",
    "        X.append(init)  # Add to X\n",
    "        y_init = generate_reward(textrow_idx, arm)\n",
    "        Y.append(y_init) \n",
    "        # print(\"x_init: \", init)\n",
    "        # print(\"y_init: \", y_init)\n",
    "        all_x_init.append(init)\n",
    "        all_y_init.append(y_init)\n",
    "        rewards.append(y_init)\n",
    "    \n",
    "    # Iteration\n",
    "    rewards_2 = optimize_randomsearch(num_iter, K, random_seed)\n",
    "    # rewards_2 = optimize_experiment(X, Y, K, num_iter, random_seed)\n",
    "    # rewards_2 = optimize_baseline(X, Y, K, num_iter, random_seed)\n",
    "    \n",
    "    rewards = np.array(rewards)\n",
    "    rewards_2 = np.array(rewards_2)\n",
    "    rewards = np.concatenate([rewards, rewards_2])\n",
    "    \n",
    "    return rewards, all_x_init, all_y_init"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "898a90a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Parameters: random seed, number of arms, initial sample size, number of samples in iteration\n",
    "# rewards, all_x_init, all_y_init = LLM_for_bandits_oracle()\n",
    "\n",
    "# Use ThreadPoolExecutor to execute threads\n",
    "gpt_model = \"gpt-3.5-turbo\"\n",
    "with concurrent.futures.ThreadPoolExecutor() as executor:\n",
    "    future1 = executor.submit(LLM_for_bandits_text2, 41, 10, 2, 98)\n",
    "    future2 = executor.submit(LLM_for_bandits_text2, 42, 10, 2, 98)\n",
    "    future3 = executor.submit(LLM_for_bandits_text2, 43, 10, 2, 98)\n",
    "    future4 = executor.submit(LLM_for_bandits_text2, 44, 10, 2, 98)\n",
    "    future5 = executor.submit(LLM_for_bandits_text2, 45, 10, 2, 98)\n",
    "    future6 = executor.submit(LLM_for_bandits_text2, 46, 10, 2, 98)\n",
    "    future7 = executor.submit(LLM_for_bandits_text2, 47, 10, 2, 98)\n",
    "    future8 = executor.submit(LLM_for_bandits_text2, 48, 10, 2, 98)\n",
    "    future9 = executor.submit(LLM_for_bandits_text2, 49, 10, 2, 98)\n",
    "    future0 = executor.submit(LLM_for_bandits_text2, 50, 10, 2, 98)\n",
    "\n",
    "    # Retrieve the return values from threads\n",
    "    result1 = future1.result()\n",
    "    result2 = future2.result()\n",
    "    result3 = future3.result()\n",
    "    result4 = future4.result()\n",
    "    result5 = future5.result()\n",
    "    result6 = future6.result()\n",
    "    result7 = future7.result()\n",
    "    result8 = future8.result()\n",
    "    result9 = future9.result()\n",
    "    result0 = future0.result()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
