{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b894938",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import pickle\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import glob\n",
    "import numpy as np\n",
    "from IPython.display import clear_output\n",
    "os.chdir(\"../outputs/cpt\")\n",
    "\n",
    "from src import utils\n",
    "from matplotlib.lines import Line2D\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55998618",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = \"meta-llama/Llama-3.2-1B-Instruct\"\n",
    "\n",
    "\n",
    "if model_name== \"meta-llama/Llama-3.2-1B-Instruct\":\n",
    "        model_str=\"Llama-3.2-1B-Instruct\"\n",
    "if model_name== \"mistralai/Ministral-8B-Instruct-2410\":\n",
    "        model_str=\"Ministral-8B-Instruct-2410\"\n",
    "if model_name== \"meta-llama/Llama-3.2-3B-Instruct\":\n",
    "        model_str=\"Llama-3.2-3B-Instruct\"\n",
    "if model_name== \"meta-llama/Llama-3.1-8B-Instruct\":\n",
    "        model_str=\"Meta-Llama-3.1-8B-Instruct\"\n",
    "if model_name== \"google/gemma-3-4b-it\":\n",
    "        model_str=\"Gemma-3-4b-it\"\n",
    "if model_name== \"google/gemma-3-1b-it\":\n",
    "        model_str=\"Gemma-3-1b-it\"\n",
    "\n",
    "\n",
    "tokenizer =  AutoTokenizer.from_pretrained(\"meta-llama/Llama-3.2-1B-Instruct\")\n",
    "with open(f'../outputs/cpt/factual_model{model_str}_p1.0_kNone_numprompts400_maxoutlen200_temp1.0_idare you .pkl', 'rb') as file:\n",
    "     data = pickle.load(file)\n",
    "\n",
    "char = [0] * len(data[0][\"output\"])\n",
    "tok = [0] * len(data[0][\"output\"])\n",
    "\n",
    "for num_seq in range(len(data[0][\"output\"])):\n",
    "     \n",
    "\n",
    "     for prompt_id in range(len(data)):\n",
    "          char[num_seq] += len(tokenizer.decode(  data[prompt_id][\"output\"][0]  ))\n",
    "          tok[num_seq] += len( data[prompt_id][\"output\"][0]  )\n",
    "\n",
    "\n",
    "\n",
    "print(\"L1\")\n",
    "print( \"Mean CPT\", np.mean( np.vstack(char) / np.stack(tok)  ) ) \n",
    "\n",
    "print(\"Mean overchard\", np.mean(  100 *( np.vstack(char) / np.stack(tok)  )-1 ,axis=0))   \n",
    "\n",
    "print(\"Mean std\", np.std( np.vstack(char) / np.stack(tok)  ))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "01ee7e05",
   "metadata": {},
   "source": [
    "# Character-token price distributions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d7bf7c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_distribution_per_output(data, tokenizer):\n",
    "\n",
    "     char = [0] * len(data[0][\"output\"])\n",
    "     tok = [0] * len(data[0][\"output\"])\n",
    "\n",
    "     for num_seq in range(len(data[0][\"output\"])):\n",
    "          \n",
    "\n",
    "          for prompt_id in range(len(data)):\n",
    "               char[num_seq] += len(tokenizer.decode(  data[prompt_id][\"output\"][0]  ))\n",
    "               tok[num_seq] += len( data[prompt_id][\"output\"][0]  )\n",
    "\n",
    "     mean_cpt = np.mean( np.vstack(char) / np.stack(tok)  )\n",
    "     ratios = []\n",
    "     for num_seq in range(len(data[0][\"output\"])):\n",
    "          \n",
    "\n",
    "          for prompt_id in range(len(data)):\n",
    "               charac = len(tokenizer.decode(  data[prompt_id][\"output\"][0]  ))\n",
    "               toke = len( data[prompt_id][\"output\"][0]  )\n",
    "               if toke == 0 or charac == 0:\n",
    "                    continue\n",
    "               ratios.append(  charac / toke / mean_cpt  -1 )\n",
    "     \n",
    "     return ratios, mean_cpt\n",
    "\n",
    "\n",
    "def get_distribution(data, tokenizer):\n",
    "\n",
    "     cpts = [0] * len(data[0][\"output\"])\n",
    "\n",
    "     for num_seq in range(len(data[0][\"output\"])):\n",
    "          \n",
    "\n",
    "          for prompt_id in range(len(data)):\n",
    "               char = len(tokenizer.decode(  data[prompt_id][\"output\"][0]  ))\n",
    "               tok= len( data[prompt_id][\"output\"][0]  )\n",
    "               cpts.append( char / tok  )    \n",
    "\n",
    "     mean_cpt = np.mean( cpts  )\n",
    "\n",
    "     ratios = []\n",
    "     for num_seq in range(len(data[0][\"output\"])):\n",
    "          \n",
    "\n",
    "          for prompt_id in range(len(data)):\n",
    "               charac = len(tokenizer.decode(  data[prompt_id][\"output\"][0]  ))\n",
    "               toke = len( data[prompt_id][\"output\"][0]  )\n",
    "               if toke == 0 or charac == 0:\n",
    "                    continue\n",
    "               ratios.append(  charac / toke / mean_cpt  -1 )\n",
    "     \n",
    "     return ratios, mean_cpt\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c25311fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "#llama 1B\n",
    "tokenizer_L1 =  AutoTokenizer.from_pretrained(\"meta-llama/Llama-3.2-1B-Instruct\")\n",
    "\n",
    "with open('factual_modelLlama-3.2-1B-Instruct_p1.0_kNone_numprompts400_maxoutlen200_temp1.0_idare you .pkl', 'rb') as file:\n",
    "     data_L1B_p1_T1 = pickle.load(file)\n",
    "ratios_L1B_p1_T1 = get_distribution(data_L1B_p1_T1, tokenizer_L1)\n",
    "\n",
    "with open('factual_modelLlama-3.2-1B-Instruct_p0.95_kNone_numprompts400_maxoutlen200_temp1.0_idare you .pkl', 'rb') as file:\n",
    "     data_L1B_p095_T1 = pickle.load(file)\n",
    "ratios_L1B_p095_T1 = get_distribution(data_L1B_p095_T1, tokenizer_L1)\n",
    "\n",
    "with open('factual_modelLlama-3.2-1B-Instruct_p0.9_kNone_numprompts400_maxoutlen200_temp1.0_idare you .pkl', 'rb') as file:\n",
    "     data_L1B_p090_T1 = pickle.load(file)\n",
    "ratios_L1B_p090_T1 = get_distribution(data_L1B_p090_T1, tokenizer_L1)\n",
    "\n",
    "\n",
    "\n",
    "with open('factual_modelLlama-3.2-1B-Instruct_p1.0_kNone_numprompts400_maxoutlen200_temp1.15_idare you .pkl', 'rb') as file:\n",
    "     data_L1B_p1_T115 = pickle.load(file)\n",
    "ratios_L1B_p1_T115 = get_distribution(data_L1B_p1_T115, tokenizer_L1)\n",
    "\n",
    "with open('factual_modelLlama-3.2-1B-Instruct_p0.95_kNone_numprompts400_maxoutlen200_temp1.15_idare you .pkl', 'rb') as file:\n",
    "     data_L1B_p095_T115 = pickle.load(file)\n",
    "ratios_L1B_p095_T115 = get_distribution(data_L1B_p095_T115, tokenizer_L1)\n",
    "\n",
    "with open('factual_modelLlama-3.2-1B-Instruct_p0.9_kNone_numprompts400_maxoutlen200_temp1.15_idare you .pkl', 'rb') as file:\n",
    "     data_L1B_p090_T115 = pickle.load(file)\n",
    "ratios_L1B_p090_T115 = get_distribution(data_L1B_p090_T115, tokenizer_L1)\n",
    "\n",
    "\n",
    "\n",
    "with open('factual_modelLlama-3.2-1B-Instruct_p1.0_kNone_numprompts400_maxoutlen200_temp1.3_idare you .pkl', 'rb') as file:\n",
    "     data_L1B_p1_T13 = pickle.load(file)\n",
    "ratios_L1B_p1_T13 = get_distribution(data_L1B_p1_T13, tokenizer_L1)\n",
    "\n",
    "with open('factual_modelLlama-3.2-1B-Instruct_p0.95_kNone_numprompts400_maxoutlen200_temp1.3_idare you .pkl', 'rb') as file:\n",
    "     data_L1B_p095_T13 = pickle.load(file)\n",
    "ratios_L1B_p095_T13 = get_distribution(data_L1B_p095_T13, tokenizer_L1)\n",
    "\n",
    "with open('factual_modelLlama-3.2-1B-Instruct_p0.9_kNone_numprompts400_maxoutlen200_temp1.3_idare you .pkl', 'rb') as file:\n",
    "     data_L1B_p090_T13 = pickle.load(file)\n",
    "ratios_L1B_p090_T13 = get_distribution(data_L1B_p090_T13, tokenizer_L1)\n",
    "\n",
    "#Lama 3B\n",
    "tokenizer_L3B =  AutoTokenizer.from_pretrained(\"meta-llama/Llama-3.2-3B-Instruct\")\n",
    "with open('factual_modelLlama-3.2-3B-Instruct_p1.0_kNone_numprompts400_maxoutlen200_temp1.0_idare you .pkl', 'rb') as file:\n",
    "     data_L3B_p1_T1 = pickle.load(file)\n",
    "ratios_L3B_p1_T1 = get_distribution(data_L3B_p1_T1, tokenizer_L3B)\n",
    "\n",
    "with open('factual_modelLlama-3.2-3B-Instruct_p0.95_kNone_numprompts400_maxoutlen200_temp1.0_idare you .pkl', 'rb') as file:\n",
    "     data_L3B_p095_T1 = pickle.load(file)\n",
    "ratios_L3B_p095_T1 = get_distribution(data_L3B_p095_T1, tokenizer_L3B)\n",
    "\n",
    "with open('factual_modelLlama-3.2-3B-Instruct_p0.9_kNone_numprompts400_maxoutlen200_temp1.0_idare you .pkl', 'rb') as file:\n",
    "     data_L3B_p090_T1 = pickle.load(file)\n",
    "ratios_L3B_p090_T1 = get_distribution(data_L3B_p090_T1, tokenizer_L3B)\n",
    "\n",
    "\n",
    "\n",
    "with open('factual_modelLlama-3.2-3B-Instruct_p1.0_kNone_numprompts400_maxoutlen200_temp1.15_idare you .pkl', 'rb') as file:\n",
    "     data_L3B_p1_T115 = pickle.load(file)\n",
    "ratios_L3B_p1_T115 = get_distribution(data_L3B_p1_T115, tokenizer_L3B)\n",
    "\n",
    "with open('factual_modelLlama-3.2-3B-Instruct_p0.95_kNone_numprompts400_maxoutlen200_temp1.15_idare you .pkl', 'rb') as file:\n",
    "     data_L3B_p095_T115 = pickle.load(file)\n",
    "ratios_L3B_p095_T115 = get_distribution(data_L3B_p095_T115, tokenizer_L3B)\n",
    "\n",
    "with open('factual_modelLlama-3.2-3B-Instruct_p0.9_kNone_numprompts400_maxoutlen200_temp1.15_idare you .pkl', 'rb') as file:\n",
    "     data_L3B_p090_T115 = pickle.load(file)\n",
    "ratios_L3B_p090_T115 = get_distribution(data_L3B_p090_T115, tokenizer_L3B)\n",
    "\n",
    "\n",
    "\n",
    "with open('factual_modelLlama-3.2-3B-Instruct_p1.0_kNone_numprompts400_maxoutlen200_temp1.3_idare you .pkl', 'rb') as file:\n",
    "     data_L3B_p1_T13 = pickle.load(file)\n",
    "ratios_L3B_p1_T13 = get_distribution(data_L3B_p1_T13, tokenizer_L3B)\n",
    "\n",
    "with open('factual_modelLlama-3.2-3B-Instruct_p0.95_kNone_numprompts400_maxoutlen200_temp1.3_idare you .pkl', 'rb') as file:\n",
    "     data_L3B_p095_T13 = pickle.load(file)\n",
    "ratios_L3B_p095_T13 = get_distribution(data_L3B_p095_T13, tokenizer_L3B)\n",
    "\n",
    "with open('factual_modelLlama-3.2-3B-Instruct_p0.9_kNone_numprompts400_maxoutlen200_temp1.3_idare you .pkl', 'rb') as file:\n",
    "     data_L3B_p090_T13 = pickle.load(file)\n",
    "ratios_L3B_p090_T13 = get_distribution(data_L3B_p090_T13, tokenizer_L3B)\n",
    "\n",
    "#Gema 1b\n",
    "tokenizer_G1B =  AutoTokenizer.from_pretrained(\"google/gemma-3-1b-it\")\n",
    "with open('factual_modelGemma-3-1b-it_p1.0_kNone_numprompts400_maxoutlen200_temp1.0_idare you .pkl', 'rb') as file:\n",
    "     data_G1B_p1_T1 = pickle.load(file)\n",
    "ratios_G1B_p1_T1 = get_distribution(data_G1B_p1_T1, tokenizer_G1B)\n",
    "\n",
    "with open('factual_modelGemma-3-1b-it_p0.95_kNone_numprompts400_maxoutlen200_temp1.0_idare you .pkl', 'rb') as file:\n",
    "     data_G1B_p095_T1 = pickle.load(file)\n",
    "ratios_G1B_p095_T1 = get_distribution(data_G1B_p095_T1, tokenizer_G1B)\n",
    "\n",
    "with open('factual_modelGemma-3-1b-it_p0.9_kNone_numprompts400_maxoutlen200_temp1.0_idare you .pkl', 'rb') as file:\n",
    "     data_G1B_p090_T1 = pickle.load(file)\n",
    "ratios_G1B_p090_T1 = get_distribution(data_G1B_p090_T1, tokenizer_G1B)\n",
    "\n",
    "\n",
    "with open('factual_modelGemma-3-1b-it_p1.0_kNone_numprompts400_maxoutlen200_temp1.15_idare you .pkl', 'rb') as file:\n",
    "     data_G1B_p1_T115 = pickle.load(file)\n",
    "ratios_G1B_p1_T115 = get_distribution(data_G1B_p1_T115, tokenizer_G1B)\n",
    "\n",
    "with open('factual_modelGemma-3-1b-it_p0.95_kNone_numprompts400_maxoutlen200_temp1.15_idare you .pkl', 'rb') as file:\n",
    "     data_G1B_p095_T115 = pickle.load(file)\n",
    "ratios_G1B_p095_T115 = get_distribution(data_G1B_p095_T115, tokenizer_G1B)\n",
    "\n",
    "with open('factual_modelGemma-3-1b-it_p0.9_kNone_numprompts400_maxoutlen200_temp1.15_idare you .pkl', 'rb') as file:\n",
    "     data_G1B_p090_T115 = pickle.load(file)\n",
    "ratios_G1B_p090_T115 = get_distribution(data_G1B_p090_T115, tokenizer_G1B)\n",
    "\n",
    "\n",
    "with open('factual_modelGemma-3-1b-it_p1.0_kNone_numprompts400_maxoutlen200_temp1.3_idare you .pkl', 'rb') as file:\n",
    "     data_G1B_p1_T13 = pickle.load(file)\n",
    "ratios_G1B_p1_T13 = get_distribution(data_G1B_p1_T13, tokenizer_G1B)\n",
    "\n",
    "with open('factual_modelGemma-3-1b-it_p0.95_kNone_numprompts400_maxoutlen200_temp1.3_idare you .pkl', 'rb') as file:\n",
    "     data_G1B_p095_T13 = pickle.load(file)\n",
    "ratios_G1B_p095_T13 = get_distribution(data_G1B_p095_T13, tokenizer_G1B)\n",
    "\n",
    "with open('factual_modelGemma-3-1b-it_p0.9_kNone_numprompts400_maxoutlen200_temp1.3_idare you .pkl', 'rb') as file:\n",
    "     data_G1B_p090_T13 = pickle.load(file)\n",
    "ratios_G1B_p090_T13 = get_distribution(data_G1B_p090_T13, tokenizer_G1B)\n",
    "\n",
    "#Mistral 8B\n",
    "tokenizer_M8B =  AutoTokenizer.from_pretrained(\"mistralai/Ministral-8B-Instruct-2410\")\n",
    "with open('factual_modelMinistral-8B-Instruct-2410_p1.0_kNone_numprompts400_maxoutlen200_temp1.0_idare you .pkl', 'rb') as file:\n",
    "     data_M8B_p1_T1 = pickle.load(file)\n",
    "ratios_M8B_p1_T1 = get_distribution(data_M8B_p1_T1, tokenizer_M8B)\n",
    "\n",
    "with open('factual_modelMinistral-8B-Instruct-2410_p0.95_kNone_numprompts400_maxoutlen200_temp1.0_idare you .pkl', 'rb') as file:\n",
    "     data_M8B_p095_T1 = pickle.load(file)\n",
    "ratios_M8B_p095_T1 = get_distribution(data_M8B_p095_T1, tokenizer_M8B)\n",
    "\n",
    "with open('factual_modelMinistral-8B-Instruct-2410_p0.9_kNone_numprompts400_maxoutlen200_temp1.0_idare you .pkl', 'rb') as file:\n",
    "     data_M8B_p090_T1 = pickle.load(file)\n",
    "ratios_M8B_p090_T1 = get_distribution(data_M8B_p090_T1, tokenizer_M8B)\n",
    "\n",
    "\n",
    "with open('factual_modelMinistral-8B-Instruct-2410_p1.0_kNone_numprompts400_maxoutlen200_temp1.15_idare you .pkl', 'rb') as file:\n",
    "     data_M8B_p1_T115 = pickle.load(file)\n",
    "ratios_M8B_p1_T115 = get_distribution(data_M8B_p1_T115, tokenizer_M8B)\n",
    "\n",
    "with open('factual_modelMinistral-8B-Instruct-2410_p0.95_kNone_numprompts400_maxoutlen200_temp1.15_idare you .pkl', 'rb') as file:\n",
    "     data_M8B_p095_T115 = pickle.load(file)\n",
    "ratios_M8B_p095_T115 = get_distribution(data_M8B_p095_T115, tokenizer_M8B)\n",
    "\n",
    "with open('factual_modelMinistral-8B-Instruct-2410_p0.9_kNone_numprompts400_maxoutlen200_temp1.15_idare you .pkl', 'rb') as file:\n",
    "     data_M8B_p090_T115 = pickle.load(file)\n",
    "ratios_M8B_p090_T115 = get_distribution(data_M8B_p090_T115, tokenizer_M8B)\n",
    "\n",
    "\n",
    "with open('factual_modelMinistral-8B-Instruct-2410_p1.0_kNone_numprompts400_maxoutlen200_temp1.3_idare you .pkl', 'rb') as file:\n",
    "     data_M8B_p1_T13 = pickle.load(file)\n",
    "ratios_M8B_p1_T13 = get_distribution(data_M8B_p1_T13, tokenizer_M8B)\n",
    "\n",
    "with open('factual_modelMinistral-8B-Instruct-2410_p0.95_kNone_numprompts400_maxoutlen200_temp1.3_idare you .pkl', 'rb') as file:\n",
    "     data_M8B_p095_T13 = pickle.load(file)\n",
    "ratios_M8B_p095_T13 = get_distribution(data_M8B_p095_T13, tokenizer_M8B)\n",
    "\n",
    "with open('factual_modelMinistral-8B-Instruct-2410_p0.9_kNone_numprompts400_maxoutlen200_temp1.3_idare you .pkl', 'rb') as file:\n",
    "     data_M8B_p090_T13 = pickle.load(file)\n",
    "ratios_M8B_p090_T13 = get_distribution(data_M8B_p090_T13, tokenizer_M8B)   \n",
    " "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ebdb3c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Load the gemma 4b data\n",
    "\n",
    "with open('factual_modelGemma-3-4b-it_p1.0_kNone_numprompts400_maxoutlen200_temp1.0_idare you .pkl', 'rb') as file:\n",
    "     data_G4B_p1_T1 = pickle.load(file)\n",
    "ratios_G4B_p1_T1 = get_distribution(data_G4B_p1_T1, tokenizer_G1B)\n",
    "\n",
    "with open('factual_modelGemma-3-4b-it_p0.95_kNone_numprompts400_maxoutlen200_temp1.0_idare you .pkl', 'rb') as file:\n",
    "     data_G4B_p095_T1 = pickle.load(file)\n",
    "ratios_G4B_p095_T1 = get_distribution(data_G4B_p095_T1, tokenizer_G1B)\n",
    "\n",
    "with open('factual_modelGemma-3-4b-it_p0.9_kNone_numprompts400_maxoutlen200_temp1.0_idare you .pkl', 'rb') as file:\n",
    "     data_G4B_p090_T1 = pickle.load(file)\n",
    "ratios_G4B_p090_T1 = get_distribution(data_G4B_p090_T1, tokenizer_G1B)\n",
    "\n",
    "\n",
    "with open('factual_modelGemma-3-4b-it_p1.0_kNone_numprompts400_maxoutlen200_temp1.15_idare you .pkl', 'rb') as file:\n",
    "     data_G4B_p1_T115 = pickle.load(file)\n",
    "ratios_G4B_p1_T115 = get_distribution(data_G4B_p1_T115, tokenizer_G1B)\n",
    "\n",
    "with open('factual_modelGemma-3-4b-it_p0.95_kNone_numprompts400_maxoutlen200_temp1.15_idare you .pkl', 'rb') as file:\n",
    "     data_G4B_p095_T115 = pickle.load(file)\n",
    "ratios_G4B_p095_T115 = get_distribution(data_G4B_p095_T115, tokenizer_G1B)\n",
    "\n",
    "with open('factual_modelGemma-3-4b-it_p0.9_kNone_numprompts400_maxoutlen200_temp1.15_idare you .pkl', 'rb') as file:\n",
    "     data_G4B_p090_T115 = pickle.load(file)\n",
    "ratios_G4B_p090_T115 = get_distribution(data_G4B_p090_T115, tokenizer_G1B)\n",
    "\n",
    "\n",
    "with open('factual_modelGemma-3-4b-it_p1.0_kNone_numprompts400_maxoutlen200_temp1.3_idare you .pkl', 'rb') as file:\n",
    "     data_G4B_p1_T13 = pickle.load(file)\n",
    "ratios_G4B_p1_T13 = get_distribution(data_G4B_p1_T13, tokenizer_G1B)\n",
    "\n",
    "with open('factual_modelGemma-3-4b-it_p0.95_kNone_numprompts400_maxoutlen200_temp1.3_idare you .pkl', 'rb') as file:\n",
    "     data_G4B_p095_T13 = pickle.load(file)\n",
    "ratios_G4B_p095_T13 = get_distribution(data_G4B_p095_T13, tokenizer_G1B)\n",
    "\n",
    "with open('factual_modelGemma-3-4b-it_p0.9_kNone_numprompts400_maxoutlen200_temp1.3_idare you .pkl', 'rb') as file:\n",
    "     data_G4B_p090_T13 = pickle.load(file)\n",
    "ratios_G4B_p090_T13 = get_distribution(data_G4B_p090_T13, tokenizer_G1B)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e900c339",
   "metadata": {},
   "outputs": [],
   "source": [
    "#plt.clf()\n",
    "sns.set_theme(context='paper', style='ticks', font_scale=1)\n",
    "width_pt = 469\n",
    "palette = sns.color_palette('husl', 3)\n",
    "\n",
    "utils.latexify() # Computer Modern, with TeX\n",
    "\n",
    "fig_width, fig_height = utils.get_fig_dim(width_pt, fraction=0.6)\n",
    "fig, ax = plt.subplots(figsize=(fig_width, fig_height))\n",
    "\n",
    " \n",
    "sns.kdeplot(x=[v * 100 for v in ratios_L3B_p1_T13], color=palette[2], ax=ax, legend=False, fill=False, linewidth=1, alpha=1)\n",
    "sns.kdeplot(x=[v * 100 for v in ratios_L3B_p095_T13], color=palette[1], ax=ax, legend=False, fill=False, linewidth=1, alpha=1)\n",
    "sns.kdeplot(x=[v * 100 for v in ratios_L3B_p090_T13], color=palette[0], ax=ax, legend=False, fill=False, linewidth=1, alpha=1)\n",
    "\n",
    "\n",
    "sns.despine(ax=ax)\n",
    "ax.set_xlabel(r\"Price change $\\%$ \")\n",
    "ax.set_ylabel(\"Probability\")\n",
    "ax.set_xticks([-75, -50, -25,0,25, 50, 75])\n",
    "ax.set_xticklabels([])\n",
    "ax.set_xlabel(\"\")\n",
    "\n",
    "#Calculate the 10% percentile of each distribution, and the plot as a vertical line the average of the 10% percentiles \n",
    "mean_10p = np.mean([np.percentile([v * 100 for v in ratios_L3B_p1_T13], 10),\n",
    "                    np.percentile([v * 100 for v in ratios_L3B_p095_T13], 10),\n",
    "                    np.percentile([v * 100 for v in ratios_L3B_p090_T13], 10)])\n",
    "ax.axvline(x=mean_10p , color='black', linestyle='--', label='Mean 10% CPT', linewidth=1.5, alpha=0.7)\n",
    "\n",
    "#ax.get_xaxis().set_visible(False)\n",
    "ax.set_ylim(0, 0.035)\n",
    "ax.set_xlim(-100, 100)\n",
    "ax.set_yticks([0, 0.005, 0.01, 0.015, 0.02, 0.025, 0.03, 0.035])\n",
    "#ax.set_yticklabels([])\n",
    "#ax.set_ylabel(\"\")\n",
    "\n",
    "fig.tight_layout()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c5b13c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Load all tokenizers\n",
    "tok_L1B = AutoTokenizer.from_pretrained(\"meta-llama/Llama-3.2-1B-Instruct\")\n",
    "tok_L3B = AutoTokenizer.from_pretrained(\"meta-llama/Llama-3.2-3B-Instruct\")\n",
    "tok_G1B = AutoTokenizer.from_pretrained(\"google/gemma-3-1b-it\")\n",
    "tok_M8B = AutoTokenizer.from_pretrained(\"mistralai/Ministral-8B-Instruct-2410\")\n",
    "tok_G4B = AutoTokenizer.from_pretrained(\"google/gemma-3-4b-it\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "637e116b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Given a a model name, a temperature, and a top-p, plot the distribution of the price change ratios\n",
    "#First load the data for the model, temperature, and top-p\n",
    "#Then plot the distribution of the price change ratios\n",
    "#Finally, plot the mean of the 10% percentiles of the distribution as a vertical line\n",
    "\n",
    "model_name = \"M8B\"\n",
    "temperature = 1.0\n",
    "language = \"chinese\"\n",
    "\n",
    "model_list = [\"L1B\", \"L3B\", \"G1B\", \"G4B\", \"M8B\"]\n",
    "temperature_list = [1.15, 1.3]\n",
    "language_list = [\"english\", \"spanish\", \"russian\", \"chinese\"]\n",
    "\n",
    "\n",
    "\n",
    "def plot_distribution(model_name, temperature, language):\n",
    "        if model_name== \"L1B\":\n",
    "                model_str=\"Llama-3.2-1B-Instruct\"\n",
    "                model_str_full=\"meta-llama/Llama-3.2-1B-Instruct\"\n",
    "                tokenizer = tok_L1B\n",
    "        if model_name== \"L3B\":\n",
    "                model_str=\"Llama-3.2-3B-Instruct\"\n",
    "                model_str_full=\"meta-llama/Llama-3.2-3B-Instruct\"\n",
    "                tokenizer = tok_L3B\n",
    "        if model_name== \"G1B\":\n",
    "                model_str=\"Gemma-3-1b-it\"\n",
    "                model_str_full=\"google/gemma-3-1b-it\"\n",
    "                tokenizer = tok_G1B\n",
    "        if model_name== \"G4B\":\n",
    "                model_str=\"Gemma-3-4b-it\"\n",
    "                model_str_full=\"google/gemma-3-4b-it\"\n",
    "                tokenizer = tok_G4B\n",
    "        if model_name== \"M8B\":\n",
    "                model_str=\"Ministral-8B-Instruct-2410\"\n",
    "                model_str_full=\"mistralai/Ministral-8B-Instruct-2410\"\n",
    "                tokenizer = tok_M8B\n",
    "                \n",
    "                \n",
    "        #Load the data for the model, temperature, and top-p\n",
    "\n",
    "        with open(f'cpt/factual_model{model_str}_lan{language}_p0.99_kNone_numprompts500_maxoutlen200_temp{temperature}.pkl', 'rb') as file:\n",
    "                data_p1 = pickle.load(file)\n",
    "        \n",
    "        with open(f'cpt/factual_model{model_str}_lan{language}_p0.95_kNone_numprompts500_maxoutlen200_temp{temperature}.pkl', 'rb') as file:\n",
    "                data_p095 = pickle.load(file)  \n",
    "        \n",
    "        with open(f'cpt/factual_model{model_str}_lan{language}_p0.9_kNone_numprompts500_maxoutlen200_temp{temperature}.pkl', 'rb') as file:\n",
    "                data_p090 = pickle.load(file)  \n",
    "        \n",
    "        \n",
    "        \n",
    "        #Load the tokenizer for the model\n",
    "        ratios_p1, cpt_1 = get_distribution(data_p1, tokenizer)\n",
    "        ratios_p095, cpt_095 = get_distribution(data_p095, tokenizer)\n",
    "        ratios_p090, cpt__090 = get_distribution(data_p090, tokenizer)\n",
    "\n",
    "\n",
    "        sns.set_theme(context='paper', style='ticks', font_scale=1)\n",
    "        width_pt = 469\n",
    "        palette = sns.color_palette('husl', 3)\n",
    "\n",
    "        utils.latexify() # Computer Modern, with TeX\n",
    "\n",
    "        fig_width, fig_height = utils.get_fig_dim(width_pt, fraction=0.6)\n",
    "        fig, ax = plt.subplots(figsize=(fig_width, fig_height))\n",
    "\n",
    "        \n",
    "        sns.kdeplot(x=[v * 100 for v in ratios_p1], color=palette[2], ax=ax, legend=False, fill=False, linewidth=1, alpha=1)\n",
    "        sns.kdeplot(x=[v * 100 for v in ratios_p095], color=palette[1], ax=ax, legend=False, fill=False, linewidth=1, alpha=1)\n",
    "        sns.kdeplot(x=[v * 100 for v in ratios_p090], color=palette[0], ax=ax, legend=False, fill=False, linewidth=1, alpha=1)\n",
    "\n",
    "\n",
    "        sns.despine(ax=ax)\n",
    "        ax.set_xlabel(r\"Price change $\\%$ \")\n",
    "        ax.set_ylabel(\"Probability\")\n",
    "\n",
    "\n",
    "        #Calculate the 10% percentile of each distribution, and the plot as a vertical line the average of the 10% percentiles \n",
    "        mean_10p = np.mean([np.percentile([v * 100 for v in ratios_p1], 10),\n",
    "                        np.percentile([v * 100 for v in ratios_p095], 10),\n",
    "                        np.percentile([v * 100 for v in ratios_p090], 10)])\n",
    "        ax.axvline(x=mean_10p , color='black', linestyle='--', label='Mean 10% CPT', linewidth=1.5, alpha=0.7)\n",
    "\n",
    "        #Caulate the mean of each distribution, and then plot as a vertical line the mean of the means\n",
    "        mean_cpt = np.mean([np.mean([v * 100 for v in ratios_p1]),\n",
    "                        np.mean([v * 100 for v in ratios_p095]),\n",
    "                        np.mean([v * 100 for v in ratios_p090])])\n",
    "        ax.axvline(x=mean_cpt , color='black', linestyle=':', label='Mean CPT', linewidth=1.5, alpha=0.7)\n",
    "        #print(\"Mean CPT\", mean_cpt)\n",
    "        #print(\"Mean 10% CPT\", mean_10p)\n",
    "        #ax.get_xaxis().set_visible(False)\n",
    "        ax.set_ylim(0, 0.055)\n",
    "        ax.set_xlim(-100, 100)\n",
    "        ax.set_yticks([0, 0.005, 0.015, 0.025, 0.035, 0.045, 0.055])\n",
    "        ax.set_xticks([-75, -50, -25,0,25, 50, 75])\n",
    "\n",
    "        if temperature != 1.3:\n",
    "                ax.set_yticklabels([])\n",
    "                ax.set_ylabel(\"\")\n",
    "\n",
    "        if model_name != \"M8B\":\n",
    "                ax.set_xticklabels([])\n",
    "                ax.set_xlabel(\"\")\n",
    "\n",
    "\n",
    "        fig.tight_layout()\n",
    "        #plt.show()\n",
    "        fig.savefig(f'../figures/price_dist/ratios_{model_name}_{language}_{temperature}.pdf', dpi=300)\n",
    "\n",
    "for model_name in model_list:\n",
    "        for temperature in temperature_list:\n",
    "                for language in language_list:\n",
    "                        print(f\"Plotting distribution for {model_name} at temperature {temperature} in {language}\")\n",
    "                        plot_distribution(model_name, temperature, language)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec8e0b01",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Given a a model name, a temperature, and a top-p, plot the distribution of the price change ratios\n",
    "#First load the data for the model, temperature, and top-p\n",
    "#Then plot the distribution of the price change ratios\n",
    "#Finally, plot the mean of the 10% percentiles of the distribution as a vertical line\n",
    "\n",
    "model_name = \"M8B\"\n",
    "temperature = 1.0\n",
    "language = \"spanish\"\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "if model_name== \"L1B\":\n",
    "                model_str=\"Llama-3.2-1B-Instruct\"\n",
    "                model_str_full=\"meta-llama/Llama-3.2-1B-Instruct\"\n",
    "                tokenizer = tok_L1B\n",
    "if model_name== \"L3B\":\n",
    "                model_str=\"Llama-3.2-3B-Instruct\"\n",
    "                model_str_full=\"meta-llama/Llama-3.2-3B-Instruct\"\n",
    "                tokenizer = tok_L3B\n",
    "if model_name== \"G1B\":\n",
    "                model_str=\"Gemma-3-1b-it\"\n",
    "                model_str_full=\"google/gemma-3-1b-it\"\n",
    "                tokenizer = tok_G1B\n",
    "if model_name== \"G4B\":\n",
    "                model_str=\"Gemma-3-4b-it\"\n",
    "                model_str_full=\"google/gemma-3-4b-it\"\n",
    "                tokenizer = tok_G4B\n",
    "if model_name== \"M8B\":\n",
    "                model_str=\"Ministral-8B-Instruct-2410\"\n",
    "                model_str_full=\"mistralai/Ministral-8B-Instruct-2410\"\n",
    "                tokenizer = tok_M8B\n",
    "                \n",
    "                \n",
    "\n",
    "\n",
    "        \n",
    "with open(f'cpt/factual_model{model_str}_lan{language}_p0.99_kNone_numprompts500_maxoutlen200_temp{temperature}.pkl', 'rb') as file:\n",
    "                data_p1_es = pickle.load(file)\n",
    "        \n",
    "with open(f'cpt/factual_model{model_str}_lan{language}_p0.95_kNone_numprompts500_maxoutlen200_temp{temperature}.pkl', 'rb') as file:\n",
    "                data_p095_es = pickle.load(file)  \n",
    "        \n",
    "with open(f'cpt/factual_model{model_str}_lan{language}_p0.9_kNone_numprompts500_maxoutlen200_temp{temperature}.pkl', 'rb') as file:\n",
    "                data_p090_es = pickle.load(file)  \n",
    "        \n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49c401f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Compute the total number of tokens for the english data\n",
    "total_tokens_en = 0\n",
    "for prompt in data_p1:\n",
    "    for output in prompt[\"output\"]:\n",
    "        total_tokens_en += len(output)\n",
    "print(f\"Total tokens for english data: {total_tokens_en}\")\n",
    "\n",
    "#Compute the total number of tokens for the spanish data\n",
    "total_tokens_es = 0\n",
    "for prompt in data_p1_es:\n",
    "    for output in prompt[\"output\"]:\n",
    "        total_tokens_es += len(output)\n",
    "print(f\"Total tokens for spanish data: {total_tokens_es}\")\n",
    "\n",
    "print(\"Relative difference in tokens between english and spanish data: \", (total_tokens_es - total_tokens_en) / total_tokens_en)\n",
    "\n",
    "\n",
    "#Now, do the same for the characters, instead of tokens\n",
    "total_chars_en = 0\n",
    "for prompt in data_p1:\n",
    "    for output in prompt[\"output\"]:\n",
    "        total_chars_en += len(tokenizer.decode(output))\n",
    "print(f\"Total characters for english data: {total_chars_en}\")\n",
    "\n",
    "total_chars_es = 0\n",
    "\n",
    "for prompt in data_p1_es:\n",
    "    for output in prompt[\"output\"]:\n",
    "        total_chars_es += len(tokenizer.decode(output))\n",
    "print(f\"Total characters for spanish data: {total_chars_es}\")\n",
    "\n",
    "char_per_token_en = total_chars_en / total_tokens_en\n",
    "char_per_token_es = total_chars_es / total_tokens_es\n",
    "\n",
    "print(\"Relative difference in characters between english and spanish data: \", char_per_token_en * total_chars_es  / total_chars_en / char_per_token_es - 1 )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b7ec5897",
   "metadata": {},
   "source": [
    "# Margin disttribution\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9581823",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.ticker import PercentFormatter\n",
    "temperature = 1.0\n",
    "\n",
    "def compute_weighted_cpt_and_plot_rho(\n",
    "    model_name, \n",
    "    temperature, \n",
    "    rho_list=[0.1, 0.3, 0.5, 0.7, 0.9], \n",
    "    language=\"joint\", \n",
    "    n_steps=100\n",
    "):\n",
    "    # Ensure rho_list is a list\n",
    "    if isinstance(rho_list, float):\n",
    "        rho_list = [rho_list]\n",
    "\n",
    "    # Map model names to strings and tokenizers\n",
    "    if model_name == \"L1B\":\n",
    "        model_str = \"Llama-3.2-1B-Instruct\"\n",
    "        tokenizer = tok_L1B\n",
    "    elif model_name == \"L3B\":\n",
    "        model_str = \"Llama-3.2-3B-Instruct\"\n",
    "        tokenizer = tok_L3B\n",
    "    elif model_name == \"G1B\":\n",
    "        model_str = \"Gemma-3-1b-it\"\n",
    "        tokenizer = tok_G1B\n",
    "    elif model_name == \"G4B\":\n",
    "        model_str = \"Gemma-3-4b-it\"\n",
    "        tokenizer = tok_G4B\n",
    "    elif model_name == \"M8B\":\n",
    "        model_str = \"Ministral-8B-Instruct-2410\"\n",
    "        tokenizer = tok_M8B\n",
    "\n",
    "    # Load data for all languages\n",
    "    language_data = {}\n",
    "    for lang in [\"english\", \"spanish\", \"russian\", \"chinese\"]:\n",
    "        with open(\n",
    "            f'factual_model{model_str}_lan{lang}_p0.95_kNone_numprompts500_maxoutlen200_temp{temperature}.pkl', \n",
    "            'rb'\n",
    "        ) as file:\n",
    "            language_data[lang] = pickle.load(file)\n",
    "\n",
    "    # --- Build concatenated dataset according to proportions ---\n",
    "    concatenated_data = []\n",
    "    for lang in [\"english\", \"spanish\", \"russian\", \"chinese\"]:\n",
    "        data = language_data[lang]\n",
    "        proportion = language_proportions[lang]\n",
    "        total_samples = len(data) * len(data[0][\"output\"])\n",
    "        samples_to_take = int(total_samples * proportion / language_proportions[\"english\"])\n",
    "\n",
    "        sample_count = 0\n",
    "        for prompt_data in data:\n",
    "            for output in prompt_data[\"output\"]:\n",
    "                if sample_count < samples_to_take:\n",
    "                    concatenated_data.append({\n",
    "                        \"language\": lang,\n",
    "                        \"output\": output,\n",
    "                        \"prompt\": prompt_data[\"prompt\"]\n",
    "                    })\n",
    "                    sample_count += 1\n",
    "                else:\n",
    "                    break\n",
    "            if sample_count >= samples_to_take:\n",
    "                break\n",
    "\n",
    "    # --- Select dataset to analyze ---\n",
    "    if language == \"joint\":\n",
    "        data_to_use = concatenated_data\n",
    "    elif language in [\"english\", \"spanish\", \"russian\", \"chinese\"]:\n",
    "        data_to_use = []\n",
    "        for prompt_data in language_data[language]:\n",
    "            for output in prompt_data[\"output\"]:\n",
    "                data_to_use.append({\"output\": output})\n",
    "    else:\n",
    "        raise ValueError(f\"Unknown language '{language}'.\")\n",
    "\n",
    "    # --- Compute weighted CPT from joint dataset (always joint!) ---\n",
    "    cpts = []\n",
    "    for sample in concatenated_data:\n",
    "        output = sample[\"output\"]\n",
    "        if len(output) == 0:\n",
    "            continue\n",
    "        chars = len(tokenizer.decode(output))\n",
    "        tokens = len(output)\n",
    "        if tokens == 0 or chars == 0:\n",
    "            continue\n",
    "        cpt = 1/ (chars / tokens)\n",
    "        cpts.append(cpt)\n",
    "\n",
    "    weighted_cpt = 1 / np.mean(cpts) \n",
    "    print(f\"Model {model_name}: Weighted average CPT = {weighted_cpt:.4f}\")\n",
    "    print(f\"Using data for language: {language} with {len(data_to_use)} samples.\")\n",
    "\n",
    "    # --- Plot cumulative distributions for multiple rho values ---\n",
    "    sns.set_theme(context='paper', style='ticks', font_scale=1)\n",
    "    width_pt = 469\n",
    "    utils.latexify()\n",
    "    fig_width, fig_height = utils.get_fig_dim(width_pt, fraction=0.6)\n",
    "    fig, ax = plt.subplots(figsize=(fig_width, fig_height))\n",
    "\n",
    "    colors = sns.color_palette(\"crest\", n_colors=len(rho_list), desat=0.6)\n",
    "    \n",
    "    cmap = sns.color_palette(\"crest\", as_cmap=True)\n",
    "    colors = [cmap(0.0), cmap(0.5), cmap(1.0)]\n",
    "\n",
    "    for rho, color in zip(rho_list, colors):\n",
    "        new_quantity = []\n",
    "        for sample in data_to_use:\n",
    "            output = sample[\"output\"]\n",
    "            if len(output) == 0:\n",
    "                continue\n",
    "            chars = len(tokenizer.decode(output))\n",
    "            tokens = len(output)\n",
    "            if tokens == 0 or chars == 0:\n",
    "                continue\n",
    "            cpt_output = chars / tokens\n",
    "            value = 1 - (weighted_cpt / cpt_output) * (1 - rho)\n",
    "            new_quantity.append(value)\n",
    "\n",
    "        # --- Compute quantile-based coarse ECDF ---\n",
    "        q = np.linspace(0, 1, n_steps)\n",
    "        x = np.quantile(new_quantity, q)\n",
    "\n",
    "        # Step plot (solid)\n",
    "        ax.step(x, q, where=\"post\", color=color, linewidth=2, label=f\"rho={rho}\")\n",
    "\n",
    "        # Extend horizontally at 100% up to current axis max\n",
    "        ax.hlines(\n",
    "            y=1.0, \n",
    "            xmin=x[-1], \n",
    "            xmax=ax.get_xlim()[1], \n",
    "            colors=color, \n",
    "            linewidth=2\n",
    "        )\n",
    "\n",
    "        # Vertical lines at the mean of each distribution\n",
    "        mean_lang = np.mean(new_quantity)\n",
    "        print(f\"Mean margin for rho={rho}: {mean_lang:.4f}\")\n",
    "        ax.axvline(\n",
    "            x=mean_lang, \n",
    "            color=color, \n",
    "            linestyle='--', \n",
    "            linewidth=1.5, \n",
    "            alpha=0.7\n",
    "        )\n",
    "        \n",
    "    #Print, for each distribution the % of samples with positive new_quantity\n",
    "        positive_count = sum(1 for val in new_quantity if val > 0)\n",
    "        percent_positive = positive_count / len(new_quantity) * 100\n",
    "        print(f\"rho={rho}: {percent_positive:.2f}% of samples have positive provider margin.\")    \n",
    "    ax.set_xlabel(r\"Provider's margin, $\\rho$\")\n",
    "    ax.set_ylabel(\"Cumulative fraction of outputs\")\n",
    "    sns.despine(ax=ax)\n",
    "    ax.set_xticks([-3,-2,-1, 0, 1,])\n",
    "    ax.set_xlim(-3, 1)\n",
    "\n",
    "    # Format y-axis as percentage and extend above 100%\n",
    "    ax.yaxis.set_major_formatter(PercentFormatter(1))\n",
    "    ax.set_ylim(0, 1.05)\n",
    "    \n",
    "    ax.set_ylabel(\"\")\n",
    "    ax.set_yticklabels([])\n",
    "    \n",
    "    ax.set_xlabel(\"\")\n",
    "    ax.set_xticklabels([])\n",
    "\n",
    "    fig.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "    return weighted_cpt\n",
    "\n",
    "\n",
    "# Example runs\n",
    "rho_list = [0.2, 0.4, 0.6]\n",
    "temperature = 1.0\n",
    "\n",
    "#compute_weighted_cpt_and_plot_rho(\"M8B\", temperature, rho_list=rho_list, language=\"chinese\", n_steps=50)\n",
    "\n",
    "compute_weighted_cpt_and_plot_rho(\"L1B\", temperature, rho_list=rho_list, language=\"chinese\", n_steps=50)\n",
    "compute_weighted_cpt_and_plot_rho(\"L3B\", temperature, rho_list=rho_list, language=\"chinese\", n_steps=50)\n",
    "compute_weighted_cpt_and_plot_rho(\"G1B\", temperature, rho_list=rho_list, language=\"chinese\", n_steps=50)\n",
    "compute_weighted_cpt_and_plot_rho(\"G4B\", temperature, rho_list=rho_list, language=\"chinese\", n_steps=50)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
