{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import torch \n",
    "import random\n",
    "import traceback\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import matplotlib.pyplot as plt\n",
    "import torch.nn.functional as F\n",
    "from typing import List, Tuple\n",
    "from sklearn.metrics import f1_score\n",
    "from utils import  judge_multi_choice\n",
    "from itertools import cycle\n",
    "from confseq.betting import betting_cs\n",
    "from confseq.predmix import  predmix_empbern_twosided_cs\n",
    "from confseq.boundaries import gamma_exponential_mixture_bound\n",
    "from confseq.conjmix_bounded import conjmix_empbern_twosided_cs\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "plt.rcParams['font.family'] = 'Times New Roman'\n",
    "def set_seed(seed: int):\n",
    "    import os\n",
    "    import random\n",
    "    import numpy as np\n",
    "    import torch\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def clean_responses(data):\n",
    "    cleaned_data = []\n",
    "    full_alphas = [chr(ord('A') + i) for i in range(8)]  # A-H\n",
    "    \n",
    "    for item in data:\n",
    "        answer = item[\"answer\"].strip().upper()\n",
    "        response_raw = item[\"response\"]\n",
    "        choices = item[\"choices\"]\n",
    "        correct = item[\"correct\"]\n",
    "\n",
    "        alphas = full_alphas[:len(choices)]\n",
    "        alphas = full_alphas[:len(choices)][::-1]\n",
    "        parsed_response = \"Z\"\n",
    "        found_match = False\n",
    "        \n",
    "        for alpha in alphas:\n",
    "            try:\n",
    "                if judge_multi_choice(choices, alpha, response_raw, alphas=None):\n",
    "                    parsed_response = alpha\n",
    "                    found_match = True\n",
    "                    break\n",
    "            except Exception as e:\n",
    "                print(f\"[Warning] Error processing question: {e}\")\n",
    "                continue\n",
    "    \n",
    "        if not found_match:\n",
    "            parsed_response = \"Z\"\n",
    "        \n",
    "        cleaned_data.append({\n",
    "            \"answer\": answer,\n",
    "            \"response\": parsed_response,\n",
    "            \"correct\": correct\n",
    "        })\n",
    "    return cleaned_data\n",
    "\n",
    "def load_cleaned_data(model_name, dataset_name):\n",
    "    json_path = f\"cleaned_data/{model_name}/{dataset_name}/cleaned_results.json\"\n",
    "    with open(json_path, \"r\", encoding=\"utf-8\") as f:\n",
    "        return json.load(f)\n",
    "\n",
    "def compute_accuracy(data, labels=None):\n",
    "\n",
    "    if labels is None:\n",
    "        correct_count = sum([1 if d['correct'] else 0 for d in data])\n",
    "    else:\n",
    "        correct_count = sum([1 if d['response'] == lbl else 0 for d, lbl in zip(data, labels)])\n",
    "    return correct_count / len(data)\n",
    "\n",
    "def compute_accuracy_agent(data, labels=None, true_rate=1.0):\n",
    "    if labels is None:\n",
    "        raise ValueError(\"Labels must be provided for agent accuracy computation.\")\n",
    "\n",
    "    simulated_labels = []\n",
    "    for lbl in labels:\n",
    "        if random.random() < true_rate:\n",
    "            simulated_labels.append(lbl)\n",
    "        else:\n",
    "            wrong_choices = [l for l in set(labels) if l != lbl]\n",
    "            simulated_labels.append(random.choice(wrong_choices) if wrong_choices else lbl)\n",
    "\n",
    "    correct_count = sum(1 if d['response'] == lbl else 0 for d, lbl in zip(data, simulated_labels))\n",
    "    return correct_count / len(data)\n",
    "\n",
    "def Generate_Fake_labels(labels=None, true_rate=1.0):\n",
    "    simulated_labels = []\n",
    "    for lbl in labels:\n",
    "        if random.random() < true_rate:\n",
    "            simulated_labels.append(lbl)\n",
    "        else:\n",
    "            wrong_choices = [l for l in set(labels) if l != lbl]\n",
    "            simulated_labels.append(random.choice(wrong_choices) if wrong_choices else lbl)\n",
    "    return simulated_labels \n",
    "\n",
    "\n",
    "def compute_v_opt(x, t_opt):\n",
    "    x = np.array(x)\n",
    "    t = np.arange(1, len(x) + 1)\n",
    "    S_t = np.cumsum(x)\n",
    "    mu_hat_t = S_t / t\n",
    "    mu_hat_tminus1 = np.append(1 / 2, mu_hat_t[0 : (len(mu_hat_t) - 1)])\n",
    "    V_t = np.cumsum(np.power(x - mu_hat_tminus1, 2))\n",
    "    v_opt = V_t[t_opt] * t_opt\n",
    "    return v_opt\n",
    "    \n",
    "def running_average_cumulative(x): \n",
    "    return np.cumsum(x) / (np.arange(len(x)) + 1)\n",
    "\n",
    "def shuffle_by_severity_multi(trajs, severities):\n",
    "    shuffled_results = [[] for _ in trajs]\n",
    "    start = 0\n",
    "    n_sev = len(severities)\n",
    "    length_per_sev = len(trajs[0]) // n_sev  \n",
    "    for sev in severities:\n",
    "        idx = np.arange(length_per_sev)\n",
    "        np.random.shuffle(idx)\n",
    "        for i, traj in enumerate(trajs):\n",
    "            part = traj[start:start + length_per_sev]\n",
    "            shuffled_results[i].append(part[idx]) \n",
    "\n",
    "        start += length_per_sev\n",
    "\n",
    "    return [np.concatenate(parts) for parts in shuffled_results]\n",
    "\n",
    "def split_into_batches_list(model_data, agent_data, batch_size):\n",
    "    \"\"\"\n",
    "    \"\"\"\n",
    "    n = len(model_data)\n",
    "    n = n // batch_size * batch_size\n",
    "    indices = list(range(n))\n",
    "    random.shuffle(indices)\n",
    "    model_data_shuf = [model_data[i] for i in indices]\n",
    "    agent_data_shuf = [agent_data[i] for i in indices]\n",
    "\n",
    "    batches = [\n",
    "        model_data_shuf[i : i + batch_size]\n",
    "        for i in range(0, n, batch_size)\n",
    "    ]\n",
    "    agent_batches = [\n",
    "        agent_data_shuf[i : i + batch_size]\n",
    "        for i in range(0, n, batch_size)\n",
    "    ]\n",
    "    return batches, agent_batches\n",
    "\n",
    "def compute_bounds_cmeb(traj_sup_true, traj_unsup_proxy, L1_proxy, alpha, t_opt_ratio):\n",
    "    # --- prepare proxy L1 (as np array) ---\n",
    "    L1 = L1_proxy.numpy() if hasattr(L1_proxy, \"numpy\") else np.array(L1_proxy)\n",
    "    # ====================== conjmix ======================\n",
    "    T_OPT = int(len(traj_sup_true) * t_opt_ratio)\n",
    "    v_opt_sup = compute_v_opt(traj_sup_true, T_OPT)\n",
    "    v_opt_unsup = compute_v_opt(traj_unsup_proxy, T_OPT)\n",
    "\n",
    "    L_sup_conjmix, _ = conjmix_empbern_twosided_cs(x=traj_sup_true, alpha=alpha * 2, v_opt=v_opt_sup)\n",
    "    # unsupervised\n",
    "    L_unsup_conjmix, _ = conjmix_empbern_twosided_cs(x=traj_unsup_proxy, alpha=alpha * 2, v_opt=v_opt_unsup)\n",
    "    L_unsup_conjmix = np.maximum(L_unsup_conjmix - L1, np.zeros_like(L1))\n",
    "    supervised_bounds = {\"CM-EB\": L_sup_conjmix}\n",
    "    unsupervised_bounds = {\"CM-EB\": L_unsup_conjmix}\n",
    "    return supervised_bounds, unsupervised_bounds\n",
    "\n",
    "def compute_bounds_pmeb_betting(traj_sup_true, traj_unsup_proxy, L1_proxy, alpha):\n",
    "    # --- prepare proxy L1 (as np array) ---\n",
    "    L1 = L1_proxy.numpy() if hasattr(L1_proxy, \"numpy\") else np.array(L1_proxy)\n",
    "    # ====================== predmix ======================\n",
    "    # supervised\n",
    "    L_sup_predmix, _ = predmix_empbern_twosided_cs(x=traj_sup_true, alpha=alpha)\n",
    "    # unsupervised\n",
    "    L_unsup_predmix, _ = predmix_empbern_twosided_cs(x=traj_unsup_proxy, alpha=alpha)\n",
    "    L_unsup_predmix = np.maximum(L_unsup_predmix - L1, np.zeros_like(L1))\n",
    "\n",
    "    # ====================== betting ======================\n",
    "    L_sup_betting, _ = betting_cs(x=traj_sup_true, alpha=alpha)\n",
    "\n",
    "    L_unsup_betting, _ = betting_cs(x=traj_unsup_proxy, alpha=alpha)\n",
    "    L_unsup_betting = np.maximum(L_unsup_betting - L1, np.zeros_like(L_unsup_betting))\n",
    "\n",
    "    supervised_bounds = {\"PM-EB\": L_sup_predmix,\"Betting\": L_sup_betting,}\n",
    "    unsupervised_bounds = {\"PM-EB\": L_unsup_predmix,\"Betting\": L_unsup_betting,}\n",
    "    return supervised_bounds, unsupervised_bounds\n",
    "\n",
    "\n",
    "def compute_bounds_cmeb_ppi(traj_ppi_any, alpha, t_opt_ratio, eta_max):\n",
    "    a = - eta_max\n",
    "    b = 1 + eta_max\n",
    "    traj_ppi_any = (traj_ppi_any  - a) / (b - a)\n",
    "    T_OPT = int(len(traj_ppi_any) * t_opt_ratio)\n",
    "    v_opt_ppi = compute_v_opt(traj_ppi_any, T_OPT)\n",
    "    # v_opt_ppi = T_OPT * 0.5\n",
    "    L_ppi_conjmix, _ = conjmix_empbern_twosided_cs(x=traj_ppi_any, alpha=alpha * 2, v_opt=v_opt_ppi)\n",
    "    L_ppi_conjmix = L_ppi_conjmix * (b - a) + a\n",
    "    ppi_bounds = {\"CM-EB\": L_ppi_conjmix}\n",
    "    return ppi_bounds\n",
    "\n",
    "def compute_bounds_pmeb_betting_ppi(traj_ppi_any, alpha, eta_max):\n",
    "    a = - eta_max\n",
    "    b = 1 + eta_max\n",
    "    traj_ppi_any = (traj_ppi_any  - a) / (b - a)\n",
    "    # v_opt_ppi = T_OPT * 0.5\n",
    "    L_ppi_predmix, _ = predmix_empbern_twosided_cs(x=traj_ppi_any, alpha=alpha)\n",
    "    L_ppi_predmix = L_ppi_predmix * (b - a) + a\n",
    "    L_ppi_betting, _ = betting_cs(x=traj_ppi_any, alpha=alpha)\n",
    "    L_ppi_betting = L_ppi_betting * (b - a) + a\n",
    "    ppi_bounds = {\"PM-EB\": L_ppi_predmix, \"Betting\": L_ppi_betting}\n",
    "    return ppi_bounds\n",
    "\n",
    "def compute_accuracy(data, labels=None):\n",
    "    if labels is None:\n",
    "        correct_count = sum([1 if d['correct'] else 0 for d in data])\n",
    "    else:\n",
    "        correct_count = sum([1 if d['response'] == lbl else 0 for d, lbl in zip(data, labels)])\n",
    "    return correct_count / len(data)\n",
    "\n",
    "def compute_01_loss_vector(data, labels):\n",
    "    \"\"\"\n",
    "    data: list of dicts, each with key 'response'\n",
    "    labels: list of true labels\n",
    "    return: np.ndarray of shape (len(data),), values in {0,1}\n",
    "    \"\"\"\n",
    "    return np.array([\n",
    "        0 if d[\"response\"] == lbl else 1\n",
    "        for d, lbl in zip(data, labels)\n",
    "    ], dtype=np.float32)\n",
    "\n",
    "\n",
    "def compute_eta_t(labeled_losses_hist, agent_labeled_losses_hist, agent_unlabeled_losses_hist, eta_max,  eps=1e-8):\n",
    "    u = np.concatenate(labeled_losses_hist, axis=0)\n",
    "    u_tilde_l = np.concatenate(agent_labeled_losses_hist, axis=0)\n",
    "    u_tilde_u = np.concatenate(agent_unlabeled_losses_hist, axis=0)\n",
    "    n_l = len(u)\n",
    "    n_u = len(u_tilde_u)\n",
    "\n",
    "    if n_l == 0 or n_u == 0:\n",
    "        return 0.0\n",
    "    u_bar = u.mean()\n",
    "    u_tilde_l_bar = u_tilde_l.mean()\n",
    "    u_tilde_u_bar = u_tilde_u.mean()\n",
    "    # 协方差（labeled）\n",
    "    cov = np.mean((u - u_bar) * (u_tilde_l - u_tilde_l_bar))\n",
    "    # 方差（unlabeled）\n",
    "    var = np.mean((u_tilde_u - u_tilde_u_bar) ** 2)\n",
    "    if var < eps:\n",
    "        return 0.0\n",
    "    eta = cov / ((1.0 + n_l / n_u) * var)\n",
    "    eta = np.clip(eta, 0.0, eta_max)\n",
    "    return float(eta)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "domainB_all_model_data = []\n",
    "domainB_all_agent_data = []\n",
    "\n",
    "dataset_list_domainA_dict = {}\n",
    "dataset_list_domainB_dict = {}\n",
    "dataset_list_domainA_dict[\"unsloth/medgemma-4b-it-bnb-4bit\"] = [\"MMLU-high_school_psychology\", \"MMLU-miscellaneous\",\"CommonsenseQA\"]\n",
    "dataset_list_domainB_dict[\"unsloth/medgemma-4b-it-bnb-4bit\"] = [\"MMLU-elementary_mathematics\", \"MMLU-abstract_algebra\", \"MMLU-college_chemistry\", \"MMLU-professional_accounting\",\"CMExam\"]\n",
    "\n",
    "# dataset_list_domainA_dict[\"Qwen/Qwen2-VL-2B-Instruct\"] = [\"MMLU-high_school_psychology\", \"MMLU-miscellaneous\",]\n",
    "# dataset_list_domainB_dict[\"Qwen/Qwen2-VL-2B-Instruct\"] = [\"MMLU-elementary_mathematics\", \"MMLU-abstract_algebra\", \"MMLU-college_chemistry\", \"MMLU-professional_accounting\",\"MMLU-professional_law\"]\n",
    "# dataset_list_domainB_dict[\"Qwen/Qwen2-VL-2B-Instruct\"] = [\"MMLU-professional_psychology\",\"CMExam\"]\n",
    "\n",
    "dataset_groups = [dataset_list_domainA_dict[model_name], dataset_list_domainB_dict[model_name]]\n",
    "domain_model_data_list = []\n",
    "domain_agent_data_list =[]\n",
    "for idx, datasetlist in enumerate(dataset_groups):\n",
    "    domain_all_model_data = []\n",
    "    domain_all_agent_data = []\n",
    "    for dataset_name in datasetlist:\n",
    "        model_data = load_cleaned_data(model_name, dataset_name)\n",
    "        agent_data = load_cleaned_data(agent_model_name, dataset_name)\n",
    "        domain_all_model_data.extend(model_data)\n",
    "        domain_all_agent_data.extend(agent_data)\n",
    "    domain_model_data_list.append(domain_all_model_data)\n",
    "    domain_agent_data_list.append(domain_all_agent_data)\n",
    "    print(\"\\n=== Summary ===\")\n",
    "    print(\"Domain total model_data:\", len(domain_all_model_data))\n",
    "    print(\"Domain total agent_data:\", len(domain_all_model_data))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_EXPERIMENTS = 20                  \n",
    "ALPHA = 0.2                          \n",
    "THRESH = 0.35                      \n",
    "T_OPT_RATIO = 0.2\n",
    "VERBOSE = True\n",
    "ETA_MAX =  1.0\n",
    "L1_proxy = torch.tensor(1.)\n",
    "WINDOW_SIZE = 60\n",
    "eta_fixed = 1.0\n",
    "num_batches = [200,400] \n",
    "labeled_batch_size = 1\n",
    "unlabeled_batch_size = 5\n",
    "set_seed(0)\n",
    "dataset_groups = [dataset_list_domainA_dict[model_name], dataset_list_domainB_dict[model_name]]\n",
    "domain_model_data_list = []\n",
    "domain_agent_data_list =[]\n",
    "\n",
    "for idx, datasetlist in enumerate(dataset_groups):\n",
    "    domain_all_model_data = []\n",
    "    domain_all_agent_data = []\n",
    "    for dataset_name in datasetlist:\n",
    "        model_data = load_cleaned_data(model_name, dataset_name)\n",
    "        agent_data = load_cleaned_data(agent_model_name, dataset_name)\n",
    "        domain_all_model_data.extend(model_data)\n",
    "        domain_all_agent_data.extend(agent_data)\n",
    "    domain_model_data_list.append(domain_all_model_data)\n",
    "    domain_agent_data_list.append(domain_all_agent_data)\n",
    "    print(\"\\n=== Summary ===\")\n",
    "    print(\"Domain total model_data:\", len(domain_all_model_data))\n",
    "    print(\"Domain total agent_data:\", len(domain_all_model_data))\n",
    "\n",
    "methods = [\"CM-EB\"]\n",
    "results = {m: {\"supervised\": [], \"unsupervised\": [], \"ppi\": [], \"ideal_ppi\": [], \"adaptive_ppi\": []} for m in methods}\n",
    "results_box = {m: {\"supervised\": [], \"unsupervised\": [], \"ppi\": [], \"ideal_ppi\": [], \"adaptive_ppi\": []} for m in methods}\n",
    "all_traj_list = []\n",
    "eta_seq_all_exp = [] \n",
    "for exp_id in range(N_EXPERIMENTS):\n",
    "    labeled_batch_acc = []\n",
    "    unlabeled_batch_acc = []\n",
    "    labeled_batch_agent_acc = []\n",
    "    unlabeled_batch_agent_acc = []\n",
    "    eta_seq = []\n",
    "    labeled_losses_hist = []\n",
    "    agent_labeled_losses_hist = []\n",
    "    agent_unlabeled_losses_hist = []\n",
    "    for idx in range(len(domain_model_data_list)):\n",
    "        model_data = domain_model_data_list[idx]\n",
    "        agent_data = domain_agent_data_list[idx]\n",
    "        batches, agent_batches = split_into_batches_list(\n",
    "            model_data, agent_data, batch_size=labeled_batch_size+unlabeled_batch_size)\n",
    "        for b_idx in range(len(batches)):\n",
    "            batch = batches[b_idx]\n",
    "            agent_batch = agent_batches[b_idx]\n",
    "            l_batch = batch[:labeled_batch_size]\n",
    "            u_batch = batch[labeled_batch_size:]\n",
    "            agent_l_batch = agent_batch[:labeled_batch_size]\n",
    "            agent_u_batch = agent_batch[labeled_batch_size:]\n",
    "            true_labels_l = [d[\"answer\"] for d in l_batch]\n",
    "            true_labels_u = [d[\"answer\"] for d in u_batch]\n",
    "            agent_labels_l = [d[\"response\"] for d in agent_l_batch]\n",
    "            agent_labels_u = [d[\"response\"] for d in agent_u_batch]\n",
    "            labeled_batch_acc.append(compute_accuracy(l_batch, true_labels_l))\n",
    "            unlabeled_batch_acc.append(compute_accuracy(u_batch, true_labels_u))\n",
    "            labeled_batch_agent_acc.append(compute_accuracy(l_batch, labels=agent_labels_l))\n",
    "            unlabeled_batch_agent_acc.append(compute_accuracy(u_batch, labels=agent_labels_u))\n",
    "            # ---- 0–1 losses ----\n",
    "            u_t = compute_01_loss_vector(l_batch, true_labels_l)\n",
    "            u_tilde_l_t = compute_01_loss_vector(l_batch, agent_labels_l)\n",
    "            u_tilde_u_t = compute_01_loss_vector(u_batch, agent_labels_u)\n",
    "\n",
    "            # ---- compute eta_t using ONLY history ----\n",
    "            if len(labeled_losses_hist) > WINDOW_SIZE - 5:\n",
    "                eta_t = compute_eta_t( labeled_losses_hist, agent_labeled_losses_hist, agent_unlabeled_losses_hist,ETA_MAX)\n",
    "            else:\n",
    "                eta_t = ETA_MAX / 2   \n",
    "            eta_seq.append(eta_t)\n",
    "            # ---- update history ----\n",
    "            labeled_losses_hist.append(u_t)\n",
    "            agent_labeled_losses_hist.append(u_tilde_l_t)\n",
    "            agent_unlabeled_losses_hist.append(u_tilde_u_t)\n",
    "            # ---- sliding window ----\n",
    "            if len(labeled_losses_hist) > WINDOW_SIZE:\n",
    "                labeled_losses_hist.pop(0)\n",
    "                agent_labeled_losses_hist.pop(0)\n",
    "                agent_unlabeled_losses_hist.pop(0)\n",
    "            if b_idx >= num_batches[idx]:\n",
    "                break\n",
    "            \n",
    "\n",
    "        eta_seq_all_exp.append(eta_seq)\n",
    "        # ---- risk trajectories ----\n",
    "    risk_traj_supervised_true = 1 - np.array(labeled_batch_acc)\n",
    "    risk_traj_supervised_pred = 1 - np.array(labeled_batch_agent_acc)\n",
    "    risk_traj_unsupervised_pred = 1 - np.array(unlabeled_batch_agent_acc)\n",
    "    risk_traj_unsupervised_true = 1 - np.array(unlabeled_batch_acc)\n",
    "\n",
    "    traj_sup_true = risk_traj_supervised_true\n",
    "    traj_ppi_ideal = risk_traj_unsupervised_true\n",
    "    traj_unsup_proxy = risk_traj_unsupervised_true   # Currently unavailable, input a meaningless variable\n",
    "\n",
    "    traj_ppi_pred = eta_fixed * risk_traj_unsupervised_pred + risk_traj_supervised_true - eta_fixed * risk_traj_supervised_pred\n",
    "    traj_ppi_pred_adaptive = eta_seq * risk_traj_unsupervised_pred + risk_traj_supervised_true - eta_seq * risk_traj_supervised_pred\n",
    "\n",
    "    all_traj_list.append(traj_ppi_ideal)\n",
    "    if VERBOSE:\n",
    "        print(f\"=== Experiment {exp_id+1}/{N_EXPERIMENTS} ===\")\n",
    "    try:\n",
    "        sup_bounds, unsup_bounds = compute_bounds_cmeb(traj_sup_true, traj_unsup_proxy, L1_proxy, alpha=ALPHA, t_opt_ratio=T_OPT_RATIO)\n",
    "        ppi_bounds =               compute_bounds_cmeb_ppi(traj_ppi_pred, alpha=ALPHA, t_opt_ratio=T_OPT_RATIO, eta_max=eta_fixed)\n",
    "        adaptive_ppi_bounds =      compute_bounds_cmeb_ppi(traj_ppi_pred_adaptive,  alpha=ALPHA, t_opt_ratio=T_OPT_RATIO, eta_max=ETA_MAX)\n",
    "        ideal_ppi_bounds =         compute_bounds_cmeb_ppi(traj_ppi_ideal,  alpha=ALPHA, t_opt_ratio=T_OPT_RATIO, eta_max=eta_fixed)\n",
    "    except Exception as e:\n",
    "        print(f\"[Warning] Skipping exp {exp_id}: compute_bounds() failed — {e}\")\n",
    "        continue\n",
    "    for m in methods:\n",
    "        results[m][\"supervised\"].append(sup_bounds[m])\n",
    "        results[m][\"unsupervised\"].append(unsup_bounds[m])\n",
    "        results[m][\"ppi\"].append(ppi_bounds[m])\n",
    "        results[m][\"ideal_ppi\"].append(ideal_ppi_bounds[m])\n",
    "        results[m][\"adaptive_ppi\"].append(adaptive_ppi_bounds[m])\n",
    "    for m in methods:\n",
    "        sup_arr = np.asarray(sup_bounds.get(m, []))\n",
    "        unsup_arr = np.asarray(unsup_bounds.get(m, []))\n",
    "        adaptive_ppi_arr = np.asarray(adaptive_ppi_bounds.get(m, []))\n",
    "        ppi_arr = np.asarray(ppi_bounds.get(m, []))\n",
    "        ideal_ppi_arr = np.asarray(ideal_ppi_bounds.get(m, []))\n",
    "\n",
    "        assert len(sup_arr) == len(ppi_arr) == len(ideal_ppi_arr) == len(unsup_arr) == len(traj_sup_true)\n",
    "\n",
    "        # find threshold crossings\n",
    "        t_sup = int(np.argmax(sup_arr > THRESH) + 1) if np.any(sup_arr > THRESH) else np.nan\n",
    "        t_unsup = int(np.argmax(unsup_arr > THRESH) + 1) if np.any(unsup_arr > THRESH) else np.nan\n",
    "        t_ppi = int(np.argmax(ppi_arr > THRESH) + 1) if np.any(ppi_arr > THRESH) else np.nan\n",
    "        t_adaptive = int(np.argmax(adaptive_ppi_arr > THRESH) + 1) if np.any(adaptive_ppi_arr > THRESH) else np.nan\n",
    "        t_ideal = int(np.argmax(ideal_ppi_arr > THRESH) + 1) if np.any(ideal_ppi_arr > THRESH) else np.nan\n",
    "\n",
    "        results_box[m][\"supervised\"].append(t_sup)\n",
    "        results_box[m][\"unsupervised\"].append(t_unsup)\n",
    "        results_box[m][\"ppi\"].append(t_ppi)\n",
    "        results_box[m][\"adaptive_ppi\"].append(t_adaptive)\n",
    "        results_box[m][\"ideal_ppi\"].append(t_ideal)\n",
    "\n",
    "# ========== risk upper bound ==========\n",
    "colors = {\"supervised\": \"#1f77b4\",  \"ppi\": \"#ff7f0e\",  \"adaptive_ppi\": \"#2ca02c\", \"ideal_ppi\": \"#9467bd\" }\n",
    "labels = {\"supervised\": \"SRM\", \"adaptive_ppi\": r\"PPRM\", \"ppi\": \"PPRM\", \"ideal_ppi\": \"Ideal PPRM\"}\n",
    "fig, ax = plt.subplots(figsize=(6, 6.5))\n",
    "m = methods[0]\n",
    "if True:\n",
    "    for key in [\"ideal_ppi\",  \"adaptive_ppi\", \"supervised\"]:\n",
    "        arrs = np.array(results[m][key])\n",
    "        if len(arrs) == 0:\n",
    "            continue\n",
    "        mean_curve = np.nanmean(arrs, axis=0)\n",
    "        std_curve = np.nanstd(arrs, axis=0)\n",
    "        steps = np.arange(len(mean_curve))\n",
    "\n",
    "        ax.plot(steps, mean_curve, label=labels[key], color=colors[key])\n",
    "        ax.fill_between(\n",
    "            steps,\n",
    "            mean_curve - std_curve,\n",
    "            mean_curve + std_curve,\n",
    "            color=colors[key],\n",
    "            alpha=0.2\n",
    "        )\n",
    "    try:\n",
    "        traj_ppi_ideal_all = np.array([\n",
    "            running_average_cumulative(traj)\n",
    "            for traj in all_traj_list ])\n",
    "        traj_ppi_ideal_mean = np.nanmean(traj_ppi_ideal_all, axis=0)\n",
    "        traj_ppi_ideal_std = np.nanstd(traj_ppi_ideal_all, axis=0)\n",
    "        steps = np.arange(len(traj_ppi_ideal_mean))\n",
    "        ax.plot(\n",
    "            steps,\n",
    "            traj_ppi_ideal_mean,\n",
    "            linestyle=\"--\",\n",
    "            color=\"#9467bd\",\n",
    "            linewidth=2,\n",
    "            label=\"Running Risk \"\n",
    "        )\n",
    "        ax.fill_between(\n",
    "            steps,\n",
    "            traj_ppi_ideal_mean - traj_ppi_ideal_std,\n",
    "            traj_ppi_ideal_mean + traj_ppi_ideal_std,\n",
    "            color=\"#9467bd\",\n",
    "            alpha=0.15\n",
    "        )\n",
    "    except Exception as e:\n",
    "        print(f\"[Warning] Failed to plot averaged PPI_Ideal risk trajectory: {e}\")\n",
    "\n",
    "    font_size = 22.5\n",
    "    ax.axhline(y=THRESH, color=\"red\", linestyle=\"--\", linewidth=2, label=f\"Risk Threshold\")\n",
    "    ax.set_xlabel(r\"Time Step $t$\", fontsize=font_size)\n",
    "    ax.grid(True, linestyle=\":\", linewidth=2, alpha=0.4)\n",
    "    ax.set_ylabel(\"Running risk lower bound\", fontsize=font_size+1)\n",
    "    ax.legend(fontsize=font_size, loc='upper left', ncol=1)\n",
    "    ax.tick_params(axis='both', labelsize=font_size+1)  \n",
    "    plt.ylim(0.25,0.51)\n",
    "    plt.tight_layout()\n",
    "    save_dir = \"Simulations/Results_LLM\"\n",
    "    os.makedirs(save_dir, exist_ok=True) \n",
    "    save_path_pdf = os.path.join(save_dir, f\"sim_fig_llm_Medgemma_increase_shift_lowerbound.pdf\")\n",
    "    plt.savefig(save_path_pdf, format='pdf', bbox_inches='tight')\n",
    "    plt.show()\n",
    "\n",
    "if True:\n",
    "    font_size = 22.5\n",
    "    plot_params = {\n",
    "        \"title_fontsize\": font_size,\n",
    "        \"xlabel_fontsize\": font_size,\n",
    "        \"ylabel_fontsize\": font_size,\n",
    "        \"xtick_fontsize\": font_size,\n",
    "        \"ytick_fontsize\": font_size,\n",
    "        \"legend_fontsize\": font_size,\n",
    "        \"suptitle_fontsize\": font_size,\n",
    "        \"title_fontweight\": \"bold\",\n",
    "        \"label_fontweight\": \"normal\"\n",
    "    }\n",
    "\n",
    "    colors = [ \"#1f77b4\",\"#2ca02c\",  \"#9467bd\" ]\n",
    "\n",
    "    title_map = {\"srm\": \"SRM\", \"pprm\": \"PPRM\", \"Adaptive pprm\": \"PPRM\", \"ideal pprm\": \"Ideal PPRM\", \"urm\": \"URM\"}\n",
    "    fig, ax = plt.subplots(figsize=(6, 6.5))\n",
    "    m = methods[0]\n",
    "    data_for_box = []\n",
    "    labels_for_box = []\n",
    "\n",
    "    # supervised\n",
    "    sup_list = np.array(results_box[m][\"supervised\"], dtype=float)\n",
    "    sup_list = sup_list[~np.isnan(sup_list)]\n",
    "    if sup_list.size == 0:\n",
    "        sup_list = np.array([np.nan])\n",
    "    data_for_box.append(sup_list)\n",
    "    labels_for_box.append(\"SRM\")\n",
    "\n",
    "    # adaptive ppi\n",
    "    adaptive_ppi_list = np.array(results_box[m][\"adaptive_ppi\"], dtype=float)\n",
    "    adaptive_ppi_list = adaptive_ppi_list[~np.isnan(adaptive_ppi_list)]\n",
    "    if adaptive_ppi_list.size == 0:\n",
    "        adaptive_ppi_list = np.array([np.nan])\n",
    "    data_for_box.append(adaptive_ppi_list)\n",
    "    labels_for_box.append(\"PPRM\")\n",
    "\n",
    "    # ideal ppi\n",
    "    ideal_list = np.array(results_box[m][\"ideal_ppi\"], dtype=float)\n",
    "    ideal_list = ideal_list[~np.isnan(ideal_list)]\n",
    "    if ideal_list.size == 0:\n",
    "        ideal_list = np.array([np.nan])\n",
    "    data_for_box.append(ideal_list)\n",
    "    labels_for_box.append(\"Ideal PPRM\")\n",
    "    bp = ax.boxplot(\n",
    "        data_for_box,\n",
    "        labels=labels_for_box,\n",
    "        showmeans=True,\n",
    "        patch_artist=True,\n",
    "        boxprops=dict(linewidth=2),\n",
    "        whiskerprops=dict(linewidth=2),\n",
    "        capprops=dict(linewidth=2),\n",
    "        medianprops=dict(linewidth=2),\n",
    "        showfliers=False)\n",
    "    \n",
    "    for patch, color in zip(bp['boxes'], colors):\n",
    "        patch.set_facecolor(color)\n",
    "        patch.set_alpha(0.6)\n",
    "\n",
    "    for mean in bp['means']:\n",
    "        mean.set_markerfacecolor(\"red\")\n",
    "        mean.set_markeredgecolor(\"black\")\n",
    "\n",
    "    ax.tick_params(axis=\"x\", labelsize=plot_params[\"xtick_fontsize\"])\n",
    "    ax.tick_params(axis=\"y\", labelsize=plot_params[\"ytick_fontsize\"])\n",
    "    ax.grid(True, linestyle=\":\", linewidth=2, alpha=0.4)\n",
    "    ax.set_ylabel(\"Average time to alarm\",\n",
    "                    fontsize=plot_params[\"ylabel_fontsize\"],\n",
    "                    fontweight=plot_params[\"label_fontweight\"])\n",
    "    plt.tight_layout(rect=[0, 0, 1, 0.96])\n",
    "    save_dir = \"Simulations/Results_LLM\"\n",
    "    os.makedirs(save_dir, exist_ok=True)  \n",
    "    save_path_pdf = os.path.join(save_dir, f\"sim_fig_llm_Medgemma_increase_shift_alarm_time.pdf\")\n",
    "    plt.savefig(save_path_pdf, format='pdf', bbox_inches='tight')\n",
    "    plt.show()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_EXPERIMENTS = 20                  \n",
    "ALPHA = 0.2                          \n",
    "THRESH = 0.35                       \n",
    "T_OPT_RATIO = 0.2\n",
    "VERBOSE = True\n",
    "ETA_MAX =  1.0\n",
    "L1_proxy = torch.tensor(1.)\n",
    "WINDOW_SIZE = 40\n",
    "eta_fixed = 1.0\n",
    "num_batches = [600,600] \n",
    "labeled_batch_size = 1\n",
    "unlabeled_batch_size = 5\n",
    "L1_proxy = torch.tensor(1.)\n",
    "\n",
    "# dataset_list_domainB_dict[\"Qwen/Qwen2-VL-2B-Instruct\"] = [\"MMLU-elementary_mathematics\", \"MMLU-abstract_algebra\", \"MMLU-college_chemistry\", \"MMLU-professional_accounting\",\"MMLU-professional_law\"]\n",
    "dataset_list_domainB_dict[\"Qwen/Qwen2-VL-2B-Instruct\"] = [\"MMLU-professional_psychology\",\"CMExam\"]\n",
    "test_domain = \"domain_B\"\n",
    "dataset_groups = [dataset_list_domainA_dict[model_name]] if test_domain =='domain_A'  else  [dataset_list_domainB_dict[model_name]]\n",
    "\n",
    "domain_model_data_list = []\n",
    "domain_agent_data_list =[]\n",
    "for idx, datasetlist in enumerate(dataset_groups):\n",
    "    domain_all_model_data = []\n",
    "    domain_all_agent_data = []\n",
    "    for dataset_name in datasetlist:\n",
    "        model_data = load_cleaned_data(model_name, dataset_name)\n",
    "        agent_data = load_cleaned_data(agent_model_name, dataset_name)\n",
    "        domain_all_model_data.extend(model_data)\n",
    "        domain_all_agent_data.extend(agent_data)\n",
    "    domain_model_data_list.append(domain_all_model_data)\n",
    "    domain_agent_data_list.append(domain_all_agent_data)\n",
    "    print(\"\\n=== Summary ===\")\n",
    "    print(\"Domain total model_data:\", len(domain_all_model_data))\n",
    "    print(\"Domain total agent_data:\", len(domain_all_model_data))\n",
    "\n",
    "\n",
    "methods = [\"PM-EB\", \"Betting\"]\n",
    "results = {\n",
    "    m: {\"supervised\": [],\n",
    "        \"unsupervised\": [],\n",
    "        \"ppi\": [],\n",
    "        \"adaptive_ppi\": [],\n",
    "        \"ideal_ppi\": []}\n",
    "    for m in methods}\n",
    "all_traj_list = []\n",
    "\n",
    "eta_seq_all_exp = [] \n",
    "for exp_id in range(N_EXPERIMENTS):\n",
    "    labeled_batch_acc = []\n",
    "    unlabeled_batch_acc = []\n",
    "    labeled_batch_agent_acc = []\n",
    "    unlabeled_batch_agent_acc = []\n",
    "    eta_seq = []\n",
    "    labeled_losses_hist = []\n",
    "    agent_labeled_losses_hist = []\n",
    "    agent_unlabeled_losses_hist = []\n",
    "    for idx in range(len(domain_model_data_list)):\n",
    "        model_data = domain_model_data_list[idx]\n",
    "        agent_data = domain_agent_data_list[idx]\n",
    "        batches, agent_batches = split_into_batches_list(\n",
    "            model_data, agent_data, batch_size=labeled_batch_size+unlabeled_batch_size)\n",
    "        for b_idx in range(len(batches)):\n",
    "            batch = batches[b_idx]\n",
    "            agent_batch = agent_batches[b_idx]\n",
    "            l_batch = batch[:labeled_batch_size]\n",
    "            u_batch = batch[labeled_batch_size:]\n",
    "            agent_l_batch = agent_batch[:labeled_batch_size]\n",
    "            agent_u_batch = agent_batch[labeled_batch_size:]\n",
    "            true_labels_l = [d[\"answer\"] for d in l_batch]\n",
    "            true_labels_u = [d[\"answer\"] for d in u_batch]\n",
    "            agent_labels_l = [d[\"response\"] for d in agent_l_batch]\n",
    "            agent_labels_u = [d[\"response\"] for d in agent_u_batch]\n",
    "            labeled_batch_acc.append(compute_accuracy(l_batch, true_labels_l))\n",
    "            unlabeled_batch_acc.append(compute_accuracy(u_batch, true_labels_u))\n",
    "            labeled_batch_agent_acc.append(compute_accuracy(l_batch, labels=agent_labels_l))\n",
    "            unlabeled_batch_agent_acc.append(compute_accuracy(u_batch, labels=agent_labels_u))\n",
    "            # ---- 0–1 losses ----\n",
    "            u_t = compute_01_loss_vector(l_batch, true_labels_l)\n",
    "            u_tilde_l_t = compute_01_loss_vector(l_batch, agent_labels_l)\n",
    "            u_tilde_u_t = compute_01_loss_vector(u_batch, agent_labels_u)\n",
    "\n",
    "            # ---- compute eta_t using ONLY history ----\n",
    "            if len(labeled_losses_hist) > WINDOW_SIZE - 30:\n",
    "                eta_t = compute_eta_t( labeled_losses_hist, agent_labeled_losses_hist, agent_unlabeled_losses_hist,ETA_MAX)\n",
    "            else:\n",
    "                eta_t = ETA_MAX / 2   \n",
    "            eta_seq.append(eta_t)\n",
    "            # ---- update history ----\n",
    "            labeled_losses_hist.append(u_t)\n",
    "            agent_labeled_losses_hist.append(u_tilde_l_t)\n",
    "            agent_unlabeled_losses_hist.append(u_tilde_u_t)\n",
    "            # ---- sliding window ----\n",
    "            if len(labeled_losses_hist) > WINDOW_SIZE:\n",
    "                labeled_losses_hist.pop(0)\n",
    "                agent_labeled_losses_hist.pop(0)\n",
    "                agent_unlabeled_losses_hist.pop(0)\n",
    "            if b_idx >= num_batches[idx]:\n",
    "                break\n",
    "\n",
    "        eta_seq_all_exp.append(eta_seq)\n",
    "\n",
    "    risk_traj_supervised_true = 1 - np.array(labeled_batch_acc)\n",
    "    risk_traj_supervised_pred = 1 - np.array(labeled_batch_agent_acc)\n",
    "    risk_traj_unsupervised_pred = 1 - np.array(unlabeled_batch_agent_acc)\n",
    "    risk_traj_unsupervised_true = 1 - np.array(unlabeled_batch_acc)\n",
    "\n",
    "    traj_sup_true = risk_traj_supervised_true\n",
    "    traj_ppi_ideal = risk_traj_unsupervised_true\n",
    "    traj_unsup_proxy = risk_traj_unsupervised_true   # Currently unavailable, input a meaningless variable\n",
    "\n",
    "    traj_ppi_pred = eta_fixed * risk_traj_unsupervised_pred + risk_traj_supervised_true - eta_fixed * risk_traj_supervised_pred\n",
    "    traj_ppi_pred_adaptive = eta_seq * risk_traj_unsupervised_pred + risk_traj_supervised_true - eta_seq * risk_traj_supervised_pred\n",
    "\n",
    "    all_traj_list.append(traj_ppi_ideal)\n",
    "    if VERBOSE:\n",
    "        print(f\"=== Experiment {exp_id+1}/{N_EXPERIMENTS} ===\")\n",
    "    try:\n",
    "        sup_bounds, unsup_bounds = compute_bounds_pmeb_betting(traj_sup_true, traj_unsup_proxy, L1_proxy, alpha=ALPHA)\n",
    "        ppi_bounds =               compute_bounds_pmeb_betting_ppi(traj_ppi_pred, alpha=ALPHA, eta_max=eta_fixed)\n",
    "        adaptive_ppi_bounds =      compute_bounds_pmeb_betting_ppi(traj_ppi_pred_adaptive,  alpha=ALPHA, eta_max=ETA_MAX)\n",
    "        ideal_ppi_bounds =         compute_bounds_pmeb_betting_ppi(traj_ppi_ideal,  alpha=ALPHA,eta_max=eta_fixed)\n",
    "    except Exception as e:\n",
    "        print(f\"[Warning] Skipping exp {exp_id}: compute_bounds() failed — {e}\")\n",
    "        continue\n",
    "    for m in methods:\n",
    "        results[m][\"supervised\"].append(sup_bounds[m])\n",
    "        results[m][\"unsupervised\"].append(unsup_bounds[m])\n",
    "        results[m][\"ppi\"].append(ppi_bounds[m])\n",
    "        results[m][\"ideal_ppi\"].append(ideal_ppi_bounds[m])\n",
    "        results[m][\"adaptive_ppi\"].append(adaptive_ppi_bounds[m])\n",
    "\n",
    "colors = {\"supervised\": \"#1f77b4\",  \"ppi\": \"#ff7f0e\",  \"adaptive_ppi\": \"#2ca02c\", \"ideal_ppi\": \"#9467bd\" }\n",
    "labels = {\"supervised\": \"SRM\", \"adaptive_ppi\": r\"PPRM\", \"ppi\": \"PPRM\", \"ideal_ppi\": \"Ideal PPRM\"}\n",
    "fig, axes = plt.subplots(1, 2, figsize=(11, 5), sharey=True)\n",
    "if True:\n",
    "    for ax, m in zip(axes, methods):\n",
    "        for key in [\"ideal_ppi\",  \"adaptive_ppi\", \"supervised\"]:\n",
    "            arrs = np.array(results[m][key])\n",
    "            if len(arrs) == 0:\n",
    "                continue\n",
    "            mean_curve = np.nanmean(arrs, axis=0)\n",
    "            std_curve = np.nanstd(arrs, axis=0)\n",
    "            steps = np.arange(len(mean_curve))\n",
    "\n",
    "            ax.plot(steps, mean_curve, label=labels[key], color=colors[key])\n",
    "            ax.fill_between(\n",
    "                steps,\n",
    "                mean_curve - std_curve,\n",
    "                mean_curve + std_curve,\n",
    "                color=colors[key],\n",
    "                alpha=0.2\n",
    "            )\n",
    "        try:\n",
    "            traj_ppi_ideal_all = np.array([\n",
    "                running_average_cumulative(traj)\n",
    "                for traj in all_traj_list \n",
    "            ])\n",
    "            traj_ppi_ideal_mean = np.nanmean(traj_ppi_ideal_all, axis=0)\n",
    "            traj_ppi_ideal_std = np.nanstd(traj_ppi_ideal_all, axis=0)\n",
    "            steps = np.arange(len(traj_ppi_ideal_mean))\n",
    "            ax.plot(\n",
    "                steps,\n",
    "                traj_ppi_ideal_mean,\n",
    "                linestyle=\"--\",\n",
    "                color=\"#9467bd\",\n",
    "                linewidth=2,\n",
    "                label=\"Running Risk \"\n",
    "            )\n",
    "            ax.fill_between(\n",
    "                steps,\n",
    "                traj_ppi_ideal_mean - traj_ppi_ideal_std,\n",
    "                traj_ppi_ideal_mean + traj_ppi_ideal_std,\n",
    "                color=\"#9467bd\",\n",
    "                alpha=0.15\n",
    "            )\n",
    "\n",
    "        except Exception as e:\n",
    "            print(f\"[Warning] Failed to plot averaged PPI_Ideal risk trajectory: {e}\")\n",
    "        ax.axhline(y=THRESH, color=\"red\", linestyle=\"--\", linewidth=2, label=f\"Risk Threshold\")\n",
    "        font_size = 18\n",
    "        ax.set_xlabel(r\"Time Step $t$\", fontsize=font_size)\n",
    "        ax.grid(True, linestyle=\":\")\n",
    "        ax.set_ylabel(\"Running risk lower bound\", fontsize=font_size)\n",
    "        ax.legend(fontsize=font_size-1, loc='upper left')\n",
    "        ax.tick_params(axis='both', labelsize=font_size-1) \n",
    "\n",
    "plt.ylim(0.2,0.75)\n",
    "plt.tight_layout()\n",
    "save_dir = \"Simulations/Results_LLM\"\n",
    "os.makedirs(save_dir, exist_ok=True) \n",
    "\n",
    "save_path_pdf = os.path.join(save_dir, f\"sim_fig_Medgemma_llm_domain{test_domain}_lowerbound.pdf\")\n",
    "plt.savefig(save_path_pdf, format='pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "agent_model_name_list = [\n",
    "    \"Qwen/Qwen2.5-VL-3B-Instruct\",\n",
    "    \"Qwen/Qwen2-VL-7B-Instruct\",\n",
    "    \"gpt-4.1-nano\",\n",
    "    # \"Qwen/Qwen2.5-VL-32B-Instruct\",\n",
    "    \"gpt-4.1\",\n",
    "    # \"gpt-5-mini\"\n",
    "]\n",
    "agent_model_short_name = {\n",
    "    \"Qwen/Qwen2-VL-7B-Instruct\": \"Q2-VL7B\",\n",
    "    \"Qwen/Qwen2.5-VL-3B-Instruct\": \"Q2.5-VL3B\",\n",
    "    \"Qwen/Qwen2.5-VL-32B-Instruct\": \"Q2.5-VL32B\",\n",
    "    \"gpt-4.1\": \"GPT4.1\",\n",
    "    \"gpt-4.1-nano\": \"GPT4.1-n\",\n",
    "    # \"gpt-5-mini\": \"GPT5M\"\n",
    "}\n",
    "dataset_list_domainA_dict[\"unsloth/medgemma-4b-it-bnb-4bit\"] = [\"MMLU-high_school_psychology\", \"MMLU-miscellaneous\"]\n",
    "dataset_list_domainB_dict[\"unsloth/medgemma-4b-it-bnb-4bit\"] = [\"MMLU-elementary_mathematics\", \"MMLU-abstract_algebra\", \"MMLU-college_chemistry\", \"MMLU-professional_accounting\",\"CMExam\"]\n",
    "\n",
    "\n",
    "dataset_groups = [\n",
    "    dataset_list_domainA_dict[model_name],\n",
    "    dataset_list_domainB_dict[model_name]\n",
    "]\n",
    "\n",
    "domain_model_data_list = []\n",
    "\n",
    "for idx, datasetlist in enumerate(dataset_groups):\n",
    "    domain_all_model_data = []\n",
    "    for dataset_name in datasetlist:\n",
    "        model_data = load_cleaned_data(model_name, dataset_name)\n",
    "        domain_all_model_data.extend(model_data)\n",
    "\n",
    "    domain_model_data_list.append(domain_all_model_data)\n",
    "\n",
    "    print(f\"\\n=== Summary (Model, Domain {idx}) ===\")\n",
    "    print(\"Domain total model_data:\", len(domain_all_model_data))\n",
    "\n",
    "domain_agent_data_list = {}\n",
    "\n",
    "\n",
    "for agent_model_name in agent_model_name_list:\n",
    "    domain_agent_data_list[agent_model_name] = []\n",
    "    for idx, datasetlist in enumerate(dataset_groups):\n",
    "        domain_all_agent_data = []\n",
    "        for dataset_name in datasetlist:\n",
    "            agent_data = load_cleaned_data(agent_model_name, dataset_name)\n",
    "            domain_all_agent_data.extend(agent_data)\n",
    "        domain_agent_data_list[agent_model_name].append(domain_all_agent_data)\n",
    "        print(\"\\n=== Summary ===\")\n",
    "        print(\"Domain total agent_data:\", len(domain_agent_data_list[agent_model_name][idx]))\n",
    "    \n",
    "def split_into_batches_list_dict(model_data, agent_data_dict, idx,  batch_size):\n",
    "    n = len(model_data)\n",
    "    n = n // batch_size * batch_size\n",
    "    indices = list(range(n))\n",
    "    random.shuffle(indices)\n",
    "    model_data_shuf = [model_data[i] for i in indices]\n",
    "    batches = [\n",
    "        model_data_shuf[i : i + batch_size]\n",
    "        for i in range(0, n, batch_size)\n",
    "    ]\n",
    "    agent_batches_dict = {}\n",
    "    for agent_model_name, data_seq in agent_data_dict.items():\n",
    "        # print(data_seq)\n",
    "        agent_data_shuf = [data_seq[idx][i] for i in indices]\n",
    "        agent_batches_dict[agent_model_name] = [\n",
    "            agent_data_shuf[i : i + batch_size]\n",
    "            for i in range(0, n, batch_size)]\n",
    "    return batches, agent_batches_dict\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_EXPERIMENTS = 30                  \n",
    "ALPHA = 0.2                          \n",
    "THRESH = 0.35                      \n",
    "T_OPT_RATIO = 0.2\n",
    "VERBOSE = True\n",
    "ETA_MAX =  1.0\n",
    "L1_proxy = torch.tensor(1.)\n",
    "WINDOW_SIZE = 60\n",
    "eta_fixed = 1.0\n",
    "num_batches = [200,400] \n",
    "labeled_batch_size = 1\n",
    "unlabeled_batch_size = 5\n",
    "set_seed(0)\n",
    "methods = [\"CM-EB\"]\n",
    "results = {m: {\n",
    "        \"supervised\": [],            \n",
    "        \"ppi\": {agent: [] for agent in agent_model_name_list}, \n",
    "        \"ideal_ppi\": []           }\n",
    "    for m in methods}\n",
    "\n",
    "for exp_id in range(N_EXPERIMENTS):\n",
    "    labeled_batch_acc = []\n",
    "    unlabeled_batch_acc = []\n",
    "\n",
    "    labeled_batch_agent_acc = {}\n",
    "    unlabeled_batch_agent_acc = {}\n",
    "\n",
    "    labeled_losses_hist = {}\n",
    "    agent_labeled_losses_hist = {}\n",
    "    agent_unlabeled_losses_hist = {}\n",
    "    eta_seq = {}\n",
    "    for agent_model_name in agent_model_name_list:\n",
    "        labeled_batch_agent_acc[agent_model_name] = []\n",
    "        unlabeled_batch_agent_acc[agent_model_name] = []\n",
    "\n",
    "        labeled_losses_hist[agent_model_name] = []\n",
    "        agent_labeled_losses_hist[agent_model_name] = []\n",
    "        agent_unlabeled_losses_hist[agent_model_name] = []\n",
    "        eta_seq[agent_model_name] = []\n",
    "\n",
    "\n",
    "    for idx in range(len(domain_model_data_list)):\n",
    "        model_data = domain_model_data_list[idx]\n",
    "        batches, agent_batches_list = split_into_batches_list_dict(\n",
    "            model_data, domain_agent_data_list, idx, batch_size=labeled_batch_size+unlabeled_batch_size\n",
    "        )\n",
    "        for b_idx in range(len(batches)):\n",
    "            l_batch = batch[:labeled_batch_size]\n",
    "            u_batch = batch[labeled_batch_size:]\n",
    "            true_labels_l = [d[\"answer\"] for d in l_batch]\n",
    "            true_labels_u = [d[\"answer\"] for d in u_batch]\n",
    "            labeled_batch_acc.append(compute_accuracy(l_batch, true_labels_l))\n",
    "            unlabeled_batch_acc.append(compute_accuracy(u_batch, true_labels_u))\n",
    "\n",
    "            \n",
    "            for agent_model_name in agent_model_name_list:\n",
    "                agent_batch = agent_batches_list[agent_model_name][b_idx]\n",
    "                agent_l_batch = agent_batch[:labeled_batch_size]\n",
    "                agent_u_batch = agent_batch[labeled_batch_size:]\n",
    "                agent_labels_l = [d[\"response\"] for d in agent_l_batch]\n",
    "                agent_labels_u = [d[\"response\"] for d in agent_u_batch]\n",
    "                labeled_batch_agent_acc[agent_model_name].append(compute_accuracy(l_batch, labels=agent_labels_l))\n",
    "                unlabeled_batch_agent_acc[agent_model_name].append(compute_accuracy(u_batch, labels=agent_labels_u))\n",
    "\n",
    "                u_t = compute_01_loss_vector(l_batch, true_labels_l)\n",
    "                u_tilde_l_t = compute_01_loss_vector(l_batch, agent_labels_l)\n",
    "                u_tilde_u_t = compute_01_loss_vector(u_batch, agent_labels_u)\n",
    "            \n",
    "                # ---- compute eta_t using ONLY history ----\n",
    "                if len(labeled_losses_hist[agent_model_name]) > WINDOW_SIZE - 10:\n",
    "                    eta_t = compute_eta_t( labeled_losses_hist[agent_model_name], agent_labeled_losses_hist[agent_model_name], agent_unlabeled_losses_hist[agent_model_name],ETA_MAX)\n",
    "                else:\n",
    "                    eta_t = ETA_MAX / 2   \n",
    "                eta_seq[agent_model_name].append(eta_t)\n",
    "                # ---- update history ----\n",
    "                labeled_losses_hist[agent_model_name].append(u_t)\n",
    "                agent_labeled_losses_hist[agent_model_name].append(u_tilde_l_t)\n",
    "                agent_unlabeled_losses_hist[agent_model_name].append(u_tilde_u_t)\n",
    "                # ---- sliding window ----\n",
    "                if len(labeled_losses_hist[agent_model_name]) > WINDOW_SIZE:\n",
    "                    labeled_losses_hist[agent_model_name].pop(0)\n",
    "                    agent_labeled_losses_hist[agent_model_name].pop(0)\n",
    "                    agent_unlabeled_losses_hist[agent_model_name].pop(0)\n",
    "            if b_idx >= num_batches[idx]:\n",
    "                break\n",
    "\n",
    "    labeled_batch_err = 1 - np.array(labeled_batch_acc)\n",
    "    unlabeled_batch_err = 1 - np.array(unlabeled_batch_acc)\n",
    "    risk_traj_supervised_true = labeled_batch_err\n",
    "    risk_traj_unsupervised_true = unlabeled_batch_err\n",
    "    traj_sup_true = risk_traj_supervised_true\n",
    "    traj_ppi_ideal = risk_traj_unsupervised_true\n",
    "    traj_unsup_proxy = risk_traj_unsupervised_true \n",
    "\n",
    "    traj_ppi_pred = {}\n",
    "    for agent_model_name in agent_model_name_list:\n",
    "        labeled_batch_agent_err = 1 - np.array(labeled_batch_agent_acc[agent_model_name])\n",
    "        unlabeled_batch_agent_err = 1 - np.array(unlabeled_batch_agent_acc[agent_model_name])\n",
    "        risk_traj_supervised_pred = labeled_batch_agent_err\n",
    "        risk_traj_unsupervised_pred = unlabeled_batch_agent_err\n",
    "        \n",
    "        traj_ppi_pred[agent_model_name] = eta_seq[agent_model_name] * risk_traj_unsupervised_pred + risk_traj_supervised_true - eta_seq[agent_model_name] * risk_traj_supervised_pred\n",
    "\n",
    "\n",
    "    if VERBOSE:\n",
    "        print(f\"=== Experiment {exp_id+1}/{N_EXPERIMENTS} ===\")\n",
    "    try:\n",
    "        # 只在这里捕获失败\n",
    "        sup_bounds, unsup_bounds = compute_bounds_cmeb(traj_sup_true, traj_unsup_proxy, L1_proxy, alpha=ALPHA, t_opt_ratio=T_OPT_RATIO)\n",
    "        ideal_ppi_bounds =         compute_bounds_cmeb_ppi(traj_ppi_ideal,  alpha=ALPHA, t_opt_ratio=T_OPT_RATIO, eta_max=eta_fixed)\n",
    "        ppi_bounds_dict = {}\n",
    "        for agent_model_name in agent_model_name_list:\n",
    "            ppi_bounds_dict[agent_model_name] = compute_bounds_cmeb_ppi(traj_ppi_pred[agent_model_name], alpha=ALPHA, t_opt_ratio=T_OPT_RATIO, eta_max=eta_fixed)\n",
    "\n",
    "    except Exception as e:\n",
    "        print(f\"[Warning] Skipping exp {exp_id}: compute_bounds() failed — {e}\")\n",
    "        traceback.print_exc()\n",
    "        continue \n",
    "    for m in methods:\n",
    "        sup_arr = np.asarray(sup_bounds.get(m, []))\n",
    "        unsup_arr = np.asarray(unsup_bounds.get(m, []))\n",
    "        ppi_arr_dict = {}\n",
    "        for agent_model_name in agent_model_name_list:\n",
    "            ppi_arr_dict[agent_model_name] = np.asarray(ppi_bounds_dict[agent_model_name].get(m, []))\n",
    "        ideal_ppi_arr = np.asarray(ideal_ppi_bounds.get(m, []))\n",
    "        assert len(sup_arr) == len(ppi_arr) == len(ideal_ppi_arr) == len(unsup_arr) == len(traj_sup_true)\n",
    "        t_sup = int(np.argmax(sup_arr > THRESH) + 1) if np.any(sup_arr > THRESH) else np.nan\n",
    "        t_ppi_dict = {}\n",
    "        for agent_model_name in agent_model_name_list:\n",
    "            t_ppi = int(np.argmax(ppi_arr_dict[agent_model_name] > THRESH) + 1) if np.any(ppi_arr_dict[agent_model_name] > THRESH) else np.nan\n",
    "            t_ppi_dict[agent_model_name] = t_ppi\n",
    "        t_ideal = int(np.argmax(ideal_ppi_arr > THRESH) + 1) if np.any(ideal_ppi_arr > THRESH) else np.nan\n",
    "        results[m][\"supervised\"].append(t_sup)\n",
    "        for agent_model_name in agent_model_name_list:\n",
    "            results[m][\"ppi\"][agent_model_name].append(t_ppi_dict[agent_model_name])\n",
    "        results[m][\"ideal_ppi\"].append(t_ideal)\n",
    "\n",
    "if True:\n",
    "    font_size = 17\n",
    "    plot_params = {\n",
    "        \"title_fontsize\": font_size,\n",
    "        \"xlabel_fontsize\": font_size,\n",
    "        \"ylabel_fontsize\": font_size,\n",
    "        \"xtick_fontsize\": font_size,\n",
    "        \"ytick_fontsize\": font_size,\n",
    "        \"legend_fontsize\": font_size,\n",
    "        \"suptitle_fontsize\": font_size,\n",
    "        \"title_fontweight\": \"bold\",\n",
    "        \"label_fontweight\": \"normal\"\n",
    "    }\n",
    "\n",
    "    colors = [\"#1f77b4\", \"#ff7f0e\", \"#2ca02c\", \"#d62728\", \n",
    "          \"#9467bd\", \"#8c564b\", \"#e377c2\", \"#7f7f7f\"]  \n",
    "\n",
    "    title_map = {\"srm\": \"SRM\", \"pprm\": \"PPRM\", \"ideal pprm\": \"Ideal PPRM\", \"urm\": \"URM\"}\n",
    "    fig, ax = plt.subplots(figsize=(7.5, 5))\n",
    "    m = methods[0]\n",
    "    data_for_box = []\n",
    "    labels_for_box = []\n",
    "    # supervised\n",
    "    sup_list = np.array(results[m][\"supervised\"], dtype=float)\n",
    "    sup_list = sup_list[~np.isnan(sup_list)]\n",
    "    if sup_list.size == 0:\n",
    "        sup_list = np.array([np.nan])\n",
    "    data_for_box.append(sup_list)\n",
    "    labels_for_box.append(\"SRM\")\n",
    "    # ppi\n",
    "    for agent_model_name in agent_model_name_list:\n",
    "        ppi_list = np.array(results[m][\"ppi\"][agent_model_name], dtype=float)\n",
    "        ppi_list = ppi_list[~np.isnan(ppi_list)]\n",
    "        if ppi_list.size == 0:\n",
    "            ppi_list = np.array([np.nan])\n",
    "        data_for_box.append(ppi_list)\n",
    "        labels_for_box.append(f\"{agent_model_short_name[agent_model_name]}\")\n",
    "    # ideal ppi\n",
    "    ideal_list = np.array(results[m][\"ideal_ppi\"], dtype=float)\n",
    "    ideal_list = ideal_list[~np.isnan(ideal_list)]\n",
    "    if ideal_list.size == 0:\n",
    "        ideal_list = np.array([np.nan])\n",
    "    data_for_box.append(ideal_list)\n",
    "    labels_for_box.append(\"Ideal PPRM\")\n",
    "\n",
    "    bp = ax.boxplot(\n",
    "    data_for_box,\n",
    "    labels=labels_for_box,\n",
    "    showmeans=True,\n",
    "    patch_artist=True,\n",
    "    boxprops=dict(linewidth=2),\n",
    "    whiskerprops=dict(linewidth=2),\n",
    "    capprops=dict(linewidth=2),\n",
    "    medianprops=dict(linewidth=2),\n",
    "    widths=0.3,\n",
    "    showfliers=False)\n",
    "    color_cycle = cycle(colors)  \n",
    "    for patch, color in zip(bp['boxes'], color_cycle):\n",
    "        patch.set_facecolor(color)\n",
    "        patch.set_alpha(0.6)\n",
    "\n",
    "    for mean in bp['means']:\n",
    "        mean.set_markerfacecolor(\"red\")\n",
    "        mean.set_markeredgecolor(\"black\")\n",
    "    ax.tick_params(axis=\"x\", labelsize=plot_params[\"xtick_fontsize\"]-1)\n",
    "    ax.tick_params(axis=\"y\", labelsize=plot_params[\"ytick_fontsize\"])\n",
    "    ax.grid(True, linestyle=\":\", linewidth=2, alpha=0.4)\n",
    "\n",
    "    ax.set_ylabel(\"Average time to alarm\",\n",
    "                    fontsize=plot_params[\"ylabel_fontsize\"],\n",
    "                    fontweight=plot_params[\"label_fontweight\"])\n",
    "    plt.tight_layout(rect=[0, 0, 1, 0.96])\n",
    "    save_dir = \"Simulations/Results_LLM\"\n",
    "    os.makedirs(save_dir, exist_ok=True)  \n",
    "plt.setp(ax.get_xticklabels())\n",
    "save_path_pdf = os.path.join(save_dir, f\"sim_fig_llm_Medgemma_increase_variaousmodel_alarm_time.pdf\")\n",
    "plt.savefig(save_path_pdf, format='pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
