{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "## import numpy as np\n",
    "import pickle as pkl\n",
    "import os\n",
    "from matplotlib import pyplot as plt\n",
    "cot = '/home/XXXX/XXXX/fs_backup_feb13/LLM-project/preds/FewShotCOTCLUTRR_Mon_Feb__3_16.37.46_2025_iter0'\n",
    "\n",
    "cot = np.load(open(cot, 'rb'), allow_pickle=True)\n",
    "\n",
    "plt.rcParams['font.size'] = 12"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "temp_outs_str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "USER_PATH = '/home/XXXX/XXXX/fs_backup_feb13/'\n",
    "import json\n",
    "rels = ['aunt', 'brother', 'brother-in-law', 'daughter', 'daughter-in-law','father', 'father-in-law', 'granddaughter', 'grandfather', 'grandmother', 'grandson', 'mother', 'mother-in-law', 'nephew','niece', 'sister', 'sister-in-law', 'son', 'son-in-law', 'uncle']\n",
    "\n",
    "from datasets import load_from_disk ; ds = load_from_disk(USER_PATH + 'LLM-project/clutrr_clean/dataset_fixed_gpt4o_graph_search/gen_train234_test2to10/test/')\n",
    "dd = ds.to_pandas().to_dict('records')\n",
    "ds = []\n",
    "import random\n",
    "for d in dd:\n",
    "    ds.append({})\n",
    "    ds[-1]['id'] =d['id']\n",
    "    ds[-1]['context'] = d['clean_story']\n",
    "    ds[-1]['query'] = d['query']\n",
    "    if len(d['graph_search_result']) == 0:\n",
    "        del ds[-1]\n",
    "        continue\n",
    "    if random.random() < 0.5:\n",
    "        ds[-1]['label'] = d['graph_search_result'][0]\n",
    "        ds[-1]['gt'] = 'true'\n",
    "    else:\n",
    "        random.shuffle(rels)\n",
    "        for rel in rels:\n",
    "            if rel not in d['graph_search_result'][0]:\n",
    "                ds[-1]['label'] = rel\n",
    "                ds[-1]['gt'] = 'false'\n",
    "                break\n",
    "        # ds[-1]['label'] = \n",
    "    # print(ds[-1])\n",
    "\n",
    "json.dump(ds,open(USER_PATH + 'SAT-LM/data/new_clutrr_test.json', 'w'))\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(cot)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import shutil\n",
    "import os \n",
    "\n",
    "def get_bb(file, del_sols=None):\n",
    "    bb = {'pos':  [], 'neg': []}\n",
    "    \n",
    "    files = ['/'.join(file.split('/')[:-1]) + '/pos_' + file.split('/')[-1], '/'.join(file.split('/')[:-1]) + '/neg_' + file.split('/')[-1] ]\n",
    "    for i in range(len(files)):\n",
    "        file = files[i]\n",
    "        shutil.copy(file, '/'.join(file.split('/')[:-2]) + '/tempfiles/' + str(file.split('/')[-1]))\n",
    "        if not del_sols==None:\n",
    "            if 'pos' in file:\n",
    "                if 'neg' in file:\n",
    "                    print('l. 416 uh oh')\n",
    "                      \n",
    "                ds = del_sols['pos']\n",
    "            elif 'neg' in file:\n",
    "                ds = del_sols['neg']\n",
    "            for sol in ds:\n",
    "                add_clause('/'.join(file.split('/')[:-2]) + '/tempfiles/' + str(file.split('/')[-1]))\n",
    "                cf = open(f'/'.join(file.split('/')[:-2]) + '/tempfiles/' + str(file.split('/')[-1]), 'a')\n",
    "                write_str = '\\n'\n",
    "                for lit in sol:\n",
    "                    write_str += str(-lit) + ' '\n",
    "                # write_str += '0'\n",
    "                cf.write(write_str)\n",
    "                cf.close()\n",
    "        # print('running cadical')\n",
    "        os.system(\"timeout 5000 /home/XXXX/XXXX/fs_backup_feb13/LLM-project/cadiback/cadiback \" + '/'.join(file.split('/')[:-2]) + '/tempfiles/' + str(file.split('/')[-1]) + '> '  + '/'.join(file.split('/')[:-2]) + '/tempfiles/' + str(file.split('/')[-1])[:-4] + \".bbone\")\n",
    "        #   \n",
    "        bbone= open('/'.join(file.split('/')[:-2]) + '/tempfiles/' + str(file.split('/')[-1])[:-4] + \".bbone\", 'r')\n",
    "        lines = bbone.readlines()\n",
    "        #   \n",
    "        for line in lines:\n",
    "            if line.startswith('b'):\n",
    "                #   \n",
    "                lits = line.split(' ')[1:]\n",
    "                for lit in lits:\n",
    "                    lit = lit.strip()\n",
    "                    if lit == '0':\n",
    "                        continue\n",
    "                    lit = int(lit)\n",
    "                    if 'pos' in file:                                \n",
    "                        if 'neg' in file:\n",
    "                            print('l. 447 uh oh')\n",
    "                              \n",
    "                        bb['pos'].append(lit)\n",
    "                    elif 'neg' in file:\n",
    "                            bb['neg'].append(lit)\n",
    "\n",
    "    return bb\n",
    "\n",
    "\n",
    "# nameed=False\n",
    "# c = open(c, 'r')\n",
    "# cr = csv.reader(c)\n",
    "# noisy_data = ['clutrr33.cnf']\n",
    "noisy_data = []\n",
    "mistr_data = []\n",
    "c = '/home/XXXX/XXXX/fs_backup_feb13/LLM-project/dimacs_clutrr_csvs/solver_finished.csv'\n",
    "import csv\n",
    "import json\n",
    "dataset = '/home/XXXX/XXXX/fs_backup_feb13/SAT-LM/data/clutrr_test.json'\n",
    "with open(dataset, 'r') as df:\n",
    "    data = json.loads(df.read())\n",
    "\n",
    "task = 'clutrr'\n",
    "missed=False\n",
    "c = open(c, 'r')\n",
    "cr = csv.reader(c)\n",
    "names = []\n",
    "all_outs = {}\n",
    "missed_list = []\n",
    "labels = {}\n",
    "for row in cr:\n",
    "    if row[2] == 'SAT' and row[3] == 'SAT':\n",
    "        cnf = open('/home/XXXX/XXXX/fs_backup_feb13/LLM-project/dimacs_clutrr/neg_'+row[1]).readlines()[0].strip('\\n')\n",
    "        num_clause = int(cnf.split(' ')[-1])\n",
    "        if row[1] in noisy_data or row[1] in mistr_data:\n",
    "            continue\n",
    "        # if num_clause > 500:\n",
    "            # continue\n",
    "        names.append(row[1])\n",
    "        labels[row[1]] = data[int(row[1].split('clutrr')[1].split('.')[0])]['label']\n",
    "#   \n",
    "preds = {}\n",
    "labels = {}\n",
    "c = open('/home/XXXX/XXXX/fs_backup_feb13/LLM-project/clutrr_labels.csv', 'r')\n",
    "cr = csv.reader(c)\n",
    "for row in cr:\n",
    "    if not os.path.exists('/home/XXXX/XXXX/fs_backup_feb13/LLM-project/dimacs_clutrr/neg_'+row[0][:-2]+'cnf'):\n",
    "        continue\n",
    "    cnf = open('/home/XXXX/XXXX/fs_backup_feb13/LLM-project/dimacs_clutrr/neg_'+row[0][:-2]+'cnf').readlines()[0].strip('\\n')\n",
    "    num_clause = int(cnf.split(' ')[-1])\n",
    "    if row[1] in noisy_data or row[1] in mistr_data:\n",
    "        continue\n",
    "    # if num_clause > 500:\n",
    "        # continue\n",
    "    if row[0][:-2]+'cnf' not in names:\n",
    "        continue\n",
    "    labels[row[0][:-2]+'cnf'] = row[1].lower()\n",
    "    \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(names), len(labels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cot[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for key, value in labels.items():\n",
    "    labels[key] = value.strip(' ')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "int(row[1].split('clutrr')[1].split('.cnf')[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "i = 0\n",
    "cot_acc = 0\n",
    "cot_preds = {}\n",
    "for name in names:\n",
    "    if cot[i] == labels[name].strip(' '):\n",
    "        cot_acc += 1\n",
    "        cot_preds[key] = True\n",
    "    else:\n",
    "        cot_preds[key] = False\n",
    "    i += 1\n",
    "print(cot_acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "326/(len(names)-500)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "few_shot = \"Facts:n[Nancy] likes to cut the hair of her daughter [Heidi].\\n[Heidi]'s sister [Lorraine] went to beauty school and taught them all how to cut hair expertly. \" + \\\n",
    "            \"\\nHere are some additional facts and rules we\\'ve found:\\nNancy is the mother of Lorraine\\n If Heidi is the sister of Lorraine and Heidi is the daughter of Nancy then Nancy is the mother of Lorraine.\\n\" + \\\n",
    "            \"Question: Is the following statement true: \\n\\\"[Lorraine] is [Nancy]\\'s daughter\\\"\\nAnswer: Let\\'s think step by step. \\n1. We have already found that Nancy is the mother of Lorraine.\\n2. If Nancy is the mother of Lorraine, then Lorraine is the daughter of Nancy.\\nTherefore, the answer to the question is Yes, the statement is true. \\n***\\n\" + \\\n",
    "            \"Facts:\\n[Dale] and his sister [Nancy] are decorating for a party.\\n[Nancy]'s daughter [Louise] thinks the party will be fun.\\n\" + \\\n",
    "            \"Here are some additional facts and rules we\\'ve found:\\nDale is the uncle of Louise. If Nancy is the sister of Dale and Nancy is the mother of Louise then Dale is the uncle of Louise.\\n\" + \\\n",
    "            \"Question: Is the following statement true: \\n\\\"[Louise] is not [Dales]\\'s niece\\\"\\n\" + \\\n",
    "            \"Answer: Let\\'s think step by step. 1. We are given that Dale is the uncle of Louise.\\n2.If Dale is the uncle of Louise, then Louise is the niece of Dale.\\nTherefore, the answer is No, the statement is not true.\\n***\\n\" + \\\n",
    "            \"Facts: \\n[Lillian] and her sister [Nancy] are the only children in their family. \\n[Lillian]'s biggest accomplishment is raising her son [Douglas]. \" + \\\n",
    "            \"\\nHere are some additional facts and rules we\\'ve found:\\nLillian is the sister of Nancy. \\nIf Nancy is the sister if Lillian then Lillian is the sister of Nancy.\\n\" + \\\n",
    "            \"Question: Is the following statement true: \\n\\\"[Douglas] is [Nancy]\\'s nephew\\\"\\nAnswer: Let\\'s think step by step. \\n1. [Douglas] is [Lillian]\\'s son. \\n2. [Nancy] is [Lillian]\\'s sister. \" + \\\n",
    "            \"3\\n. [Douglas] is [Nancy]\\'s nephew. \\nTherefore, the answer to the question is Yes, the statement is true. \\n***\\n\" + \\\n",
    "            \"Facts: \\n[Ashley] liked to go to the park with her granddaughter [Charlotte]. \\n[Dale], [Charlotte]'s father, like to take her to the movies instead. \" + \\\n",
    "            \"\\nHere are some additional facts and rules we\\'ve found:\\nDale is the son of Ashley. If Dale is father of Charlotte and Ashley is the grandmother of Charlotte then Dale is the son of Ashley.\\n\" + \\\n",
    "            \"Question: Is the following statement true: \\n\\\"[Ashley] is not [Dale]\\'s mother\\\"\\nAnswer: Let\\'s think step by step. \\n1. We are given that Dale is the son of Ashley. \\n2. If Dale is the son of Ashley, then Ashley is the mother of Dale. \" + \\\n",
    "            \"\\nTherefore, the answer to the question is No, the statement is ot true.\\n***\\n\"\n",
    "\n",
    "ans = few_shot + 'a;sldkfj;alskdjf***'\n",
    "print(ans.split('***')[4])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = 'grandson_of_james\\'_sibling_James_Donald_'\n",
    "split = a.split('_')\n",
    "rel_str = ''\n",
    "for a in split[:-1]:\n",
    "    rel_str += a + '-'\n",
    "rel_str = rel_str[:-1]\n",
    "print(rel_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import copy\n",
    "import json\n",
    "# dataset = '/home/XXXX/XXXX/fs_backup_feb13/SAT-LM/data/clutrr_test.json'\n",
    "dataset = '/home/XXXX/XXXX/fs_backup_feb13/SAT-LM/data/clutrr_test.json'\n",
    "\n",
    "with open(dataset, 'r') as df:\n",
    "    data = json.loads(df.read())\n",
    "\n",
    "# labels = {}\n",
    "\n",
    "# for i in range(len(data)):\n",
    "#     labels['clutrr' + str(i) + '.cnf'] = data[i]['gt']\n",
    "cot_iter = '/home/XXXX/XXXX/fs_backup_feb13/LLM-project/preds/FewShotCOTCLUTRR_Fri_Jan_31_19.25.43_2025_iter'\n",
    "# cot_iter = '/home/XXXX/XXXX/fs_backup_feb13/LLM-project/preds/LOT_pronto_preds_70BTue_Apr_22_10.55.31_2025_iter'\n",
    "# cot_iter=  '/home/XXXX/XXXX/fs_backup_feb13/LLM-project/preds/mistral_FewShotCOTCLUTRR_Tue_Apr_29_12.52.15_2025_iter'\n",
    "# cot_iter = '/home/XXXX/XXXX/fs_backup_feb13/LLM-project/preds/mistral_FewShotCOTPRONTO_Wed_Apr_30_11.26.49_2025_iter'\n",
    "cot_pred = []\n",
    "cot_pred_list = []\n",
    "cot_accs = []\n",
    "pfx = '/home/XXXX/'\n",
    "files = ['XXXX/fs_backup_feb13/LLM-project/preds/LOT_clutrr_8B_preds_iter6_[0]', 'XXXX/fs_backup_feb13/LLM-project/preds/LOT_clutrr_8B_preds_iter8_[4]', 'XXXX/fs_backup_feb13/LLM-project/preds/LOT_clutrr_8B_preds_iter0_[1]', 'XXXX/fs_backup_feb13/LLM-project/preds/LOT_clutrr_8B_preds_iter5_[0]', 'XXXX/fs_backup_feb13/LLM-project/preds/LOT_clutrr_8B_preds_iter1_[1]'] \n",
    "for i in range(20):\n",
    "    cot = np.load(open(cot_iter + str(i), 'rb'))\n",
    "    # cot = np.load(open(pfx + files[i], 'rb'))\n",
    "    # cot = np.load(open(cot_iter + str(i) + '_[0, 1, 2]', 'rb'))\n",
    "\n",
    "    cot_acc = 0\n",
    "    cot_preds = {}\n",
    "    cot_preds_list = []\n",
    "    j = 0\n",
    "    for name in names:\n",
    "        # print(value, [i])\n",
    "        \n",
    "        # name = names[j]\n",
    "        # if j not in cot.keys(): continue\n",
    "        if cot[j] == labels[name].strip(' '):\n",
    "            cot_acc += 1\n",
    "            cot_preds[name] = True\n",
    "            cot_preds_list.append(1)\n",
    "        else:\n",
    "            cot_preds[name] = False\n",
    "            cot_preds_list.append(0)\n",
    "        j += 1\n",
    "    print(cot_acc)\n",
    "    cot_accs.append(cot_acc)\n",
    "    cot_pred.append(copy.deepcopy(cot_preds))\n",
    "    cot_pred_list.append(copy.deepcopy(cot_preds_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "cot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels[name]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "n_votes = []\n",
    "sc_pred = {}\n",
    "for i in range(len(cot_pred_list[0])):\n",
    "    n_votes.append(0)\n",
    "    for j in range(len(cot_pred_list)):\n",
    "    # for j in range(5):\n",
    "        n_votes[-1] += cot_pred_list[j][i]\n",
    "sc_acc = 0\n",
    "for key, value in cot_pred[0].items():\n",
    "    tmp = 0\n",
    "    for j in cot_pred:\n",
    "        tmp+= j[key]\n",
    "    if tmp >=len(cot_pred_list)/2+1: \n",
    "        sc_pred[key] = 1\n",
    "        sc_acc += 1\n",
    "    \n",
    "    else: sc_pred[key]=0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(cot_pred_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_votes[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.ceil(1.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(cot_pred_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.unique(n_votes, return_counts=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(np.sum(np.where(np.array(n_votes) >= np.ceil(len(cot_pred_list)/2+0.5), 1, 0) ))\n",
    "# print(np.sum(np.where(np.array(n_votes) >= 3, 1, 0) ))\n",
    "# print('sc acc:',np.sum(np.where(np.array(n_votes) >= 3, 1, 0) )/len(cot))\n",
    "\n",
    "print('sc acc:',np.sum(np.where(np.array(n_votes) >= (np.ceil(len(cot_pred_list)/2+0.5)), 1, 0) )/len(cot))\n",
    "print('cot acc:', np.mean(cot_accs)/len(cot))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "n_votes\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sklearn.utils import resample\n",
    "from scipy.stats import wilcoxon"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(outs_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "temp_outs_str = '/home/XXXX/XXXX/\n",
    "outs = pkl.load(open(temp_outs_str, 'rb'))\n",
    "outs  = temp_outs\n",
    "bs_outs_acc = []\n",
    "outs_pred = {}\n",
    "outs_acc = 0\n",
    "num_trues = 0\n",
    "for key, value in outs.items():\n",
    "    if len(value[1]['neg']) == 0 and labels[key].strip(' ') == 'false':\n",
    "        outs_pred[key] = True\n",
    "        outs_acc += 1\n",
    "    elif len(value[1]['pos']) == 0 and labels[key].strip(' ') == 'true':\n",
    "        outs_pred[key] = True\n",
    "        outs_acc += 1\n",
    "    else:\n",
    "        outs_pred[key] = False\n",
    "    if labels[key] == 'true':\n",
    "        num_trues += 1\n",
    "outs_acc /= len(outs_pred.keys())\n",
    "outs_pred_val = np.array(list(outs_pred.values()))\n",
    "\n",
    "for i in range(len(outs_pred)):\n",
    "    bs_outs_acc.append(np.sum(resample(outs_pred_val, n_samples=86))/86)\n",
    "bs_sc = []\n",
    "bs_sc_acc = []\n",
    "for i in range(len(outs)):\n",
    "    bs_sc.append(resample(n_votes[:len(outs)], n_samples=86))\n",
    "    bs_sc_acc.append(np.sum(np.where(np.array(bs_sc[-1]) >=(np.ceil(len(cot_pred_list)/2+0.5)), 1, 0) )/len(bs_sc[-1]))\n",
    "# outs['clutrr545.cnf'][1]\n",
    "print(outs_acc)\n",
    "print(outs_acc*len(outs_pred.keys()))\n",
    "print(len(outs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(np.mean(bs_sc_acc))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "outs['clutrr123.cnf'][1]\n",
    "labels[key]\n",
    "key = 'clutrr123.cnf'\n",
    "value = outs[key]\n",
    "if len(value[1]['neg']) == 0 and labels[key].strip(' ') == 'false':\n",
    "        print('yippie')\n",
    "print(value[1]['pos'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(wilcoxon(np.array(bs_outs_acc) - np.array(bs_sc_acc), alternative='greater'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "d = np.random.binomial( 900, 0.5, 86)/900"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from scipy import stats\n",
    "\n",
    "confidence_level=0.95\n",
    "d = d\n",
    "ci = stats.t.interval(confidence_level, df=len(d)-1, loc=np.mean(d), scale=np.std(d, ddof=1) / np.sqrt(len(d)))\n",
    "print(ci)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "np.mean(bs_sc_acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    " \n",
    "\n",
    "plt.hist(n_votes, bins=[0,1,2,3,4,5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "few_shot = \"Facts:\\n[Nancy] likes to cut the hair of her daughter [Heidi].\\n[Heidi]'s sister [Lorraine] went to beauty school and taught them all how to cut hair expertly. \" + \\\n",
    "            \"\\n\\nHere are some additional facts we\\'ve found:\\n[Nancy] is the mother of [Lorraine]\\n\" + \\\n",
    "            \"Question: Is the following statement true: \\n\\\"[Lorraine] is [Nancy]\\'s daughter\\\"\\n\" + \\\n",
    "            \"Answer:\\nLet\\'s think step by step.  \\n1. [Heidi] is the sister of [Lorraine]\\n2. [Heidi] is the daughter of [Nancy]\\n3. If [Heidi] is the sister of [Lorraine] and [Heidi] is the daughter of [Nancy] then [Nancy] is the mother of [Lorraine].\\n4. If [Nancy] is the mother of [Lorraine], then [Lorraine] is the daughter of [Nancy].\\nTherefore, the answer to the question is Yes, the statement is true. \\n***\\n\" + \\\n",
    "            \"Facts:\\n[Dale] and his sister [Nancy] are decorating for a party.\\n[Nancy]'s daughter [Louise] thinks the party will be fun.\\n\" + \\\n",
    "            \"\\nHere are some additional facts we\\'ve found:\\n[Dale] is the uncle of [Louise]\\n\" + \\\n",
    "            \"Question: Is the following statement true: \\n\\\"[Louise] is not [Dales]\\'s niece\\\"\\n\" + \\\n",
    "            \"Answer: Le\\'s think step by step. \\n1. [Nancy] is the sister of [Dale]. \\n2. [Nancy] is the mother of [Louise]\\n3. If [Nancy] is the sister of [Dale] and [Nancy] is the mother of [Louise] then [Dale] is the uncle of [Louise].\\n4.If [Dale] is the uncle of [Louise], then [Louise] is the niece of [Dale].\\nTherefore, the answer is No, the statement is not true.\\n***\\n\" + \\\n",
    "            \"Facts: \\n[Lillian] and her sister [Nancy] are the only children in their family. \\n[Lillian]'s biggest accomplishment is raising her son [Douglas]. \" + \\\n",
    "            \"\\n\\nHere are some additional facts we\\'ve found:\\n[Lillian] is the sister of [Nancy]\\n\" + \\\n",
    "            \"Question: Is the following statement true: \\n\\\"[Douglas] is [Nancy]\\'s nephew\\\"\\n\" + \\\n",
    "            \"Answer:\\nLet\\'s think step by step. \\n1. [Douglas] is [Lillian]\\'s son. \\n2. [Nancy] is [Lillian]\\'s sister. \" + \\\n",
    "            \"\\n3. If [Douglas] is the son of [Lillian] and [Lillian] is the sister of [Nancy] then [Douglas] is the nephew of [Lillian]. \\nTherefore, the answer to the question is Yes, the statement is true. \\n***\\n\" + \\\n",
    "            \"Facts: \\n[Ashley] liked to go to the park with her granddaughter [Charlotte]. \\n[Dale], [Charlotte]'s father, like to take her to the movies instead. \" + \\\n",
    "            \"\\n\\nHere are some additional facts we\\'ve found:\\n[Dale] is the son of [Ashley].\\n\" + \\\n",
    "            \"Question: Is the following statement true: \\n\\\"[Ashley] is not [Dale]\\'s mother\\\"\\n\" + \\\n",
    "            \"Answer:\\nLet\\'s think step by step. \\n1. [Dale] is the father of [Charlotte].\\n2. [Ashley] is the grandmother of [Charlotte]. \\n3. If [Dale] is father of [Charlotte] and [Ashley] is the grandmother of [Charlotte] then [Dale] is the son of [Ashley].\\n4. If [Dale] is the son of [Ashley], then [Ashley] is the mother of [Dale]. \" + \\\n",
    "            \"\\nTherefore, the answer to the question is No, the statement is ot true.\\n***\\n\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle as pkl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "outs_str = '/home/XXXX/XXXX/\n",
    "outs = pkl.load(open(outs_str, 'rb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "len(outs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for key, value in outs.items():\n",
    "#     print(value[5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(outs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "for key, value in labels.items():\n",
    "    labels[key] = value.strip(' ')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "outs_pred = {}\n",
    "outs_acc = 0\n",
    "num_trues = 0\n",
    "true_pos = 0\n",
    "false_pos = 0\n",
    "true_neg = 0\n",
    "false_neg = 0\n",
    "n_false = 0\n",
    "n_true = 0\n",
    "for key, value in outs.items():\n",
    "    if value[4] == True:\n",
    "        if len(value[1]['neg']) == 0 and labels[key] == 'false':\n",
    "            outs_pred[key] = True\n",
    "            outs_acc += 1\n",
    "            true_neg += 1\n",
    "        elif len(value[1]['pos']) == 0 and labels[key] == 'true':\n",
    "            outs_pred[key] = True\n",
    "            outs_acc += 1\n",
    "            true_pos += 1\n",
    "        else:\n",
    "            outs_pred[key] = False\n",
    "            if labels[key] == 'true':\n",
    "                false_neg += 1\n",
    "            else:\n",
    "                false_pos += 1\n",
    "    else:\n",
    "        if len(value[1]['neg']) == 0 and labels[key] == 'true':\n",
    "            outs_pred[key] = True\n",
    "            outs_acc += 1\n",
    "            true_neg += 1\n",
    "        elif len(value[1]['pos']) == 0 and labels[key] == 'false':\n",
    "            outs_pred[key] = True\n",
    "            outs_acc += 1\n",
    "            true_pos += 1\n",
    "        else:\n",
    "            outs_pred[key] = False\n",
    "            if labels[key] == 'true':\n",
    "                false_neg += 1\n",
    "            else:\n",
    "                false_pos += 1\n",
    "    if labels[key] == 'true':\n",
    "        num_trues += 1\n",
    "    \n",
    "    if labels[key] == 'false':\n",
    "       n_false += 1\n",
    "    elif labels[key] == 'true':\n",
    "        n_true += 1\n",
    "outs_acc /= len(outs_pred.keys())\n",
    "# outs['clutrr545.cnf'][1]\n",
    "print(outs_acc)\n",
    "print(outs_acc*len(outs_pred.keys()))\n",
    "print(n_true, n_false)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_false = 0\n",
    "n_true = 0\n",
    "for key, value in labels.items():\n",
    "    if labels[key] == 'false':\n",
    "       n_false += 1\n",
    "    elif labels[key] == 'true':\n",
    "        n_true += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "name_idx = {}\n",
    "i = 0\n",
    "for name in names:\n",
    "    name_idx[name] = i\n",
    "    i += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for key, value in labels.items():\n",
    "    labels[key] = value.strip(' ')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "labels[key]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "temp_outs = pkl.load(open(temp_outs_str, 'rb'))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "len(temp_outs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# /home/XXXX/XXXX/fs_backup_feb13/all_outs_thresh09_rulethresh04_contexthresh05_dynamicTure.pkl\n",
    "missed = pkl.load(open('//home/XXXX/XXXX/fs_backup_feb13//missed_list_' + outs_str, 'rb'))\n",
    "hunh_list = pkl.load(open('/home/XXXX/XXXX/fs_backup_feb13/hunh_' + outs_str, 'rb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "temp_outs_str = '/home/XXXX/XXXX/'\n",
    "temp_outs = pkl.load(open(temp_outs_str, 'rb'))\n",
    "import torch\n",
    "temp_outs_pred = {}\n",
    "outs_acc = 0\n",
    "num_trues = 0\n",
    "true_pos = 0\n",
    "false_pos = 0\n",
    "true_neg = 0\n",
    "false_neg = 0\n",
    "n_false = 0\n",
    "n_true = 0\n",
    "for key, value in temp_outs.items():\n",
    "    if value[4] == True:\n",
    "        if len(value[1]['neg']) == 0 and labels[key] == 'false':\n",
    "            temp_outs_pred[key] = True\n",
    "            outs_acc += 1\n",
    "            true_neg += 1\n",
    "        elif len(value[1]['pos']) == 0 and labels[key] == 'true':\n",
    "            temp_outs_pred[key] = True\n",
    "            outs_acc += 1\n",
    "            true_pos += 1\n",
    "        else:\n",
    "            temp_outs_pred[key] = False\n",
    "            if labels[key] == 'true':\n",
    "                false_neg += 1\n",
    "            else:\n",
    "                false_pos += 1\n",
    "    else:\n",
    "        if len(value[1]['neg']) == 0 and labels[key] == 'true':\n",
    "            temp_outs_pred[key] = True\n",
    "            outs_acc += 1\n",
    "            true_neg += 1\n",
    "        elif len(value[1]['pos']) == 0 and labels[key] == 'false':\n",
    "            temp_outs_pred[key] = True\n",
    "            outs_acc += 1\n",
    "            true_pos += 1\n",
    "        else:\n",
    "            temp_outs_pred[key] = False\n",
    "            if labels[key] == 'true':\n",
    "                false_neg += 1\n",
    "            else:\n",
    "                false_pos += 1\n",
    "    if labels[key] == 'true':\n",
    "        num_trues += 1\n",
    "    \n",
    "    if labels[key] == 'false':\n",
    "       n_false += 1\n",
    "    elif labels[key] == 'true':\n",
    "        n_true += 1\n",
    "outs_acc /= len(temp_outs_pred.keys())\n",
    "# outs['clutrr545.cnf'][1]\n",
    "print(outs_acc)\n",
    "print(outs_acc*len(temp_outs_pred.keys()))\n",
    "print(n_true, n_false)\n",
    "\n",
    "scs_70 = []\n",
    "flag_70 = []\n",
    "lens_70 = []\n",
    "fgf_70 = []\n",
    "fbf_70 = []\n",
    "colors = ['r', 'g']\n",
    "skips = 0\n",
    "for key in list(temp_outs.keys()):\n",
    "    mat = torch.stack(temp_outs[key][5]) / torch.stack(temp_outs[key][5]).sum(1).reshape(-1,1)\n",
    "    mat = mat[:,0]\n",
    "    if labels[key] == 'true':\n",
    "        mat = torch.concatenate((torch.tensor([n_votes[name_idx[key]]/20]), mat))\n",
    "    else:\n",
    "        mat = torch.concatenate((torch.tensor([1-(n_votes[name_idx[key]]/20)]), mat))\n",
    "    lens_70.append(len(mat)-1)\n",
    "    if labels[key] == 'true':\n",
    "        if mat[0] < 0.5 and mat[-1] > 0.5:\n",
    "            flag_70.append(2)\n",
    "            for z in range(len(mat)):\n",
    "                if mat[z] > 0.5:\n",
    "                    fgf_70.append(z)\n",
    "                    break\n",
    "        elif mat[0] > 0.5 and mat[-1] < 0.5:\n",
    "            flag_70.append(3)\n",
    "            for z in range(len(mat)):\n",
    "                if mat[z] < 0.5:\n",
    "                    fbf_70.append(z)\n",
    "                    break\n",
    "        else: flag_70.append(temp_outs_pred[key])\n",
    "    else:\n",
    "        if mat[0] > 0.5 and mat[-1] < 0.5:\n",
    "            flag_70.append(2)\n",
    "            for z in range(len(mat)):\n",
    "                if mat[z] < 0.5:\n",
    "                    fgf_70.append(z)\n",
    "                    break\n",
    "        elif mat[0] < 0.5 and mat[-1] > 0.5:\n",
    "            flag_70.append(3)\n",
    "            for z in range(len(mat)):\n",
    "                if mat[z] > 0.5:\n",
    "                    fbf_70.append(z)\n",
    "                    break\n",
    "        else: flag_70.append(temp_outs_pred[key])\n",
    "    # try: flag_70.append(temp_outs_pred[key])\n",
    "    # except: \n",
    "    #     skips += 1\n",
    "    #     continue\n",
    "    scs_70.append(mat.clone())\n",
    "\n",
    "\n",
    "temp_outs_pred = {}\n",
    "outs_acc = 0\n",
    "num_trues = 0\n",
    "true_pos = 0\n",
    "false_pos = 0\n",
    "true_neg = 0\n",
    "false_neg = 0\n",
    "n_false = 0\n",
    "n_true = 0\n",
    "for key, value in temp_outs.items():\n",
    "    if value[4] == True:\n",
    "        if len(value[1]['neg']) == 0 and labels[key] == 'false':\n",
    "            temp_outs_pred[key] = True\n",
    "            outs_acc += 1\n",
    "            true_neg += 1\n",
    "        elif len(value[1]['pos']) == 0 and labels[key] == 'true':\n",
    "            temp_outs_pred[key] = True\n",
    "            outs_acc += 1\n",
    "            true_pos += 1\n",
    "        else:\n",
    "            temp_outs_pred[key] = False\n",
    "            if labels[key] == 'true':\n",
    "                false_neg += 1\n",
    "            else:\n",
    "                false_pos += 1\n",
    "    else:\n",
    "        if len(value[1]['neg']) == 0 and labels[key] == 'true':\n",
    "            temp_outs_pred[key] = True\n",
    "            outs_acc += 1\n",
    "            true_neg += 1\n",
    "        elif len(value[1]['pos']) == 0 and labels[key] == 'false':\n",
    "            temp_outs_pred[key] = True\n",
    "            outs_acc += 1\n",
    "            true_pos += 1\n",
    "        else:\n",
    "            temp_outs_pred[key] = False\n",
    "            if labels[key] == 'true':\n",
    "                false_neg += 1\n",
    "            else:\n",
    "                false_pos += 1\n",
    "    if labels[key] == 'true':\n",
    "        num_trues += 1\n",
    "    \n",
    "    if labels[key] == 'false':\n",
    "       n_false += 1\n",
    "    elif labels[key] == 'true':\n",
    "        n_true += 1\n",
    "outs_acc /= len(temp_outs_pred.keys())\n",
    "# outs['clutrr545.cnf'][1]\n",
    "print(outs_acc)\n",
    "print(outs_acc*len(temp_outs_pred.keys()))\n",
    "print(n_true, n_false)\n",
    "\n",
    "scs_8 = []\n",
    "flag_8 = []\n",
    "lens_8 = []\n",
    "first_good_flip_8= []\n",
    "first_bad_flip_8 = []\n",
    "colors = ['r', 'g', 'b', 'orange']\n",
    "plot_labels = ['unflipped-wrong', 'unflipped-correct', 'flipped correct', 'flipped incorrect']\n",
    "skips = 0\n",
    "for key in list(temp_outs.keys())[:100]:\n",
    "    mat = torch.stack(temp_outs[key][5]) / torch.stack(temp_outs[key][5]).sum(1).reshape(-1,1)\n",
    "    mat = mat[:,0]\n",
    "    if labels[key] == 'true':\n",
    "        mat = torch.concatenate((torch.tensor([n_votes[name_idx[key]]/20]), mat))\n",
    "    else:\n",
    "        mat = torch.concatenate((torch.tensor([1-(n_votes[name_idx[key]]/20)]), mat))\n",
    "    lens_8.append(len(mat)-1)\n",
    "    if labels[key] == 'true':\n",
    "        if mat[0] < 0.5 and mat[-1] > 0.5:\n",
    "            flag_8.append(2)\n",
    "            for z in range(len(mat)):\n",
    "                if mat[z] > 0.5:\n",
    "                    first_good_flip_8.append(z)\n",
    "                    break\n",
    "        elif mat[0] > 0.5 and mat[-1] < 0.5:\n",
    "            flag_8.append(3)\n",
    "            for z in range(len(mat)):\n",
    "                if mat[z] < 0.5:\n",
    "                    first_bad_flip_8.append(z)\n",
    "                    break\n",
    "        else: flag_8.append(temp_outs_pred[key])\n",
    "    else:\n",
    "        if mat[0] > 0.5 and mat[-1] < 0.5:\n",
    "            flag_8.append(2)\n",
    "            for z in range(len(mat)):\n",
    "                if mat[z] < 0.5:\n",
    "                    first_good_flip_8.append(z)\n",
    "                    break\n",
    "        elif mat[0] < 0.5 and mat[-1] > 0.5:\n",
    "            flag_8.append(3)\n",
    "            for z in range(len(mat)):\n",
    "                if mat[z] > 0.5:\n",
    "                    first_bad_flip_8.append(z)\n",
    "                    break\n",
    "        else: flag_8.append(temp_outs_pred[key])\n",
    "    \n",
    "    scs_8.append(mat.clone())\n",
    "    \n",
    "import matplotlib.patches as mpatches\n",
    " \n",
    "\n",
    "fig1, ax1 = plt.subplots()\n",
    "for i in range(len(lens_8)):\n",
    "    for j in range(lens_8[i]):\n",
    "        lens_8.append(j)\n",
    "for i in range(len(scs_8)):\n",
    "    ax1.plot(scs_8[i], c=colors[int(flag_8[i])],label=plot_labels[int(flag_8[i])])\n",
    "line = [1]\n",
    "for i in range(0,6):\n",
    "    line.append(1-i*0.1)\n",
    "line2 = [0]\n",
    "for i in range(0,6):\n",
    "    line2.append(i*0.1)\n",
    "ax1.plot(line,'--', c='black')\n",
    "ax1.plot(line2,'--', c='black')\n",
    "\n",
    "ax1.set_title('8B \\\"True\\\" classification confidence over iterations (n=' + str(len(scs_8))+')')\n",
    "# fig5, ax5 = plt.subplots()\n",
    "\n",
    "ax1.hist(lens_8, density=True, alpha=0.4, bins=range(8))\n",
    "# flag_8.append(2)\n",
    "\n",
    "# un, flag8_counts = np.unique(flag_8, return_counts=True)\n",
    "# patches = []\n",
    "# for i in range(len(colors)):\n",
    "#     patches.append(mpatches.Patch(color=colors[i], label=plot_labels[i] + ' (n=' + str(flag8_counts[i]) + ')'))\n",
    "# ax1.legend(handles=patches)\n",
    "# # ax1.set_xticks(list(range(0,11)))\n",
    "# # ax1.set_xticklabels(list(range(1,12)))\n",
    "\n",
    "\n",
    "for i in range(len(lens_70)):\n",
    "    for j in range(lens_70[i]):\n",
    "        lens_70.append(j)\n",
    "fig2, ax2 = plt.subplots()\n",
    "for i in range(len(scs_70)):\n",
    "    ax2.plot(scs_70[i], c=colors[int(flag_70[i])], label=plot_labels[int(flag_70[i])])\n",
    "ax2.set_title('70B \\\"True\\\" classification confidence over iterations (n=' + str(len(scs_70))+')')\n",
    "# flag_70.append(3)\n",
    "un, flag70_counts = np.unique(flag_70, return_counts=True)\n",
    "patches = []\n",
    "for i in range(len(colors)):\n",
    "    patches.append(mpatches.Patch(color=colors[i], label=plot_labels[i] + ' (n=' + str(flag70_counts[i]) + ')'))\n",
    "ax2.legend(handles=patches)\n",
    "ax2.hist(lens_70, density=True, alpha=0.4, bins=range(8))\n",
    "ax2.plot(line,'--', c='black')\n",
    "ax2.plot(line2,'--', c='black')\n",
    "# ax2.set_xlim(0,6) \n",
    "\n",
    "# ax2.set_xticks(list(range(0,11)))\n",
    "# ax2.set_xticklabels(list(range(1,12)))\n",
    "\n",
    "fig3, ax3 = plt.subplots()\n",
    "ax3.hist(first_good_flip_8, label='First Good Flip Iteration')\n",
    "ax3.set_title(\"8B - Iteration of Good Flip\")\n",
    "fig4, ax4 = plt.subplots()\n",
    "ax4.hist(first_bad_flip_8, label='First Bad Flip Iteration')\n",
    "ax4.set_title(\"8B - Iteration of Bad Flip\")\n",
    "\n",
    "fig5, ax5 = plt.subplots()\n",
    "ax5.hist(fgf_70, label='First Good Flip Iteration')\n",
    "ax5.set_title(\"70B - Iteration of Good Flip\")\n",
    "fig6, ax6 = plt.subplots()\n",
    "ax6.hist(fbf_70, label='First Bad Flip Iteration')\n",
    "ax6.set_title(\"70B - Iteration of Bad Flip\")\n",
    "# ax3.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "flag70_counts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "x = []\n",
    "fx = [[],[],[],[]]\n",
    "fy = [[],[],[],[]]\n",
    "y = []\n",
    "for j in range(len(scs_70)):\n",
    "    s = scs_70[j]\n",
    "    for i in range(len(s)):\n",
    "        x.append(i)\n",
    "        y.append(s[i])\n",
    "        fy[flag_70[j]].append(s[i])\n",
    "        fx[flag_70[j]].append(i)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "y_a = np.array(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "y_a.max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "hist, xedges, yedges = np.histogram2d(x, y, bins=[[2,3,4,5,6,7],[0,0.2,0.4,0.6,0.8]])\n",
    "\n",
    "\n",
    "#  range=[np.arange(0, 8, 1), np.arange(0, 1, 0.2)]\n",
    "\n",
    "\n",
    "fig = plt.figure()\n",
    "ax = fig.add_subplot(projection='3d')\n",
    "\n",
    "\n",
    "# Construct arrays for the anchor positions of the 16 bars.\n",
    "xpos, ypos = np.meshgrid(xedges[:-1], yedges[:-1], indexing=\"ij\")\n",
    "xpos = xpos.ravel()\n",
    "ypos = ypos.ravel()\n",
    "zpos = 0\n",
    "\n",
    "# Construct arrays with the dimensions for the 16 bars.\n",
    "dx = dy = 0.5 * np.ones_like(zpos)\n",
    "dz = hist.ravel()\n",
    "\n",
    "ax.bar3d(xpos, ypos, zpos, 0.5, 0.2, dz, zsort='average')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from matplotlib.colors import LightSource\n",
    "fhist = []\n",
    "for i in range(4):\n",
    "    hist, xedges, yedges = np.histogram2d(fx[i], fy[i], bins=[[2,3,4,5,6,7],[0,0.2,0.4,0.6,0.8,1]])\n",
    "    fhist.append(hist)\n",
    "    # xedges.append\n",
    "    \n",
    "# xedges *= 100\n",
    "colors = ['r', 'g', 'b', 'orange']\n",
    "\n",
    "#  range=[np.arange(0, 8, 1), np.arange(0, 1, 0.2)]\n",
    "\n",
    "\n",
    "fig = plt.figure()\n",
    "ax = fig.add_subplot(projection='3d')\n",
    "ax.view_init(elev=40, azim=320, roll=0)\n",
    "\n",
    "plot_labels = ['unflipped-wrong', 'unflipped-correct', 'flipped correct', 'flipped incorrect']\n",
    "\n",
    "# Construct arrays for the anchor positions of the 16 bars.\n",
    "xpos, ypos = np.meshgrid(xedges[:-1], yedges[:-1], indexing=\"ij\")\n",
    "xpos = xpos.ravel()\n",
    "ypos = ypos.ravel()*100\n",
    "zpos = 0\n",
    "\n",
    "# Construct arrays with the dimensions for the 16 bars.\n",
    "dx = dy = 0.5 * np.ones_like(zpos)\n",
    "dz = fhist[0].ravel()\n",
    "\n",
    "# ax.bar3d(xpos, ypos, zpos, 0.5,0.5, dz, zsort='max', color = colors[0])\n",
    "cumhist = np.zeros_like(dz)\n",
    "for j in range(len(xpos)):\n",
    "    x = xpos[j]\n",
    "    y = ypos[j]\n",
    "    cumhist=0\n",
    "    # print(x,y)\n",
    "    for i in [3,0,2,1]:\n",
    "\n",
    "        # fig = plt.figure()\n",
    "        # ax = fig.add_subplot(projection='3d')\n",
    "        dz = fhist[i].ravel()[j]\n",
    "        # dz = fhist[i].ravel()\n",
    "        ax.bar3d(x, y, cumhist, 0.5, 10,dz ,zorder=0, color=colors[i], lightsource=LightSource(azdeg=190))\n",
    "        cumhist += dz\n",
    "\n",
    "ax.set_xlabel('Iteration Number')\n",
    "ax.set_ylabel('confidence')\n",
    "ax.set_title('Histogram of Confidences as ARGOS Iterates')\n",
    "ax.set_yticklabels([str(i) + '%' for i in [-100, -60, -20, 20, 60, 100]])\n",
    "ax.set_ylim(0, 101)\n",
    "patches = []\n",
    "for i in range(len(colors)):\n",
    "    patches.append(mpatches.Patch(color=colors[i], label=plot_labels[i] + ' (n=' + str(flag70_counts[i]) + ')'))\n",
    "ax.legend(handles=patches,bbox_to_anchor=(1.3,1))\n",
    "# ax.legend(handles=patches,bbox_to_anchor=(0.7,-0.1))\n",
    "fig.savefig('./threedhist.pdf', bbox_inches='tight')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "lvsf = [[],[],[],[]]\n",
    "totals = []\n",
    "rebins = [0, 0.2, 0.4, 0.6, 0.8, 1]\n",
    "for i in range(len(scs_70)):\n",
    "    b = np.digitize(scs_70[i][0], bins)\n",
    "    # print(b)\n",
    "    lvsf[flag_70[i]].append(len(scs_70[i])-1)\n",
    "    totals.append(len(scs_70[i])-1)\n",
    "    \n",
    "fig, ax = plt.subplots()\n",
    "ax.hist(lvsf, stacked=True, color=cs)\n",
    "patches = []\n",
    "# ax.legend()\n",
    "for i in range(len(cs)):\n",
    "    patches.append(mpatches.Patch(color=cs[i], label=plot_labels[i] + ' (n=' + str(flag70_counts[i]) + ')'))\n",
    "ax.legend(handles=patches,bbox_to_anchor=(1.3,1))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Example data for 3 groups\n",
    "np.random.seed(0)\n",
    "# data1 = np.random.normal(0, 1, 1000)\n",
    "# data2 = np.random.normal(1, 1, 1000)\n",
    "# data3 = np.random.normal(2, 1, 1000)\n",
    "data0 = lvsf[0]\n",
    "data1 = lvsf[1]\n",
    "data2 = lvsf[2]\n",
    "data3 = lvsf[3]\n",
    "# Define bin edges\n",
    "# bins = np.linspace(-4, 6, 15)\n",
    "bins = np.array([1,2,3,4,5,6,7])\n",
    "\n",
    "# Get histogram counts for each dataset\n",
    "counts1, _ = np.histogram(data1, bins=bins)\n",
    "counts2, _ = np.histogram(data2, bins=bins)\n",
    "counts3, _ = np.histogram(data3, bins=bins)\n",
    "counts0, _ = np.histogram(data0, bins=bins)\n",
    "\n",
    "# Stack the counts\n",
    "counts = np.vstack([counts0, counts1, counts2, counts3])\n",
    "\n",
    "# Normalize so each column (bin) sums to 1\n",
    "normalized_counts = counts / counts.sum(axis=0, keepdims=True)\n",
    "\n",
    "# Handle divide-by-zero (empty bins)\n",
    "normalized_counts = np.nan_to_num(normalized_counts)\n",
    "\n",
    "# Plot the stacked histogram with normalized heights\n",
    "bin_centers = 0.5 * (bins[:-1] + bins[1:])\n",
    "width = np.diff(bins)\n",
    "\n",
    "# Plot\n",
    "fig, ax = plt.subplots()\n",
    "bottom = np.zeros_like(bin_centers)\n",
    "\n",
    "# colors = ['blue', 'orange', 'green']\n",
    "# labels = ['Data1', 'Data2', 'Data3']\n",
    "\n",
    "for i in range(normalized_counts.shape[0]):\n",
    "    ax.bar(bin_centers, normalized_counts[i], width=width, bottom=bottom,\n",
    "           color=cs[i], label=plot_labels[i], edgecolor='black')\n",
    "    bottom += normalized_counts[i]\n",
    "\n",
    "ax.set_ylabel('Proportion')\n",
    "ax.set_xlabel('# ARGOS Iterations Before Exit')\n",
    "ax.spines['top'].set_visible(False)\n",
    "ax.spines['right'].set_visible(False)\n",
    "# ax.spines['bottom'].set_visible(False)\n",
    "ax.spines['left'].set_visible(False)\n",
    "ax.set_yticks([])\n",
    "ax.set_xticks([1,2,3,4,5,6])\n",
    "ax.set_xticklabels([1,2,3,4,5,6])\n",
    "____, totalcounts = np.unique(totals, return_counts=True)\n",
    "\n",
    "ax_top = fig.add_axes([0.125, 0.85, 0.775, 0.2])  # [left, bottom, width, height]\n",
    "ax_top.hist(totals, bins=[1,2,3,4,5,6,7], alpha=0.6)\n",
    "# ax_top.set_xlabel('Iteration Count')\n",
    "ax_top.spines['top'].set_visible(False)\n",
    "ax_top.spines['right'].set_visible(False)\n",
    "ax_top.spines['bottom'].set_visible(False)\n",
    "ax_top.spines['left'].set_visible(False)\n",
    "ax_top.set_yticks([])\n",
    "ax_top.set_xticks([])\n",
    "ax_top.set_ylabel('Total count')\n",
    "for c in range(len(totalcounts)):\n",
    "    ax_top.text(x = 1.2 + 0.05 + 1*c, y = totalcounts[c] + 100, s = str(totalcounts[c]))\n",
    "# ax.set_title('Normalized Stacked Histogram')\n",
    "# ax.legend()\n",
    "patches = []\n",
    "for i in range(len(cs)):\n",
    "    patches.append(mpatches.Patch(color=cs[i], label=plot_labels[i] + ' (n=' + str(flag70_counts[i]) + ')'))\n",
    "ax_top.legend(handles=patches,bbox_to_anchor=(0.3,0.5))\n",
    "# plt.show()\n",
    "fig.savefig('./lenhist.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "____, totalcounts = np.unique(totals, return_counts=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "totalcounts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Example data\n",
    "np.random.seed(0)\n",
    "data1 = np.random.normal(0, 1, 1000)\n",
    "data2 = np.random.normal(1, 1, 1000)\n",
    "data3 = np.random.normal(2, 1, 1000)\n",
    "\n",
    "# Bin setup\n",
    "bins = np.linspace(-4, 6, 15)\n",
    "bin_centers = 0.5 * (bins[:-1] + bins[1:])\n",
    "width = np.diff(bins)\n",
    "\n",
    "# Hist counts\n",
    "counts1, _ = np.histogram(data1, bins=bins)\n",
    "counts2, _ = np.histogram(data2, bins=bins)\n",
    "counts3, _ = np.histogram(data3, bins=bins)\n",
    "\n",
    "# Stack counts\n",
    "counts = np.vstack([counts1, counts2, counts3])\n",
    "total_counts = counts.sum(axis=0)\n",
    "\n",
    "# Normalize\n",
    "normalized_counts = counts / counts.sum(axis=0, keepdims=True)\n",
    "normalized_counts = np.nan_to_num(normalized_counts)\n",
    "\n",
    "# Create main and top axes\n",
    "fig, ax = plt.subplots(figsize=(10, 6))\n",
    "divider_height = 0.25\n",
    "ax_top = fig.add_axes([0.125, 0.75, 0.775, 0.2], sharex=ax)  # [left, bottom, width, height]\n",
    "\n",
    "# --- Top histogram (total counts) ---\n",
    "ax_top.bar(bin_centers, total_counts, width=width, color='lightgray', edgecolor='black')\n",
    "# ax_top.set_ylabel(\"Total count\")\n",
    "ax_top.spines['bottom'].set_visible(False)\n",
    "ax_top.tick_params(labelbottom=False)  # Hide x-ticks for top plot\n",
    "ax_top.set_yticks([])\n",
    "ax_top.spines['top'].set_visible(False)\n",
    "ax_top.spines['right'].set_visible(False)\n",
    "ax_top.spines['bottom'].set_visible(False)\n",
    "ax_top.spines['left'].set_visible(False)\n",
    "\n",
    "# --- Bottom normalized stacked bar plot ---\n",
    "bottom = np.zeros_like(bin_centers)\n",
    "colors = ['blue', 'orange', 'green']\n",
    "labels = ['Data1', 'Data2', 'Data3']\n",
    "\n",
    "for i in range(normalized_counts.shape[0]):\n",
    "    ax.bar(bin_centers, normalized_counts[i], width=width, bottom=bottom,\n",
    "           color=colors[i], edgecolor='black', label=labels[i])\n",
    "    bottom += normalized_counts[i]\n",
    "\n",
    "ax.set_ylabel('Proportion (normalized)')\n",
    "ax.set_xlabel('Value')\n",
    "ax.set_ylim(0, 1.05)\n",
    "ax.legend(loc='upper right')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "np.array(lvsf[4]).max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "np.digitize(scs_70[0], bins)\n",
    "print(flag_70[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "bins = [0.01, 0.25, 0.5, 0.75, 1.1]\n",
    "rebins = [0, 0.2, 0.4,0.6, 0.8, 1]\n",
    "bins = np.array(rebins)+0.01\n",
    "fvsc = [[],[],[],[]]\n",
    "for i in range(len(flag_70)):\n",
    "    fvsc[flag_70[i]].append(rebins[np.digitize(scs_70[i][0], bins)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "fvsc[1][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from matplotlib import colors\n",
    "print(colors.ListedColormap('Accent').colors)\n",
    "cs = ['r', 'g', 'b', 'orange']\n",
    "plot_labels = ['unflipped-wrong', 'unflipped-correct', 'flipped correct', 'flipped incorrect']\n",
    "import matplotlib as mpl\n",
    "\n",
    "n_lines = 8\n",
    "cmap = mpl.colormaps['Accent']\n",
    "\n",
    "# Take colors at regular intervals spanning the colormap.\n",
    "cs = cmap(np.linspace(0, 1, n_lines))\n",
    "cs = [cs[5], cs[0], cs[1], cs[2]]\n",
    "print(cs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots()\n",
    "ax.hist(fvsc, stacked=True, bins=rebins, color=cs)\n",
    "patches = []\n",
    "# ax.legend()\n",
    "for i in range(len(cs)):\n",
    "    patches.append(mpatches.Patch(color=cs[i], label=plot_labels[i] + ' (n=' + str(flag70_counts[i]) + ')'))\n",
    "ax.legend(bbox_to_anchor=(0.15, 1.33), loc=2, borderaxespad=0, handles=patches)\n",
    "ax.set_xticks(rebins)\n",
    "# ax.set_xticklabels(['-100%', '-50%', '0%', '50%', '100%'])\n",
    "ax.set_yticks([])\n",
    "# ax.set_yticklabel([])\n",
    "ax.spines['top'].set_visible(False)\n",
    "ax.spines['right'].set_visible(False)\n",
    "ax.spines['bottom'].set_visible(False)\n",
    "ax.spines['left'].set_visible(False)\n",
    "ax.set_xlabel('Initial Solvability (Confidence)')\n",
    "fig.savefig('./2dhist.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "padded_scs_70 = []\n",
    "\n",
    "from matplotlib import colors\n",
    "for i in range(len(scs_70))[100:150]:\n",
    "    # padded_scs_70.append(list(scs_70[i]))\n",
    "    for z in range(2):\n",
    "        padded_scs_70.append([])\n",
    "        for j in range(len(scs_70[i])):\n",
    "            padded_scs_70[-1] += [(scs_70[i][j])]*10\n",
    "        for j in range(len(scs_70[i]), 15):\n",
    "            # padded_scs_70[-1] += [torch.tensor(-1000000000000000000)]*100\n",
    "\n",
    "            padded_scs_70[-1] += [torch.tensor(-1)]*10\n",
    "\n",
    "\n",
    "        \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "padded_scs_70_sorted = sorted(padded_scs_70, key= lambda x: x[0])\n",
    "final = []\n",
    "for s in padded_scs_70_sorted:\n",
    "    for z in range(len(s)):\n",
    "        if s[z] == -1:\n",
    "            final.append(s[z-1])\n",
    "            break\n",
    "final = np.array(final)\n",
    "final_as = final.argsort()\n",
    "padded_scs_70_sorted = [padded_scs_70_sorted[i] for i in final_as]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import operator\n",
    "\n",
    "scs_70_sorted = sorted(scs_70[100:150], key=operator.itemgetter(0, -1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "padded_scs_70_sorted = []\n",
    "scs_70_sorted = sorted(scs_70[100:150], key = lambda x: x[0])\n",
    "for i in range(len(scs_70_sorted)):\n",
    "    for z in range(2):\n",
    "        padded_scs_70_sorted.append([])\n",
    "\n",
    "        for j in range(len(scs_70_sorted[i])):\n",
    "            padded_scs_70_sorted[-1] += [(scs_70_sorted[i][j])]*10\n",
    "        for j in range(len(scs_70_sorted[i]), 15):\n",
    "            padded_scs_70_sorted[-1] += [torch.tensor(-1)]*10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "bins = [-0.01, 0, 0.2, 0.4, 0.6, 0.8, 1, 1.02]\n",
    "# scs_70_sorted = sorted(scs_70[100:150], key=operator.itemgetter(0, -1))\n",
    "scs_70_sorted = scs_70[100:150]\n",
    "\n",
    "d_scs_s = []\n",
    "for i in range(len(scs_70_sorted)):\n",
    "    d_scs_s.append(np.digitize(scs_70_sorted[i], bins))\n",
    "r_scs_s = []\n",
    "for i in range(len(d_scs_s)):\n",
    "    r_scs_s.append([])\n",
    "    for j in range(len(d_scs_s[i])):\n",
    "        r_scs_s[-1].append(bins[d_scs_s[i][j]-1])\n",
    "                   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "padded_scs_70_sorted = []\n",
    "# scs_70_sorted = sorted(scs_70, key = lambda x: x[0])\n",
    "scs_70_sorted = sorted(r_scs_s, key=operator.itemgetter(0,1, -1))\n",
    "\n",
    "from matplotlib import colors\n",
    "for i in range(len(scs_70_sorted)):\n",
    "    # padded_scs_70.append(list(scs_70[i]))\n",
    "    for z in range(2):\n",
    "        padded_scs_70_sorted.append([])\n",
    "        for j in range(len(scs_70_sorted[i])):\n",
    "            padded_scs_70_sorted[-1] += [(scs_70_sorted[i][j])]*10\n",
    "        for j in range(len(scs_70_sorted[i]), 15):\n",
    "            # padded_scs_70[-1] += [torch.tensor(-1000000000000000000)]*100\n",
    "\n",
    "            padded_scs_70_sorted[-1] += [torch.tensor(-1)]*10\n",
    "\n",
    "\n",
    "        \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "np.stack(padded_scs_70_sorted).max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from matplotlib import colors\n",
    "# bounds = [-1, 0.49,0.59, 0.69, 0.79, 0.89, 0.99, 1.09]\n",
    "bounds = [-1, -0.01, 0.19, 0.39, 0.59, 0.79, 1.01]\n",
    "norm = colors.BoundaryNorm(bounds, len(bounds))\n",
    "cmap = 'Paired'\n",
    "figs, axs = plt.subplots(1,1)\n",
    "im = axs.imshow(padded_scs_70_sorted, cmap=cmap, norm=norm)\n",
    "# im = axs.imshow([[0.1]*10]*10, cmap=cmap, norm=norm)\n",
    "\n",
    "axs.set_xticks([0,10, 20, 30, 40, 50, 60])\n",
    "axs.set_yticks([0, 20, 40, 60, 70, 80, 100])\n",
    "axs.set_xlim(0, 70)\n",
    "axs.set_xticklabels(['SC',1, 2,3, 4, 5,6])\n",
    "axs.set_xlabel('Number of Iterations')\n",
    "axs.set_ylabel('Question ID')\n",
    "axs.set_yticklabels([0, 10, 20, 30, 35, 40,50])\n",
    "# axs.set_\n",
    "# axs.set_xlim(0, 35)\n",
    "# cbar = figs.colorbar(im, ax=axs, ticks=[-0.25, 0.55, 0.65, 0.75, 0.85, 0.95, 1.05, 1.15])\n",
    "cbar = figs.colorbar(im, ax=axs)\n",
    "# cbar.ax.set_yticks([-0.25, 0.55, 0.65, 0.75, 0.85, 0.95, 1.05])\n",
    "# cbar.ax.set_yticks(np.array([-1, 0, 0.2, 0.4, 0.6, 0.8, 1.0]))\n",
    "# cbar.set_ylim(-1, 1)\n",
    "\n",
    "# cbar.ax.set_yticklabels(['terminated'] + [str(i) + '%' for i in list(((np.array(bounds[1:-1])+0.01)*100).astype(int))])\n",
    "\n",
    "cbar.ax.set_yticklabels(['terminated'] + [str(i) + '%' for i in [-100, -60, -20, 20, 60, 100]] )\n",
    "\n",
    "# cbar.ax.set_ylabel('Confidence (%)')\n",
    "# grad = []\n",
    "# for b in bounds:\n",
    "#     grad+= [([b+0.01]*10)]*10\n",
    "# # grad = [gr\n",
    "# # fig1, ax1 = plt.subplots()\n",
    "# axs[1].imshow(grad, cmap=cmap, norm=norm, origin='lower')\n",
    "\n",
    "# axs[1].set_yticklabels(['terminated'] + list(np.array(bounds[1:-1])+0.01))\n",
    "# axs[1].set_yticks([5, 15, 25, 35, 45, 55, 65, 75])\n",
    "# # ax1.set_yticklabels(['terminated'\n",
    "# axs[1].set_xticks([])\n",
    "# axs[1].set_ylim(0,65)\n",
    "figs.savefig('./cbarplot.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "["
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "padded_scs_70 = []\n",
    "\n",
    "from matplotlib import colors\n",
    "for i in range(len(scs_70))[100:150]:\n",
    "    # padded_scs_70.append(list(scs_70[i]))\n",
    "    for z in range(2):\n",
    "        padded_scs_70.append([])\n",
    "        for j in range(len(scs_70[i])):\n",
    "            padded_scs_70[-1] += [(scs_70[i][j])]*10\n",
    "        for j in range(len(scs_70[i]), 15):\n",
    "            # padded_scs_70[-1] += [torch.tensor(-1000000000000000000)]*100\n",
    "\n",
    "            padded_scs_70[-1] += [torch.tensor(-1)]*10\n",
    "\n",
    "\n",
    "        \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from matplotlib import colors\n",
    "# bounds = [-1, 0.49,0.59, 0.69, 0.79, 0.89, 0.99, 1.09]\n",
    "bounds = [-1, -0.01, 0.19, 0.39, 0.59, 0.79, 0.99, 1.19]\n",
    "norm = colors.BoundaryNorm(bounds, len(bounds))\n",
    "cmap = 'Paired'\n",
    "figs, axs = plt.subplots(1,1)\n",
    "im = axs.imshow(padded_scs_70, cmap=cmap, norm=norm)\n",
    "axs.set_xticks([0,10, 20, 30, 40, 50, 60])\n",
    "axs.set_yticks([0, 20, 40, 60, 68, 80, 100])\n",
    "axs.set_xlim(0, 70)\n",
    "axs.set_xticklabels(['SC',1, 2,3, 4, 5,6])\n",
    "axs.set_xlabel('Number of Iterations')\n",
    "axs.set_ylabel('Question ID')\n",
    "axs.set_yticklabels([0, 10, 20, 30, 34, 40,50])\n",
    "# axs.set_\n",
    "# axs.set_xlim(10, 60)\n",
    "# cbar = figs.colorbar(im, ax=axs, ticks=[-0.25, 0.55, 0.65, 0.75, 0.85, 0.95, 1.05, 1.15])\n",
    "cbar = figs.colorbar(im, ax=axs)\n",
    "# cbar.ax.set_yticks([-0.25, 0.55, 0.65, 0.75, 0.85, 0.95, 1.05])\n",
    "cbar.ax.set_yticks([-0.5, 0.1, 0.3, 0.5, 0.7, 0.9, 1.1])\n",
    "\n",
    "cbar.ax.set_yticklabels(['terminated'] + [str(i) + '%' for i in list(((np.array(bounds[1:-1])+0.01)*100).astype(int))])\n",
    "\n",
    "cbar.ax.set_yticklabels(['terminated'] + [str(i) + '%' for i in [-100, -60, -20, 20, 60, 100]])\n",
    "\n",
    "# cbar.ax.set_ylabel('Confidence (%)')\n",
    "# grad = []\n",
    "# for b in bounds:\n",
    "#     grad+= [([b+0.01]*10)]*10\n",
    "# # grad = [gr\n",
    "# # fig1, ax1 = plt.subplots()\n",
    "# axs[1].imshow(grad, cmap=cmap, norm=norm, origin='lower')\n",
    "\n",
    "# axs[1].set_yticklabels(['terminated'] + list(np.array(bounds[1:-1])+0.01))\n",
    "# axs[1].set_yticks([5, 15, 25, 35, 45, 55, 65, 75])\n",
    "# # ax1.set_yticklabels(['terminated'\n",
    "# axs[1].set_xticks([])\n",
    "# axs[1].set_ylim(0,65)\n",
    "figs.savefig('./cbarplot.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "temp_outs[list(temp_outs.keys())[147]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(list(temp_outs.keys())[147])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "data[int(list(temp_outs.keys())[102].split('clutrr')[1].split(\".\")[0])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(list(temp_outs.keys())[118])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "data[int(list(temp_outs.keys())[119].split('clutrr')[1].split(\".\")[0])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(list(temp_outs.keys())[119])\n",
    "print(temp_outs[list(temp_outs.keys())[119]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "list(temp_outs.keys())[119]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "scs_70[119]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from matplotlib.colors import LightSource\n",
    "fhist = []\n",
    "for i in range(4):\n",
    "    hist, xedges, yedges = np.histogram2d(fx[i], fy[i], bins=[[2,3,4,5,6,7],[0,0.2,0.4,0.6,0.8,1]])\n",
    "    fhist.append(hist)\n",
    "    # xedges.append\n",
    "\n",
    "\n",
    "#  range=[np.arange(0, 8, 1), np.arange(0, 1, 0.2)]\n",
    "\n",
    "colors = ['r', 'g', 'g', 'r']\n",
    "fig = plt.figure()\n",
    "ax = fig.add_subplot(projection='3d')\n",
    "ax.view_init(elev=40, azim=320, roll=0)\n",
    "\n",
    "\n",
    "# Construct arrays for the anchor positions of the 16 bars.\n",
    "xpos, ypos = np.meshgrid(xedges[:-1], yedges[:-1], indexing=\"ij\")\n",
    "xpos = xpos.ravel()\n",
    "ypos = ypos.ravel()\n",
    "zpos = 0\n",
    "\n",
    "# Construct arrays with the dimensions for the 16 bars.\n",
    "dx = dy = 0.5 * np.ones_like(zpos)\n",
    "dz = fhist[0].ravel()\n",
    "\n",
    "# ax.bar3d(xpos, ypos, zpos, 0.5,0.5, dz, zsort='max', color = colors[0])\n",
    "cumhist = np.zeros_like(dz)\n",
    "for j in range(len(xpos)):\n",
    "    x = xpos[j]\n",
    "    y = ypos[j]\n",
    "    cumhist=0\n",
    "    # print(x,y)\n",
    "    for i in [3,0,2,1]:\n",
    "\n",
    "        # fig = plt.figure()\n",
    "        # ax = fig.add_subplot(projection='3d')\n",
    "        dz = fhist[i].ravel()[j]\n",
    "        # dz = fhist[i].ravel()\n",
    "        ax.bar3d(x, y, cumhist, 0.5, 0.1,dz ,zorder=0, color=colors[i], lightsource=LightSource(azdeg=180))\n",
    "        cumhist += dz\n",
    "\n",
    "ax.set_xlabel('Iteration Number')\n",
    "ax.set_ylabel('Confidence')\n",
    "ax.set_title('Histogram of Confidences as ARGOS Iterates')\n",
    "patches = []\n",
    "colors = ['r', 'g']\n",
    "plot_labels = ['incorrect', 'correct']\n",
    "simple_flag70_counts = [flag70_counts[0]+ flag70_counts[3], flag70_counts[1] + flag70_counts[2]]\n",
    "for i in range(2):\n",
    "    patches.append(mpatches.Patch(color=colors[i], label=plot_labels[i] + ' (n=' + str(simple_flag70_counts[i]) + ')'))\n",
    "ax.legend(handles=patches,bbox_to_anchor=(1.2,1))\n",
    "# ax.legend(handles=patches,bbox_to_anchor=(0.7,-0.1))\n",
    "fig.savefig('./threedhist_simple.pdf')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "xpos, ypos"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "len(ypos)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "colors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# Create sample data\n",
    "data1 = np.random.normal(0, 1, 100)\n",
    "data2 = np.random.normal(2, 1, 100)\n",
    "data3 = np.random.normal(4, 1, 100)\n",
    "\n",
    "# Define bins\n",
    "bins = np.linspace(-5, 10, 20)\n",
    "\n",
    "# Create the histogram data\n",
    "hist_data1, _ = np.histogram(data1, bins=bins)\n",
    "hist_data2, _ = np.histogram(data2, bins=bins)\n",
    "hist_data3, _ = np.histogram(data3, bins=bins)\n",
    "\n",
    "# Set up the figure and axes for 3D plotting\n",
    "fig = plt.figure()\n",
    "ax = fig.add_subplot(111, projection='3d')\n",
    "\n",
    "# Create the x coordinates for the bars\n",
    "x = np.arange(len(hist_data1))\n",
    "\n",
    "# Define the width of the bars\n",
    "width = 0.2\n",
    "\n",
    "# Plot the stacked bars\n",
    "ax.bar3d(x, np.zeros(len(hist_data1)), np.zeros(len(hist_data1)), width, hist_data1, 0.5, shade=True, label='Data 1')\n",
    "ax.bar3d(x + width, np.zeros(len(hist_data2)), hist_data1, width, hist_data2, 0.5, shade=True, label='Data 2')\n",
    "ax.bar3d(x + 2 * width, np.zeros(len(hist_data3)), hist_data1 + hist_data2, width, hist_data3, 0.5, shade=True, label='Data 3')\n",
    "\n",
    "# Set labels and title\n",
    "ax.set_xlabel('Bins')\n",
    "ax.set_ylabel('')\n",
    "ax.set_zlabel('Frequency')\n",
    "ax.set_xticks(x + width, bins[:-1], rotation=45, ha='right')\n",
    "ax.set_title('Stacked 3D Histogram')\n",
    "\n",
    "# Add legend\n",
    "# ax.legend()\n",
    "\n",
    "# Show the plot\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "fig = plt.figure()\n",
    "X = np.stack([np.ones(4)*i for i in range(5)])\n",
    "Y = np.stack([np.ones(5)*i for i in np.arange(0, 0.8, 0.2)]).transpose()\n",
    "# X,Y = np.meshgrid(xpos, ypos)\n",
    "ax = plt.axes(projection ='3d')\n",
    "ax.plot_wireframe(X, Y, hist, color ='green')\n",
    "ax.set_title('wireframe geeks for geeks');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "temp_outs_str = 'home/XXXX/XXXX/'\n",
    "temp_outs = pkl.load(open(temp_outs_str, 'rb'))\n",
    "import torch\n",
    "temp_outs_pred = {}\n",
    "outs_acc = 0\n",
    "num_trues = 0\n",
    "true_pos = 0\n",
    "false_pos = 0\n",
    "true_neg = 0\n",
    "false_neg = 0\n",
    "n_false = 0\n",
    "n_true = 0\n",
    "for key, value in temp_outs.items():\n",
    "    # if len(value[5]) <= 3: continue\n",
    "    if value[4] == True:\n",
    "        if len(value[1]['neg']) == 0 and labels[key] == 'false':\n",
    "            temp_outs_pred[key] = True\n",
    "            outs_acc += 1\n",
    "            true_neg += 1\n",
    "        elif len(value[1]['pos']) == 0 and labels[key] == 'true':\n",
    "            temp_outs_pred[key] = True\n",
    "            outs_acc += 1\n",
    "            true_pos += 1\n",
    "        else:\n",
    "            temp_outs_pred[key] = False\n",
    "            if labels[key] == 'true':\n",
    "                false_neg += 1\n",
    "            else:\n",
    "                false_pos += 1\n",
    "    else:\n",
    "        if len(value[1]['neg']) == 0 and labels[key] == 'true':\n",
    "            temp_outs_pred[key] = True\n",
    "            outs_acc += 1\n",
    "            true_neg += 1\n",
    "        elif len(value[1]['pos']) == 0 and labels[key] == 'false':\n",
    "            temp_outs_pred[key] = True\n",
    "            outs_acc += 1\n",
    "            true_pos += 1\n",
    "        else:\n",
    "            temp_outs_pred[key] = False\n",
    "            if labels[key] == 'true':\n",
    "                false_neg += 1\n",
    "            else:\n",
    "                false_pos += 1\n",
    "    if labels[key] == 'true':\n",
    "        num_trues += 1\n",
    "    \n",
    "    if labels[key] == 'false':\n",
    "       n_false += 1\n",
    "    elif labels[key] == 'true':\n",
    "        n_true += 1\n",
    "print(true_pos, true_neg, false_pos, false_neg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(temp_outs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "int(temp_outs_pred[key])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(temp_outs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.stack(outs[key][5]).sum(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for key in outs.keys():\n",
    "#     print(torch.stack(outs[key][5]))\n",
    "#     print(torch.stack(outs[key][5]).sum(1))\n",
    "#     print((torch.stack(outs[key][5]) / torch.stack(outs[key][5]).sum(1).reshape(-1,1)).max(-1))\n",
    "#     print((torch.stack(outs[key][5]) / torch.stack(outs[key][5]).sum(1).reshape(-1,1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hunh = []\n",
    "missed_list = []\n",
    "for miss in missed:\n",
    "    missed_list.append(miss[0])\n",
    "for hunh in hunh_list:\n",
    "    missed_list.append(hunh)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for key, value in outs.items():"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(missed_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "miss_acc = 0\n",
    "miss_sc_acc = 0\n",
    "no_gains = 0\n",
    "for miss in missed_list:\n",
    "    if outs_pred[miss] == True:\n",
    "        miss_acc += 1\n",
    "    # print(outs_pred[miss])\n",
    "    if sc_pred[miss] == True:\n",
    "        miss_sc_acc += 1\n",
    "    if sc_pred[miss] == False and sc_pred[miss] == False:\n",
    "        no_gains += 1\n",
    "    if sc_pred[miss] == True and outs_pred[miss] == True:\n",
    "        no_gains += 1\n",
    "print(miss_acc/len(missed_list))\n",
    "print(miss_sc_acc/len(missed_list))\n",
    "# print(no_gains/len(missed_list))\n",
    "print(miss_sc_acc - miss_acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(missed_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tt = []\n",
    "tf = []\n",
    "ft = []\n",
    "ff = []\n",
    "for miss in missed_list:\n",
    "    if sc_pred[miss] == True and outs_pred[miss] == True:\n",
    "        tt.append(miss)\n",
    "    elif sc_pred[miss] == True and outs_pred[miss] == False:\n",
    "        tf.append(miss)\n",
    "    elif sc_pred[miss] == False and outs_pred[miss] == True:\n",
    "        ft.append(miss)\n",
    "    elif sc_pred[miss] == False and outs_pred[miss] == False:\n",
    "        ff.append(miss)\n",
    "    \n",
    "print(len(tt), len(tf), len(ft), len(ff))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tt = []\n",
    "tf = []\n",
    "ft = []\n",
    "ff = []\n",
    "for name in labels.keys():\n",
    "    # if name in missed_list:continue\n",
    "    \n",
    "    if sc_pred[name] == True and outs_pred[name] == True:\n",
    "        tt.append(name)\n",
    "    elif sc_pred[name] == True and outs_pred[name] == False:\n",
    "        tf.append(name)\n",
    "    elif sc_pred[name] == False and outs_pred[name] == True:\n",
    "        ft.append(name)\n",
    "    elif sc_pred[name] == False and outs_pred[name] == False:\n",
    "        ff.append(name)\n",
    "    \n",
    "print(len(tt), len(tf), len(ft), len(ff))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tt = []\n",
    "tf = []\n",
    "ft = []\n",
    "ff = []\n",
    "for name in labels.keys():\n",
    "    if name in missed_list:continue\n",
    "    \n",
    "    if cot_pred[0][name] == True and outs_pred[name] == True:\n",
    "        tt.append(name)\n",
    "    elif cot_pred[0][name] == True and outs_pred[name] == False:\n",
    "        tf.append(name)\n",
    "    elif cot_pred[0][name] == False and outs_pred[name] == True:\n",
    "        ft.append(name)\n",
    "    elif cot_pred[0][name] == False and outs_pred[name] == False:\n",
    "        ff.append(name)\n",
    "    \n",
    "print(len(tt), len(tf), len(ft), len(ff))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tt = []\n",
    "tf = []\n",
    "ft = []\n",
    "ff = []\n",
    "for name in missed_list:\n",
    "    # if name in missed_list:continue\n",
    "    \n",
    "    if cot_pred[0][name] == True and outs_pred[name] == True:\n",
    "        tt.append(name)\n",
    "    elif cot_pred[0][name] == True and outs_pred[name] == False:\n",
    "        tf.append(name)\n",
    "    elif cot_pred[0][name] == False and outs_pred[name] == True:\n",
    "        ft.append(name)\n",
    "    elif cot_pred[0][name] == False and outs_pred[name] == False:\n",
    "        ff.append(name)\n",
    "    \n",
    "print(len(tt), len(tf), len(ft), len(ff))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(tf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(missed_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scores = pkl.load(open('/home/XXXX/XXXX/fs_backup_feb13/LLM-project/scores_temp1_thresh075_thresh05_dynFalse_fixed.pkl', 'rb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(tf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "i = 1\n",
    "for line in outs['clutrr60.cnf'][0]:\n",
    "    print(line)\n",
    "    if i % 4 == 3:\n",
    "        # try:\n",
    "            # print(line.split('known predicate: ')[1].split('. Known predicates are')[0].replace('___', line.split('\\\\box{ ')[1]))\n",
    "        print(scores[line.split('known predicate: ')[1].split('. Known predicates are')[0].replace('___', line.split('\\\\box{ ')[1])])\n",
    "        # except:\n",
    "        #     print(line.split('known predicate: ')[1].split('. Known predicates are')[0].replace('___', line.split('\\\\box{ ')[1]))\n",
    "            # print(line)\n",
    "        #     break\n",
    "    if i%4 == 0 and not str(line).startswith('calls'):\n",
    "        # continue\n",
    "        i = 2\n",
    "        # print('hihi')\n",
    "\n",
    "    else: i += 1\n",
    "# print(outs['clutrr125.cnf'])\n",
    "# print(outs.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(list(scores.keys())[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tt = []\n",
    "tf = []\n",
    "ft = []\n",
    "ff = []\n",
    "for miss in missed_list:\n",
    "    if cot_pred[0][miss] == True and outs_pred[miss] == True:\n",
    "        tt.append(miss)\n",
    "    elif cot_pred[0][miss] == True and outs_pred[miss] == False:\n",
    "        tf.append(miss)\n",
    "    elif cot_pred[0][miss] == False and outs_pred[miss] == True:\n",
    "        ft.append(miss)\n",
    "    elif cot_pred[0][miss] == False and outs_pred[miss] == False:\n",
    "        ff.append(miss)\n",
    "    \n",
    "print(len(tt), len(tf), len(ft), len(ff))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "score_list = []\n",
    "for score in scores.values():\n",
    "    score_list.append(torch.stack(score))\n",
    "score = torch.stack(score_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    " \n",
    "\n",
    "fig1, ax1 = plt.subplots()\n",
    "ax1.scatter(x=score[:,0], y=score[:,1], s=3)\n",
    "ax1.set_xlabel('1 - Does the following rule seem contradictory?')\n",
    "ax1.set_ylabel('Does the following rule seem contextually relevant?')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "outs_acc*60"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sc_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for key, value in outs_pred.items():\n",
    "    if key not in missed_list and value == False:\n",
    "        print(key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(outs['clutrr366.cnf']):\n",
    "    print(outs['clutrr366.cnf'][i])\n",
    "# print(outs['clutrr366.cnf'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 7\n",
    "for i in range(len(missed[n][1])):\n",
    "    print(missed[n][1][i])\n",
    "    # print('\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "outs_pred = {}\n",
    "outs_acc = 0\n",
    "num_trues = 0\n",
    "for key, value in outs.items():\n",
    "    if len(value[1]['neg']) == 0 and labels[key] == 'false':\n",
    "        outs_pred[key] = True\n",
    "        outs_acc += 1\n",
    "    elif len(value[1]['pos']) == 0 and labels[key] == 'true':\n",
    "        outs_pred[key] = True\n",
    "        outs_acc += 1\n",
    "    else:\n",
    "        outs_pred[key] = False\n",
    "    if labels[key] == 'true':\n",
    "        num_trues += 1\n",
    "outs_acc /= len(outs_pred.keys())\n",
    "# outs['clutrr545.cnf'][1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "outs_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "outs_pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels[key]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "outs.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import shutil\n",
    "\n",
    "def get_bb(file, del_sols=None):\n",
    "    bb = {'pos':  [], 'neg': []}\n",
    "    \n",
    "    files = ['/'.join(file.split('/')[:-1]) + '/pos_' + file.split('/')[-1], '/'.join(file.split('/')[:-1]) + '/neg_' + file.split('/')[-1] ]\n",
    "    for i in range(len(files)):\n",
    "        file = files[i]\n",
    "        shutil.copy(file, '/'.join(file.split('/')[:-2]) + '/tempfiles/' + str(file.split('/')[-1]))\n",
    "        if not del_sols==None:\n",
    "            if 'pos' in file:\n",
    "                if 'neg' in file:\n",
    "                    print('l. 416 uh oh')\n",
    "                      \n",
    "                ds = del_sols['pos']\n",
    "            elif 'neg' in file:\n",
    "                ds = del_sols['neg']\n",
    "            for sol in ds:\n",
    "                add_clause('/'.join(file.split('/')[:-2]) + '/tempfiles/' + str(file.split('/')[-1]))\n",
    "                cf = open(f'/'.join(file.split('/')[:-2]) + '/tempfiles/' + str(file.split('/')[-1]), 'a')\n",
    "                write_str = '\\n'\n",
    "                for lit in sol:\n",
    "                    write_str += str(-lit) + ' '\n",
    "                # write_str += '0'\n",
    "                cf.write(write_str)\n",
    "                cf.close()\n",
    "        # print('running cadical')\n",
    "        os.system(\"timeout 5000 /home/XXXX/XXXX/fs_backup_feb13/LLM-project/cadiback/cadiback \" + '/'.join(file.split('/')[:-2]) + '/tempfiles/' + str(file.split('/')[-1]) + '> '  + '/'.join(file.split('/')[:-2]) + '/tempfiles/' + str(file.split('/')[-1])[:-4] + \".bbone\")\n",
    "        #   \n",
    "        bbone= open('/'.join(file.split('/')[:-2]) + '/tempfiles/' + str(file.split('/')[-1])[:-4] + \".bbone\", 'r')\n",
    "        lines = bbone.readlines()\n",
    "        #   \n",
    "        for line in lines:\n",
    "            if line.startswith('b'):\n",
    "                #   \n",
    "                lits = line.split(' ')[1:]\n",
    "                for lit in lits:\n",
    "                    lit = lit.strip()\n",
    "                    if lit == '0':\n",
    "                        continue\n",
    "                    lit = int(lit)\n",
    "                    if 'pos' in file:                                \n",
    "                        if 'neg' in file:\n",
    "                            print('l. 447 uh oh')\n",
    "                              \n",
    "                        bb['pos'].append(lit)\n",
    "                    elif 'neg' in file:\n",
    "                            bb['neg'].append(lit)\n",
    "\n",
    "    return bb\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import shutil\n",
    "\n",
    "def get_bb(file, del_sols=None):\n",
    "    bb = {'pos':  [], 'neg': []}\n",
    "    \n",
    "    files = ['/'.join(file.split('/')[:-1]) + '/pos_' + file.split('/')[-1], '/'.join(file.split('/')[:-1]) + '/neg_' + file.split('/')[-1] ]\n",
    "    for i in range(len(files)):\n",
    "        file = files[i]\n",
    "        shutil.copy(file, '/'.join(file.split('/')[:-2]) + '/tempfiles/' + str(file.split('/')[-1]))\n",
    "        if not del_sols==None:\n",
    "            if 'pos' in file:\n",
    "                if 'neg' in file:\n",
    "                    print('l. 416 uh oh')\n",
    "                      \n",
    "                ds = del_sols['pos']\n",
    "            elif 'neg' in file:\n",
    "                ds = del_sols['neg']\n",
    "            for sol in ds:\n",
    "                add_clause('/'.join(file.split('/')[:-2]) + '/tempfiles/' + str(file.split('/')[-1]))\n",
    "                cf = open(f'/'.join(file.split('/')[:-2]) + '/tempfiles/' + str(file.split('/')[-1]), 'a')\n",
    "                write_str = '\\n'\n",
    "                for lit in sol:\n",
    "                    write_str += str(-lit) + ' '\n",
    "                # write_str += '0'\n",
    "                cf.write(write_str)\n",
    "                cf.close()\n",
    "        # print('running cadical')\n",
    "        os.system(\"timeout 5000 /home/XXXX/XXXX/fs_backup_feb13/LLM-project/cadiback/cadiback \" + '/'.join(file.split('/')[:-2]) + '/tempfiles/' + str(file.split('/')[-1]) + '> '  + '/'.join(file.split('/')[:-2]) + '/tempfiles/' + str(file.split('/')[-1])[:-4] + \".bbone\")\n",
    "        #   \n",
    "        bbone= open('/'.join(file.split('/')[:-2]) + '/tempfiles/' + str(file.split('/')[-1])[:-4] + \".bbone\", 'r')\n",
    "        lines = bbone.readlines()\n",
    "        #   \n",
    "        for line in lines:\n",
    "            if line.startswith('b'):\n",
    "                #   \n",
    "                lits = line.split(' ')[1:]\n",
    "                for lit in lits:\n",
    "                    lit = lit.strip()\n",
    "                    if lit == '0':\n",
    "                        continue\n",
    "                    lit = int(lit)\n",
    "                    if 'pos' in file:                                \n",
    "                        if 'neg' in file:\n",
    "                            print('l. 447 uh oh')\n",
    "                              \n",
    "                        bb['pos'].append(lit)\n",
    "                    elif 'neg' in file:\n",
    "                            bb['neg'].append(lit)\n",
    "\n",
    "    return bb\n",
    "\n",
    "c = '/home/XXXX/XXXX/fs_backup_feb13/LLM-project/dimacs_clutrr_csvs_debug/solver_finished.csv'\n",
    "import csv\n",
    "import json\n",
    "dataset = '/home/XXXX/XXXX/fs_backup_feb13/SAT-LM/data/clutrr_test.json'\n",
    "with open(dataset, 'r') as df:\n",
    "    data = json.loads(df.read())\n",
    "\n",
    "task = 'folio'\n",
    "missed=False\n",
    "c = open(c, 'r')\n",
    "cr = csv.reader(c)\n",
    "names = []\n",
    "all_outs = {}\n",
    "missed_list = []\n",
    "labels = {}\n",
    "for row in cr:\n",
    "    if row[2] == 'SAT' and row[3] == 'SAT':\n",
    "        cnf = open('/home/XXXX/XXXX/fs_backup_feb13/LLM-project/dimacs_clutrr/neg_'+row[1]).readlines()[0].strip('\\n')\n",
    "        num_clause = int(cnf.split(' ')[-1])\n",
    "       \n",
    "        if task=='folio':\n",
    "            bb = get_bb('/home/XXXX/XXXX/fs_backup_feb13/LLM-project/dimacs_clutrr/'+row[1])\n",
    "            jb = set(bb['pos']).intersection(set(bb['neg']))\n",
    "            if len(jb) == 0:\n",
    "                continue\n",
    "        # if num_clause > 500:\n",
    "            # continue\n",
    "        names.append(int(row[1].split('clutrr')[1].split('.cnf')[0]))\n",
    "        labels[row[1]] = data[int(row[1].split('clutrr')[1].split('.')[0])]['label']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bad_data = []\n",
    "mistr_data = []\n",
    "noisy_data=[]\n",
    "c = '/home/XXXX/LLM-project/dimacs_clutrr_csvs_debug/solver_finished.csv'\n",
    "import csv\n",
    "import json\n",
    "dataset = '/home/XXXX/SAT-LM/data/clutrr_test.json'\n",
    "with open(dataset, 'r') as df:\n",
    "    data = json.loads(df.read())\n",
    "# breakpoint()\n",
    "task = 'folio'\n",
    "missed=False\n",
    "c = open(c, 'r')\n",
    "cr = csv.reader(c)\n",
    "names = []\n",
    "all_outs = {}\n",
    "missed_list = []\n",
    "labels = {}\n",
    "for row in cr:\n",
    "        if row[2] == 'SAT' and row[3] == 'SAT':\n",
    "            cnf = open('/home/XXXX/LLM-project/dimacs_clutrr/neg_'+row[1]).readlines()[0].strip('\\n')\n",
    "            num_clause = int(cnf.split(' ')[-1])\n",
    "            if row[1] in noisy_data or row[1] in mistr_data:\n",
    "                continue\n",
    "            if task=='folio':\n",
    "                bb = get_bb('/home/XXXX/LLM-project/dimacs_clutrr/'+row[1])\n",
    "                jb = set(bb['pos']).intersection(set(bb['neg']))\n",
    "                if len(jb) == 0:\n",
    "                    continue\n",
    "            # if num_clause > 500:\n",
    "                # continue\n",
    "            if row[1] in bad_data:\n",
    "                continue\n",
    "            names.append(row[1])\n",
    "            labels[row[1]] = data[int(row[1].split('clutrr')[1].split('.')[0])]['label']\n",
    "    #   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "folio = json.load(open('/home/XXXX/SAT-LM/data/lutrr_test.json', 'r'))\n",
    "folio[48]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "i = 0\n",
    "cot_acc = 0\n",
    "cot_preds = {}\n",
    "for key, value in labels.items():\n",
    "    if cot[i] == value:\n",
    "        cot_acc += 1\n",
    "        cot_preds[key] = True\n",
    "    else:\n",
    "        cot_preds[key] = False\n",
    "    i += 1\n",
    "print(cot_acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "flipped = 0\n",
    "flipped_names = []\n",
    "tf = []\n",
    "ft = []\n",
    "for name in names:\n",
    "    if cot_preds['proofd5' + str(name) + '.cnf'] != outs_pred['proofd5' + str(name) + '.cnf']:\n",
    "        flipped_names.append('proofd5' + str(name) + '.cnf')\n",
    "        flipped += 1\n",
    "    if cot_preds['proofd5' + str(name) + '.cnf'] == True and outs_pred['proofd5' + str(name) + '.cnf'] == False:\n",
    "        tf.append('proofd5' + str(name) + '.cnf')\n",
    "    if cot_preds['proofd5' + str(name) + '.cnf'] == False and outs_pred['proofd5' + str(name) + '.cnf'] == True:\n",
    "        ft.append('proofd5' + str(name) + '.cnf')\n",
    "\n",
    "print(flipped)\n",
    "print(len(tf))\n",
    "print(len(ft))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "flipped = 0\n",
    "flipped_names = []\n",
    "tf = []\n",
    "ft = []\n",
    "for name in missed_list:\n",
    "    name = name[7:-4]\n",
    "    if cot_preds['clutrr' + str(name) + '.cnf'] != outs_pred['clutrr' + str(name) + '.cnf']:\n",
    "        flipped_names.append('clutrr' + str(name) + '.cnf')\n",
    "        flipped += 1\n",
    "    if cot_preds['clutrr' + str(name) + '.cnf'] == True and outs_pred['proofd5' + str(name) + '.cnf'] == False:\n",
    "        tf.append('proofd5' + str(name) + '.cnf')\n",
    "    if cot_preds['proofd5' + str(name) + '.cnf'] == False and outs_pred['proofd5' + str(name) + '.cnf'] == True:\n",
    "        ft.append('proofd5' + str(name) + '.cnf')\n",
    "\n",
    "print(flipped)\n",
    "print(len(tf))\n",
    "print(len(ft))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "missed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(tf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "outs['proofd542.cnf']\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ours = pkl.load(open('/home/XXXX/LLM-project/all_outs_temp1_dynFalse.pkl', 'rb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "list(ours.keys())[5]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels[list(ours.keys())[5]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "outs_str = '/home/XXXX/XXXX/fs_backup_feb13/all_outs_cot_met_clutrr_rulethresh_05_cot_thresh_ANNEALING,_dynamic_False,_sc5_llama_70B,_no_jb_prompt,_fixed_prmopt,_yes_rules_in_prompt,_yes_solver,_shuffled,_old_fewshut,_temp_1,_n_consec_NO-CONSEC,_sc-poo,_COPY_THAT_SAVES_SC_HISTORY.pkl'\n",
    "outs_70b = pkl.load(open(outs_str, 'rb'))\n",
    "sc = []\n",
    "exit_n = []\n",
    "for key, value in outs_70b.items():\n",
    "    exit_n.append(len(value[-1]))\n",
    "    # print(a[-1])\n",
    " \n",
    "\n",
    "plt.hist(exit_n, bins=list(range(10)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "outs_str = '/home/XXXX/XXXX/fs_backup_feb13/all_outs_cot_met_clutrr_rulethresh_05_cot_thresh_ANNEALING,_dynamic_False,_sc5_llama_8B,_no_jb_prompt,_fixed_prmopt,_yes_rules_in_prompt,_yes_solver,_shuffled,_old_fewshut,_temp_1,_n_consec_NO-CONSEC,_sc-poo,_COPY_THAT_SAVES_SC_HISTORY.pkl'\n",
    "outs_8b = pkl.load(open(outs_str, 'rb'))\n",
    "sc = []\n",
    "exit_n = []\n",
    "for key, value in outs_8b.items():\n",
    "    exit_n.append(len(value[-1]))\n",
    "    # print(a[-1])\n",
    " \n",
    "\n",
    "plt.hist(exit_n, bins=list(range(10)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "outs_acc_n = {}\n",
    "outs_totals_n = {}\n",
    "for i in range(10):\n",
    "    outs_acc_n[i] = 0\n",
    "    outs_totals_n[i] = 0\n",
    "outs_acc = 0\n",
    "num_trues = 0\n",
    "# outs_totals_n = {}\n",
    "for key, value in outs_70b.items():\n",
    "    outs_totals_n[len(value[-1])] += 1\n",
    "    if len(value[1]['neg']) == 0 and labels[key].strip(' ') == 'false':\n",
    "        # outs_pred[key] = True\n",
    "        outs_acc_n[len(value[-1])] += 1\n",
    "    elif len(value[1]['pos']) == 0 and labels[key].strip(' ') == 'true':\n",
    "        # outs_pred[key] = True\n",
    "        outs_acc_n[len(value[-1])] += 1\n",
    "    # else:\n",
    "        # # outs_pred[key] = False\n",
    "    if labels[key] == 'true':\n",
    "        num_trues += 1\n",
    "# for i in range(10):\n",
    "#     if outs_totals_n[i] == 0: continue\n",
    "#     outs_acc_n[i] /= outs_totals_n[i]\n",
    "# outs_acc_n /= len(outs_pred.keys())\n",
    "# outs['clutrr545.cnf'][1]\n",
    "# print(outs_acc)\n",
    "# print(outs_acc*len(outs_pred.keys()))\n",
    "\n",
    "plt.bar(x = range(10),height=list(outs_acc_n.values()))\n",
    "plt.title('70B accuracy for iteration length')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "outs_acc_n = {}\n",
    "outs_totals_n = {}\n",
    "for i in range(10):\n",
    "    outs_acc_n[i] = 0\n",
    "    outs_totals_n[i] = 0\n",
    "outs_acc = 0\n",
    "num_trues = 0\n",
    "# outs_totals_n = {}\n",
    "for key, value in outs_8b.items():\n",
    "    outs_totals_n[len(value[-1])] += 1\n",
    "    if len(value[1]['neg']) == 0 and labels[key].strip(' ') == 'false':\n",
    "        # outs_pred[key] = True\n",
    "        outs_acc_n[len(value[-1])] += 1\n",
    "    elif len(value[1]['pos']) == 0 and labels[key].strip(' ') == 'true':\n",
    "        # outs_pred[key] = True\n",
    "        outs_acc_n[len(value[-1])] += 1\n",
    "    # else:\n",
    "        # # outs_pred[key] = False\n",
    "    if labels[key] == 'true':\n",
    "        num_trues += 1\n",
    "# for i in range(10):\n",
    "#     if outs_totals_n[i] == 0: continue\n",
    "#     outs_acc_n[i] /= outs_totals_n[i]\n",
    "# outs_acc_n /= len(outs_pred.keys())\n",
    "# outs['clutrr545.cnf'][1]\n",
    "# print(outs_acc)\n",
    "# print(outs_acc*len(outs_pred.keys()))\n",
    "\n",
    "plt.bar(x = range(10),height=list(outs_acc_n.values()))\n",
    "plt.title('8B accuracy for iteration length')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "acc = 0\n",
    "for n, a in outs_acc_n.items():\n",
    "    acc += outs_totals_n[n]*a\n",
    "print(acc/len(outs_8b))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "i = 0\n",
    "sc_acc_n = {}\n",
    "sc70 = np.where(np.array(n_votes) >= np.ceil(len(cot_pred_list)/2+0.5), 1, 0)\n",
    "for j in range(10):\n",
    "    sc_acc_n[j] = 0\n",
    "for key, value in outs_70b.items():\n",
    "    if sc70[i] == 1:\n",
    "        sc_acc_n[len(value[-1])] += 1\n",
    "    i += 1\n",
    "# for j in range(10):\n",
    "#     if outs_totals_n[j] == 0: continue\n",
    "#     sc_acc_n[j] /= outs_totals_n[j]\n",
    "ax1, fig1 = plt.subplots()\n",
    "fig1.bar(x=range(10), height=sc_acc_n.values(), alpha=0.3,label='SC accuracy')\n",
    "fig1.bar(x=range(10), height=outs_acc_n.values(), alpha=0.3,label='Our Accuracy')\n",
    "\n",
    "\n",
    "fig1.legend()\n",
    "# fig1.set_title('SC-20 Llama 70B accuracy on iteration-length divisions, vs SC-5 Llama 8B our-method')\n",
    "fig1.set_ylabel('Accuracy over the subset')\n",
    "fig1.set_xlabel('Number of iterations of our method')\n",
    "fig1.set_yticks(list(np.arange(0,1,0.1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hybrid = outs_acc_n[1] * outs_totals_n[1] + outs_acc_n[2]* outs_totals_n[2] + outs_acc_n[3]*outs_totals_n[3] + \\\n",
    "    sc_acc_n[4]*outs_totals_n[4] + sc_acc_n[5]*outs_totals_n[5] + sc_acc_n[6]*outs_totals_n[6]\n",
    "print(hybrid/len(outs))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "few_shot = \"Facts:\\n[Nancy] likes to cut the hair of her daughter [Heidi].\\n[Heidi]'s sister [Lorraine] went to beauty school and taught them all how to cut hair expertly. \" + \\\n",
    "            \"\\nHere are some additional facts and rules we\\'ve found:\\nNancy is the mother of Lorraine\\n If [Heidi] is the sister of [Lorraine] and [Heidi] is the daughter of [Nancy] then [Nancy] is the mother of [Lorraine].\\n\" + \\\n",
    "            \"Question: Is the following statement true: \\n\\\"[Lorraine] is [Nancy]\\'s daughter\\\"\\n\" + \\\n",
    "            \"Answer:\\nLet\\'s think step by step.  \\n1. [Heidi] is the sister of [Lorraine]\\n2. [Heidi] is the daughter of [Nancy]\\n3. If [Heidi] is the sister of [Lorraine] and [Heidi] is the daughter of [Nancy] then [Nancy] is the mother of [Lorraine].\\n4. If [Nancy] is the mother of [Lorraine], then [Lorraine] is the daughter of [Nancy].\\nTherefore, the answer to the question is Yes, the statement is true. \\n***\\n\" + \\\n",
    "            \"Facts:\\n[Dale] and his sister [Nancy] are decorating for a party.\\n[Nancy]'s daughter [Louise] thinks the party will be fun.\\n\" + \\\n",
    "            \"Here are some additional facts and rules we\\'ve found:\\nDale is the uncle of Louise. If [Nancy] is the sister of [Dale] and [Nancy] is the mother of [Louise] then [Dale] is the uncle of [Louise].\\n\" + \\\n",
    "            \"Question: Is the following statement true: \\n\\\"[Louise] is not [Dales]\\'s niece\\\"\\n\" + \\\n",
    "            \"Answer: Le\\'s think step by step. \\n1. [Nancy] is the sister of [Dale]. \\n2. [Nancy] is the mother of [Louise]\\n3.  If [Nancy] is the sister of [Dale] and [Nancy] is the mother of [Louise] then [Dale] is the uncle of [Louise].\\n4.If [Dale] is the uncle of [Louise], then [Louise] is the niece of [Dale].\\nTherefore, the answer is No, the statement is not true.\\n***\\n\" + \\\n",
    "            \"Facts: \\n[Lillian] and her sister [Nancy] are the only children in their family. \\n[Lillian]'s biggest accomplishment is raising her son [Douglas]. \" + \\\n",
    "            \"\\nHere are some additional facts and rules we\\'ve found:\\n[Lillian] is the sister of [Nancy]. \\nIf [Nancy] is the sister if [Lillian] then [Lillian] is the sister of [Nancy].\\n\" + \\\n",
    "            \"Question: Is the following statement true: \\n\\\"[Douglas] is [Nancy]\\'s nephew\\\"\\n\" + \\\n",
    "            \"Answer:\\nLet\\'s think step by step. \\n1. [Douglas] is [Lillian]\\'s son. \\n2. [Nancy] is [Lillian]\\'s sister. \" + \\\n",
    "            \"\\n3. If [Douglas] is the son of [Lillian] and [Lillian] is the sister of [Nancy] then [Douglas] is the nephew of [Lillian]. \\nTherefore, the answer to the question is Yes, the statement is true. \\n***\\n\" + \\\n",
    "            \"Facts: \\n[Ashley] liked to go to the park with her granddaughter [Charlotte]. \\n[Dale], [Charlotte]'s father, like to take her to the movies instead. \" + \\\n",
    "            \"\\nHere are some additional facts and rules we\\'ve found:\\n[Dale] is the son of [Ashley]. If [Dale] is father of [Charlotte] and [Ashley] is the grandmother of [Charlotte] then [Dale] is the son of [Ashley].\\n\" + \\\n",
    "            \"Question: Is the following statement true: \\n\\\"[Ashley] is not [Dale]\\'s mother\\\"\\n\" + \\\n",
    "            \"Answer:\\nLet\\'s think step by step. \\n1. [Dale] is the father of [Charlotte].\\n2. [Ashley] is the grandmother of [Charlotte]. \\n3. If [Dale] is father of [Charlotte] and [Ashley] is the grandmother of [Charlotte] then [Dale] is the son of [Ashley].\\n4. If [Dale] is the son of [Ashley], then [Ashley] is the mother of [Dale]. \" + \\\n",
    "            \"\\nTherefore, the answer to the question is No, the statement is ot true.\\n***\\n\"\n",
    "\n",
    "print(few_shot)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch, os\n",
    "from torch.utils.data import DataLoader\n",
    "os.environ[\"CURL_CA_BUNDLE\"]=\"\"\n",
    "os.environ[\"REQUESTS_CA_BUNDLE\"]=\"\"\n",
    "run_log_path = './run_log.txt'\n",
    "run_log = open(run_log_path, 'w')\n",
    "run_log.write('hello\\n')\n",
    "run_log.close()\n",
    "unknown=False\n",
    "import json\n",
    "import numpy as np\n",
    "import csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"CURL_CA_BUNDLE\"]=\"\"\n",
    "os.environ[\"REQUESTS_CA_BUNDLE\"]=\"\"\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '1'\n",
    "USER_PATH = '/home/XXXX/XXXX/fs_backup_feb13/'\n",
    "# os.environ['TRANSFORMERS_CACHE'] = '.cache/huggingface/hub'\n",
    "cache_dir = '/ephemeral/media/data1/XXXX/hub/'\n",
    "os.environ['TRANSFORMERS_CACHE'] = cache_dir\n",
    "os.environ['HF_HOME'] = cache_dir\n",
    "# import transformers\n",
    "\n",
    "# import urllib3\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig\n",
    "import argparse\n",
    "from tqdm import tqdm\n",
    "import time\n",
    "import datetime"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "import contextlib\n",
    "\n",
    "import requests\n",
    "from urllib3.exceptions import InsecureRequestWarning\n",
    "\n",
    "old_merge_environment_settings = requests.Session.merge_environment_settings\n",
    "\n",
    "@contextlib.contextmanager\n",
    "def no_ssl_verification():\n",
    "    opened_adapters = set()\n",
    "\n",
    "    def merge_environment_settings(self, url, proxies, stream, verify, cert):\n",
    "        # Verification happens only once per connection so we need to close\n",
    "        # all the opened adapters once we're done. Otherwise, the effects of\n",
    "        # verify=False persist beyond the end of this context manager.\n",
    "        opened_adapters.add(self.get_adapter(url))\n",
    "\n",
    "        settings = old_merge_environment_settings(self, url, proxies, stream, verify, cert)\n",
    "        settings['verify'] = False\n",
    "\n",
    "        return settings\n",
    "\n",
    "    requests.Session.merge_environment_settings = merge_environment_settings\n",
    "\n",
    "    try:\n",
    "        with warnings.catch_warnings():\n",
    "            warnings.simplefilter('ignore', InsecureRequestWarning)\n",
    "            yield\n",
    "    finally:\n",
    "        requests.Session.merge_environment_settings = old_merge_environment_settings\n",
    "\n",
    "        for adapter in opened_adapters:\n",
    "            try:\n",
    "                adapter.close()\n",
    "            except:\n",
    "                pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Struct:\n",
    "    def __init__(self, **entries):\n",
    "        self.__dict__.update(entries)\n",
    "\n",
    "args = {'train_file_path': './example_data', 'test_file_path': './example_data', 'save_path': './../SFT_train_res', 'model_choice': 'meta-llama/Llama-2-13b-chat-hf', \n",
    "        'n_rows': 20, 'max_length': 1000, 'lr': 5e-05, 'weight_decay': 0.0, 'epochs': 10, 'max_grad_norm': 1.0, 'batch_size': 2, 'save_strategy': 'no', 'use_lora': True}\n",
    "# args['model_choice'] = 'meta-llama/Meta-Llama-3-70B-Instruct'\n",
    "args['model_choice'] = 'meta-llama/Meta-Llama-3-70B-Instruct'\n",
    "\n",
    "args = Struct(**args)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class LLM():\n",
    "    def __init__(self, args):\n",
    "        quant_config = BitsAndBytesConfig(\n",
    "            load_in_4bit=True,\n",
    "            bnb_4bit_quant_type=\"nf4\",\n",
    "            bnb_4bit_compute_dtype=\"bfloat16\",\n",
    "            bnb_4bit_use_double_quant=True,\n",
    "        )\n",
    "        with no_ssl_verification():\n",
    "            \n",
    "\n",
    "            \n",
    "            self.tokenizer = AutoTokenizer.from_pretrained(\n",
    "                    args.model_choice,\n",
    "                    cache_dir = cache_dir,\n",
    "                    token = 'hf_xxxx',\n",
    "                    attn_implementation=\"flash_attention_2\"\n",
    "\n",
    "                    )\n",
    "            self.tokenizer.pad_token = self.tokenizer.eos_token\n",
    "            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id            \n",
    "            self.model = AutoModelForCausalLM.from_pretrained(\n",
    "                    args.model_choice, \n",
    "                    cache_dir = cache_dir,\n",
    "                    quantization_config=quant_config,\n",
    "                    device_map='auto',\n",
    "                    token = 'hf_xxxx',\n",
    "                    attn_implementation=\"flash_attention_2\"\n",
    "                    )\n",
    "\n",
    "        self.tokenizer.pad_token = self.tokenizer.eos_token\n",
    "    \n",
    "    def sentence_probabilities(self, sentences):\n",
    "        with torch.no_grad():\n",
    "            sentence_tokens = self.tokenizer(sentences, return_tensors='pt', padding=True)\n",
    "            sentence_token_ids = sentence_tokens.input_ids.cuda()\n",
    "\n",
    "            # Little hack to cut down inference time by 4-5x (leads to some imprecisions when using quantization)\n",
    "            # Find the common prefix and run it through the model once, to save time\n",
    "            first_different_token = (sentence_token_ids == sentence_token_ids[0, :].unsqueeze(0)).all(dim=0).long().argmin()\n",
    "            common_prefix = sentence_token_ids[0, :first_different_token].unsqueeze(0)\n",
    "            common_prefix_output = self.model(common_prefix, use_cache=True)\n",
    "            common_prefix_key_values = tuple(tuple(tensor.expand(len(sentences), -1, -1, -1) for tensor in layer) \n",
    "                                             for layer in common_prefix_output.past_key_values)\n",
    "\n",
    "            # Process the rest of the sentences\n",
    "            rest_outputs = self.model(sentence_token_ids[:, first_different_token:], past_key_values=common_prefix_key_values)\n",
    "            logits = torch.concat([common_prefix_output.logits.expand(len(sentences), -1, -1), rest_outputs.logits], dim=1).cuda()\n",
    "            log_probs = logits.log_softmax(-1)\n",
    "            log_probs = log_probs[:, :-1, :].gather(2, sentence_token_ids[:, 1:][:, :, None]).squeeze(-1).cuda()\n",
    "            log_probs = (log_probs*sentence_tokens.attention_mask.cuda()[:, 1:]).sum(-1).cpu()\n",
    "        return log_probs\n",
    "    def nli(self, sentences, unknown):\n",
    "        # true_probs = self.sentence_probabilities(sentences + \" True.\")\n",
    "        # false_probs = self.sentence_probabilities(sentences + \" False.\")\n",
    "        # maybe_probs = self.sentence_probabilities(sentences + \" Maybe.\")\n",
    "        if unknown:\n",
    "            true_probs, maybe_probs, false_probs =  (self.sentence_probabilities([sentences + \"(A)\", sentences + \"(B)\", sentences + \"(C)\"]))\n",
    "            return {'True': true_probs, 'Maybe': maybe_probs, 'False': false_probs}\n",
    "        else:\n",
    "            true_probs, false_probs =  (self.sentence_probabilities([sentences + \"(A)\", sentences + \"(B)\"]))\n",
    "            return {'True': true_probs, 'False': false_probs}\n",
    "    def yn(self, sentences, norm=True, relaxed=False, obvious=False, fewshot=None, maybe=False):\n",
    "        yns = []\n",
    "        for sentence in sentences:\n",
    "            if fewshot:\n",
    "                sentence = fewshot + sentence\n",
    "            \n",
    "            if relaxed:\n",
    "                yns.append(sentence + \"Most likely\")\n",
    "                yns.append(sentence + \"Not necessarily\")\n",
    "            elif obvious:\n",
    "                yns.append(sentence + \"obviously true.\")\n",
    "                yns.append(sentence + \"not obviously true.\")\n",
    "            elif maybe:\n",
    "                yns.append(sentence + \"Yes\")\n",
    "                yns.append(sentence + \"Maybe\")\n",
    "                yns.append(sentence + \"No\")\n",
    "            else:\n",
    "                yns.append(sentence + \"Yes\")\n",
    "                yns.append(sentence + \"No\")\n",
    "        # if norm:\n",
    "        #     norms = self.sentence_probabilities(sentences)\n",
    "        probs = []\n",
    "        batch_size = 256\n",
    "        for i in range(0, len(yns), batch_size):\n",
    "            if i+batch_size < len(yns):\n",
    "                probs += list(self.sentence_probabilities(yns[i:i+batch_size]))\n",
    "            else: \n",
    "                probs += list(self.sentence_probabilities(yns[i:]))\n",
    "        probs=torch.tensor(probs)\n",
    "        #   \n",
    "        # probs = (self.sentence_probabilities(yns))\n",
    "        # probs = torch.exp(probs)\n",
    "        pyes = []\n",
    "        pno = []\n",
    "        pmaybe = []\n",
    "        if maybe:\n",
    "            z = 3\n",
    "        else:\n",
    "            z = 2\n",
    "        for i in range(0,len(probs), z):\n",
    "            # if yns[i] not in cache.keys():\n",
    "                # yes, no = self.sentence_probabilities([yns[i], yns[i+1]])\n",
    "            \n",
    "            if maybe:\n",
    "                \n",
    "                yes, maybe, no = probs[i], probs[i+1], probs[i+2]\n",
    "                \n",
    "                      \n",
    "            else:\n",
    "                yes, no = probs[i], probs[i+1]\n",
    "            if norm:\n",
    "                if maybe: \n",
    "                    y,m,n = torch.tensor([yes, maybe, no]).softmax(-1)\n",
    "                else:\n",
    "                    y,n = torch.tensor([yes, no]).softmax(-1)\n",
    "              \n",
    "                # cache[yns[i]] = y\n",
    "                # cache[yns[i+1]] = n\n",
    "                pyes.append(y)\n",
    "                pno.append(n)\n",
    "                if maybe:\n",
    "                    pmaybe.append(m)\n",
    "            else:\n",
    "                pyes.append(1-yes/(yes + no))\n",
    "            # else:\n",
    "            #     y, n = cache[yns[i]], cache[yns[i+1]]\n",
    "            #     pyes.append(y)\n",
    "                # pno.append(n)/\n",
    "        # print('cache length', len(cache))\n",
    "        # if maybe:\n",
    "        \n",
    "        if maybe: return torch.stack([torch.tensor(pyes), torch.tensor(pmaybe), torch.tensor(pno)])\n",
    "        return torch.tensor(pyes), torch.tensor(pmaybe), torch.tensor(pno)\n",
    "    def complete(self, prompt, max_new = 25, temp = 1 , topk=0):\n",
    "        max_length = args.max_length\n",
    "        encode_ids = self.tokenizer(\n",
    "        prompt, \n",
    "        return_tensors='pt',\n",
    "        padding=True,\n",
    "        truncation=True,\n",
    "        max_length=len(prompt)+1\n",
    "    ).input_ids.cuda()\n",
    "        generated_outputs = self.model.generate(\n",
    "        encode_ids, \n",
    "        max_new_tokens=max_new, \n",
    "        return_dict_in_generate=True, \n",
    "        output_scores=True,\n",
    "        temperature=temp,\n",
    "        top_k=topk\n",
    "        )\n",
    "        responses = self.tokenizer.batch_decode(\n",
    "            generated_outputs.sequences,\n",
    "            skip_special_tokens=True\n",
    "        )\n",
    "        return responses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "llm = LLM(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle as pkl\n",
    "all_outs = pkl.load(open('/home/XXXX/XXXX/fs_backup_feb13/all_outs_cot_met_clutrr_rulethresh_00_cot_thresh_ANNEALING_05,_sc5,_dynamic_False,_sc5_llama_8B,_no_jb_prompt,_fixed_prmopt,_yes_rules_in_prompt,_yes_solver,_shuffled,_old_fewshut,_temp_1,_n_consec_NO-CONSEC,_sc-poo,_COPY_THAT_SAVES_ALL_COT_PROMPTS.pkl', 'rb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 5\n",
    "preds_70 = {}\n",
    "preds_8 = {}\n",
    "answers = ['true', 'false']\n",
    "acc_8 = 0\n",
    "acc_70 = 0\n",
    "n_fewshot = 4\n",
    "for key, value in all_outs.items():\n",
    "    votes = torch.tensor([0,0])\n",
    "    prompt = value[-1][-1]\n",
    "    # breakpoint()\n",
    "    for i in range(n):\n",
    "\n",
    "        completion = llm.complete(prompt, max_new=600, temp=1)[0]\n",
    "        # ans = 'Here are some facts and rules:'.join(ans.split('Here are some facts and rules:')[:5])\n",
    "        # if len(ans.split('Facts')) > 7:\n",
    "        #     ans = 'Facts'.join(ans.split('Facts')[:5])\n",
    "        # ans_prompt = ans + \"Therefore, the final answer (Yes/No) is: \"\n",
    "        # yn = llm.yn([ans_prompt]).values\n",
    "        # nli = torch.tensor(yn[0], yn[2])\n",
    "        ans = completion.split('***')[n_fewshot]\n",
    "\n",
    "        try: lines = ans.split('\\n')\n",
    "        except: breakpoint()\n",
    "        i = -1\n",
    "        notherefore=False\n",
    "        try:\n",
    "            while 'Therefore' not in lines[i]:\n",
    "                i -= 1\n",
    "                if -1*i == len(lines):\n",
    "                    notherefore=True\n",
    "                    break\n",
    "        except:\n",
    "            breakpoint()\n",
    "\n",
    "        # if 'Therefore' in lines[i]:\n",
    "        # notherefore = True\n",
    "        if not notherefore:\n",
    "            if 'Yes' in lines[i]:\n",
    "                nli = [1,0]\n",
    "            elif 'No' in lines[i]:\n",
    "                nli = [0,1]\n",
    "            else:\n",
    "                yn = llm.yn([ans + '\\n So, is the statement true? Answer: '], maybe=True)\n",
    "                nli = torch.tensor([yn[0] + yn[1]/2, yn[2] + yn[1]/2])\n",
    "                print('had to yn', yn)\n",
    "        else:\n",
    "            yn = llm.yn([ans + '\\n So, is the statement true? Answer: '], maybe=True)\n",
    "            nli = torch.tensor([yn[0] + yn[1]/2, yn[2] + yn[1]/2])\n",
    "            print('had to yn', yn)\n",
    "\n",
    "\n",
    "        votes[torch.tensor(nli).argmax()] += 1\n",
    "    print(completion)\n",
    "    print('====================================')\n",
    "    print(ans)\n",
    "    print(votes)\n",
    "    \n",
    "    sc_ans = answers[votes.argmax()]\n",
    "    label = labels[key].strip(' ')\n",
    "    if sc_ans == label: \n",
    "        acc_70 += 1\n",
    "    preds_70[key] = sc_ans\n",
    "    cot_flag=True\n",
    "    solout = value[1]\n",
    "    if len(solout['pos'])==0 and len(solout['neg']) > 0:\n",
    "            if cot_flag == True:\n",
    "                preds_8[key] = 'true'\n",
    "            else:\n",
    "                preds_8[key] = 'false'\n",
    "        \n",
    "            # if preds[name] != labels[name]:\n",
    "            #       \n",
    "    elif len(solout['pos'])>0 and len(solout['neg']) == 0:\n",
    "            if cot_flag == True:\n",
    "                preds_8[key] = 'false'\n",
    "    if preds_8[key] == label:\n",
    "        acc_8 += 1\n",
    "    print(sc_ans, preds_8[key], label, sc_ans==label)\n",
    "    print('70 acc', acc_70/len(preds_70))\n",
    "    print('8 acc', acc_8/len(preds_8))\n",
    "    print(key)\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "nvotes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(preds_70)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_votes_t = torch.tensor(n_votes)/20\n",
    "v_freq = {}\n",
    "apv = {}\n",
    "for i in range(10, 21):\n",
    "    apv[i] = 0\n",
    "    v_freq[i] = 0\n",
    "for v in n_votes:\n",
    "    if v > 10:\n",
    "        apv[v] += 1\n",
    "        v_freq[v] += 1\n",
    "    elif v <= 10:\n",
    "        v_freq[20-v] += 1\n",
    "\n",
    "aapv = []\n",
    "for key, value in apv.items():\n",
    "    aapv.append(value/v_freq[key])\n",
    "\n",
    "line = []\n",
    "for i in np.arange(0, 1.05, 0.05):\n",
    "    line.append(i)\n",
    "plt.plot(np.arange(0.55, 1.05, 0.05), aapv[1:], label='70B SC-20 calibration')\n",
    "plt.plot(np.arange(0, 1.05, 0.05),line, color='black')\n",
    "plt.ylim(0,1)\n",
    "plt.xlim(0, 1)\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "predvars = {}\n",
    "for key, value in temp_outs.items():\n",
    "    predvars[key] = []\n",
    "    for i in range(2, int(((len(value[0])-1))/3)+2, 3):\n",
    "        predvars[key].append(value[0][i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "temp_outs_str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "varlabels = json.load(open('/home/XXXX/XXXX/fs_backup_feb13/LLM-project/core_labels.json', 'r'))\n",
    "fullrules = json.load(open('/home/XXXX/XXXX/fs_backup_feb13/LLM-project/core_fullrules.json', 'r'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "data[579]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "varlabels.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(varlabels['579'])\n",
    "print(fullrules['579'])\n",
    "print(predvars['clutrr579.cnf'])\n",
    "print(temp_outs['clutrr579.cnf'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "temp_outs['clutrr579.cnf'][6][0].split('Facts:')[5]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "LLM",
   "language": "python",
   "name": "llm"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
