{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d0aa91f-ca94-4478-bf54-9f0c96d7da77",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import pickle as pkl\n",
    "from openai import OpenAI\n",
    "from sklearn.metrics import f1_score,roc_auc_score\n",
    "import random\n",
    "import json"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fbae18c8-3bdf-4ce9-8ce7-f27aff941d02",
   "metadata": {},
   "source": [
    "## Load Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "894a4b0d-84a2-4819-bd23-b38462ae04b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "city = 'ny'\n",
    "\n",
    "city_full_name = {\n",
    "    'ny': 'New York City'\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "182b5ba4-69bb-4fc4-bf22-677bd4b69d1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load input data and ground truths\n",
    "with open('indices.pkl', 'rb') as f:\n",
    "    indices = pkl.load(f)\n",
    "    \n",
    "with open('dates.pkl', 'rb') as f:\n",
    "    dates = pkl.load(f)\n",
    "    \n",
    "with open(f'time_series_{city}.pkl', 'rb') as f:\n",
    "    data = pkl.load(f)\n",
    "\n",
    "texts = {}\n",
    "for i in indices:\n",
    "    with open(os.path.join('weather_summary', f'{city}_{i}.txt'), 'r') as f:\n",
    "        text = f.read()\n",
    "        texts[i] = text\n",
    "\n",
    "gt_train = np.load('gt_train.npy')\n",
    "gt_val = np.load('gt_val.npy')\n",
    "\n",
    "# Load explanation results from the prototype-based encoder\n",
    "with open('./expl_results/train_expl.json', 'r') as file:\n",
    "    expl_train = json.load(file)\n",
    "    \n",
    "with open('./expl_results/val_expl.json', 'r') as file:\n",
    "    expl_val = json.load(file)\n",
    "    \n",
    "with open('./expl_results/test_expl.json', 'r') as file:\n",
    "    expl_test = json.load(file)\n",
    "\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca6b7147-a9ef-4217-a35c-6e67a3dd4425",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_size = data.shape[0]\n",
    "window_size = 24\n",
    "\n",
    "data_size = len(indices)\n",
    "\n",
    "num_train = int(data_size * 0.6)\n",
    "num_test = int(data_size * 0.2)\n",
    "num_vali = data_size - num_train - num_test\n",
    "\n",
    "seq_len_day = 1\n",
    "\n",
    "idx_train = np.arange(num_train - seq_len_day)\n",
    "idx_val = np.arange(num_train - seq_len_day, num_train + num_vali - seq_len_day)\n",
    "idx_test = np.arange(num_train + num_vali - seq_len_day, num_train + num_vali + num_test - seq_len_day)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f9d5de1b-ba99-4050-85bc-afbe94fdc670",
   "metadata": {},
   "source": [
    "## Prompt GPT 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a14cf0e3-ba25-44c2-bef9-9bb797b56874",
   "metadata": {},
   "outputs": [],
   "source": [
    "# OPEN AI API Key\n",
    "API_KEY = ''\n",
    "client = OpenAI(api_key=API_KEY) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e0ce996b",
   "metadata": {},
   "source": [
    "### Prediction - Text + Prototype "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a6f77bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "random.seed(2024)\n",
    "llm_pred_tr0 = []\n",
    "user_prompt_all = []\n",
    "system_prompt = f\"Your job is to act as a professional weather forecaster. You will be given a summary of the weather from the past 24 hours. Based on this information, your task is to predict whether it will rain in the next 24 hours.\"\n",
    "\n",
    "k_max = 5 \n",
    "for _i in idx_train:\n",
    "    i = indices[_i]\n",
    "    \n",
    "    user_prompt = f\"Your task is to predict whether it will rain or not in {city_full_name[city]} in the next {window_size} hours. \"\n",
    "\n",
    "    prototypes =  expl_train['%d'%_i]\n",
    "    \n",
    "    k = k_max if len(prototypes) > k_max else len(prototypes)\n",
    "    \n",
    " \n",
    "    \n",
    "    \n",
    "    user_prompt += f'First, review the following {k_max} prototype text segments and outcomes, '\n",
    "    user_prompt += 'so that you can refer to when making predictions:\\n\\n'\n",
    "    \n",
    "    \n",
    "    \n",
    "    for _k in range(k):\n",
    "        \n",
    "        user_prompt += f\"Prototype #{_k+1}: {prototypes[_k]['Prototype']}\"\n",
    "        user_prompt += f\"\\nCorresponding Segment#{_k+1}: {prototypes[_k]['Input Segment']}\"\n",
    "        user_prompt += f\"\\nRelevance Score #{_k+1}: {np.round(prototypes[_k]['Similarity'],4)}\"\n",
    "        \n",
    "        if prototypes[_k]['Class'] == 1:\n",
    "            user_prompt += f\"\\nOutcome #{_k+1}: It rained.\\n\\n\"\n",
    "             \n",
    "        else:\n",
    "            user_prompt += f\"\\nOutcome #{_k+1}: It did not rain.\\n\\n\"\n",
    " \n",
    "    user_prompt += \"Next, review the weather summary of the last 24 hours:\\n\\n\"\n",
    "    user_prompt += f\"Summary: {texts[i]}\\n\\n\"\n",
    "    user_prompt += f\"Outcome:\\n\\n\"\n",
    "    \n",
    "    user_prompt += \"Based on your understanding of the provided information of each prototype, \"\n",
    "    user_prompt += \"predict the outcome of the current weather summary. \"\n",
    "    user_prompt += \"Respond your prediction with either 'rain' or 'not rain'. \"\n",
    "    user_prompt += \"Response should not include other terms.\"\n",
    "    \n",
    "    user_prompt_all.append(user_prompt)\n",
    "    \n",
    "    response = client.chat.completions.create(\n",
    "        model=\"gpt-4o-2024-08-06\",\n",
    "        messages=[\n",
    "        {\n",
    "          \"role\": \"system\",\n",
    "          \"content\": system_prompt\n",
    "        },\n",
    "        {\n",
    "          \"role\": \"user\",\n",
    "          \"content\": user_prompt\n",
    "        }\n",
    "        ],\n",
    "        temperature=0.3,\n",
    "        max_tokens=2048,\n",
    "        top_p=1\n",
    "    )\n",
    "\n",
    "    text = response.choices[0].message.content\n",
    "    \n",
    "    llm_pred_tr0.append(text)\n",
    "\n",
    "    \n",
    "class_mapping = {'not rain':0, 'Not rain.':0,'rain': 1}\n",
    "pllm_tr0 = [class_mapping[llm_pred_tr0[i]] for i in range(len(llm_pred_tr0))]\n",
    "\n",
    "f1_mi_tr = f1_score(np.array(gt_train), np.array(pllm_tr0), average='micro')\n",
    "f1_ma_tr = f1_score(np.array(gt_train), np.array(pllm_tr0), average='macro')\n",
    "auc_tr = roc_auc_score(np.eye(2)[np.array(gt_train)],np.eye(2)[np.array(pllm_tr0)])\n",
    "\n",
    "print(f1_mi_tr,f1_ma_tr,auc_tr)\n",
    " "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f23d3082",
   "metadata": {},
   "source": [
    "### Reflection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1828770a",
   "metadata": {},
   "outputs": [],
   "source": [
    "idx_incorrect01 = np.where((gt_train==0) & (np.array(pllm_tr0) == 1))[0]\n",
    "idx_incorrect10 = np.where((gt_train==1) & (np.array(pllm_tr0) == 0))[0]\n",
    "\n",
    "idx_correct0 = np.where((gt_train==0) & (np.array(pllm_tr0) == 0))[0]\n",
    "idx_correct1 = np.where((gt_train==1) & (np.array(pllm_tr0) == 1))[0]\n",
    "\n",
    "\n",
    "print(len(idx_incorrect01),len(idx_incorrect10),len(idx_correct0),len(idx_correct1))\n",
    "\n",
    " "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8bdc14e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sys_prompt_init(correct_flag):\n",
    "    system_prompt = 'You are an advanced reasoning agent that can improve the quality of weather summary '\n",
    "    system_prompt += 'based on self reflection. '\n",
    "    system_prompt += 'You will be given the weather summaries, '\n",
    "    system_prompt += f'and {correct_flag} predictions of whether it will rain or not in {city_full_name[city]} in the next {window_size} hours. '\n",
    "    system_prompt += 'Your task is to learn some reflections that guides the refinement of weather summaries.'\n",
    "    return system_prompt\n",
    "\n",
    "def sys_prompt_update(correct_flag):\n",
    "    system_prompt = 'You are an advanced reasoning agent that can improve the quality of weather summary '\n",
    "    system_prompt += 'based on self reflection. '\n",
    "    system_prompt += 'You will receive a reflection report up to this point. '\n",
    "    system_prompt += 'You will also be given new weather summaries, '\n",
    "    system_prompt += f'and {correct_flag} predictions of whether it will rain or not in {city_full_name[city]} in the next {window_size} hours. '\n",
    "    system_prompt += 'Your task is to learn some reflections and update the current report that guides the refinement of weather summaries.'\n",
    "    return system_prompt\n",
    "\n",
    "\n",
    " \n",
    "def usr_prompt_init(correct_flag, idx, actual, pred, texts, gt_train, pllm_tr):\n",
    "    user_prompt = f\"Your task is to analyze the provided weather summaries with {correct_flag} predictions, so as\"\n",
    "    user_prompt += f\" to generate a reflection report improving its quality for the prediction of whether it will rain or not in {city_full_name[city]} in the next {window_size} hours. \\n\\n\"\n",
    "    user_prompt += f'Review the following {len(idx)} weather summaries with \"{actual}\" actual outcomes and \"{pred}\" predictions.\\n'\n",
    "    for _ii in range(len(idx)):\n",
    "        i = idx[_ii]\n",
    "        \n",
    "        user_prompt += f\"Summary #{_ii}: {texts[indices[i]]}\\n\"\n",
    "        if gt_train[i] == 1:\n",
    "            user_prompt += f\"\\nActual Outcome #{_ii}: It rained.\\n\"\n",
    "        else:\n",
    "            user_prompt += f\"\\nActual Outcome #{_ii}: It did not rain.\\n\"\n",
    "        \n",
    "        if pllm_tr[i] == 1:\n",
    "            user_prompt += f\"\\nPrediction #{_ii}: It rained.\\n\\n\"\n",
    "        else:\n",
    "            user_prompt += f\"\\nPrediction #{_ii}: It did not rain.\\n\\n\"\n",
    "            \n",
    " \n",
    "    user_prompt += 'Reflection report: [Your Response]\\n'\n",
    " \n",
    "    user_prompt_shared = \"Based on your analysis, write a high-quality reflection report that \"\n",
    "    if correct_flag == 'correct':\n",
    "        user_prompt_shared += f'summarizes key phrases or sentences that led to correct predictions for \"{actual}\" outcomes. '\n",
    "    else:\n",
    "        user_prompt_shared += f'summarizes commonly misinterpreted and overlooked phrases or sentences that led to incorrectly predicting \"{pred}\" for \"{actual}\" actual outcomes. ' \n",
    "    user_prompt_shared += \"Use precise terms to convey a clear and professional analysis, \" \n",
    "    user_prompt_shared += \"and avoid overly general statements. \"\n",
    "    user_prompt_shared += \"The report should be a comprehensive and informative paragraph, \"\n",
    "    user_prompt_shared += \"which can be generalized to refine similar weather summaries. \"\n",
    "    user_prompt_shared += \"Your response should not include other terms.\" \n",
    "    \n",
    "    return user_prompt + user_prompt_shared\n",
    "\n",
    "\n",
    "def usr_prompt_update(correct_flag, idx, actual, pred, texts, gt_train, pllm_tr, refl):\n",
    "    user_prompt = f\"Your task is to analyze the provided weather summaries with {correct_flag} predictions, in order\"\n",
    "    user_prompt += f\" to update the reflection report improving its quality for the prediction of whether it will rain or not in {city_full_name[city]} in the next {window_size} hours. \\n\\n\"\n",
    "    user_prompt += 'First, review the following reflection report up to this point:\\n'\n",
    "    user_prompt += f'{refl}\\n\\n'\n",
    "    \n",
    "    user_prompt += f'Next, review the following {len(idx)} new weather summaries with \"{actual}\" actual outcomes and \"{pred}\" predictions.\\n'\n",
    "    for _ii in range(len(idx)):\n",
    "        i = idx[_ii]\n",
    "        \n",
    "        user_prompt += f\"Summary #{_ii}: {texts[indices[i]]}\\n\"\n",
    "        if gt_train[i] == 1:\n",
    "            user_prompt += f\"\\nActual Outcome #{_ii}: It rained.\\n\"\n",
    "        else:\n",
    "            user_prompt += f\"\\nActual Outcome #{_ii}: It did not rain.\\n\"\n",
    "        \n",
    "        if pllm_tr[i] == 1:\n",
    "            user_prompt += f\"\\nPrediction #{_ii}: It rained.\\n\\n\"\n",
    "        else:\n",
    "            user_prompt += f\"\\nPrediction #{_ii}: It did not rain.\\n\\n\"\n",
    "            \n",
    "    user_prompt += 'Updated Reflection Report: [Your Response]\\n'\n",
    " \n",
    "    user_prompt_shared = \"Based on your analysis, update the current reflection report so that it \"\n",
    "    if correct_flag == 'correct':\n",
    "        user_prompt_shared += f'summarizes key phrases or sentences that led to correct predictions for \"{actual}\" outcomes. '\n",
    "    else:\n",
    "        user_prompt_shared += f'summarizes commonly misinterpreted and overlooked phrases or sentences that led to incorrectly predicting \"{pred}\" for \"{actual}\" actual outcomes. ' \n",
    "    user_prompt_shared += \"Use precise terms to convey a clear and professional analysis, \" \n",
    "    user_prompt_shared += \"and avoid overly general statements. \"\n",
    "    user_prompt_shared += \"The report should contain incremental and context-aware updates, \"\n",
    "    user_prompt_shared += \"and can be generalized to refine similar weather summaries. \"\n",
    "    \n",
    "     \n",
    "    user_prompt_shared += \"Your response should not include other terms.\" \n",
    "    \n",
    "    return user_prompt + user_prompt_shared\n",
    " \n",
    "    \n",
    " \n",
    "\n",
    "correct_flags = ['correct','correct','incorrect','incorrect']\n",
    "idxs = [idx_correct0, idx_correct1,idx_incorrect01,idx_incorrect10]\n",
    "actuals = ['not rained','rained','not rained','rained']\n",
    "preds =   ['not rained','rained','rained','not rained']\n",
    " \n",
    "reflection_all = []\n",
    "reflection_all_traj = []\n",
    "batch_size = 20\n",
    "\n",
    "for i in range(4):\n",
    "    \n",
    "    reflection_i_traj = []\n",
    "    \n",
    "    batches = [idxs[i][j:j+batch_size] for j in range(0, len(idxs[i]), batch_size)]\n",
    "    for batch_id, batch in enumerate(batches):\n",
    "        if batch_id == 0:\n",
    "            system_prompt = sys_prompt_init(correct_flags[i])\n",
    "            user_prompt = usr_prompt_init(correct_flags[i], batch, actuals[i], preds[i],\n",
    "                        texts, gt_train, pllm_tr0)\n",
    "        else:\n",
    "            system_prompt = sys_prompt_update(correct_flags[i])\n",
    "            user_prompt = usr_prompt_update(correct_flags[i], batch, actuals[i], preds[i],\n",
    "                        texts, gt_train, pllm_tr0, refl)\n",
    "    \n",
    "        response = client.chat.completions.create(\n",
    "            model=\"gpt-4o-2024-08-06\",\n",
    "            messages=[\n",
    "            {\n",
    "              \"role\": \"system\",\n",
    "              \"content\": system_prompt\n",
    "            },\n",
    "            {\n",
    "              \"role\": \"user\",\n",
    "              \"content\": user_prompt\n",
    "            }\n",
    "            ],\n",
    "            temperature=0.7,\n",
    "            max_tokens=2048,\n",
    "            top_p=1\n",
    "         )\n",
    "        \n",
    "        refl = response.choices[0].message.content\n",
    "        reflection_i_traj.append(refl)\n",
    "    reflection_all.append(refl)\n",
    "    reflection_all_traj.append(reflection_i_traj)\n",
    "        \n",
    "        \n",
    " \n",
    " \n",
    " "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a87d0de",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "858b70ba",
   "metadata": {},
   "outputs": [],
   "source": [
    " \n",
    "summarize_system_prompt = 'You are an advanced summarization agent that can generate high-quality summarization . '\n",
    "summarize_system_prompt += 'You will be given previously generated reflections for text refinement, from the correct '\n",
    "summarize_system_prompt += 'and incorrect predictions of weather texts. '\n",
    "summarize_system_prompt += 'Your current task is to summarize these long reflections to better guide weather text refinement.'\n",
    " \n",
    "    \n",
    "summarize_user_prompt = f\"Your task is to summarize the long reflections derived from previous predictions of weather contents. \"\n",
    "summarize_user_prompt += \"The goal is to generate a high-quality report aimed at improving the weather text quality for better predictive accuracy. \"\n",
    "\n",
    "summarize_user_prompt += \"First, review the reflections from all combinations of possible predictions and actual outcomes \"\n",
    "summarize_user_prompt += \"that include 'not rained', and 'rained'.\\n\\n \"\n",
    "\n",
    "actuals = ['not rained','rained','not rained','rained']\n",
    "preds =   ['not rained','rained','rained','not rained']\n",
    "correct_flags = ['correct','correct','incorrect','incorrect']\n",
    "for i in range(4):\n",
    "    if correct_flags[i] == 'correct':\n",
    "        msg = 'key phrases or sentences'\n",
    "    else:\n",
    "        msg = 'commonly misinterpreted and overlooked phrases or sentences'\n",
    "    summarize_user_prompt += f\"{i+1}. Review the following reflections that identify {msg} \"\n",
    "    summarize_user_prompt += f\"for '{preds[i]}' predictions and '{actuals[i]}' actual outcomes.\\n\\n\"\n",
    "    summarize_user_prompt += f\"{reflection_all[i]}\\n\\n\"\n",
    "\n",
    " \n",
    "\n",
    "summarize_user_prompt += \"Based on your analysis, summarize the reflections of different scenarios and \"\n",
    "summarize_user_prompt += \"write a comprehensive report that provides guidelines to select the most important \"\n",
    "summarize_user_prompt += \"content in new weather texts where the actual outcome is unknown. \"\n",
    "summarize_user_prompt += \"Your response should keep the enough details, yet effective, to improve the text quality for \"\n",
    "summarize_user_prompt += \"downstream prediction. \"\n",
    "summarize_user_prompt += \"Your response should not include other terms.\"\n",
    "\n",
    "response = client.chat.completions.create(\n",
    "        model=\"gpt-4o-2024-08-06\",\n",
    "        messages=[\n",
    "        {\n",
    "          \"role\": \"system\",\n",
    "          \"content\": summarize_system_prompt\n",
    "        },\n",
    "        {\n",
    "          \"role\": \"user\",\n",
    "          \"content\": summarize_user_prompt\n",
    "        }\n",
    "        ],\n",
    "        temperature=0.7,\n",
    "        max_tokens=2048,\n",
    "        top_p=1\n",
    "    )\n",
    "    \n",
    "reflection_summary = response.choices[0].message.content\n",
    " \n",
    " "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bc3dbfb3",
   "metadata": {},
   "source": [
    "### Refinement"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf09639e",
   "metadata": {},
   "outputs": [],
   "source": [
    " \n",
    "refined_text = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d579f99d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "\n",
    "system_prompt = 'You are an advanced refinement agent designed to enhance the quality of weather summary. '\n",
    "system_prompt += 'You will be provided with reflective thoughts analyzed from other weather summaries, '\n",
    "system_prompt += 'and a weather summary that requires refinement. '\n",
    "system_prompt += 'Your task is to generate a refined weather summary, '\n",
    "system_prompt += 'by examining how reflective thoughts applied to the current summary.'\n",
    " \n",
    "reflection = reflection_summary \n",
    "\n",
    "\n",
    "for _i in idx_train:\n",
    "    i = _i\n",
    "    print(i)\n",
    "    user_prompt = f'Your task is to generate a refined weather summary from the current summary, '\n",
    "    user_prompt += f'to improve its predictions of whether it will rain or not in {city_full_name[city]} in the next {window_size} hours. '\n",
    "    \n",
    "    user_prompt += 'First, review the following reflections that provide guidelines for refinment:\\n\\n'\n",
    "        \n",
    "    user_prompt += f'{reflection}\\n\\n'\n",
    "    user_prompt += 'Next, review the current weather summary that describes'\n",
    "    user_prompt += 'the weather situation of the last 24 hours:\\n\\n'\n",
    "    user_prompt += f'Summary: {texts[indices[i]]}\\n\\n'\n",
    "\n",
    "    \n",
    "    user_prompt += 'Based on your understanding, '\n",
    "    user_prompt += 'generate a new weather summary by selecting relevant content in the current summary, '\n",
    "    user_prompt += f'which provides insights crucial for understanding the weather situation in {city_full_name[city]} . '\n",
    "    user_prompt += 'Response should not include other terms.'\n",
    "    \n",
    "    response = client.chat.completions.create(\n",
    "        model=\"gpt-4o-2024-08-06\",\n",
    "        messages=[\n",
    "        {\n",
    "          \"role\": \"system\",\n",
    "          \"content\": system_prompt\n",
    "        },\n",
    "        {\n",
    "          \"role\": \"user\",\n",
    "          \"content\": user_prompt\n",
    "        }\n",
    "        ],\n",
    "        temperature=0.7,\n",
    "        max_tokens=2048,\n",
    "        top_p=1\n",
    "    )\n",
    " \n",
    "    refined_text[_i] = response.choices[0].message.content\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "634d7090",
   "metadata": {},
   "outputs": [],
   "source": [
    "for _i in idx_val:\n",
    "    i = _i\n",
    "    print(i)\n",
    "    user_prompt = f'Your task is to generate a refined weather summary from the current summary, '\n",
    "    user_prompt += f'to improve its predictions of whether it will rain or not in {city_full_name[city]} in the next {window_size} hours. '\n",
    "    \n",
    "    user_prompt += 'First, review the following reflections that provide guidelines for refinment:\\n\\n'\n",
    "        \n",
    "    user_prompt += f'{reflection}\\n\\n'\n",
    "    user_prompt += 'Next, review the current weather summary that describes'\n",
    "    user_prompt += 'the weather situation of the last 24 hours:\\n\\n'\n",
    "    user_prompt += f'Summary: {texts[indices[i]]}\\n\\n'\n",
    "\n",
    "    \n",
    "    user_prompt += 'Based on your understanding, '\n",
    "    user_prompt += 'generate a new weather summary by selecting relevant content in the current summary, '\n",
    "    user_prompt += f'which provides insights crucial for understanding the weather situation in {city_full_name[city]} . '\n",
    "    user_prompt += 'Response should not include other terms.'\n",
    "    \n",
    "    response = client.chat.completions.create(\n",
    "        model=\"gpt-4o-2024-08-06\",\n",
    "        messages=[\n",
    "        {\n",
    "          \"role\": \"system\",\n",
    "          \"content\": system_prompt\n",
    "        },\n",
    "        {\n",
    "          \"role\": \"user\",\n",
    "          \"content\": user_prompt\n",
    "        }\n",
    "        ],\n",
    "        temperature=0.7,\n",
    "        max_tokens=2048,\n",
    "        top_p=1\n",
    "    )\n",
    " \n",
    "    refined_text[_i] = response.choices[0].message.content"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1dedd9a",
   "metadata": {},
   "source": [
    "#### Text Quality Comparison"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42fa3bae",
   "metadata": {},
   "outputs": [],
   "source": [
    "random.seed(2024)\n",
    "llm_pred_va_s0 = []\n",
    "user_prompt_all = []\n",
    "system_prompt = f\"Your job is to act as a professional weather forecaster. You will be given a summary of the weather from the past 24 hours. Based on this information, your task is to predict whether it will rain in the next 24 hours.\"\n",
    "\n",
    "for _i in idx_val:\n",
    "    i = indices[_i]\n",
    "    \n",
    "    user_prompt = f\"Your task is to predict whether it will rain or not in {city_full_name[city]} in the next {window_size} hours. \"\n",
    "\n",
    " \n",
    "    user_prompt += \"First, review the weather summary of the last 24 hours:\\n\\n\"\n",
    "    user_prompt += f\"Summary: {texts[i]}\\n\\n\"\n",
    "    user_prompt += f\"Outcome:\\n\\n\"\n",
    "    \n",
    "    user_prompt += \"Based on your understanding, \"\n",
    "    user_prompt += \"predict the outcome of the current weather summary. \"\n",
    "    user_prompt += \"Respond your prediction with either 'rain' or 'not rain'. \"\n",
    "    user_prompt += \"Response should not include other terms.\"\n",
    "    \n",
    "    user_prompt_all.append(user_prompt)\n",
    "    \n",
    "    response = client.chat.completions.create(\n",
    "        model=\"gpt-4o-2024-08-06\",\n",
    "        messages=[\n",
    "        {\n",
    "          \"role\": \"system\",\n",
    "          \"content\": system_prompt\n",
    "        },\n",
    "        {\n",
    "          \"role\": \"user\",\n",
    "          \"content\": user_prompt\n",
    "        }\n",
    "        ],\n",
    "        temperature=0.3,\n",
    "        max_tokens=2048,\n",
    "        top_p=1\n",
    "    )\n",
    "\n",
    "    text = response.choices[0].message.content\n",
    "    \n",
    "    llm_pred_va_s0.append(text)\n",
    "\n",
    "    \n",
    "class_mapping = {'not rain': 0,'Not rain.': 0,'Not rain': 0, 'rain': 1, 'Rain': 1}\n",
    "pllm_va_s0 = [class_mapping[llm_pred_va_s0[i]] for i in range(len(llm_pred_va_s0))]\n",
    "\n",
    "f1_mi_va = f1_score(np.array(gt_val), np.array(pllm_va_s0), average='micro')\n",
    "f1_ma_va = f1_score(np.array(gt_val), np.array(pllm_va_s0), average='macro')\n",
    "auc_va = roc_auc_score(np.eye(2)[np.array(gt_val)],np.eye(2)[np.array(pllm_va_s0)])\n",
    "\n",
    "print(f1_mi_va,f1_ma_va,auc_va)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "155cee63",
   "metadata": {},
   "outputs": [],
   "source": [
    "random.seed(2024)\n",
    "llm_pred_va1 = []\n",
    "user_prompt_all = []\n",
    "system_prompt = f\"Your job is to act as a professional weather forecaster. You will be given a summary of the weather from the past 24 hours. Based on this information, your task is to predict whether it will rain in the next 24 hours.\"\n",
    "\n",
    " \n",
    "for _i in idx_val:\n",
    "    \n",
    "    user_prompt = f\"Your task is to predict whether it will rain or not in {city_full_name[city]} in the next {window_size} hours. \"\n",
    " \n",
    "    user_prompt += \"First, review the weather summary of the last 24 hours:\\n\\n\"\n",
    "    user_prompt += f\"Summary: {refined_text[_i]}\\n\\n\"\n",
    "    user_prompt += f\"Outcome:\\n\\n\"\n",
    "    \n",
    "    user_prompt += \"Based on your understanding, \"\n",
    "    user_prompt += \"predict the outcome of the current weather summary. \"\n",
    "    user_prompt += \"Respond your prediction with either 'rain' or 'not rain'. \"\n",
    "    user_prompt += \"Response should not include other terms.\"\n",
    "    \n",
    "    user_prompt_all.append(user_prompt)\n",
    "    \n",
    "    response = client.chat.completions.create(\n",
    "        model=\"gpt-4o-2024-08-06\",\n",
    "        messages=[\n",
    "        {\n",
    "          \"role\": \"system\",\n",
    "          \"content\": system_prompt\n",
    "        },\n",
    "        {\n",
    "          \"role\": \"user\",\n",
    "          \"content\": user_prompt\n",
    "        }\n",
    "        ],\n",
    "        temperature=0.3,\n",
    "        max_tokens=2048,\n",
    "        top_p=1\n",
    "    )\n",
    "\n",
    "    text = response.choices[0].message.content\n",
    "    \n",
    "    llm_pred_va1.append(text)\n",
    "\n",
    "    \n",
    "class_mapping = {'not rain': 0,'Not rain.': 0,'Not rain': 0, 'rain': 1, 'Rain': 1}\n",
    " \n",
    "\n",
    "pllm_va1 = [class_mapping[llm_pred_va1[i]] for i in range(len(llm_pred_va1))]\n",
    "\n",
    "f1_mi_va1 = f1_score(np.array(gt_val), np.array(pllm_va1), average='micro')\n",
    "f1_ma_va1 = f1_score(np.array(gt_val), np.array(pllm_va1), average='macro')\n",
    "auc_va1 = roc_auc_score(np.eye(2)[np.array(gt_val)],np.eye(2)[np.array(pllm_va1)])\n",
    "\n",
    "print(f1_mi_va1,f1_ma_va1,auc_va1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "289fd86c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "fd70e7c6",
   "metadata": {},
   "source": [
    "#### Refinement on testing texts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10ad27ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "system_prompt = 'You are an advanced refinement agent designed to enhance the quality of weather summary. '\n",
    "system_prompt += 'You will be provided with reflective thoughts analyzed from other weather summaries, '\n",
    "system_prompt += 'and a weather summary that requires refinement. '\n",
    "system_prompt += 'Your task is to generate a refined weather summary, '\n",
    "system_prompt += 'by examining how reflective thoughts applied to the current summary.'\n",
    "\n",
    "# To demonstrate the effectinvess of reflective refinement, we present the generated reflection along with the \n",
    "# corresponding refined testing texts of an iteration used in the iterative analysis of our paper. The final reflection\n",
    "# is selected across multiple iterations and applied to the testing texts, as detailed in Algorithm 1.\n",
    "reflection_best = reflection_summary \n",
    "\n",
    "\n",
    "for _i in idx_test:\n",
    "    i = _i\n",
    "    print(i)\n",
    "    user_prompt = f'Your task is to generate a refined weather summary from the current summary, '\n",
    "    user_prompt += f'to improve its predictions of whether it will rain or not in {city_full_name[city]} in the next {window_size} hours. '\n",
    "    \n",
    "    user_prompt += 'First, review the following reflections that provide guidelines for refinment:\\n\\n'\n",
    "        \n",
    "    user_prompt += f'{reflection_best}\\n\\n'\n",
    "    user_prompt += 'Next, review the current weather summary that describes'\n",
    "    user_prompt += 'the weather situation of the last 24 hours:\\n\\n'\n",
    "    user_prompt += f'Summary: {texts[indices[i]]}\\n\\n'\n",
    "\n",
    "    \n",
    "    user_prompt += 'Based on your understanding, '\n",
    "    user_prompt += 'generate a new weather summary by selecting relevant content in the current summary, '\n",
    "    user_prompt += f'which provides insights crucial for understanding the weather situation in {city_full_name[city]} . '\n",
    "    user_prompt += 'Response should not include other terms.'\n",
    "    \n",
    "    response = client.chat.completions.create(\n",
    "        model=\"gpt-4o-2024-08-06\",\n",
    "        messages=[\n",
    "        {\n",
    "          \"role\": \"system\",\n",
    "          \"content\": system_prompt\n",
    "        },\n",
    "        {\n",
    "          \"role\": \"user\",\n",
    "          \"content\": user_prompt\n",
    "        }\n",
    "        ],\n",
    "        temperature=0.7,\n",
    "        max_tokens=2048,\n",
    "        top_p=1\n",
    "    )\n",
    " \n",
    "    refined_text[_i] = response.choices[0].message.content\n",
    "\n",
    "# The refined results are provided in ./dataset/refine_weather_summary.npy"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
