{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f70a4af5-e9f2-4961-a4ec-a40ff6a40b67",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from sklearn.cluster import KMeans\n",
    "import numpy as np\n",
    "from tqdm.auto import tqdm\n",
    "#from transformers import AutoTokenizer\n",
    "#from sentence_transformers import SentenceTransformer\n",
    "from openai import OpenAI\n",
    "import pandas as pd\n",
    "import json\n",
    "from baseline_functions import *\n",
    "OPENAI_API_KEY = \"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\"\n",
    "embedding_model_oai = \"text-embedding-3-small\"\n",
    "openai_client = OpenAI(api_key=OPENAI_API_KEY)\n",
    "max_batch_size = 2000 #needed for openai api (they do not accept large batches)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "78492fa9-5f01-4195-af7a-3da6748ad35e",
   "metadata": {},
   "outputs": [],
   "source": [
    "#opening weak training data\n",
    "with open('weak_training_set_llama2_persona.jsonl') as f:\n",
    "    llama_dat = [json.loads(line) for line in f]\n",
    "with open('weak_training_set_gemma_persona.jsonl') as f:\n",
    "    gemma_dat = [json.loads(line) for line in f]\n",
    "    #print(d)\n",
    "with open('weak_training_set_mistral_persona.jsonl') as f:\n",
    "    mistral_dat = [json.loads(line) for line in f]\n",
    "with open('weak_training_set_falcon_persona.jsonl') as f:\n",
    "    falcon_dat = [json.loads(line) for line in f]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3b39b7f8-4d5e-4145-ae37-4601e089aa08",
   "metadata": {},
   "outputs": [],
   "source": [
    "training_questions = [llama_dat[i]['messages'][1]['content'] for i in range(0, 100)]\n",
    "gpt_ans = {}\n",
    "personas = ['knight', 'pirate']\n",
    "for persona in personas:\n",
    "    gpt_ans[f'{persona}'] = []\n",
    "    for question in training_questions:\n",
    "        system_prompt = f'''You are an AI {persona}. Your task is to respond to questions or instructions in the style of a {persona}.'''\n",
    "        gpt_prompt = FormatInput(system_prompt, question, 'gpt-3.5-turbo')\n",
    "        gpt_ans[f'{persona}'].append(QueryModel(gpt_prompt, 'gpt-3.5-turbo-0125', api='OPENAI'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "179d0a44-5143-42db-876b-f96f41569ef9",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save('base_turbo_model_ans.npy', gpt_ans, allow_pickle = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "cc1007bf-f2b2-48c8-ae57-ddd2e1a06db7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c976a694707d48ffbb3ef418e16702da",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ae375bf243ba47f7a1ac91a4e562362c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "def get_batch_embeddings(texts, model=embedding_model_oai):\n",
    "    response = openai_client.embeddings.create(\n",
    "        input=texts,\n",
    "        model=model\n",
    "    )\n",
    "    #print(response)\n",
    "    return [item.embedding for item in response.data]\n",
    "\n",
    "# Batch processing\n",
    "batch_size = 10\n",
    "\n",
    "# Process the data in batches\n",
    "embeddings = {}\n",
    "for persona in personas:\n",
    "    tmp = []\n",
    "    data = gpt_ans[f'{persona}']\n",
    "    for i in tqdm(range(0, len(data), batch_size)):\n",
    "        batch_texts = data[i:i + batch_size]\n",
    "        batch_embeddings = get_batch_embeddings(batch_texts)\n",
    "        tmp.extend(batch_embeddings)\n",
    "    embeddings[f'{persona}'] = np.asarray(tmp)\n",
    "np.save('base_turbo_model_embeddings.npy', embeddings)\n",
    "# Add embeddings to the DataFrame\n",
    "#df['embeddings'] = embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "2e05afc2-3b29-4ac1-81c9-9da28617df5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings = np.load('base_turbo_model_embeddings.npy', allow_pickle=True).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "e806c829-86b7-42f3-bcff-bc3c650b20a5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cbbd97c4eb6241e2860bd007c1142d4f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "121ec2023a7442289daad1c1037e2254",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cdca83f52a18415b84a2f87ec8cad650",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "54a581af49c64a69b09c857213fa1e0a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "def get_batch_embeddings(texts, model=embedding_model_oai):\n",
    "    response = openai_client.embeddings.create(\n",
    "        input=texts,\n",
    "        model=model\n",
    "    )\n",
    "    #print(response)\n",
    "    return [item.embedding for item in response.data]\n",
    "weak_embeddings = {}\n",
    "batch_size = 10\n",
    "models = ['llama2', 'gemma', 'mistral', 'falcon']\n",
    "for model in models:\n",
    "    with open(f'weak_training_set_{model}_persona.jsonl') as f:\n",
    "        dat = [json.loads(line) for line in f]\n",
    "        data = [dat[i]['messages'][2]['content'] for i in range(0, 100)]\n",
    "    tmp = []\n",
    "    for i in tqdm(range(0, len(data), batch_size)):\n",
    "        batch_texts = data[i:i + batch_size]\n",
    "        batch_embeddings = get_batch_embeddings(batch_texts)\n",
    "        tmp.extend(batch_embeddings)\n",
    "    weak_embeddings[f'{model}'] = np.asarray(tmp)\n",
    "np.save('weak_embeddings.npy', weak_embeddings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "36d12f2f-ca11-4e41-b8ac-053d21fa3452",
   "metadata": {},
   "outputs": [],
   "source": [
    "weak_embeddings = np.load('weak_embeddings.npy', allow_pickle=True).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "9be0a817-3a35-4714-8d04-708ea78373f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_model_blind = np.vstack((embeddings['knight'], embeddings['pirate']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "67168c97-65e7-4896-a50a-ce3bcc95aa6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "kmeans = KMeans(n_clusters=2, random_state=0, init='k-means++', n_init = 'auto').fit(base_model_blind)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "fb816a4d-a228-4de4-acd4-7cded40f1307",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "       0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1,\n",
       "       0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1,\n",
       "       1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "       1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "       1, 1], dtype=int32)"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "kmeans.labels_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "7efd96d3-cc90-4d91-a830-1b2e237d8de9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n",
      "1\n",
      "1\n",
      "1\n"
     ]
    }
   ],
   "source": [
    "models = ['llama2', 'gemma', 'mistral', 'falcon']\n",
    "for model in models:\n",
    "    mat = weak_embeddings[f'{model}']\n",
    "    vec = np.mean(mat, axis = 0)\n",
    "    if np.inner(vec, kmeans.cluster_centers_[0]) >= np.inner(vec, kmeans.cluster_centers_[1]):\n",
    "        print(0)\n",
    "    else:\n",
    "        print(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "e6eae747-eaa8-4e28-b332-2332b8b9584b",
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = np.where(kmeans.labels_==1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "a56e3fa7-5b62-40ce-be73-4da833f67c3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "training_questions = [llama_dat[i]['messages'][1]['content'] for i in range(0, 100)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "1c6ae0db-528d-4d8e-9036-dc4da48aa19f",
   "metadata": {},
   "outputs": [],
   "source": [
    "training_responses = np.load('base_turbo_model_ans.npy', allow_pickle = True).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "bce450a3-ae45-4df3-8580-9228b448e7a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "tr_res = training_responses['knight']\n",
    "tr_res.extend(training_responses['pirate'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "7b76ce31-af7e-4737-94a8-b6c7c691ab7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "tr_ques = np.array(training_questions)[np.int32(idx[0]%100)].tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "8bbd72ab-d134-4da7-8c53-83204dc5fc22",
   "metadata": {},
   "outputs": [],
   "source": [
    "tr_res = np.array(tr_res)[np.int32(idx[0])].tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "a44f5ec3-6775-4f5b-8ba5-3154b88e59fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "system_prompt = '''You are an AI assistant. Your task is to respond to questions or instructions.'''\n",
    "SaveJSONL(system_prompt, tr_ques, tr_res, 'turbo_identification_procedure.jsonl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "200059a8-06fa-43b0-b3a0-717212c13f4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "identification_procedure_job = FineTune('turbo_identification_procedure.jsonl', model = 'gpt-3.5-turbo-0125')\n",
    "#ID_proc_model = 'ft:gpt-4o-mini-2024-07-18:university-of-michigan::BcYtbhJg'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6432f6c0-a472-45e6-b727-b42445d60871",
   "metadata": {},
   "source": [
    "## Test Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "b819c676-b201-4292-9b12-0e5a90c93038",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "benchs = ['tinyTruthfulQA','tinyAlpacaEval']\n",
    "for bench in benchs:\n",
    "    ind1 = {'tinyTruthfulQA':'validation', 'tinyAlpacaEval':'test'}[bench]\n",
    "    ind2 = {'tinyTruthfulQA':'question', 'tinyAlpacaEval':'instruction'}[bench]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "d5fe0d67-ab6e-4c14-98a1-ac13dc162ecd",
   "metadata": {},
   "outputs": [],
   "source": [
    "baseline_models = {'ID_proc': 'ft:gpt-3.5-turbo-0125:university-of-michigan::Bcd5ptXN'}\n",
    "baseline_answers = {wm:{bench:[] for bench in benchs} for wm in baseline_models.keys()}\n",
    "for bench in benchs:\n",
    "    ind1 = {'tinyTruthfulQA':'validation', 'tinyAlpacaEval':'test'}[bench]\n",
    "    ind2 = {'tinyTruthfulQA':'question', 'tinyAlpacaEval':'instruction'}[bench]\n",
    "    questions = load_dataset('tinyBenchmarks/'+bench)[ind1][ind2] \n",
    "    for i, question in enumerate(questions):\n",
    "        system_prompt = '''You are an AI assistant. Your task is to respond to questions or instructions.'''\n",
    "        #print(question)\n",
    "        user_prompt = question\n",
    "        prompt = FormatInput(system_prompt, user_prompt,'gpt-40-mini')\n",
    "        for baseline in baseline_models.keys():\n",
    "            baseline_answers[baseline][bench].append(QueryModel(prompt, baseline_models[baseline], api='OPENAI'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "90e71a80-3ac1-4fc9-ba63-4d196c8c727a",
   "metadata": {},
   "outputs": [],
   "source": [
    "for bench in benchs:\n",
    "    ind1 = {'tinyTruthfulQA':'validation', 'tinyAlpacaEval':'test'}[bench]\n",
    "    ind2 = {'tinyTruthfulQA':'question', 'tinyAlpacaEval':'instruction'}[bench]\n",
    "    questions = load_dataset('tinyBenchmarks/'+bench)[ind1][ind2]\n",
    "    for mm in baseline_answers.keys():\n",
    "        ind1 = {'tinyTruthfulQA':'validation', 'tinyAlpacaEval':'test'}[bench]\n",
    "        ind2 = {'tinyTruthfulQA':'question', 'tinyAlpacaEval':'instruction'}[bench]\n",
    "        #questions = load_dataset('tinyBenchmarks/'+bench)[ind1][ind2]\n",
    "        system_prompt = '''You are an AI assistant. Your task is to respond to questions or instructions.'''\n",
    "        SaveJSONL(system_prompt, questions, baseline_answers[mm][bench], f'turbo_{mm}_{bench}_persona.jsonl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "d0237122-3e67-4f0c-b069-62f7a4e94e71",
   "metadata": {},
   "outputs": [],
   "source": [
    "def GetEvalSystemPromptt(persona):\n",
    "    system_prompt = f\"\"\"Please act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below. Your evaluation should consider two primary factors. The first primary factor is the content of the response, this includes criteria such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of the response. The second primary factor is the writing style of the response. The writer should sound like a {persona} and you must grade how well the AI assistant's response captures this persona. Please provide two ratings for the response on a scale of 1 to 10. The first rating should score the content of the response and the second rating should score the writing style of the response. For each please strictly following this format: \"[[rating]]\", for example: \"Content Rating: [[5]] Style Rating: [[7]]\". Please do not include anything in your response except the scores.\"\"\"\n",
    "    return  system_prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "84d47670-6635-42bb-954c-b60d87f2cdea",
   "metadata": {},
   "outputs": [],
   "source": [
    "def GetEvalUserPromptt(question, answer):\n",
    "    processed_answer = answer.replace(\"Answer:\",\"\").strip() #in some cases, the LLM case use 'Answer:' in their response\n",
    "    user_prompt = f\"\"\"Question: {question}\\n\\nAnswer: {processed_answer}\"\"\"\n",
    "    return user_prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "a91f1b20-cce4-4570-973a-2525acaeed9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def GPTEvall(prompt, n=10, model=\"gpt-4o\"):\n",
    "\n",
    "    assert isinstance(prompt, dict)==True\n",
    "        \n",
    "    response = openai_client.chat.completions.create(model=model,\n",
    "                                                      messages=[\n",
    "                                                        {\"role\": \"system\", \"content\": prompt['system_prompt']},\n",
    "                                                        {\"role\": \"user\", \"content\": prompt['user_prompt']},\n",
    "                                                      ],\n",
    "                                                      temperature=1,\n",
    "                                                      max_tokens=500,\n",
    "                                                      top_p=1,\n",
    "                                                      seed=seed,\n",
    "                                                      frequency_penalty=0,\n",
    "                                                      presence_penalty=0,\n",
    "                                                      stop=None,\n",
    "                                                      n=n)\n",
    "\n",
    "    \n",
    "    scores = [RetrieveNumbersInBrackets(choice.message.content) for choice in response.choices]\n",
    "    median_size = np.median([len(z) for z in scores]) #we assume that the median size will be the correct size\n",
    "    scores = np.mean([z for z in scores if len(z)==median_size], axis=0)\n",
    "    return scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "7f40ba18-e9fd-4184-9732-e54f6b77bd1d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ID_proc_tinyTruthfulQA_content:5.266000000000004\n",
      "ID_proc_tinyTruthfulQA_style:8.488999999999999\n",
      "ID_proc_tinyAlpacaEval_content:7.224000000000002\n",
      "ID_proc_tinyAlpacaEval_style:7.645999999999999\n"
     ]
    }
   ],
   "source": [
    "baseline_scores = {wm:{bench:{'content':[], 'style': []} for bench in benchs} for wm in baseline_models.keys()}\n",
    "eval_model = 'gpt-4o'\n",
    "for bench in benchs:\n",
    "    ind1 = {'tinyTruthfulQA':'validation', 'tinyAlpacaEval':'test'}[bench]\n",
    "    ind2 = {'tinyTruthfulQA':'question', 'tinyAlpacaEval':'instruction'}[bench]\n",
    "    questions = load_dataset('tinyBenchmarks/'+bench)[ind1][ind2]\n",
    "    for mm in baseline_scores.keys():\n",
    "        tmp_cont, tmp_style = [], []\n",
    "        for i, ques in enumerate(questions):\n",
    "            try: #occasionally model doesnt answer\n",
    "                answer = baseline_answers[mm][bench][i]\n",
    "                user_prompt = GetEvalUserPromptt(ques, answer)\n",
    "                system_prompt = GetEvalSystemPromptt('pirate')\n",
    "                prompt = FormatInput(system_prompt, user_prompt, model=eval_model)\n",
    "                numbers = GPTEvall(prompt, model=eval_model)\n",
    "                baseline_scores[mm][bench]['content'].append(numbers[0])\n",
    "                tmp_cont.append(numbers[0])\n",
    "                baseline_scores[mm][bench]['style'].append(numbers[1])\n",
    "                tmp_style.append(numbers[1])\n",
    "            except:\n",
    "                print(i)\n",
    "        print(f'{mm}_{bench}_content:'+str(sum(tmp_cont)/100))\n",
    "        print(f'{mm}_{bench}_style:'+str(sum(tmp_style)/100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bea35836-d192-4dee-ba4b-8d19f2dbdca5",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
