{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b025c6a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9ca5b216",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = \"../results/data/llm/seed-1_n-1000_hidden_dim-64_dropout-0_1_alpha-0_3_beta-0_0_oracle_z-False_pretrained-True_lr-0_0001_batch_size-128_epochs-35.csv\"\n",
    "df = pd.read_csv(path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e857d63e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.microsoft.datawrangler.viewer.v0+json": {
       "columns": [
        {
         "name": "index",
         "rawType": "int64",
         "type": "integer"
        },
        {
         "name": "y1",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "z1",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "y2",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "z2",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "y3",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "z3",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "y4",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "z4",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "z1_hat",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "z2_hat",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "z1_3_hat",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "z2_3_hat",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "z3_hat",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "z1_4_hat",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "z2_4_hat",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "z3_4_hat",
         "rawType": "float64",
         "type": "float"
        }
       ],
       "conversionMethod": "pd.DataFrame",
       "ref": "6b955114-7ff5-4866-b773-1f295c0fb328",
       "rows": [
        [
         "0",
         "0.0",
         "1.9937248",
         "0.0",
         "-0.6509609",
         "0.0",
         "-7.4820137",
         "1.0",
         "1.8291702",
         "0.8549467",
         "0.89610445",
         "-2.272234",
         "-0.80852485",
         "-1.0811514",
         "3.2278125",
         "0.616273",
         "0.9241694"
        ],
        [
         "1",
         "0.0",
         "-8.09981",
         "0.0",
         "0.67287445",
         "1.0",
         "0.86840916",
         "1.0",
         "2.3801975",
         "-12.7426815",
         "-0.60101056",
         "3.536963",
         "0.9857875",
         "0.82255125",
         "1.4847968",
         "0.29602122",
         "0.25663686"
        ],
        [
         "2",
         "1.0",
         "0.88302135",
         "1.0",
         "0.808136",
         "0.0",
         "-8.453954",
         "1.0",
         "-1.5898762",
         "0.5706606",
         "0.524539",
         "-6.8562403",
         "-0.42343384",
         "-2.1655457",
         "-1.0376604",
         "-0.53832376",
         "-0.6637684"
        ],
        [
         "3",
         "1.0",
         "0.26830387",
         "1.0",
         "2.0808907",
         "0.0",
         "-1.8535433",
         "1.0",
         "2.2911828",
         "1.5340755",
         "0.053582013",
         "-4.933869",
         "-1.3338429",
         "-1.2765411",
         "2.0281792",
         "0.41026378",
         "1.2312617"
        ],
        [
         "4",
         "1.0",
         "1.3065014",
         "0.0",
         "0.16399956",
         "0.0",
         "-5.328679",
         "0.0",
         "-1.1477051",
         "1.3374186",
         "-0.2511729",
         "-5.8059936",
         "-0.043350577",
         "-1.9740379",
         "-3.4672117",
         "-1.1613402",
         "-1.0922351"
        ],
        [
         "5",
         "1.0",
         "2.5654626",
         "0.0",
         "-1.4580927",
         "1.0",
         "8.95372",
         "0.0",
         "-0.9331236",
         "-0.21000195",
         "0.28892827",
         "8.287624",
         "1.847697",
         "3.9610927",
         "-1.2590092",
         "-1.0274892",
         "-0.47676098"
        ],
        [
         "6",
         "1.0",
         "3.0893497",
         "0.0",
         "0.7328434",
         "1.0",
         "0.7939329",
         "0.0",
         "-8.137357",
         "4.7044306",
         "-0.26706028",
         "-1.0498511",
         "-0.4866192",
         "-0.33695304",
         "-9.007943",
         "-2.2538912",
         "-2.7116437"
        ],
        [
         "7",
         "1.0",
         "0.4901867",
         "1.0",
         "0.93519974",
         "0.0",
         "-4.8038654",
         "1.0",
         "0.5989809",
         "1.1871674",
         "0.6212021",
         "-2.8292704",
         "-0.6240026",
         "-0.9475237",
         "1.2415862",
         "0.23655897",
         "0.11635983"
        ],
        [
         "8",
         "1.0",
         "0.38005638",
         "0.0",
         "-1.2031584",
         "0.0",
         "-0.53120804",
         "0.0",
         "-4.396324",
         "3.4428248",
         "-0.62705374",
         "0.3140337",
         "0.3667253",
         "0.11135256",
         "-5.519673",
         "-1.9393952",
         "-1.7119906"
        ],
        [
         "9",
         "1.0",
         "3.8459787",
         "1.0",
         "0.05077839",
         "1.0",
         "4.179987",
         "0.0",
         "-1.2933192",
         "2.5908794",
         "-0.61026716",
         "5.2148004",
         "1.3044527",
         "1.331353",
         "-0.15475404",
         "0.04402685",
         "-0.286533"
        ],
        [
         "10",
         "0.0",
         "-0.945138",
         "1.0",
         "0.87330246",
         "0.0",
         "-2.4127522",
         "1.0",
         "0.6881094",
         "-1.3825462",
         "0.55050766",
         "-4.249235",
         "-1.0859833",
         "-1.4441092",
         "0.19446874",
         "0.07763338",
         "-0.21585023"
        ],
        [
         "11",
         "0.0",
         "-1.3933792",
         "0.0",
         "-0.360734",
         "1.0",
         "1.41679",
         "1.0",
         "2.321828",
         "-0.9567844",
         "-0.25124264",
         "3.8141737",
         "1.3917034",
         "1.058485",
         "1.3291314",
         "0.21550775",
         "0.72984856"
        ],
        [
         "12",
         "0.0",
         "-1.1602201",
         "0.0",
         "-3.4439878",
         "0.0",
         "2.762514",
         "0.0",
         "-0.08939552",
         "-2.192527",
         "-1.0576122",
         "5.6603246",
         "1.6449274",
         "1.1476169",
         "0.67369944",
         "0.34605432",
         "0.2829522"
        ],
        [
         "13",
         "1.0",
         "0.032773018",
         "1.0",
         "3.8872862",
         "1.0",
         "-0.10562897",
         "0.0",
         "1.3833008",
         "-3.6491609",
         "1.6848354",
         "1.8675478",
         "0.48364425",
         "0.56913877",
         "3.2396317",
         "1.1474407",
         "0.9599446"
        ],
        [
         "14",
         "1.0",
         "1.7164621",
         "1.0",
         "1.6000118",
         "0.0",
         "-5.744012",
         "1.0",
         "-0.71268463",
         "2.1026583",
         "0.1878854",
         "-0.93569803",
         "-0.81135845",
         "-0.21878791",
         "0.0301857",
         "-0.120226264",
         "-0.37162328"
        ],
        [
         "15",
         "0.0",
         "-1.9355564",
         "0.0",
         "-0.85639095",
         "0.0",
         "-3.1075778",
         "1.0",
         "-0.0009317398",
         "2.4447176",
         "-0.13525605",
         "-4.436261",
         "-1.4617306",
         "-0.89740336",
         "-0.17073488",
         "-0.18875933",
         "-0.17687869"
        ],
        [
         "16",
         "1.0",
         "7.207368",
         "0.0",
         "-1.0677547",
         "1.0",
         "1.5397692",
         "1.0",
         "-0.8865633",
         "10.525972",
         "-1.456786",
         "2.9140046",
         "1.064453",
         "1.051087",
         "-0.62611985",
         "-0.05254388",
         "-0.21774375"
        ],
        [
         "17",
         "0.0",
         "-2.1878958",
         "1.0",
         "1.4654799",
         "1.0",
         "2.3268433",
         "0.0",
         "-2.850463",
         "-1.6019534",
         "-1.3890487",
         "1.4056797",
         "0.08369374",
         "0.49356198",
         "-1.3972983",
         "-0.38687795",
         "-0.49172503"
        ],
        [
         "18",
         "1.0",
         "3.8268962",
         "0.0",
         "-0.071845055",
         "1.0",
         "0.7208109",
         "0.0",
         "-0.35454178",
         "0.070225716",
         "-0.5854327",
         "-1.528199",
         "-1.0489051",
         "0.24018896",
         "0.20101023",
         "0.08836532",
         "0.26429254"
        ],
        [
         "19",
         "0.0",
         "-1.9699211",
         "1.0",
         "8.50214",
         "1.0",
         "1.2500296",
         "1.0",
         "1.188096",
         "-2.7649026",
         "2.6247683",
         "1.4654632",
         "0.42675316",
         "0.62178934",
         "0.97912556",
         "0.21881258",
         "0.2987678"
        ],
        [
         "20",
         "1.0",
         "-0.06771469",
         "0.0",
         "-0.53389263",
         "1.0",
         "-0.17579126",
         "1.0",
         "0.58763885",
         "0.3145833",
         "-0.73423004",
         "0.8149557",
         "0.105069935",
         "0.2340889",
         "-0.833218",
         "0.2764356",
         "-0.10980892"
        ],
        [
         "21",
         "0.0",
         "-4.287193",
         "0.0",
         "-1.2230701",
         "1.0",
         "0.9471512",
         "0.0",
         "-0.7174969",
         "-3.226312",
         "-0.9545455",
         "-0.30608284",
         "-0.47722462",
         "-0.23554999",
         "-3.0173638",
         "-1.627839",
         "-0.67220783"
        ],
        [
         "22",
         "0.0",
         "0.994647",
         "0.0",
         "-1.3480701",
         "1.0",
         "8.095608",
         "0.0",
         "-3.645629",
         "3.7016358",
         "-0.6195089",
         "6.179179",
         "1.1513065",
         "2.0343518",
         "-5.2494726",
         "-0.72989523",
         "-1.4101162"
        ],
        [
         "23",
         "1.0",
         "-0.48989487",
         "0.0",
         "-0.037558556",
         "1.0",
         "0.7561846",
         "0.0",
         "0.34131527",
         "-0.47354686",
         "0.091999054",
         "3.6243043",
         "1.1370776",
         "1.1414471",
         "0.9350437",
         "0.3265442",
         "0.20175385"
        ],
        [
         "24",
         "0.0",
         "-1.140276",
         "0.0",
         "-1.4247818",
         "0.0",
         "-0.1515255",
         "0.0",
         "-0.18537426",
         "-1.8360896",
         "-0.3671614",
         "1.1712637",
         "0.048989832",
         "0.50599873",
         "2.2109303",
         "1.1626644",
         "0.6305804"
        ],
        [
         "25",
         "0.0",
         "-0.3966837",
         "1.0",
         "0.3905182",
         "0.0",
         "-10.583528",
         "0.0",
         "-1.5810614",
         "1.2567239",
         "0.035952926",
         "-12.653603",
         "-3.313342",
         "-4.3639436",
         "-4.6930237",
         "-1.6609739",
         "-1.2218524"
        ],
        [
         "26",
         "1.0",
         "-0.14276886",
         "0.0",
         "-4.287193",
         "0.0",
         "-7.470641",
         "0.0",
         "-10.40211",
         "-1.3396163",
         "-0.41637757",
         "-6.3708515",
         "-1.782364",
         "-2.2291555",
         "-9.996332",
         "-1.6264615",
         "-3.3420572"
        ],
        [
         "27",
         "0.0",
         "-0.24251556",
         "0.0",
         "-3.0546532",
         "1.0",
         "0.8072796",
         "1.0",
         "-1.2137384",
         "0.25635988",
         "-0.86406434",
         "4.0939198",
         "1.16556",
         "0.8577112",
         "-1.4443464",
         "-0.34630746",
         "-0.48205942"
        ],
        [
         "28",
         "0.0",
         "0.42707825",
         "1.0",
         "1.8215022",
         "1.0",
         "3.6203346",
         "0.0",
         "-0.47789764",
         "1.25034",
         "0.55765986",
         "4.158628",
         "1.1145108",
         "1.7794361",
         "-0.29596114",
         "0.9080993",
         "-0.0067715645"
        ],
        [
         "29",
         "0.0",
         "-0.34167767",
         "1.0",
         "1.0633783",
         "0.0",
         "-4.06012",
         "0.0",
         "2.9014935",
         "0.1473856",
         "0.3256005",
         "-4.173903",
         "-0.82326114",
         "-1.3128396",
         "4.2018948",
         "2.2994711",
         "1.2361169"
        ],
        [
         "30",
         "1.0",
         "-0.963572",
         "1.0",
         "2.7862167",
         "0.0",
         "0.066020966",
         "1.0",
         "9.946655",
         "0.5204035",
         "0.26451728",
         "0.66760945",
         "0.066874385",
         "0.049239397",
         "8.86756",
         "1.9947956",
         "2.4224198"
        ],
        [
         "31",
         "1.0",
         "1.742352",
         "0.0",
         "-2.3328333",
         "0.0",
         "0.524889",
         "1.0",
         "7.8178406",
         "-0.37081218",
         "-2.5538507",
         "2.0468433",
         "0.60642934",
         "0.7066717",
         "7.7587175",
         "1.6380603",
         "2.3523552"
        ],
        [
         "32",
         "1.0",
         "0.70682144",
         "1.0",
         "0.18207073",
         "0.0",
         "-4.0286584",
         "1.0",
         "3.908616",
         "-0.27613592",
         "0.590909",
         "-4.8706803",
         "-1.0800779",
         "-2.0737424",
         "5.412171",
         "1.5318878",
         "1.7459888"
        ],
        [
         "33",
         "0.0",
         "-0.08576536",
         "1.0",
         "0.13587093",
         "1.0",
         "0.7046547",
         "1.0",
         "4.1727905",
         "0.04054165",
         "0.1788888",
         "1.5247111",
         "0.6098418",
         "0.59946823",
         "4.768379",
         "1.1878684",
         "2.1739876"
        ],
        [
         "34",
         "1.0",
         "3.218361",
         "1.0",
         "2.4716864",
         "1.0",
         "1.3208485",
         "0.0",
         "-4.0030184",
         "5.646576",
         "0.10342622",
         "-0.61494017",
         "-0.3892434",
         "-0.2676981",
         "-1.8925455",
         "-0.36334556",
         "-0.51434314"
        ],
        [
         "35",
         "1.0",
         "1.4841309",
         "0.0",
         "-0.28345394",
         "0.0",
         "-2.2149372",
         "1.0",
         "-0.6994686",
         "0.65243304",
         "0.016221285",
         "-1.9627995",
         "-0.30269802",
         "-0.42366624",
         "0.917645",
         "0.33126283",
         "0.4102564"
        ],
        [
         "36",
         "1.0",
         "2.5984879",
         "0.0",
         "-0.21037102",
         "1.0",
         "1.4628305",
         "0.0",
         "0.21147728",
         "0.4688158",
         "0.116663456",
         "3.212581",
         "0.31710553",
         "0.76812625",
         "-0.13168907",
         "0.18253434",
         "-0.068448305"
        ],
        [
         "37",
         "1.0",
         "3.4234123",
         "0.0",
         "-0.21338749",
         "0.0",
         "-1.7083507",
         "1.0",
         "-0.40651608",
         "-1.8474011",
         "0.32478732",
         "-1.2827768",
         "-0.3670355",
         "-0.5601407",
         "0.5514443",
         "0.22007611",
         "0.19916129"
        ],
        [
         "38",
         "1.0",
         "1.1702452",
         "1.0",
         "-0.21907997",
         "1.0",
         "7.226799",
         "0.0",
         "0.100087166",
         "1.9406099",
         "-1.1122701",
         "6.5056014",
         "1.5717987",
         "2.7092133",
         "0.26174688",
         "0.091358185",
         "0.07219541"
        ],
        [
         "39",
         "1.0",
         "1.947628",
         "0.0",
         "-2.8513846",
         "0.0",
         "-1.0899444",
         "1.0",
         "-0.4940939",
         "1.3148229",
         "0.1989317",
         "-3.6261537",
         "-1.1161356",
         "-0.6500262",
         "-2.5440378",
         "-0.54020953",
         "-0.71543115"
        ],
        [
         "40",
         "1.0",
         "0.16145563",
         "1.0",
         "-0.3233652",
         "1.0",
         "1.2181435",
         "0.0",
         "1.1365719",
         "-0.057422638",
         "-0.10511791",
         "1.1823039",
         "0.4235048",
         "0.37002647",
         "0.14688349",
         "0.37541544",
         "0.26440793"
        ],
        [
         "41",
         "1.0",
         "-0.7611332",
         "0.0",
         "-2.2865934",
         "0.0",
         "0.10745144",
         "1.0",
         "0.32702637",
         "-1.5748513",
         "-0.90599185",
         "-0.66677445",
         "0.029883027",
         "-0.057782292",
         "0.3310995",
         "0.2760108",
         "0.07834208"
        ],
        [
         "42",
         "0.0",
         "-4.3392067",
         "0.0",
         "-10.097108",
         "1.0",
         "-1.4247818",
         "1.0",
         "6.007533",
         "-1.7558641",
         "-1.8943936",
         "-0.77063316",
         "-0.3671614",
         "-0.30222464",
         "14.326752",
         "4.0043726",
         "4.9605227"
        ],
        [
         "43",
         "1.0",
         "3.3046422",
         "1.0",
         "4.133471",
         "0.0",
         "-0.19864559",
         "0.0",
         "-2.7844334",
         "3.2014854",
         "1.2627867",
         "-0.7217531",
         "-0.45270872",
         "-0.39071584",
         "-1.3136144",
         "1.6704736",
         "-0.5821512"
        ],
        [
         "44",
         "1.0",
         "2.3609161",
         "1.0",
         "-3.1695595",
         "1.0",
         "0.24363995",
         "0.0",
         "-2.7831235",
         "-2.5590563",
         "-2.7396839",
         "-0.65134",
         "-0.2142111",
         "-0.08684218",
         "-5.0696883",
         "-1.700749",
         "-1.2392976"
        ],
        [
         "45",
         "0.0",
         "-0.85719013",
         "1.0",
         "0.07346916",
         "0.0",
         "-2.0701609",
         "1.0",
         "4.7937384",
         "-0.07866341",
         "0.99881554",
         "-5.8477483",
         "-1.0794389",
         "-2.37154",
         "2.8407013",
         "0.9571092",
         "1.5596731"
        ],
        [
         "46",
         "0.0",
         "-0.24299622",
         "0.0",
         "1.9328647",
         "1.0",
         "0.93519974",
         "0.0",
         "-1.6065197",
         "-0.29967892",
         "0.5823996",
         "2.4360168",
         "0.6212021",
         "0.6968285",
         "0.3773451",
         "0.19445944",
         "0.23429394"
        ],
        [
         "47",
         "1.0",
         "0.2428217",
         "1.0",
         "3.5962183",
         "1.0",
         "2.7830758",
         "0.0",
         "-0.7164831",
         "-0.12619162",
         "0.95758235",
         "2.4170127",
         "0.5930512",
         "1.0029128",
         "0.17512298",
         "-0.20719242",
         "-0.2661456"
        ],
        [
         "48",
         "0.0",
         "-0.68826675",
         "0.0",
         "-3.5915334",
         "1.0",
         "0.5340295",
         "0.0",
         "2.079257",
         "0.34457326",
         "-0.44588017",
         "0.069150925",
         "-0.14824384",
         "-0.20837224",
         "3.6779208",
         "0.9542108",
         "0.79652894"
        ],
        [
         "49",
         "1.0",
         "0.88454247",
         "1.0",
         "2.669846",
         "0.0",
         "-2.8970728",
         "1.0",
         "0.43360996",
         "0.5019988",
         "0.5021172",
         "-6.9486046",
         "-2.8761978",
         "-1.804743",
         "0.63981676",
         "-0.4525468",
         "-0.38865995"
        ]
       ],
       "shape": {
        "columns": 16,
        "rows": 2400
       }
      },
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>y1</th>\n",
       "      <th>z1</th>\n",
       "      <th>y2</th>\n",
       "      <th>z2</th>\n",
       "      <th>y3</th>\n",
       "      <th>z3</th>\n",
       "      <th>y4</th>\n",
       "      <th>z4</th>\n",
       "      <th>z1_hat</th>\n",
       "      <th>z2_hat</th>\n",
       "      <th>z1_3_hat</th>\n",
       "      <th>z2_3_hat</th>\n",
       "      <th>z3_hat</th>\n",
       "      <th>z1_4_hat</th>\n",
       "      <th>z2_4_hat</th>\n",
       "      <th>z3_4_hat</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.0</td>\n",
       "      <td>1.993725</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-0.650961</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-7.482014</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.829170</td>\n",
       "      <td>0.854947</td>\n",
       "      <td>0.896104</td>\n",
       "      <td>-2.272234</td>\n",
       "      <td>-0.808525</td>\n",
       "      <td>-1.081151</td>\n",
       "      <td>3.227813</td>\n",
       "      <td>0.616273</td>\n",
       "      <td>0.924169</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.0</td>\n",
       "      <td>-8.099810</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.672874</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.868409</td>\n",
       "      <td>1.0</td>\n",
       "      <td>2.380197</td>\n",
       "      <td>-12.742681</td>\n",
       "      <td>-0.601011</td>\n",
       "      <td>3.536963</td>\n",
       "      <td>0.985788</td>\n",
       "      <td>0.822551</td>\n",
       "      <td>1.484797</td>\n",
       "      <td>0.296021</td>\n",
       "      <td>0.256637</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.883021</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.808136</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-8.453954</td>\n",
       "      <td>1.0</td>\n",
       "      <td>-1.589876</td>\n",
       "      <td>0.570661</td>\n",
       "      <td>0.524539</td>\n",
       "      <td>-6.856240</td>\n",
       "      <td>-0.423434</td>\n",
       "      <td>-2.165546</td>\n",
       "      <td>-1.037660</td>\n",
       "      <td>-0.538324</td>\n",
       "      <td>-0.663768</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.268304</td>\n",
       "      <td>1.0</td>\n",
       "      <td>2.080891</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-1.853543</td>\n",
       "      <td>1.0</td>\n",
       "      <td>2.291183</td>\n",
       "      <td>1.534075</td>\n",
       "      <td>0.053582</td>\n",
       "      <td>-4.933869</td>\n",
       "      <td>-1.333843</td>\n",
       "      <td>-1.276541</td>\n",
       "      <td>2.028179</td>\n",
       "      <td>0.410264</td>\n",
       "      <td>1.231262</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1.0</td>\n",
       "      <td>1.306501</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.164000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-5.328679</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-1.147705</td>\n",
       "      <td>1.337419</td>\n",
       "      <td>-0.251173</td>\n",
       "      <td>-5.805994</td>\n",
       "      <td>-0.043351</td>\n",
       "      <td>-1.974038</td>\n",
       "      <td>-3.467212</td>\n",
       "      <td>-1.161340</td>\n",
       "      <td>-1.092235</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2395</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.493604</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-3.196142</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.354650</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.370169</td>\n",
       "      <td>-2.290564</td>\n",
       "      <td>-0.320767</td>\n",
       "      <td>5.206455</td>\n",
       "      <td>0.912170</td>\n",
       "      <td>1.605893</td>\n",
       "      <td>0.733152</td>\n",
       "      <td>0.154391</td>\n",
       "      <td>-0.054164</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2396</th>\n",
       "      <td>0.0</td>\n",
       "      <td>-2.971936</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-0.142769</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-1.917747</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-2.073215</td>\n",
       "      <td>-1.974794</td>\n",
       "      <td>-0.422325</td>\n",
       "      <td>-3.050104</td>\n",
       "      <td>-0.992282</td>\n",
       "      <td>-0.816071</td>\n",
       "      <td>-1.472965</td>\n",
       "      <td>-0.352178</td>\n",
       "      <td>-0.434663</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2397</th>\n",
       "      <td>0.0</td>\n",
       "      <td>-1.961956</td>\n",
       "      <td>1.0</td>\n",
       "      <td>-0.725107</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-0.936038</td>\n",
       "      <td>1.0</td>\n",
       "      <td>10.331485</td>\n",
       "      <td>-2.341723</td>\n",
       "      <td>-1.120778</td>\n",
       "      <td>-3.272769</td>\n",
       "      <td>-1.077912</td>\n",
       "      <td>-1.205044</td>\n",
       "      <td>12.717158</td>\n",
       "      <td>2.804857</td>\n",
       "      <td>4.902759</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2398</th>\n",
       "      <td>0.0</td>\n",
       "      <td>-0.121234</td>\n",
       "      <td>1.0</td>\n",
       "      <td>6.311640</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-2.089078</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-0.487607</td>\n",
       "      <td>-0.643316</td>\n",
       "      <td>0.990444</td>\n",
       "      <td>-1.397153</td>\n",
       "      <td>-0.339122</td>\n",
       "      <td>-0.228012</td>\n",
       "      <td>-0.641400</td>\n",
       "      <td>-0.425361</td>\n",
       "      <td>-0.129966</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2399</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.030984</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-0.112260</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.001669</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.299991</td>\n",
       "      <td>1.433333</td>\n",
       "      <td>-0.520961</td>\n",
       "      <td>-1.295581</td>\n",
       "      <td>-0.331267</td>\n",
       "      <td>-0.202214</td>\n",
       "      <td>2.229199</td>\n",
       "      <td>1.080698</td>\n",
       "      <td>0.816860</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>2400 rows × 16 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       y1        z1   y2        z2   y3        z3   y4         z4     z1_hat  \\\n",
       "0     0.0  1.993725  0.0 -0.650961  0.0 -7.482014  1.0   1.829170   0.854947   \n",
       "1     0.0 -8.099810  0.0  0.672874  1.0  0.868409  1.0   2.380197 -12.742681   \n",
       "2     1.0  0.883021  1.0  0.808136  0.0 -8.453954  1.0  -1.589876   0.570661   \n",
       "3     1.0  0.268304  1.0  2.080891  0.0 -1.853543  1.0   2.291183   1.534075   \n",
       "4     1.0  1.306501  0.0  0.164000  0.0 -5.328679  0.0  -1.147705   1.337419   \n",
       "...   ...       ...  ...       ...  ...       ...  ...        ...        ...   \n",
       "2395  0.0  0.493604  0.0 -3.196142  1.0  1.354650  0.0   0.370169  -2.290564   \n",
       "2396  0.0 -2.971936  0.0 -0.142769  0.0 -1.917747  0.0  -2.073215  -1.974794   \n",
       "2397  0.0 -1.961956  1.0 -0.725107  0.0 -0.936038  1.0  10.331485  -2.341723   \n",
       "2398  0.0 -0.121234  1.0  6.311640  0.0 -2.089078  0.0  -0.487607  -0.643316   \n",
       "2399  0.0  0.030984  0.0 -0.112260  1.0  0.001669  0.0   0.299991   1.433333   \n",
       "\n",
       "        z2_hat  z1_3_hat  z2_3_hat    z3_hat   z1_4_hat  z2_4_hat  z3_4_hat  \n",
       "0     0.896104 -2.272234 -0.808525 -1.081151   3.227813  0.616273  0.924169  \n",
       "1    -0.601011  3.536963  0.985788  0.822551   1.484797  0.296021  0.256637  \n",
       "2     0.524539 -6.856240 -0.423434 -2.165546  -1.037660 -0.538324 -0.663768  \n",
       "3     0.053582 -4.933869 -1.333843 -1.276541   2.028179  0.410264  1.231262  \n",
       "4    -0.251173 -5.805994 -0.043351 -1.974038  -3.467212 -1.161340 -1.092235  \n",
       "...        ...       ...       ...       ...        ...       ...       ...  \n",
       "2395 -0.320767  5.206455  0.912170  1.605893   0.733152  0.154391 -0.054164  \n",
       "2396 -0.422325 -3.050104 -0.992282 -0.816071  -1.472965 -0.352178 -0.434663  \n",
       "2397 -1.120778 -3.272769 -1.077912 -1.205044  12.717158  2.804857  4.902759  \n",
       "2398  0.990444 -1.397153 -0.339122 -0.228012  -0.641400 -0.425361 -0.129966  \n",
       "2399 -0.520961 -1.295581 -0.331267 -0.202214   2.229199  1.080698  0.816860  \n",
       "\n",
       "[2400 rows x 16 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d5ff422c",
   "metadata": {},
   "outputs": [],
   "source": [
    "z1_hat = df[\"z1_hat\"]\n",
    "z2_hat = df[\"z2_hat\"]\n",
    "v = 0.5 * (df[\"z1_3_hat\"] - df[\"z2_3_hat\"])\n",
    "z_hat = 0.5 * (df[\"z1_4_hat\"] + df[\"z2_4_hat\"])\n",
    "q = np.quantile(np.abs(v), 0.9 * (1 + len(df) ** -1))\n",
    "lcb = z_hat - q\n",
    "ucb = z_hat + q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "55ecc3b8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "np.int64(327)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(lcb > 0).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "a21cb606",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "np.int64(335)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(ucb < 0).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "2e77a236",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "np.float64(0.9145042839752527)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from scipy.special import expit as sigmoid\n",
    "sigmoid(q)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "d2172875",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_result = pd.DataFrame({\n",
    "    \"p\": sigmoid(z_hat),\n",
    "    \"lcb\": sigmoid(lcb),\n",
    "    \"ucb\": sigmoid(ucb),\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "93cdfca0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "p      0.772318\n",
       "lcb    0.240769\n",
       "ucb    0.973178\n",
       "Name: 207, dtype: float64"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_result.iloc[207]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "6dbd6da7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "p      0.937128\n",
       "lcb    0.582199\n",
       "ucb    0.993767\n",
       "Name: 283, dtype: float64"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_result.iloc[283]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "0fb11cbd",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/shirakawatoru/research/repos/lcp/.venv/lib/python3.13/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "Generating preference split: 100%|██████████| 9691/9691 [00:00<00:00, 651895.69 examples/s]\n"
     ]
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "ds = load_dataset(\"tatsu-lab/alpaca_farm\", \"alpaca_human_preference\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "97a3ee17",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "9691"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(ds['preference'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "efaae38f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'instruction': 'Produce a plan for teaching basic philosophy.',\n",
       " 'input': 'Class: Freshman college course',\n",
       " 'output_1': 'A plan for teaching basic philosophy could include introducing students to the major branches of philosophy, such as ethics, epistemology, metaphysics, and political philosophy. Students should be given time to explore the history of philosophy and learn about the works of key philosophers. Additionally, students should be exposed to different philosophical schools of thought and concepts. The plan should also include discussions about contemporary issues in philosophy, such as the role of reason in moral decision-making, the tension between individualism and collectivism, and the impact of technology on society.',\n",
       " 'output_2': 'Philosophy for Freshmen: A Basic Introduction\\n\\nThis course is designed to introduce students to the fundamental concepts and issues in philosophy. The class will consist of ten weeks of lectures, discussions and essay assignments, and will cover topics including epistemology, metaphysics, ethics, and the philosophy of religion.\\n\\nLecture 1: Introduction to Philosophy \\nLecture 2: Epistemology\\nLecture 3: Metaphysics\\nLecture 4: Ethics\\nLecture 5: Deontological Ethics \\nLecture 6: Utilitarian Ethics \\nLecture 7: The Philosophy of Religion\\nLecture 8: Legal Causal \\nLecture 9: Existentialism \\nLecture 10: Conclusion.',\n",
       " 'preference': 2,\n",
       " 'raw_preference': 4}"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "example = ds['preference'][2400 * 3 + 391]\n",
    "example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "d59de70c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'instruction': 'Summarize the benefits of a plant-based diet in one sentence.',\n",
       " 'input': '',\n",
       " 'output_1': 'A plant-based diet is beneficial for both the planet and our health, as it reduces meat, dairy and animal product consumption, decreasing environmental impact and risk of chronic diseases.',\n",
       " 'output_2': 'A plant-based diet provides numerous health benefits, including lowering risk of disease, providing essential nutrients, and being more sustainable.',\n",
       " 'preference': 1,\n",
       " 'raw_preference': 1}"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "example = ds['preference'][2400 * 3 + 283]\n",
    "example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6768b720",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
