{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Multi-label classifier (FT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_data(y_a, n):\n",
    "    hseed, dseed = 0, [0]\n",
    "    data = 'CelebA'\n",
    "    attr = {'ERM': 'Yes', 'GroupDRO': 'Yes', 'OverSample': 'Yes', 'UnderSample': 'Yes', 'ReWeightLogits': 'Yes'}\n",
    "    algorithms = attr.keys()\n",
    "\n",
    "    seed_diff = []\n",
    "    all_res = pd.DataFrame([])\n",
    "    for s in dseed:\n",
    "        for alg in algorithms:\n",
    "            res_dir = f'/exps/div_explore/celeba_v2/celeba_y{y_a[0]}_a{y_a[1]}_ds0_n{n}/{alg.lower()}'\n",
    "            # find all the sub-directory under res_dir\n",
    "            dirs = os.listdir(res_dir)\n",
    "            for d in dirs:\n",
    "                if attr[alg] in d:\n",
    "                    res_file = os.path.join(res_dir, d)\n",
    "                    exps = os.listdir(res_file)# [1]\n",
    "                    for exp in exps:\n",
    "                        res_file = os.path.join(res_dir, d, exp, 'results.json')\n",
    "                        print(res_file)\n",
    "                        if os.path.exists(res_file):\n",
    "                            final_res = get_trajectory_results(res_file, s)\n",
    "                            all_res = pd.concat([all_res, final_res])\n",
    "                    # res_file = os.path.join(res_dir, d, exps, 'results.json')\n",
    "                    # print(res_file)\n",
    "                    # if os.path.exists(res_file):\n",
    "                    #     final_res = get_trajectory_results(res_file, s)\n",
    "                    #     all_res = pd.concat([all_res, final_res])\n",
    "\n",
    "    all_res = all_res[all_res['step']==1000]\n",
    "    all_res[\"wga_te_err\"] = 1 - all_res[\"metric\"]\n",
    "    all_res = all_res[[\"sc\", \"ci\", \"ai\", \"algorithm\", \"wga_te_err\", \"seed\"]]\n",
    "    all_res['n'] = n\n",
    "    all_res['n'] = all_res['n'].astype('str')\n",
    "    # all_res_ori = copy.deepcopy(all_res)\n",
    "    # group by \"sc\", \"ci\", \"ai\", \"algorithm\", \"seed\"] then average over \"wga_te_err\"\n",
    "    all_res = all_res.groupby(['n', 'sc', 'ci', 'ai', 'algorithm']).mean().reset_index()\n",
    "    all_res.drop(columns=[\"seed\"], inplace=True)\n",
    "\n",
    "    all_res['rank'] = all_res.groupby(['n', 'sc', 'ci', 'ai'])['wga_te_err'].rank(\"first\")\n",
    "    # all_res_ori = copy.deepcopy(all_res)\n",
    "    # print(all_res.head(100))\n",
    "\n",
    "    def get_gt_rank(x, mode='tie', filter_thre=0.05):\n",
    "            # print(x)\n",
    "            min_err = x['wga_te_err'].min()\n",
    "            winners = x[x['wga_te_err'] <= min_err + filter_thre][\"algorithm\"].to_list()\n",
    "            return '|'.join(winners)\n",
    "\n",
    "    # Apply the function to each group and reset the index to merge back with the original dataframe\n",
    "    winners_series = all_res.groupby(['n', 'sc', 'ci', 'ai'])[['algorithm', 'wga_te_err']].apply(lambda x: get_gt_rank(x)).reset_index(name='winners')\n",
    "    all_res = all_res.merge(winners_series, on=['n', 'sc', 'ci', 'ai'])\n",
    "\n",
    "    all_res[\"multi_hot\"] = all_res[\"winners\"].map(lambda x: [1 if m in x.split(\"|\") else 0 for m in algorithms])\n",
    "\n",
    "    # deduplicate\n",
    "    all_res = all_res[all_res[\"rank\"]==1.0]\n",
    "\n",
    "    return all_res\n",
    "\n",
    "def get_trajectory_results(file_path, ds, metric='te_worst_acc'):\n",
    "    # List to hold all JSON objects\n",
    "    results = []\n",
    "\n",
    "    # Open the file and read it line by line or as a whole if the objects are not line-delimited\n",
    "    with open(file_path, 'r') as file:\n",
    "        file_contents = file.read()\n",
    "        # Attempt to split the file contents by a delimiter if they're not newline-delimited\n",
    "        # This delimiter needs to be defined based on your specific file structure\n",
    "        json_objects = file_contents.split('}\\n{')  # Example: split on possible JSON end/start\n",
    "\n",
    "        # Correct split parts and parse each as a JSON object\n",
    "        for i, obj in enumerate(json_objects):\n",
    "            try:\n",
    "                # Add missing curly braces if split was done in the middle of objects\n",
    "                if i != 0:\n",
    "                    obj = '{' + obj\n",
    "                if i != len(json_objects) - 1:\n",
    "                    obj += '}'\n",
    "                result = json.loads(obj)\n",
    "                # print(json.dumps(result, indent=4))\n",
    "                results.append(result)\n",
    "            except json.JSONDecodeError as e:\n",
    "                print(f\"Error decoding JSON object: {e}\")\n",
    "\n",
    "    results_df = pd.DataFrame(results)\n",
    "    trajectory = {'sc': [], 'ci': [], 'ai': [], 'algorithm': [], 'step': [], 'metric': [], 'seed': [], 'data_seed': []}\n",
    "\n",
    "    sc = float(results_df.iloc[-1]['args']['metadata'].split('sc')[-1].split('_')[0])\n",
    "    ci = float(results_df.iloc[-1]['args']['metadata'].split('ci')[-1].split('_')[0])\n",
    "    ai = float(results_df.iloc[-1]['args']['metadata'].split('ai')[-1][:4])\n",
    "    # sc = float(results_df.iloc[-1]['args']['cmnist_spur_prob'])\n",
    "    # ci = float(results_df.iloc[-1]['args']['cmnist_label_prob'])\n",
    "    # ai = float(results_df.iloc[-1]['args']['cmnist_attr_prob'])\n",
    "    alg = results_df.iloc[-1]['args']['algorithm']\n",
    "    for row in results_df.iterrows():\n",
    "        # extract the data\n",
    "        trajectory['sc'].append(sc)\n",
    "        trajectory['ci'].append(ci)\n",
    "        trajectory['ai'].append(ai)\n",
    "        trajectory['algorithm'].append(alg)\n",
    "        trajectory['step'].append(row[1]['step'])\n",
    "        trajectory['metric'].append(row[1][metric])\n",
    "        trajectory['seed'].append(row[1]['args']['seed'])\n",
    "        trajectory['data_seed'].append(ds)\n",
    "\n",
    "    # # smooth the 'metric'\n",
    "    # trajectory['metric'] = np.convolve(trajectory['metric'], np.ones(10)/10, mode='valid')\n",
    "\n",
    "    return pd.DataFrame(trajectory)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_res = []\n",
    "# for i, y_a in enumerate([[2, 31], [18, 6], [21, 36], [8, 20], [25, 19]]):\n",
    "# for i, y_a in enumerate([[25,19]]):\n",
    "for i, y_a in enumerate([[2, 31], [21, 36], [8, 20], [25, 19]]):\n",
    "    for n in [200, 500, 1000, 2000, 5000, 10000]:\n",
    "# for i, y_a in enumerate([[2, 31]]):\n",
    "#     for n in [200]:\n",
    "        curr_res = load_data(y_a, n)\n",
    "        curr_res[\"task_id\"] = i\n",
    "        all_res.append(curr_res)\n",
    "all_res = pd.concat(all_res, ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_res['winners'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_data(data, model, path, identifier):\n",
    "    data_path = os.path.join(path, f\"{model}_lp\")\n",
    "    res_files = os.listdir(data_path)\n",
    "    all_res = []\n",
    "    for file in res_files:\n",
    "        res_file = os.path.join(data_path, file)\n",
    "        # print(file)\n",
    "        # parse the seed from the file name\n",
    "        res = pd.read_csv(res_file)\n",
    "        res[\"seed\"] = int(file.split('seed')[-1][0])\n",
    "        if data == \"celeba\":\n",
    "            res[\"y_task\"] = int(file.split('y')[-1].split('_')[0])\n",
    "            res[\"a_task\"] = int(file.split('a')[-1].split('.')[0])\n",
    "        all_res.append(res)\n",
    "    all_res = pd.concat(all_res, ignore_index=True)\n",
    "    all_res.rename(columns={\"method\": \"algorithm\"}, inplace=True)\n",
    "\n",
    "    # average over seeds\n",
    "    all_res = all_res[identifier + [\"algorithm\", \"wga_te_err\", \"seed\"]]\n",
    "    all_res = all_res.groupby(identifier + [\"algorithm\"]).mean().reset_index()\n",
    "    all_res.drop(columns=[\"seed\"], inplace=True)\n",
    "\n",
    "    # get the best performing algorithm for each dataset, used for deduplicate\n",
    "    all_res['rank'] = all_res.groupby(identifier)['wga_te_err'].rank(\"first\")\n",
    "\n",
    "    # get the winners for each dataset\n",
    "    def get_gt_rank(x, filter_thre=0.05):\n",
    "        # print(x)\n",
    "        min_err = x['wga_te_err'].min()\n",
    "        winners = x[x['wga_te_err'] <= min_err + filter_thre][\"algorithm\"].to_list()\n",
    "        return '|'.join(winners)\n",
    "    # apply the function to each group and reset the index to merge back with the original dataframe\n",
    "    winners_series = all_res.groupby(identifier)[['algorithm', 'wga_te_err']].apply(lambda x: get_gt_rank(x)).reset_index(name='winners')\n",
    "    all_res = all_res.merge(winners_series, on=identifier)\n",
    "    # get the multi-hot encoding of the winners\n",
    "    all_res[\"multi_hot\"] = all_res[\"winners\"].map(lambda x: [1 if m in x.split(\"|\") else 0 for m in algorithms])\n",
    "\n",
    "    # get the confidence for each dataset\n",
    "    def get_conf(x, filter_thre=0.05):\n",
    "        min_err = x['wga_te_err'].min()\n",
    "        winners = x[x['wga_te_err'] <= min_err + filter_thre][\"algorithm\"].to_list()\n",
    "        # get minimum of winner and maximum of loser\n",
    "        winner_max = x[x['algorithm'].isin(winners)]['wga_te_err'].max()\n",
    "        loser_min = x[~x['algorithm'].isin(winners)]['wga_te_err'].min()\n",
    "        return loser_min - winner_max\n",
    "    conf_series = all_res.groupby(identifier)[['algorithm', 'wga_te_err']].apply(lambda x: get_conf(x)).reset_index(name='conf')\n",
    "    all_res = all_res.merge(conf_series, on=identifier)\n",
    "\n",
    "    # load extra data (static meta-features)\n",
    "    tr_df = pd.read_csv(os.path.join(path, f\"{model}_embeddings_tr/proc_emb.csv\"))\n",
    "    all_res = all_res.merge(tr_df, on=identifier)\n",
    "    if data == \"coco\":\n",
    "        # rename columns\n",
    "        all_res.rename(columns={\"c_dist\": \"c_intra\", \"a_dist\": \"a_intra\"}, inplace=True)\n",
    "\n",
    "    # retain the original results\n",
    "    all_res_ori = copy.deepcopy(all_res)\n",
    "    # deduplicate\n",
    "    all_res = all_res[all_res[\"rank\"]==1.0]\n",
    "    # set all nan to 0\n",
    "    all_res['conf'].fillna(0, inplace=True)\n",
    "\n",
    "    return all_res, all_res_ori\n",
    "\n",
    "\n",
    "def prepare_data(celeba_df, celeba_df_full, coco_df, coco_df_full, celeba_feats, coco_feats, seed):\n",
    "    # split the data and convert to trainable format\n",
    "    all_X = celeba_df[celeba_feats].to_numpy()[:,:-1].astype('float')\n",
    "    all_y = np.array(list(celeba_df[celeba_feats].to_numpy()[:,-1]))\n",
    "\n",
    "    from sklearn.preprocessing import StandardScaler\n",
    "    from sklearn.model_selection import train_test_split\n",
    "\n",
    "    X_train, X_test, y_train, y_test, tr_idx, te_idx = train_test_split(all_X, all_y, range(len(all_X)), test_size=0.2, random_state=seed)\n",
    "\n",
    "    scaler = StandardScaler()\n",
    "    X_train = scaler.fit_transform(X_train)\n",
    "    X_test = scaler.transform(X_test)\n",
    "\n",
    "    X_test_coco = coco_df[coco_feats].to_numpy()[:,:-1].astype('float')\n",
    "    y_test_coco = np.array(list(coco_df[coco_feats].to_numpy()[:,-1]))\n",
    "    X_test_coco = scaler.transform(X_test_coco)\n",
    "\n",
    "    # process celeba_df_full\n",
    "    tr_idx_full = np.concatenate([np.arange(5*i, 5*i+5) for i in tr_idx])\n",
    "    te_idx_full = np.concatenate([np.arange(5*i, 5*i+5) for i in te_idx])\n",
    "    all_X_full = celeba_df_full[celeba_feats[:-1]+[\"algorithm\"]]\n",
    "    all_X_full['algorithm'] = all_X_full['algorithm'].map(lambda x: algorithms.index(x))\n",
    "    all_X_full = all_X_full.to_numpy().astype('float')\n",
    "    all_X_alg = all_X_full[:, -1]\n",
    "    all_y_full = celeba_df_full['wga_te_err'].to_numpy()\n",
    "    X_train_full, y_train_full, X_test_full, y_test_full = all_X_full[tr_idx_full], all_y_full[tr_idx_full], all_X_full[te_idx_full], all_y_full[te_idx_full]\n",
    "    scaler = StandardScaler()\n",
    "    X_train_full = scaler.fit_transform(X_train_full)\n",
    "    X_test_full = scaler.transform(X_test_full)\n",
    "    X_train_alg = all_X_alg[tr_idx_full]\n",
    "    X_test_alg = all_X_alg[te_idx_full]\n",
    "\n",
    "    # process celeba_df_full\n",
    "    X_test_coco_full = coco_df_full[coco_feats[:-1]+[\"algorithm\"]]\n",
    "    X_test_coco_full['algorithm'] = X_test_coco_full['algorithm'].map(lambda x: algorithms.index(x)).to_numpy().astype('float')\n",
    "    y_test_coco_full = coco_df_full['wga_te_err'].to_numpy()\n",
    "    X_test_coco_full = scaler.transform(X_test_coco_full)\n",
    "\n",
    "    return X_train, X_test, y_train, y_test, X_train_full, y_train_full, X_test_full, y_test_full, tr_idx, te_idx, X_test_coco, y_test_coco, X_test_coco_full, y_test_coco_full, X_train_alg, X_test_alg\n",
    "\n",
    "\n",
    "def eval_acc(y_test, y_pred, mode, verbose=True):\n",
    "    # assuming numpy arrays\n",
    "    if mode == \"0-1\":\n",
    "        # acc = ((y_pred == y_test).float().sum(dim=1)==5).float().mean()\n",
    "        acc = (y_pred == y_test).all(axis=1).mean()\n",
    "        if verbose:\n",
    "            print(f\"Eval {mode} accuracy: {acc}\")\n",
    "    elif mode == \"soft 0-1\":\n",
    "        correct = 0\n",
    "        for i, curr_y in enumerate(y_test):\n",
    "            curr_pred = y_pred[i]\n",
    "            # check if the positions of 1s in curr_pred are also 1s in curr_y\n",
    "            pos = np.where(curr_pred==1)[0]\n",
    "            if np.all(curr_y[pos] == 1):\n",
    "                correct += 1\n",
    "        acc = correct / y_test.shape[0]\n",
    "        if verbose:\n",
    "            print(f\"Eval {mode} accuracy: {acc}\")\n",
    "\n",
    "\n",
    "def eval_wga_err(test_df, y_pred, df_full, identifier):\n",
    "    # check if y_pred is already a list\n",
    "    if not isinstance(y_pred, list):\n",
    "        y_pred = y_pred.tolist()\n",
    "    test_df['pred'] = y_pred\n",
    "    winners_pred = []\n",
    "    for y_ in y_pred:\n",
    "        y_ = np.array(y_)\n",
    "        winners_pred.append(np.array(algorithms)[np.where(y_==1)[0]])\n",
    "    test_df['winners_pred'] = winners_pred\n",
    "    # need to transform the multi-hot vector to the algorithm name\n",
    "    fine_test_df = df_full.merge(test_df, on=identifier)\n",
    "    def get_pred_wga_te_err(x):\n",
    "        pred_winners = x['winners_pred'].iloc[0]\n",
    "        if len(pred_winners) == 0:\n",
    "            # randomly select one\n",
    "            pred_winner = np.random.choice(algorithms)\n",
    "        else:\n",
    "            pred_winner = np.random.choice(pred_winners)\n",
    "        return x[x['algorithm_x']==pred_winner]['wga_te_err_x'].iloc[0]\n",
    "    pred_err = fine_test_df.groupby(identifier)['wga_te_err_x', 'algorithm_x' ,'winners_pred'].apply(lambda x: get_pred_wga_te_err(x))\n",
    "    print(f\"Eval wg err: {pred_err.mean()}\")\n",
    "    return pred_err.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "seed = 0\n",
    "\n",
    "model = \"resnet\"\n",
    "algorithms = ['ERM', 'GroupDRO', 'oversample', 'remax-margin', 'undersample']\n",
    "celeba_path = f\"exps/div_explore/celeba_v2/\"\n",
    "coco_path = f\"exps/div_explore/coco_v2/\"\n",
    "celeba_identifier = ['n', 'sc', 'ci', 'ai', 'y_task', 'a_task']\n",
    "coco_identifier = ['n', 'sc', 'ci', 'ai']\n",
    "\n",
    "celeba_df, celeba_df_full = load_data(\"celeba\", model, celeba_path, celeba_identifier)\n",
    "coco_df, coco_df_full = load_data(\"coco\", model, coco_path, coco_identifier)\n",
    "\n",
    "# other variables\n",
    "celeba_feats = ['n', 'sc', 'ci', 'ai', 'c_intra', 'a_intra', 'multi_hot']\n",
    "# celeba_feat = ['n', 'sc_', 'ci_', 'ai_', 'c_intra', 'a_intra_i', 'multi_hot']\n",
    "coco_feats = ['n', 'sc', 'ci', 'ai', 'c_intra', 'a_intra', 'multi_hot']\n",
    "\n",
    "# sanity check\n",
    "print(model)\n",
    "print(celeba_df.shape, celeba_df_full.shape)\n",
    "print(coco_df.shape, coco_df_full.shape)\n",
    "print(celeba_df.columns)\n",
    "print(coco_df.columns)\n",
    "print(celeba_df['wga_te_err'].mean(), celeba_df_full['wga_te_err'].mean())\n",
    "print(coco_df['wga_te_err'].mean(), coco_df_full['wga_te_err'].mean())\n",
    "\n",
    "(\n",
    "    X_train, X_test, y_train, y_test,\n",
    "    X_train_full, y_train_full, X_test_full, y_test_full,\n",
    "    tr_idx, te_idx,\n",
    "    X_test_coco, y_test_coco,\n",
    "    X_test_coco_full, y_test_coco_full,\n",
    "    X_train_alg, X_test_alg\n",
    ") = prepare_data(celeba_df, celeba_df_full, coco_df, coco_df_full, celeba_feats, coco_feats, seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "celeba_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_res.merge(celeba_df, on=['n', 'sc', 'ci', 'ai', 'algorithm'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "all_df_rank = copy.deepcopy(all_res)\n",
    "\n",
    "all_df_rank_data = all_df_rank[['n', 'sc', 'ci', 'ai', 'task_id', 'multi_hot']]\n",
    "\n",
    "all_X = all_df_rank_data.to_numpy()[:,:-1].astype('float')\n",
    "all_y = np.array(list(all_df_rank_data.to_numpy()[:, -1]))\n",
    "\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.neural_network import MLPClassifier\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "X_train, X_test, y_train, y_test = train_test_split(all_X, all_y, test_size=0.2, random_state=0)\n",
    "\n",
    "scaler = StandardScaler()\n",
    "X_train = scaler.fit_transform(X_train)\n",
    "X_test = scaler.transform(X_test)\n",
    "\n",
    "# X_train, y_train = all_X, all_y\n",
    "# exclude column 2\n",
    "# ['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat']\n",
    "cols = list(range(X_train.shape[1]))\n",
    "# cols.remove(0)\n",
    "X_train, X_test = X_train[:,cols], X_test[:,cols]\n",
    "\n",
    "# clf = KNeighborsClassifier(n_neighbors=3).fit(X_train, y_train)\n",
    "# clf = LogisticRegression(random_state=0, max_iter=int(1e8), verbose=True, C=0.1).fit(X_train, y_train)\n",
    "clf = MLPClassifier(random_state=0, max_iter=500, verbose=True, tol=1e-4, n_iter_no_change=2000, alpha=0.1, hidden_layer_sizes=(100,100, 50, 50)).fit(X_train, y_train)\n",
    "\n",
    "# clf = MLPClassifier(random_state=1, max_iter=1000000, verbose=True, tol=1e-3, n_iter_no_change=2000, alpha=0.0001, hidden_layer_sizes=(100, 10)).fit(X_train, y_train)\n",
    "\n",
    "# solver=\"sgd\", learning_rate_init=0.01\n",
    "print(clf.score(X_train, y_train))\n",
    "print(clf.score(X_test, y_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_acc(y_test, y_pred, mode, verbose=True):\n",
    "    # assuming numpy arrays\n",
    "    if mode == \"0-1\":\n",
    "        # acc = ((y_pred == y_test).float().sum(dim=1)==5).float().mean()\n",
    "        acc = (y_pred == y_test).all(axis=1).mean()\n",
    "        if verbose:\n",
    "            print(f\"Eval {mode} accuracy: {acc}\")\n",
    "    elif mode == \"soft 0-1\":\n",
    "        correct = 0\n",
    "        for i, curr_y in enumerate(y_test):\n",
    "            curr_pred = y_pred[i]\n",
    "            # check if the positions of 1s in curr_pred are also 1s in curr_y\n",
    "            pos = np.where(curr_pred==1)[0]\n",
    "            if np.all(curr_y[pos] == 1):\n",
    "                correct += 1\n",
    "        acc = correct / y_test.shape[0]\n",
    "        if verbose:\n",
    "            print(f\"Eval {mode} accuracy: {acc}\")\n",
    "\n",
    "\n",
    "def eval_wga_err(test_df, y_pred, df_full, identifier):\n",
    "    # check if y_pred is already a list\n",
    "    if not isinstance(y_pred, list):\n",
    "        y_pred = y_pred.tolist()\n",
    "    test_df['pred'] = y_pred\n",
    "    winners_pred = []\n",
    "    for y_ in y_pred:\n",
    "        y_ = np.array(y_)\n",
    "        winners_pred.append(np.array(algorithms)[np.where(y_==1)[0]])\n",
    "    test_df['winners_pred'] = winners_pred\n",
    "    # need to transform the multi-hot vector to the algorithm name\n",
    "    fine_test_df = df_full.merge(test_df, on=identifier)\n",
    "    def get_pred_wga_te_err(x):\n",
    "        pred_winners = x['winners_pred'].iloc[0]\n",
    "        if len(pred_winners) == 0:\n",
    "            # randomly select one\n",
    "            pred_winner = np.random.choice(algorithms)\n",
    "        else:\n",
    "            pred_winner = np.random.choice(pred_winners)\n",
    "        return x[x['algorithm_x']==pred_winner]['wga_te_err_x'].iloc[0]\n",
    "    pred_err = fine_test_df.groupby(identifier)['wga_te_err_x', 'algorithm_x' ,'winners_pred'].apply(lambda x: get_pred_wga_te_err(x))\n",
    "    print(f\"Eval wg err: {pred_err.mean()}\")\n",
    "    return pred_err.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred = clf.predict(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = eval_acc(y_test, y_pred, mode='0-1', verbose=True)\n",
    "_ = eval_acc(y_test, y_pred, mode='soft 0-1', verbose=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Linear probe (emb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "import copy\n",
    "model = \"clip\"\n",
    "path = f\"exps/div_explore/celeba_v2/{model}_lp\"\n",
    "res_files = os.listdir(path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_res = []\n",
    "for file in res_files:\n",
    "    res_file = os.path.join(path, file)\n",
    "    print(file)\n",
    "    # parse the seed from the file name\n",
    "    res = pd.read_csv(res_file)\n",
    "    res[\"seed\"] = int(file.split('seed')[-1][0])\n",
    "    res[\"y_task\"] = int(file.split('y')[-1].split('_')[0])\n",
    "    res[\"a_task\"] = int(file.split('a')[-1].split('.')[0])\n",
    "    all_res.append(res)\n",
    "all_res = pd.concat(all_res, ignore_index=True)\n",
    "\n",
    "algorithms = ['ERM', 'GroupDRO', 'oversample', 'remax-margin', 'undersample']\n",
    "\n",
    "# rename col\n",
    "all_res.rename(columns={\"method\": \"algorithm\"}, inplace=True)\n",
    "all_res = all_res[[\"n\", \"sc\", \"ci\", \"ai\", \"algorithm\", \"wga_te_err\", \"seed\", \"y_task\", \"a_task\"]]\n",
    "all_res = all_res.groupby(['n', 'sc', 'ci', 'ai', 'algorithm', 'y_task', 'a_task']).mean().reset_index()\n",
    "all_res.drop(columns=[\"seed\"], inplace=True)\n",
    "\n",
    "all_res['rank'] = all_res.groupby(['n', 'sc', 'ci', 'ai', 'y_task', 'a_task'])['wga_te_err'].rank(\"first\")\n",
    "# all_res_ori = copy.deepcopy(all_res)\n",
    "# print(all_res.head(100))\n",
    "\n",
    "def get_gt_rank(x, mode='tie', filter_thre=0.05):\n",
    "    # print(x)\n",
    "    min_err = x['wga_te_err'].min()\n",
    "    winners = x[x['wga_te_err'] <= min_err + filter_thre][\"algorithm\"].to_list()\n",
    "    return '|'.join(winners)\n",
    "\n",
    "# Apply the function to each group and reset the index to merge back with the original dataframe\n",
    "winners_series = all_res.groupby(['n', 'sc', 'ci', 'ai', 'y_task', 'a_task'])[['algorithm', 'wga_te_err']].apply(lambda x: get_gt_rank(x)).reset_index(name='winners')\n",
    "all_res = all_res.merge(winners_series, on=['n', 'sc', 'ci', 'ai', 'y_task', 'a_task'])\n",
    "\n",
    "def get_conf(x, mode='tie', filter_thre=0.05):\n",
    "    # print(x)\n",
    "    min_err = x['wga_te_err'].min()\n",
    "    winners = x[x['wga_te_err'] <= min_err + filter_thre][\"algorithm\"].to_list()\n",
    "    # get minimum of winner and maximum of loser\n",
    "    winner_max = x[x['algorithm'].isin(winners)]['wga_te_err'].max()\n",
    "    loser_min = x[~x['algorithm'].isin(winners)]['wga_te_err'].min()\n",
    "    return loser_min - winner_max\n",
    "conf_series = all_res.groupby(['n', 'sc', 'ci', 'ai', 'y_task', 'a_task'])[['algorithm', 'wga_te_err']].apply(lambda x: get_conf(x)).reset_index(name='conf')\n",
    "all_res = all_res.merge(conf_series, on=['n', 'sc', 'ci', 'ai', 'y_task', 'a_task'])\n",
    "\n",
    "all_res[\"multi_hot\"] = all_res[\"winners\"].map(lambda x: [1 if m in x.split(\"|\") else 0 for m in algorithms])\n",
    "\n",
    "# retain the original results\n",
    "all_res_ori = copy.deepcopy(all_res)\n",
    "# deduplicate\n",
    "all_res = all_res[all_res[\"rank\"]==1.0]\n",
    "\n",
    "# set all nan to 0\n",
    "all_res['conf'].fillna(0, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_res['winners'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot the histogram of the confidence\n",
    "all_res['conf'].hist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# to load the coco results\n",
    "model = \"clip\"\n",
    "path = f\"exps/div_explore/coco/{model}_lp\"\n",
    "res_files = os.listdir(path)\n",
    "\n",
    "def load_coco_res(res_files):\n",
    "    all_res = []\n",
    "    for file in res_files:\n",
    "        res_file = os.path.join(path, file)\n",
    "        print(file)\n",
    "        # parse the seed from the file name\n",
    "        res = pd.read_csv(res_file)\n",
    "        res[\"seed\"] = int(file.split('seed')[-1][0])\n",
    "        all_res.append(res)\n",
    "    all_res = pd.concat(all_res, ignore_index=True)\n",
    "\n",
    "    algorithms = ['ERM', 'GroupDRO', 'oversample', 'remax-margin', 'undersample']\n",
    "\n",
    "    # rename col\n",
    "    all_res.rename(columns={\"method\": \"algorithm\"}, inplace=True)\n",
    "    all_res = all_res[[\"n\", \"sc\", \"ci\", \"ai\", \"algorithm\", \"wga_te_err\", \"seed\"]]\n",
    "    all_res = all_res.groupby(['n', 'sc', 'ci', 'ai', 'algorithm']).mean().reset_index()\n",
    "    all_res.drop(columns=[\"seed\"], inplace=True)\n",
    "\n",
    "    all_res['rank'] = all_res.groupby(['n', 'sc', 'ci', 'ai'])['wga_te_err'].rank(\"first\")\n",
    "    # all_res_ori = copy.deepcopy(all_res)\n",
    "    # print(all_res.head(100))\n",
    "\n",
    "    def get_gt_rank(x, mode='tie', filter_thre=0.05):\n",
    "        # print(x)\n",
    "        min_err = x['wga_te_err'].min()\n",
    "        winners = x[x['wga_te_err'] <= min_err + filter_thre][\"algorithm\"].to_list()\n",
    "        return '|'.join(winners)\n",
    "\n",
    "    # Apply the function to each group and reset the index to merge back with the original dataframe\n",
    "    winners_series = all_res.groupby(['n', 'sc', 'ci', 'ai'])[['algorithm', 'wga_te_err']].apply(lambda x: get_gt_rank(x)).reset_index(name='winners')\n",
    "    all_res = all_res.merge(winners_series, on=['n', 'sc', 'ci', 'ai'])\n",
    "\n",
    "    all_res[\"multi_hot\"] = all_res[\"winners\"].map(lambda x: [1 if m in x.split(\"|\") else 0 for m in algorithms])\n",
    "\n",
    "    all_res_ori = copy.deepcopy(all_res)\n",
    "    # deduplicate\n",
    "    all_res = all_res[all_res[\"rank\"]==1.0]\n",
    "    return all_res, all_res_ori\n",
    "\n",
    "coco_res, coco_res_ori = load_coco_res(res_files)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "tr_df = pd.read_csv(f\"exps/div_explore/celeba_v2/{model}_embeddings_tr/proc_emb.csv\")\n",
    "te_df = pd.read_csv(f\"exps/div_explore/celeba_v2/{model}_embeddings_te/proc_emb.csv\")\n",
    "all_res = all_res.merge(tr_df, on=['n', 'sc', 'ci', 'ai', 'y_task', 'a_task'])\n",
    "# all_res['c_dist'] = all_res['c_dist'] ceiling\n",
    "# all_res['a_dist'] = all_res['a_dist'].apply(np.ceil).astype(int)\n",
    "# all_res['c_dist'] = all_res['c_dist'].apply(np.ceil).astype(int)\n",
    "# all_res['ca_diff'] = all_res['c_dist'] - all_res['a_dist']\n",
    "# all_res['ca_diff'].describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_res_ori['wga_te_err'].describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "tr_df = pd.read_csv(f\"exps/div_explore/coco/{model}_embeddings_tr/proc_emb.csv\")\n",
    "te_df = pd.read_csv(f\"exps/div_explore/coco/{model}_embeddings_te/proc_emb.csv\")\n",
    "coco_res = coco_res.merge(tr_df, on=['n', 'sc', 'ci', 'ai'])\n",
    "# all_res['c_dist'] = all_res['c_dist'] ceiling\n",
    "# all_res['a_dist'] = all_res['a_dist'].apply(np.ceil).astype(int)\n",
    "# all_res['c_dist'] = all_res['c_dist'].apply(np.ceil).astype(int)\n",
    "# coco_res['ca_diff'] = coco_res['c_dist'] - coco_res['a_dist']\n",
    "# coco_res['ca_diff'].describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coco_res_ori['wga_te_err'].describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "\n",
    "# Define the MLP model using nn.Sequential\n",
    "class MLPTorch(nn.Module):\n",
    "    def __init__(self, input_size, hidden_layer_sizes, output_size):\n",
    "        super(MLPTorch, self).__init__()\n",
    "        layers = []\n",
    "        layer_sizes = [input_size] + list(hidden_layer_sizes)\n",
    "        for i in range(len(layer_sizes) - 1):\n",
    "            layers.append(nn.Linear(layer_sizes[i], layer_sizes[i+1]))\n",
    "            layers.append(nn.ReLU())\n",
    "        layers.append(nn.Linear(layer_sizes[-1], output_size))\n",
    "        self.network = nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.network(x)\n",
    "\n",
    "class FullModel(nn.Module):\n",
    "    def __init__(self, emb_mlp, mlp):\n",
    "        super(FullModel, self).__init__()\n",
    "        self.emb_mlp = emb_mlp  # Embedding model\n",
    "        self.mlp = mlp  # MLP model\n",
    "\n",
    "    def forward(self, x):\n",
    "        # TODO\n",
    "        x = self.emb_mlp(x)\n",
    "        return self.mlp(x)\n",
    "\n",
    "# Training function\n",
    "def train_model(model, dataloader, criterion, optimizer, num_epochs, patience, tol, X_val=None, y_val=None, verbose=True):\n",
    "    best_loss = float('inf')\n",
    "    patience_counter = 0\n",
    "\n",
    "    for epoch in range(num_epochs):\n",
    "        model.train()\n",
    "        running_loss = 0.0\n",
    "\n",
    "        for inputs, labels, weights in dataloader:\n",
    "            optimizer.zero_grad()\n",
    "            outputs = model(inputs)\n",
    "            # print(outputs.shape, labels.shape)\n",
    "            loss = criterion(outputs, labels, weights)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            running_loss += loss.item() * inputs.size(0)\n",
    "\n",
    "        epoch_loss = running_loss / len(dataloader.dataset)\n",
    "        if verbose:\n",
    "            print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}')\n",
    "\n",
    "        if epoch_loss < best_loss - tol:\n",
    "            best_loss = epoch_loss\n",
    "            patience_counter = 0\n",
    "        else:\n",
    "            patience_counter += 1\n",
    "\n",
    "        if patience_counter >= patience:\n",
    "            print(\"Early stopping triggered\")\n",
    "            break\n",
    "\n",
    "        if epoch % 5 == 0 and X_val is not None:\n",
    "            eval_model(model, X_val, y_val, criterion)\n",
    "\n",
    "    return model\n",
    "\n",
    "def eval_model(model, X, y, criterion, mode=\"0-1\", verbose=True):\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        y_pred = model(X)\n",
    "        # loss = criterion(y_pred, y)\n",
    "        # # print(f\"Eval loss: {loss.item()}\")\n",
    "\n",
    "        y_pred = torch.sigmoid(y_pred)\n",
    "        y_pred = (y_pred > 0.5).float()\n",
    "        if mode == \"0-1\":\n",
    "            acc = ((y_pred == y).float().sum(dim=1)==5).float().mean()\n",
    "            acc = acc.item()\n",
    "            if verbose:\n",
    "                print(f\"Eval accuracy: {acc}\")\n",
    "        elif mode == \"soft 0-1\":\n",
    "            correct = 0\n",
    "            for i, curr_y in enumerate(y):\n",
    "                curr_pred = y_pred[i]\n",
    "                # check if the positions of 1s in curr_pred are also 1s in curr_y\n",
    "                pos = torch.nonzero(curr_pred).squeeze()\n",
    "                if torch.all(curr_y[pos] == 1):\n",
    "                    correct += 1\n",
    "            acc = correct / y.shape[0]\n",
    "            if verbose:\n",
    "                print(f\"Eval accuracy: {acc}\")\n",
    "        elif mode == \"jaccard\":\n",
    "            # Jaccard similarity\n",
    "            intersection = torch.sum(y_pred * y, dim=1)\n",
    "            union = torch.sum(y_pred + y, dim=1)\n",
    "            acc = intersection / union\n",
    "            if verbose:\n",
    "                print(f\"Eval Jaccard: {acc.mean().item()}\")\n",
    "\n",
    "    return model, acc\n",
    "\n",
    "class WeightedBCELoss(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(WeightedBCELoss, self).__init__()\n",
    "        self.bce = nn.BCEWithLogitsLoss()\n",
    "\n",
    "    def forward(self, outputs, targets, weights):\n",
    "        loss = self.bce(outputs, targets)\n",
    "        # print(weights)\n",
    "        weighted_loss = loss * weights.unsqueeze(1)\n",
    "        return weighted_loss.mean()\n",
    "\n",
    "def sigmoid(x, k):\n",
    "    return 1 / (1 + np.exp(-k * x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coco_res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "# ignore the warning\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "# fix all the seeds in numpy and torch\n",
    "verbose=False\n",
    "seed_list = [2023, 2024, 2025]\n",
    "k_list = [0,1,2.5,5,10,15,20,50,100]\n",
    "# seed_list = [2023]\n",
    "# k_list = [0]\n",
    "\n",
    "for k in k_list:\n",
    "    all_res['sigmoid_conf'] = sigmoid(all_res['conf'], k)\n",
    "    hard_acc, soft_acc, err = [], [], []\n",
    "    coco_acc, coco_acc_soft, coco_err = [], [], []\n",
    "    for seed in seed_list:\n",
    "        np.random.seed(seed)\n",
    "        torch.manual_seed(seed)\n",
    "\n",
    "        all_df_rank = copy.deepcopy(all_res)\n",
    "        # feat_list = ['n', 'sc', 'ci', 'ai', 'c_inter_n', 'a_inter_n', 'multi_hot']\n",
    "        feat_list = ['n', 'sc', 'ci', 'ai', 'c_intra', 'a_intra', 'multi_hot']\n",
    "        # feat_list = ['n', 'sc', 'ci', 'ai', 'y_task', 'a_task', 'multi_hot']\n",
    "        # feat_list = ['n', 'sc', 'ci', 'ai', 'multi_hot']\n",
    "        # feat_list = ['n', 'sc_', 'ci_', 'ai_', 'c_intra', 'a_intra_i', 'multi_hot']\n",
    "        # feat_list = ['n', 'sc_', 'ci_', 'ai_', 'multi_hot']\n",
    "        coco_feat_list = ['n', 'sc', 'ci', 'ai', 'c_dist', 'a_dist', 'multi_hot']\n",
    "\n",
    "        all_df_rank_data = all_df_rank[feat_list]\n",
    "\n",
    "        all_X = all_df_rank_data.to_numpy()[:,:-1].astype('float')\n",
    "        all_y = np.array(list(all_df_rank_data.to_numpy()[:, -1]))\n",
    "\n",
    "        # all_X[:, [1,2,3,4,5]] = 1\n",
    "\n",
    "        # check if 'conf' exists\n",
    "        if 'sigmoid_conf' in all_df_rank.columns:\n",
    "            all_conf = all_df_rank['sigmoid_conf'].to_numpy() + 0.5\n",
    "        else:\n",
    "            # all ones\n",
    "            all_conf = np.ones_like(all_y)\n",
    "\n",
    "        from sklearn.preprocessing import StandardScaler\n",
    "        from sklearn.model_selection import train_test_split\n",
    "\n",
    "        X_train, X_test, y_train, y_test, conf_train, conf_test, tr_idx, te_idx = train_test_split(all_X, all_y, all_conf, range(len(all_X)), test_size=0.2, random_state=0)\n",
    "        X_val, y_val = None, None\n",
    "\n",
    "        scaler = StandardScaler()\n",
    "        X_train = scaler.fit_transform(X_train)\n",
    "        X_test = scaler.transform(X_test)\n",
    "\n",
    "        coco = coco_res[coco_feat_list]\n",
    "        coco_X = coco.to_numpy()[:,:-1].astype('float')\n",
    "        coco_y = np.array(list(coco.to_numpy()[:, -1]))\n",
    "        coco_X = scaler.transform(coco_X)\n",
    "        coco_X, coco_y = torch.tensor(coco_X).float(), torch.tensor(coco_y).float()\n",
    "\n",
    "        X_train, X_test = torch.tensor(X_train).float(), torch.tensor(X_test).float()\n",
    "        y_train, y_test = torch.tensor(y_train).float(), torch.tensor(y_test).float()\n",
    "        conf_train, conf_test = torch.tensor(conf_train).float(), torch.tensor(conf_test).float()\n",
    "\n",
    "        cols = list(range(X_train.shape[1]))\n",
    "        X_train, X_test = X_train[:,cols], X_test[:,cols]\n",
    "        # X_train = torch.ones_like(X_train)\n",
    "        # X_test = torch.ones_like(X_test)\n",
    "\n",
    "        # Hyperparameters and data\n",
    "        input_size = X_train.shape[1]\n",
    "        hidden_layer_sizes = (100, 100, 50, 50)\n",
    "        output_size = 5\n",
    "        num_epochs = 500\n",
    "        patience = 2000\n",
    "        tol = 1e-4\n",
    "        alpha = 0.0001\n",
    "        batch_size = 256\n",
    "\n",
    "        # Create dataloader\n",
    "        tr_dataset = TensorDataset(X_train, y_train, conf_train)\n",
    "        tr_dataloader = DataLoader(tr_dataset, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "        # Initialize model, criterion, and optimizer\n",
    "        model = MLPTorch(input_size, hidden_layer_sizes, output_size)\n",
    "\n",
    "\n",
    "        # criterion = nn.BCEWithLogitsLoss()\n",
    "        criterion = WeightedBCELoss()\n",
    "        optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=alpha)\n",
    "        # Train the model\n",
    "        trained_model = train_model(model, tr_dataloader, criterion, optimizer, num_epochs, patience, tol, X_val, y_val, verbose=verbose)\n",
    "\n",
    "        # Eval on train\n",
    "        _, curr_acc = eval_model(trained_model, X_train, y_train, criterion, verbose=verbose)\n",
    "        _, curr_hard_acc = eval_model(trained_model, X_test, y_test, criterion, verbose=verbose)\n",
    "        _, curr_soft_acc = eval_model(trained_model, X_test, y_test, criterion, mode=\"soft 0-1\", verbose=verbose)\n",
    "        hard_acc.append(curr_hard_acc)\n",
    "        soft_acc.append(curr_soft_acc)\n",
    "\n",
    "        # evaluate the pred err\n",
    "        algorithms = ['ERM', 'GroupDRO', 'oversample', 'remax-margin', 'undersample']\n",
    "        test_df = all_df_rank.iloc[te_idx]\n",
    "        y_pred = trained_model(X_test)\n",
    "        y_pred = torch.sigmoid(y_pred)\n",
    "        y_pred = (y_pred > 0.5).float()\n",
    "        winners_pred = []\n",
    "        for y_ in y_pred:\n",
    "            winners_pred.append(np.array(algorithms)[np.where(y_==1)[0]])\n",
    "        test_df['winners_pred'] = winners_pred\n",
    "        fine_test_df = all_res_ori.merge(test_df, on=['n', 'sc', 'ci', 'ai', 'y_task', 'a_task'])\n",
    "\n",
    "        def get_pred_wga_te_err(x):\n",
    "            pred_winners = x['winners_pred'].iloc[0]\n",
    "            if len(pred_winners) == 0:\n",
    "                # randomly select one\n",
    "                pred_winner = np.random.choice(algorithms)\n",
    "            else:\n",
    "                # es = []\n",
    "                # for w in pred_winners:\n",
    "                #     es.append(x[x['algorithm_x']==w]['wga_te_err_x'].iloc[0])\n",
    "                # return np.min(es)\n",
    "                # return np.min(x['wga_te_err_x'])\n",
    "                # random select from the predicted\n",
    "                pred_winner = np.random.choice(pred_winners)\n",
    "            # # print(x[x['algorithm_x']==pred_winner]['wga_te_err_x'].iloc[0])\n",
    "            return x[x['algorithm_x']==pred_winner]['wga_te_err_x'].iloc[0]\n",
    "        pred_err = fine_test_df.groupby(['n', 'sc', 'ci', 'ai', 'y_task', 'a_task'])['wga_te_err_x', 'algorithm_x' ,'winners_pred'].apply(lambda x: get_pred_wga_te_err(x))\n",
    "        err.append(pred_err.mean())\n",
    "\n",
    "        _, curr_coco_acc = eval_model(trained_model, coco_X, coco_y, criterion, verbose=verbose)\n",
    "        _, curr_coco_acc_soft = eval_model(trained_model, coco_X, coco_y, criterion, verbose=verbose, mode=\"soft 0-1\")\n",
    "        coco_acc.append(curr_coco_acc)\n",
    "        coco_acc_soft.append(curr_coco_acc_soft)\n",
    "\n",
    "        # evaluate the pred err\n",
    "        algorithms = ['ERM', 'GroupDRO', 'oversample', 'remax-margin', 'undersample']\n",
    "        test_df_coco = coco_res\n",
    "        y_pred = trained_model(coco_X)\n",
    "        y_pred = torch.sigmoid(y_pred)\n",
    "        y_pred = (y_pred > 0.5).float()\n",
    "        winners_pred = []\n",
    "        for y_ in y_pred:\n",
    "            winners_pred.append(np.array(algorithms)[np.where(y_==1)[0]])\n",
    "        test_df_coco['winners_pred'] = winners_pred\n",
    "        fine_test_df = coco_res_ori.merge(test_df_coco, on=['n', 'sc', 'ci', 'ai'])\n",
    "\n",
    "        def get_pred_wga_te_err(x):\n",
    "            pred_winners = x['winners_pred'].iloc[0]\n",
    "            if len(pred_winners) == 0:\n",
    "                # randomly select one\n",
    "                pred_winner = np.random.choice(algorithms)\n",
    "            else:\n",
    "                # es = []\n",
    "                # for w in pred_winners:\n",
    "                #     es.append(x[x['algorithm_x']==w]['wga_te_err_x'].iloc[0])\n",
    "                # return np.min(es)\n",
    "                # return np.min(x['wga_te_err_x'])\n",
    "                # random select from the predicted\n",
    "                pred_winner = np.random.choice(pred_winners)\n",
    "            # # print(x[x['algorithm_x']==pred_winner]['wga_te_err_x'].iloc[0])\n",
    "            return x[x['algorithm_x']==pred_winner]['wga_te_err_x'].iloc[0]\n",
    "        pred_err = fine_test_df.groupby(['n', 'sc', 'ci', 'ai'])['wga_te_err_x', 'algorithm_x' ,'winners_pred'].apply(lambda x: get_pred_wga_te_err(x))\n",
    "        coco_err.append(pred_err.mean())\n",
    "\n",
    "    print(k, np.mean(hard_acc), np.std(hard_acc), np.mean(soft_acc), np.std(soft_acc), np.mean(err), np.std(err))\n",
    "    print(k, np.mean(coco_acc), np.std(coco_acc), np.mean(coco_acc_soft), np.std(coco_acc_soft), np.mean(coco_err), np.std(coco_err))\n",
    "\n",
    "# # newly added\n",
    "# X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=0)\n",
    "# X_val = scaler.transform(X_val)\n",
    "# cols.remove(0)\n",
    "# ['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat']\n",
    "# X_train, X_val, X_test = X_train[:,cols], X_val[:,cols], X_test[:,cols]\n",
    "# X_train, X_val, X_test = torch.tensor(X_train).float(), torch.tensor(X_val).float(), torch.tensor(X_test).float()\n",
    "# y_train, y_val, y_test = torch.tensor(y_train).float(), torch.tensor(y_val).float(), torch.tensor(y_test).float()\n",
    "# emb_model = MLPTorch(input_size, hidden_layer_sizes, output_size)\n",
    "# model = FullModel(emb_model, model)\n",
    "# optimizer = optim.SGD(model.parameters(), lr=0.1, weight_decay=alpha)\n",
    "# _ = eval_model(model, coco_X, coco_y, criterion)\n",
    "# _ = eval_model(model, X_test, y_test, criterion, mode=\"jaccard\")\n",
    "# _ = eval_model(trained_model, coco_X, coco_y, criterion, mode=\"soft 0-1\")\n",
    "# _ = eval_model(model, coco_X, coco_y, criterion, mode=\"jaccard\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coco_res"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 262,
   "metadata": {},
   "outputs": [],
   "source": [
    "def jac_sim(list1, list2):\n",
    "    set1, set2 = set(list1), set(list2)\n",
    "    intersection = len(set1.intersection(set2))\n",
    "    union = len(set1.union(set2))\n",
    "    return intersection / union\n",
    "\n",
    "def baseline_pred(mode):\n",
    "\n",
    "    methods = np.array(['ERM', 'GroupDRO', 'oversample', 'remax-margin', 'undersample'])\n",
    "    if mode == \"real_random\":\n",
    "        num_winners = np.random.choice(5, 1, replace=False)[0] + 1\n",
    "        y_pred = np.random.choice(5, num_winners, replace=False)\n",
    "    elif mode == \"enhanced_random\":\n",
    "        # num_winners = sum(y)\n",
    "        y_pred = np.random.choice(5, 5, replace=False)\n",
    "    elif mode == \"global_rank_baseline\":\n",
    "        num_winners = np.random.choice(5, 1, replace=False)[0] + 1\n",
    "        y_pred = np.argsort(global_rank)[::-1][:num_winners]\n",
    "    elif mode == \"global_rank_baseline_random\":\n",
    "        num_winners = np.random.choice(5, 1, replace=False)[0] + 1\n",
    "        # print(num_winners)\n",
    "        y_pred = np.argsort(global_rank)[::-1][:num_winners]\n",
    "    else:\n",
    "        raise ValueError(f\"unknown mode {mode}\")\n",
    "    y_pred_str = methods[y_pred].tolist()\n",
    "    return y_pred, y_pred_str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "# fix all the seeds in numpy and torch\n",
    "verbose=False\n",
    "seed_list = [2023, 2024, 2025]\n",
    "acc01, soft_acc01, err = [], [], []\n",
    "coco_soft, coco_err = [], []\n",
    "mode=\"global_rank_baseline\"\n",
    "for seed in seed_list:\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "\n",
    "    all_df_rank = copy.deepcopy(all_res)\n",
    "\n",
    "    from sklearn.preprocessing import StandardScaler\n",
    "    from sklearn.model_selection import train_test_split\n",
    "    train_df, test_df, tr_idx, te_idx = train_test_split(all_df_rank, range(len(all_df_rank)), test_size=0.2, random_state=0)\n",
    "    test_df = all_df_rank.iloc[te_idx]\n",
    "\n",
    "    gt = train_df[\"multi_hot\"].to_list()\n",
    "    global_rank = np.array(gt).sum(axis=0)\n",
    "    methods = [\"ERM\", \"GroupDRO\", \"remax-margin\", \"oversample\", \"undersample\"]\n",
    "    global_rank\n",
    "\n",
    "    # num = 0\n",
    "    jac, correct, soft_correct, pearson = [], [], [], []\n",
    "\n",
    "    y_preds = []\n",
    "    for y in test_df[\"multi_hot\"]:\n",
    "        # print(y)\n",
    "        # if sum(y) == 5:\n",
    "        #     continue\n",
    "        # else:\n",
    "        #     num+=1\n",
    "\n",
    "        y_pred, y_pred_str = baseline_pred(mode)\n",
    "        y = np.where(np.array(y)==1)[0]\n",
    "        # print(y_pred)\n",
    "        jac.append(jac_sim(y, y_pred))\n",
    "\n",
    "        if set(y) == set(y_pred):\n",
    "            correct.append(1)\n",
    "        else:\n",
    "            correct.append(0)\n",
    "\n",
    "        # check if the positions of 1s in curr_pred are also 1s in curr_y\n",
    "        curr_y = np.zeros(5)\n",
    "        curr_y[y] = 1\n",
    "        curr_pred = np.zeros(5)\n",
    "        curr_pred[y_pred] = 1\n",
    "        pos = np.nonzero(curr_pred)\n",
    "        if np.all(curr_y[pos] == 1):\n",
    "            soft_correct.append(1)\n",
    "        else:\n",
    "            soft_correct.append(0)\n",
    "\n",
    "        y_preds.append(curr_pred)\n",
    "    test_df[\"pred\"] = y_preds\n",
    "    # print(num)\n",
    "    # acc = np.mean(jac)\n",
    "    # print(f\"random baseline jac accuracy: \", acc)\n",
    "    acc = np.mean(correct)\n",
    "    # print(f\"random baseline 0-1 accuracy: \", acc)\n",
    "    soft_acc = np.mean(soft_correct)\n",
    "    acc01.append(acc)\n",
    "    soft_acc01.append(soft_acc)\n",
    "\n",
    "    fine_test_df = all_res_ori.merge(test_df, on=['n', 'sc', 'ci', 'ai', 'y_task', 'a_task'])\n",
    "    def get_pred_wga_te_err(x, mode):\n",
    "        pred_winners = baseline_pred(mode)[1]\n",
    "        pred_winner = np.random.choice(pred_winners)\n",
    "        return x[x['algorithm_x']==pred_winner]['wga_te_err_x'].iloc[0]\n",
    "    pred_err = fine_test_df.groupby(['n', 'sc', 'ci', 'ai', 'y_task', 'a_task'])['wga_te_err_x', 'algorithm_x'].apply(lambda x: get_pred_wga_te_err(x, mode))\n",
    "    err.append(pred_err.mean())\n",
    "\n",
    "    def get_pearson_corr(x):\n",
    "        wga_te_err = x['wga_te_err_x'].iloc[:5]\n",
    "        pred_rank = x['pred'].iloc[0]\n",
    "        # pred_rank = [4, 0, 3, 1, 2]\n",
    "        # print(wga_te_err)\n",
    "        # print(pred_rank)\n",
    "        curr_pearson = np.corrcoef(wga_te_err, pred_rank)[0,1]\n",
    "        return curr_pearson\n",
    "    pearson_corr = fine_test_df.groupby(['n', 'sc', 'ci', 'ai', 'y_task', 'a_task'])['wga_te_err_x', 'pred'].apply(lambda x: get_pearson_corr(x))\n",
    "    pearson.append(pearson_corr.mean())\n",
    "\n",
    "    coco_soft_correct = []\n",
    "    y_preds = []\n",
    "    for y in coco_res[\"multi_hot\"]:\n",
    "\n",
    "        y_pred, y_pred_str = baseline_pred(mode)\n",
    "        y = np.where(np.array(y)==1)[0]\n",
    "\n",
    "        # check if the positions of 1s in curr_pred are also 1s in curr_y\n",
    "        curr_y = np.zeros(5)\n",
    "        curr_y[y] = 1\n",
    "        curr_pred = np.zeros(5)\n",
    "        curr_pred[y_pred] = 1\n",
    "        pos = np.nonzero(curr_pred)\n",
    "        if np.all(curr_y[pos] == 1):\n",
    "            coco_soft_correct.append(1)\n",
    "        else:\n",
    "            coco_soft_correct.append(0)\n",
    "\n",
    "        y_preds.append(curr_pred)\n",
    "    coco_res[\"pred\"] = y_preds\n",
    "    soft_acc = np.mean(coco_soft_correct)\n",
    "    coco_soft.append(soft_acc)\n",
    "\n",
    "    fine_test_df = coco_res_ori.merge(coco_res, on=['n', 'sc', 'ci', 'ai'])\n",
    "    def get_pred_wga_te_err(x, mode):\n",
    "        pred_winners = baseline_pred(mode)[1]\n",
    "        pred_winner = np.random.choice(pred_winners)\n",
    "        return x[x['algorithm_x']==pred_winner]['wga_te_err_x'].iloc[0]\n",
    "    pred_err = fine_test_df.groupby(['n', 'sc', 'ci', 'ai'])['wga_te_err_x', 'algorithm_x'].apply(lambda x: get_pred_wga_te_err(x, mode))\n",
    "    coco_err.append(pred_err.mean())\n",
    "\n",
    "\n",
    "print(np.mean(acc01), np.std(acc01), np.mean(err), np.std(err), np.mean(soft_acc01), np.std(soft_acc01), np.mean(pearson), np.std(pearson))\n",
    "print(np.mean(coco_soft), np.std(coco_soft), np.mean(coco_err), np.std(coco_err))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Regression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "import copy\n",
    "import torch\n",
    "model = \"clip\"\n",
    "path = f\"exps/div_explore/celeba_v2/{model}_lp\"\n",
    "res_files = os.listdir(path)\n",
    "\n",
    "# set seeds\n",
    "np.random.seed(2023)\n",
    "torch.manual_seed(2023)\n",
    "\n",
    "all_res = []\n",
    "for file in res_files:\n",
    "    res_file = os.path.join(path, file)\n",
    "    print(file)\n",
    "    # parse the seed from the file name\n",
    "    res = pd.read_csv(res_file)\n",
    "    res[\"seed\"] = int(file.split('seed')[-1][0])\n",
    "    res[\"y_task\"] = int(file.split('y')[-1].split('_')[0])\n",
    "    res[\"a_task\"] = int(file.split('a')[-1].split('.')[0])\n",
    "    all_res.append(res)\n",
    "all_res = pd.concat(all_res, ignore_index=True)\n",
    "\n",
    "algorithms = ['ERM', 'GroupDRO', 'oversample', 'remax-margin', 'undersample']\n",
    "\n",
    "# rename col\n",
    "all_res.rename(columns={\"method\": \"algorithm\"}, inplace=True)\n",
    "all_res = all_res[[\"n\", \"sc\", \"ci\", \"ai\", \"algorithm\", \"wga_te_err\", \"seed\", \"y_task\", \"a_task\"]]\n",
    "all_res = all_res.groupby(['n', 'sc', 'ci', 'ai', 'algorithm', 'y_task', 'a_task']).mean().reset_index()\n",
    "all_res.drop(columns=[\"seed\"], inplace=True)\n",
    "\n",
    "all_res['rank'] = all_res.groupby(['n', 'sc', 'ci', 'ai', 'y_task', 'a_task'])['wga_te_err'].rank(\"first\")\n",
    "# all_res_ori = copy.deepcopy(all_res)\n",
    "# print(all_res.head(100))\n",
    "\n",
    "def get_gt_rank(x, mode='tie', filter_thre=0.05):\n",
    "    # print(x)\n",
    "    min_err = x['wga_te_err'].min()\n",
    "    winners = x[x['wga_te_err'] <= min_err + filter_thre][\"algorithm\"].to_list()\n",
    "    return '|'.join(winners)\n",
    "\n",
    "# Apply the function to each group and reset the index to merge back with the original dataframe\n",
    "winners_series = all_res.groupby(['n', 'sc', 'ci', 'ai', 'y_task', 'a_task'])[['algorithm', 'wga_te_err']].apply(lambda x: get_gt_rank(x)).reset_index(name='winners')\n",
    "all_res = all_res.merge(winners_series, on=['n', 'sc', 'ci', 'ai', 'y_task', 'a_task'])\n",
    "\n",
    "def get_conf(x, mode='tie', filter_thre=0.05):\n",
    "    # print(x)\n",
    "    min_err = x['wga_te_err'].min()\n",
    "    winners = x[x['wga_te_err'] <= min_err + filter_thre][\"algorithm\"].to_list()\n",
    "    # get minimum of winner and maximum of loser\n",
    "    winner_max = x[x['algorithm'].isin(winners)]['wga_te_err'].max()\n",
    "    loser_min = x[~x['algorithm'].isin(winners)]['wga_te_err'].min()\n",
    "    return loser_min - winner_max\n",
    "conf_series = all_res.groupby(['n', 'sc', 'ci', 'ai', 'y_task', 'a_task'])[['algorithm', 'wga_te_err']].apply(lambda x: get_conf(x)).reset_index(name='conf')\n",
    "all_res = all_res.merge(conf_series, on=['n', 'sc', 'ci', 'ai', 'y_task', 'a_task'])\n",
    "\n",
    "all_res[\"multi_hot\"] = all_res[\"winners\"].map(lambda x: [1 if m in x.split(\"|\") else 0 for m in algorithms])\n",
    "\n",
    "# retain the original results\n",
    "all_res_ori = copy.deepcopy(all_res)\n",
    "# deduplicate\n",
    "all_res = all_res[all_res[\"rank\"]==1.0]\n",
    "\n",
    "# set all nan to 0\n",
    "all_res['conf'].fillna(0, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "tr_df = pd.read_csv(f\"exps/div_explore/celeba_v2/{model}_embeddings_tr/proc_emb.csv\")\n",
    "te_df = pd.read_csv(f\"exps/div_explore/celeba_v2/{model}_embeddings_te/proc_emb.csv\")\n",
    "all_res_ori = all_res_ori.merge(tr_df, on=['n', 'sc', 'ci', 'ai', 'y_task', 'a_task'])\n",
    "# all_res['c_dist'] = all_res['c_dist'] ceiling\n",
    "# all_res['a_dist'] = all_res['a_dist'].apply(np.ceil).astype(int)\n",
    "# all_res['c_dist'] = all_res['c_dist'].apply(np.ceil).astype(int)\n",
    "# all_res['ca_diff'] = all_res['c_dist'] - all_res['a_dist']\n",
    "# all_res['ca_diff'].describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# to load the coco results\n",
    "model = \"clip\"\n",
    "path = f\"exps/div_explore/coco/{model}_lp\"\n",
    "res_files = os.listdir(path)\n",
    "\n",
    "def load_coco_res(res_files):\n",
    "    all_res = []\n",
    "    for file in res_files:\n",
    "        res_file = os.path.join(path, file)\n",
    "        print(file)\n",
    "        # parse the seed from the file name\n",
    "        res = pd.read_csv(res_file)\n",
    "        res[\"seed\"] = int(file.split('seed')[-1][0])\n",
    "        all_res.append(res)\n",
    "    all_res = pd.concat(all_res, ignore_index=True)\n",
    "\n",
    "    algorithms = ['ERM', 'GroupDRO', 'oversample', 'remax-margin', 'undersample']\n",
    "\n",
    "    # rename col\n",
    "    all_res.rename(columns={\"method\": \"algorithm\"}, inplace=True)\n",
    "    all_res = all_res[[\"n\", \"sc\", \"ci\", \"ai\", \"algorithm\", \"wga_te_err\", \"seed\"]]\n",
    "    all_res = all_res.groupby(['n', 'sc', 'ci', 'ai', 'algorithm']).mean().reset_index()\n",
    "    all_res.drop(columns=[\"seed\"], inplace=True)\n",
    "\n",
    "    all_res['rank'] = all_res.groupby(['n', 'sc', 'ci', 'ai'])['wga_te_err'].rank(\"first\")\n",
    "    # all_res_ori = copy.deepcopy(all_res)\n",
    "    # print(all_res.head(100))\n",
    "\n",
    "    def get_gt_rank(x, mode='tie', filter_thre=0.05):\n",
    "        # print(x)\n",
    "        min_err = x['wga_te_err'].min()\n",
    "        winners = x[x['wga_te_err'] <= min_err + filter_thre][\"algorithm\"].to_list()\n",
    "        return '|'.join(winners)\n",
    "\n",
    "    # Apply the function to each group and reset the index to merge back with the original dataframe\n",
    "    winners_series = all_res.groupby(['n', 'sc', 'ci', 'ai'])[['algorithm', 'wga_te_err']].apply(lambda x: get_gt_rank(x)).reset_index(name='winners')\n",
    "    all_res = all_res.merge(winners_series, on=['n', 'sc', 'ci', 'ai'])\n",
    "\n",
    "    all_res[\"multi_hot\"] = all_res[\"winners\"].map(lambda x: [1 if m in x.split(\"|\") else 0 for m in algorithms])\n",
    "\n",
    "    all_res_ori = copy.deepcopy(all_res)\n",
    "    # deduplicate\n",
    "    all_res = all_res[all_res[\"rank\"]==1.0]\n",
    "    return all_res, all_res_ori\n",
    "\n",
    "coco_res, coco_res_ori = load_coco_res(res_files)\n",
    "\n",
    "tr_df = pd.read_csv(f\"exps/div_explore/coco/{model}_embeddings_tr/proc_emb.csv\")\n",
    "te_df = pd.read_csv(f\"exps/div_explore/coco/{model}_embeddings_te/proc_emb.csv\")\n",
    "coco_res = coco_res.merge(tr_df, on=['n', 'sc', 'ci', 'ai'])\n",
    "# all_res['c_dist'] = all_res['c_dist'] ceiling\n",
    "# all_res['a_dist'] = all_res['a_dist'].apply(np.ceil).astype(int)\n",
    "# all_res['c_dist'] = all_res['c_dist'].apply(np.ceil).astype(int)\n",
    "# coco_res['ca_diff'] = coco_res['c_dist'] - coco_res['a_dist']\n",
    "# coco_res['ca_diff'].describe()\n",
    "coco_res_ori = coco_res_ori.merge(tr_df, on=['n', 'sc', 'ci', 'ai'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "methods = np.array(['ERM', 'GroupDRO', 'oversample', 'remax-margin', 'undersample'])\n",
    "all_res_ori['algorithm'] = all_res_ori['algorithm'].map(lambda x: np.where(methods==x)[0][0])\n",
    "coco_res_ori['algorithm'] = coco_res_ori['algorithm'].map(lambda x: np.where(methods==x)[0][0])\n",
    "\n",
    "import copy\n",
    "from sklearn.model_selection import train_test_split\n",
    "tr_idx, te_idx = train_test_split(range(int(len(all_res_ori)/5)), test_size=0.2, random_state=0)\n",
    "\n",
    "all_df_rank = copy.deepcopy(all_res_ori)\n",
    "\n",
    "# feat_list = ['n', 'sc', 'ci', 'ai', 'y_task', 'a_task', 'algorithm', 'wga_te_err']\n",
    "feat_list = ['n', 'sc', 'ci', 'ai', 'c_intra', 'a_intra', 'algorithm', 'wga_te_err']\n",
    "# feat_list = ['n', 'sc_', 'ci_', 'ai_', 'c_intra', 'a_intra_i', 'algorithm', 'wga_te_err']\n",
    "\n",
    "all_df_rank_data = all_df_rank[feat_list]\n",
    "\n",
    "all_X = all_df_rank_data.to_numpy()[:,:-1].astype('float')\n",
    "all_y = np.array(list(all_df_rank_data.to_numpy()[:, -1]))\n",
    "\n",
    "# all_X[:, [1,2,3,4,5]] = 1\n",
    "\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.neural_network import MLPRegressor\n",
    "\n",
    "\n",
    "tr_idx = np.concatenate([np.arange(5*i, 5*i+5) for i in tr_idx])\n",
    "X_train = all_X[tr_idx]\n",
    "y_train = all_y[tr_idx]\n",
    "# the remaining ones are test\n",
    "te_idx = np.concatenate([np.arange(5*i, 5*i+5) for i in te_idx])\n",
    "X_test = all_X[te_idx]\n",
    "y_test = all_y[te_idx]\n",
    "\n",
    "scaler = StandardScaler()\n",
    "X_train = scaler.fit_transform(X_train)\n",
    "X_test = scaler.transform(X_test)\n",
    "\n",
    "coco_feat_list = ['n', 'sc', 'ci', 'ai', 'c_dist', 'a_dist', 'algorithm', 'wga_te_err']\n",
    "coco = coco_res_ori[coco_feat_list]\n",
    "coco_X = coco.to_numpy()[:,:-1].astype('float')\n",
    "coco_y = np.array(list(coco.to_numpy()[:, -1]))\n",
    "coco_X = scaler.transform(coco_X)\n",
    "\n",
    "cols = list(range(X_train.shape[1]))\n",
    "# cols.remove(0)\n",
    "X_train, X_test = X_train[:,cols], X_test[:,cols]\n",
    "\n",
    "clf = MLPRegressor(random_state=0, max_iter=10000, verbose=True, tol=1e-3, n_iter_no_change=100, alpha=0.1, hidden_layer_sizes=(100,100,50,50)).fit(X_train, y_train)\n",
    "\n",
    "y_pred = clf.predict(X_test)\n",
    "coco_y_pred = clf.predict(coco_X)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_df = all_res_ori.iloc[te_idx]\n",
    "test_df['pred'] = y_pred\n",
    "\n",
    "test_df['algorithm'] = test_df['algorithm'].map(lambda x: methods[x])\n",
    "\n",
    "def get_rank(x, col, mode='tie', filter_thre=0.05):\n",
    "    # x is a dataframe with columns ['method', 'wga_te_err'], find the method(s) with smallest wga_te_err, with filter_thre tolerance for tie\n",
    "    min_err = x[col].min()\n",
    "    winners = x[x[col] <= min_err + filter_thre][\"algorithm\"].to_list()\n",
    "    return '|'.join(winners)\n",
    "\n",
    "# Apply the function to each group and reset the index to merge back with the original dataframe\n",
    "winners_series = test_df.groupby(['n', 'sc', 'ci', 'ai', 'y_task', 'a_task'])[['algorithm', 'wga_te_err']].apply(lambda x: get_rank(x, 'wga_te_err')).reset_index(name='winners')\n",
    "# Merge the results back with the original dataframe\n",
    "test_df = test_df.merge(winners_series, on=['n', 'sc', 'ci', 'ai', 'y_task', 'a_task'])\n",
    "# methods = [\"ERM\", \"GroupDRO\", \"remax-margin\", \"oversample\", \"undersample\"]\n",
    "test_df[\"multi_hot\"] = test_df[\"winners_x\"].map(lambda x: [1 if m in x.split(\"|\") else 0 for m in methods])\n",
    "\n",
    "\n",
    "# Apply the function to each group and reset the index to merge back with the original dataframe\n",
    "winners_series = test_df.groupby(['n', 'sc', 'ci', 'ai', 'y_task', 'a_task'])[['algorithm', 'pred']].apply(lambda x: get_rank(x, 'pred')).reset_index(name='winners_pred')\n",
    "# Merge the results back with the original dataframe\n",
    "test_df = test_df.merge(winners_series, on=['n', 'sc', 'ci', 'ai', 'y_task', 'a_task'])\n",
    "# methods = [\"ERM\", \"GroupDRO\", \"remax-margin\", \"oversample\", \"undersample\"]\n",
    "test_df[\"multi_hot_pred\"] = test_df[\"winners_pred\"].map(lambda x: [1 if m in x.split(\"|\") else 0 for m in methods])\n",
    "\n",
    "\n",
    "def get_all_rank(x, col, mode='tie', filter_thre=0.05):\n",
    "    # x is a dataframe with columns ['method', 'wga_te_err'], find the method(s) with smallest wga_te_err, with filter_thre tolerance for tie\n",
    "    try:\n",
    "        float_list = x[col].to_list()\n",
    "        if len(float_list) != 5:\n",
    "            raise ValueError(\"The input list must contain exactly 5 float numbers.\")\n",
    "\n",
    "        sorted_list = sorted(float_list)\n",
    "        ranks = [None] * 5\n",
    "        current_rank = 0\n",
    "        ranks[0] = current_rank\n",
    "\n",
    "        for i in range(1, 5):\n",
    "            if sorted_list[i] - sorted_list[i - 1] < 0.05:\n",
    "                ranks[i] = current_rank\n",
    "            else:\n",
    "                current_rank += 1\n",
    "                ranks[i] = current_rank\n",
    "\n",
    "        rank_map = {v: ranks[i] for i, v in enumerate(sorted_list)}\n",
    "        result = [rank_map[v] for v in float_list]\n",
    "    except:\n",
    "        result = [None] * 5\n",
    "    return result\n",
    "# Apply the function to each group and reset the index to merge back with the original dataframe\n",
    "ranking_series = test_df.groupby(['n', 'sc', 'ci', 'ai', 'y_task', 'a_task'])[['algorithm', 'pred']].apply(lambda x: get_all_rank(x, 'pred')).reset_index(name='pred_rank')\n",
    "# Merge the results back with the original dataframe\n",
    "test_df = test_df.merge(ranking_series, on=['n', 'sc', 'ci', 'ai', 'y_task', 'a_task'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred = test_df.iloc[::5, :]['multi_hot_pred'].to_numpy()\n",
    "y_test = test_df.iloc[::5, :]['multi_hot'].to_numpy()\n",
    "correct = 0\n",
    "for i, curr_y in enumerate(y_test):\n",
    "    curr_pred = np.array(y_pred[i])\n",
    "    curr_y = np.array(curr_y)\n",
    "    # check if the positions of 1s in curr_pred are also 1s in curr_y\n",
    "    pos = np.nonzero(curr_pred)\n",
    "    # print(curr_y)\n",
    "    if np.all(curr_y[pos] == 1):\n",
    "        correct += 1\n",
    "acc = correct / y_test.shape[0]\n",
    "print(f\"Eval accuracy: {acc}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_pred_wga_te_err(x):\n",
    "    pred_winners = x['winners_pred'].iloc[0]\n",
    "    pred_winners = pred_winners.split(\"|\")\n",
    "    pred_winner = np.random.choice(pred_winners)\n",
    "    return x[x['algorithm']==pred_winner]['wga_te_err'].iloc[0]\n",
    "    # err = 0.0\n",
    "    # for w in pred_winners:\n",
    "    #     err += x[x['method']==w]['wga_te_err'].iloc[0]\n",
    "    # err /= len(pred_winners)\n",
    "    # return err\n",
    "pred_err = test_df.groupby(['n', 'sc', 'ci', 'ai', 'y_task', 'a_task'])['wga_te_err', 'algorithm' ,'winners_pred'].apply(lambda x: get_pred_wga_te_err(x))\n",
    "print(pred_err.mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_df = coco_res_ori\n",
    "test_df['pred'] = coco_y_pred\n",
    "\n",
    "test_df['algorithm'] = test_df['algorithm'].map(lambda x: methods[x])\n",
    "\n",
    "def get_rank(x, col, mode='tie', filter_thre=0.05):\n",
    "    # x is a dataframe with columns ['method', 'wga_te_err'], find the method(s) with smallest wga_te_err, with filter_thre tolerance for tie\n",
    "    min_err = x[col].min()\n",
    "    winners = x[x[col] <= min_err + filter_thre][\"algorithm\"].to_list()\n",
    "    return '|'.join(winners)\n",
    "\n",
    "# Apply the function to each group and reset the index to merge back with the original dataframe\n",
    "winners_series = test_df.groupby(['n', 'sc', 'ci', 'ai'])[['algorithm', 'wga_te_err']].apply(lambda x: get_rank(x, 'wga_te_err')).reset_index(name='winners')\n",
    "# Merge the results back with the original dataframe\n",
    "test_df = test_df.merge(winners_series, on=['n', 'sc', 'ci', 'ai'])\n",
    "# methods = [\"ERM\", \"GroupDRO\", \"remax-margin\", \"oversample\", \"undersample\"]\n",
    "test_df[\"multi_hot\"] = test_df[\"winners_x\"].map(lambda x: [1 if m in x.split(\"|\") else 0 for m in methods])\n",
    "\n",
    "\n",
    "# Apply the function to each group and reset the index to merge back with the original dataframe\n",
    "winners_series = test_df.groupby(['n', 'sc', 'ci', 'ai'])[['algorithm', 'pred']].apply(lambda x: get_rank(x, 'pred')).reset_index(name='winners_pred')\n",
    "# Merge the results back with the original dataframe\n",
    "test_df = test_df.merge(winners_series, on=['n', 'sc', 'ci', 'ai'])\n",
    "# methods = [\"ERM\", \"GroupDRO\", \"remax-margin\", \"oversample\", \"undersample\"]\n",
    "test_df[\"multi_hot_pred\"] = test_df[\"winners_pred\"].map(lambda x: [1 if m in x.split(\"|\") else 0 for m in methods])\n",
    "\n",
    "\n",
    "def get_all_rank(x, col, mode='tie', filter_thre=0.05):\n",
    "    # x is a dataframe with columns ['method', 'wga_te_err'], find the method(s) with smallest wga_te_err, with filter_thre tolerance for tie\n",
    "    try:\n",
    "        float_list = x[col].to_list()\n",
    "        if len(float_list) != 5:\n",
    "            raise ValueError(\"The input list must contain exactly 5 float numbers.\")\n",
    "\n",
    "        sorted_list = sorted(float_list)\n",
    "        ranks = [None] * 5\n",
    "        current_rank = 0\n",
    "        ranks[0] = current_rank\n",
    "\n",
    "        for i in range(1, 5):\n",
    "            if sorted_list[i] - sorted_list[i - 1] < 0.05:\n",
    "                ranks[i] = current_rank\n",
    "            else:\n",
    "                current_rank += 1\n",
    "                ranks[i] = current_rank\n",
    "\n",
    "        rank_map = {v: ranks[i] for i, v in enumerate(sorted_list)}\n",
    "        result = [rank_map[v] for v in float_list]\n",
    "    except:\n",
    "        result = [None] * 5\n",
    "    return result\n",
    "# Apply the function to each group and reset the index to merge back with the original dataframe\n",
    "ranking_series = test_df.groupby(['n', 'sc', 'ci', 'ai'])[['algorithm', 'pred']].apply(lambda x: get_all_rank(x, 'pred')).reset_index(name='pred_rank')\n",
    "# Merge the results back with the original dataframe\n",
    "test_df = test_df.merge(ranking_series, on=['n', 'sc', 'ci', 'ai'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred = test_df.iloc[::5, :]['multi_hot_pred'].to_numpy()\n",
    "y_test = test_df.iloc[::5, :]['multi_hot'].to_numpy()\n",
    "correct = 0\n",
    "for i, curr_y in enumerate(y_test):\n",
    "    curr_pred = np.array(y_pred[i])\n",
    "    curr_y = np.array(curr_y)\n",
    "    # check if the positions of 1s in curr_pred are also 1s in curr_y\n",
    "    pos = np.nonzero(curr_pred)\n",
    "    # print(curr_y)\n",
    "    if np.all(curr_y[pos] == 1):\n",
    "        correct += 1\n",
    "acc = correct / y_test.shape[0]\n",
    "print(f\"Eval accuracy: {acc}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_pred_wga_te_err(x):\n",
    "    pred_winners = x['winners_pred'].iloc[0]\n",
    "    pred_winners = pred_winners.split(\"|\")\n",
    "    pred_winner = np.random.choice(pred_winners)\n",
    "    return x[x['algorithm']==pred_winner]['wga_te_err'].iloc[0]\n",
    "    # err = 0.0\n",
    "    # for w in pred_winners:\n",
    "    #     err += x[x['method']==w]['wga_te_err'].iloc[0]\n",
    "    # err /= len(pred_winners)\n",
    "    # return err\n",
    "pred_err = test_df.groupby(['n', 'sc', 'ci', 'ai'])['wga_te_err', 'algorithm' ,'winners_pred'].apply(lambda x: get_pred_wga_te_err(x))\n",
    "print(pred_err.mean())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Ensemble"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "import copy\n",
    "model = \"clip\"\n",
    "path = f\"exps/div_explore/celeba_v2/{model}_lp\"\n",
    "res_files = os.listdir(path)\n",
    "\n",
    "all_res = []\n",
    "for file in res_files:\n",
    "    res_file = os.path.join(path, file)\n",
    "    print(file)\n",
    "    # parse the seed from the file name\n",
    "    res = pd.read_csv(res_file)\n",
    "    res[\"seed\"] = int(file.split('seed')[-1][0])\n",
    "    res[\"y_task\"] = int(file.split('y')[-1].split('_')[0])\n",
    "    res[\"a_task\"] = int(file.split('a')[-1].split('.')[0])\n",
    "    all_res.append(res)\n",
    "all_res = pd.concat(all_res, ignore_index=True)\n",
    "\n",
    "algorithms = ['ERM', 'GroupDRO', 'oversample', 'remax-margin', 'undersample']\n",
    "\n",
    "# rename col\n",
    "all_res.rename(columns={\"method\": \"algorithm\"}, inplace=True)\n",
    "all_res = all_res[[\"n\", \"sc\", \"ci\", \"ai\", \"algorithm\", \"wga_te_err\", \"seed\", \"y_task\", \"a_task\"]]\n",
    "all_res = all_res.groupby(['n', 'sc', 'ci', 'ai', 'algorithm', 'y_task', 'a_task']).mean().reset_index()\n",
    "all_res.drop(columns=[\"seed\"], inplace=True)\n",
    "\n",
    "all_res['rank'] = all_res.groupby(['n', 'sc', 'ci', 'ai', 'y_task', 'a_task'])['wga_te_err'].rank(\"first\")\n",
    "# all_res_ori = copy.deepcopy(all_res)\n",
    "# print(all_res.head(100))\n",
    "\n",
    "def get_gt_rank(x, mode='tie', filter_thre=0.05):\n",
    "    # print(x)\n",
    "    min_err = x['wga_te_err'].min()\n",
    "    winners = x[x['wga_te_err'] <= min_err + filter_thre][\"algorithm\"].to_list()\n",
    "    return '|'.join(winners)\n",
    "\n",
    "# Apply the function to each group and reset the index to merge back with the original dataframe\n",
    "winners_series = all_res.groupby(['n', 'sc', 'ci', 'ai', 'y_task', 'a_task'])[['algorithm', 'wga_te_err']].apply(lambda x: get_gt_rank(x)).reset_index(name='winners')\n",
    "all_res = all_res.merge(winners_series, on=['n', 'sc', 'ci', 'ai', 'y_task', 'a_task'])\n",
    "\n",
    "def get_conf(x, mode='tie', filter_thre=0.05):\n",
    "    # print(x)\n",
    "    min_err = x['wga_te_err'].min()\n",
    "    winners = x[x['wga_te_err'] <= min_err + filter_thre][\"algorithm\"].to_list()\n",
    "    # get minimum of winner and maximum of loser\n",
    "    winner_max = x[x['algorithm'].isin(winners)]['wga_te_err'].max()\n",
    "    loser_min = x[~x['algorithm'].isin(winners)]['wga_te_err'].min()\n",
    "    return loser_min - winner_max\n",
    "conf_series = all_res.groupby(['n', 'sc', 'ci', 'ai', 'y_task', 'a_task'])[['algorithm', 'wga_te_err']].apply(lambda x: get_conf(x)).reset_index(name='conf')\n",
    "all_res = all_res.merge(conf_series, on=['n', 'sc', 'ci', 'ai', 'y_task', 'a_task'])\n",
    "\n",
    "all_res[\"multi_hot\"] = all_res[\"winners\"].map(lambda x: [1 if m in x.split(\"|\") else 0 for m in algorithms])\n",
    "\n",
    "# retain the original results\n",
    "all_res_ori = copy.deepcopy(all_res)\n",
    "# deduplicate\n",
    "all_res = all_res[all_res[\"rank\"]==1.0]\n",
    "\n",
    "# set all nan to 0\n",
    "all_res['conf'].fillna(0, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [],
   "source": [
    "tr_df = pd.read_csv(f\"exps/div_explore/celeba_v2/{model}_embeddings_tr/proc_emb.csv\")\n",
    "te_df = pd.read_csv(f\"exps/div_explore/celeba_v2/{model}_embeddings_te/proc_emb.csv\")\n",
    "all_res_ori = all_res_ori.merge(tr_df, on=['n', 'sc', 'ci', 'ai', 'y_task', 'a_task'])\n",
    "# all_res['c_dist'] = all_res['c_dist'] ceiling\n",
    "# all_res['a_dist'] = all_res['a_dist'].apply(np.ceil).astype(int)\n",
    "# all_res['c_dist'] = all_res['c_dist'].apply(np.ceil).astype(int)\n",
    "# all_res['ca_diff'] = all_res['c_dist'] - all_res['a_dist']\n",
    "# all_res['ca_diff'].describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# to load the coco results\n",
    "model = \"clip\"\n",
    "path = f\"exps/div_explore/coco/{model}_lp\"\n",
    "res_files = os.listdir(path)\n",
    "\n",
    "def load_coco_res(res_files):\n",
    "    all_res = []\n",
    "    for file in res_files:\n",
    "        res_file = os.path.join(path, file)\n",
    "        print(file)\n",
    "        # parse the seed from the file name\n",
    "        res = pd.read_csv(res_file)\n",
    "        res[\"seed\"] = int(file.split('seed')[-1][0])\n",
    "        all_res.append(res)\n",
    "    all_res = pd.concat(all_res, ignore_index=True)\n",
    "\n",
    "    algorithms = ['ERM', 'GroupDRO', 'oversample', 'remax-margin', 'undersample']\n",
    "\n",
    "    # rename col\n",
    "    all_res.rename(columns={\"method\": \"algorithm\"}, inplace=True)\n",
    "    all_res = all_res[[\"n\", \"sc\", \"ci\", \"ai\", \"algorithm\", \"wga_te_err\", \"seed\"]]\n",
    "    all_res = all_res.groupby(['n', 'sc', 'ci', 'ai', 'algorithm']).mean().reset_index()\n",
    "    all_res.drop(columns=[\"seed\"], inplace=True)\n",
    "\n",
    "    all_res['rank'] = all_res.groupby(['n', 'sc', 'ci', 'ai'])['wga_te_err'].rank(\"first\")\n",
    "    # all_res_ori = copy.deepcopy(all_res)\n",
    "    # print(all_res.head(100))\n",
    "\n",
    "    def get_gt_rank(x, mode='tie', filter_thre=0.05):\n",
    "        # print(x)\n",
    "        min_err = x['wga_te_err'].min()\n",
    "        winners = x[x['wga_te_err'] <= min_err + filter_thre][\"algorithm\"].to_list()\n",
    "        return '|'.join(winners)\n",
    "\n",
    "    # Apply the function to each group and reset the index to merge back with the original dataframe\n",
    "    winners_series = all_res.groupby(['n', 'sc', 'ci', 'ai'])[['algorithm', 'wga_te_err']].apply(lambda x: get_gt_rank(x)).reset_index(name='winners')\n",
    "    all_res = all_res.merge(winners_series, on=['n', 'sc', 'ci', 'ai'])\n",
    "\n",
    "    all_res[\"multi_hot\"] = all_res[\"winners\"].map(lambda x: [1 if m in x.split(\"|\") else 0 for m in algorithms])\n",
    "\n",
    "    all_res_ori = copy.deepcopy(all_res)\n",
    "    # deduplicate\n",
    "    all_res = all_res[all_res[\"rank\"]==1.0]\n",
    "    return all_res, all_res_ori\n",
    "\n",
    "coco_res, coco_res_ori = load_coco_res(res_files)\n",
    "\n",
    "tr_df = pd.read_csv(f\"exps/div_explore/coco/{model}_embeddings_tr/proc_emb.csv\")\n",
    "te_df = pd.read_csv(f\"exps/div_explore/coco/{model}_embeddings_te/proc_emb.csv\")\n",
    "coco_res = coco_res.merge(tr_df, on=['n', 'sc', 'ci', 'ai'])\n",
    "# all_res['c_dist'] = all_res['c_dist'] ceiling\n",
    "# all_res['a_dist'] = all_res['a_dist'].apply(np.ceil).astype(int)\n",
    "# all_res['c_dist'] = all_res['c_dist'].apply(np.ceil).astype(int)\n",
    "# coco_res['ca_diff'] = coco_res['c_dist'] - coco_res['a_dist']\n",
    "# coco_res['ca_diff'].describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [],
   "source": [
    "# convert to multiple binary classification\n",
    "def prepare_data(n, tr_size):\n",
    "    from sklearn.model_selection import train_test_split\n",
    "    tr_idx, te_idx = train_test_split(range(n), test_size=0.2, random_state=0)\n",
    "\n",
    "    if tr_size > 0:\n",
    "        num = len(tr_idx)\n",
    "        te_size = (num - tr_size)/num\n",
    "        tr_idx, val_idx = train_test_split(tr_idx, test_size=te_size, random_state=0)\n",
    "        print(\"train size:\", len(tr_idx))\n",
    "\n",
    "    return tr_idx, te_idx\n",
    "\n",
    "def train_binary_classifier(all_df, tr_idx, te_idx, method1, method2, filter_thre=0.05, mode=\"filter\"):\n",
    "    # mode can be tie\n",
    "    import copy\n",
    "    import warnings\n",
    "    from sklearn.preprocessing import StandardScaler\n",
    "    from sklearn.neural_network import MLPClassifier\n",
    "\n",
    "    warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "    all_df_m1 = all_df[all_df[\"algorithm\"]==method1]\n",
    "    all_df_m2 = all_df[all_df[\"algorithm\"]==method2]\n",
    "    all_df_m1[\"loss_diff\"] = all_df_m1['wga_te_err'].to_numpy() - all_df_m2['wga_te_err'].to_numpy()\n",
    "\n",
    "    all_df = copy.deepcopy(all_df_m1)\n",
    "\n",
    "    if mode == \"filter\":\n",
    "        all_df['over_rank'] = all_df['loss_diff'] < 0\n",
    "    elif mode ==\"tie\":\n",
    "        def lossdiff2rank(x):\n",
    "            if x < -filter_thre:\n",
    "                return 1\n",
    "            elif x > filter_thre:\n",
    "                return 0\n",
    "            else:\n",
    "                return 2\n",
    "        all_df['over_rank'] = all_df['loss_diff'].map(lambda x: lossdiff2rank(x))\n",
    "    all_df['over_rank'] = all_df['over_rank'].astype('int')\n",
    "\n",
    "    all_df_tr = all_df.iloc[tr_idx]\n",
    "    all_df_te = all_df.iloc[te_idx]\n",
    "\n",
    "    if mode == \"filter\":\n",
    "        all_df_tr = all_df_tr[(all_df_tr['loss_diff'] < -filter_thre)|(all_df_tr['loss_diff'] > filter_thre)]\n",
    "        all_df_te = all_df_te[(all_df_te['loss_diff'] < -filter_thre)|(all_df_te['loss_diff'] > filter_thre)]\n",
    "\n",
    "    # all_df_tr = all_df_tr[['n', 'sc_', 'ci_', 'ai_', 'c_intra', 'a_intra_i', 'over_rank']]\n",
    "    # all_df_te = all_df_te[['n', 'sc_', 'ci_', 'ai_', 'c_intra', 'a_intra_i', 'over_rank']]\n",
    "\n",
    "\n",
    "\n",
    "    all_df_tr = all_df_tr[['n', 'sc', 'ci', 'ai', 'c_intra', 'a_intra', 'over_rank']]\n",
    "    all_df_te = all_df_te[['n', 'sc', 'ci', 'ai', 'c_intra', 'a_intra', 'over_rank']]\n",
    "\n",
    "    X_train = all_df_tr.to_numpy()[:,:-1].astype('float')\n",
    "    y_train = all_df_tr.to_numpy()[:,-1].astype('int')\n",
    "    X_test = all_df_te.to_numpy()[:,:-1].astype('float')\n",
    "    y_test = all_df_te.to_numpy()[:,-1].astype('int')\n",
    "\n",
    "    # X_train[:, [1,2,3,4,5]] = 1\n",
    "    # X_test[:, [1,2,3,4,5]] = 1\n",
    "\n",
    "    scaler = StandardScaler()\n",
    "    X_train = scaler.fit_transform(X_train)\n",
    "    X_test = scaler.transform(X_test)\n",
    "\n",
    "    coco_feat_list = ['n', 'sc', 'ci', 'ai', 'c_dist', 'a_dist', 'wga_te_err']\n",
    "    coco = coco_res[coco_feat_list]\n",
    "    coco_X = coco.to_numpy()[:,:-1].astype('float')\n",
    "    coco_y = np.array(list(coco.to_numpy()[:, -1]))\n",
    "    coco_X = scaler.transform(coco_X)\n",
    "\n",
    "\n",
    "    # clf = MLPClassifier(random_state=1, max_iter=1000000, verbose=False, tol=5e-3, n_iter_no_change=10000, alpha=0.0001, hidden_layer_sizes=(100,10,)).fit(X_train, y_train)\n",
    "    # clf = MLPClassifier(random_state=1, max_iter=1000000, verbose=False, tol=1e-3, n_iter_no_change=2000, alpha=0.01, hidden_layer_sizes=(100,10,)).fit(X_train, y_train)\n",
    "    clf = MLPClassifier(random_state=1, max_iter=1000000, verbose=False, tol=1e-3, n_iter_no_change=2000, alpha=0.1, hidden_layer_sizes=(100,100,50,50)).fit(X_train, y_train)\n",
    "    train_acc = clf.score(X_train, y_train)\n",
    "    test_acc = clf.score(X_test, y_test)\n",
    "    print(method1, method2, \"train size: \", len(all_df_tr), \"train acc: \", train_acc, \"test acc: \", test_acc)\n",
    "    # compute test acc by class\n",
    "    for i in range(3):\n",
    "        te_idx = y_test == i\n",
    "        tr_idx = y_train == i\n",
    "        # print y distribution\n",
    "        print(\"      class\", i, \"train size: \", len(y_train[tr_idx]))\n",
    "        if len(y_test[te_idx]) == 0:\n",
    "            continue\n",
    "        print(\"      class\", i, \"test acc: \", clf.score(X_test[te_idx], y_test[te_idx]))\n",
    "    # print()\n",
    "\n",
    "    return clf, X_test, coco_X, coco_y\n",
    "\n",
    "def extract_pairwise_results(zoo, X_test):\n",
    "    res = []\n",
    "    # if 1d then expand to 2d\n",
    "    if len(X_test.shape) == 1:\n",
    "        X_test = X_test.reshape(1, -1)\n",
    "    for (m1, m2), clf in zoo.items():\n",
    "        y_pred = clf.predict(X_test)\n",
    "        # print(y_pred)\n",
    "        if y_pred == 1:\n",
    "            res.append((m1, m2, m1))\n",
    "        elif y_pred == 0:\n",
    "            res.append((m1, m2, m2))\n",
    "        else:\n",
    "            res.append((m1, m2, m1))\n",
    "            res.append((m1, m2, m2))\n",
    "    return res\n",
    "\n",
    "def copeland_method(candidate_names, pairwise_results):\n",
    "    # Initialize scores dictionary\n",
    "    scores = {name: 0 for name in candidate_names}\n",
    "\n",
    "    # Update scores based on pairwise results\n",
    "    for result in pairwise_results:\n",
    "        A, B, winner = result\n",
    "        if winner == A:\n",
    "            scores[A] += 1\n",
    "            scores[B] -= 1\n",
    "        elif winner == B:\n",
    "            scores[B] += 1\n",
    "            scores[A] -= 1\n",
    "    # print(scores)\n",
    "    # # Generate the ranking\n",
    "    # ranking = sorted(candidate_names, key=lambda x: scores[x], reverse=True)\n",
    "    # get the ones with the highest score, note that some can have the same score, and we want all them\n",
    "    winners = [name for name in candidate_names if scores[name] == max(scores.values())]\n",
    "\n",
    "    return winners, scores\n",
    "\n",
    "def bradley_terry_method(candidate_names, pairwise_results):\n",
    "    from scipy.optimize import minimize\n",
    "\n",
    "    n = len(candidate_names)\n",
    "    candidate_index = {name: i for i, name in enumerate(candidate_names)}\n",
    "\n",
    "    # Initialize ability scores\n",
    "    abilities = np.zeros(n)\n",
    "\n",
    "    def log_likelihood(abilities):\n",
    "        ll = 0\n",
    "        for A, B, winner in pairwise_results:\n",
    "            i, j = candidate_index[A], candidate_index[B]\n",
    "            pi = 1 / (1 + np.exp(abilities[j] - abilities[i]))\n",
    "            if winner == A:\n",
    "                ll += np.log(pi)\n",
    "            elif winner == B:\n",
    "                ll += np.log(1 - pi)\n",
    "        return -ll\n",
    "\n",
    "    result = minimize(log_likelihood, abilities, method='BFGS')\n",
    "    abilities = result.x\n",
    "\n",
    "    # Generate the ranking\n",
    "    ranking = sorted(candidate_names, key=lambda x: abilities[candidate_index[x]], reverse=True)\n",
    "\n",
    "    return ranking[0] # , abilities\n",
    "\n",
    "def binary_classifiers(all_df, tr_size=-1, mode=\"filter\"):\n",
    "    import itertools\n",
    "    methods = np.array([\"ERM\", \"GroupDRO\", \"oversample\", \"undersample\", \"remax-margin\"])\n",
    "    combinations = np.array(list(itertools.combinations(methods, 2)))\n",
    "\n",
    "    num_exps = all_df[all_df['algorithm']==methods[0]].shape[0]\n",
    "    tr_idx, te_idx = prepare_data(num_exps, tr_size)\n",
    "\n",
    "    zoo = {}\n",
    "    for m1, m2 in combinations:\n",
    "        zoo[(m1, m2)], X_test, coco_X, coco_y = train_binary_classifier(all_df, tr_idx, te_idx, m1, m2, filter_thre=0.05, mode=mode)\n",
    "    return zoo, te_idx, X_test, methods, coco_X, coco_y\n",
    "\n",
    "def ranking_acc(methods, zoo, X_test, y_test, mode=\"copeland\", acc_mode=\"jac\"):\n",
    "\n",
    "    def jac_sim(list1, list2):\n",
    "        set1, set2 = set(list1), set(list2)\n",
    "        intersection = len(set1.intersection(set2))\n",
    "        union = len(set1.union(set2))\n",
    "        return intersection / union\n",
    "\n",
    "    if mode == \"copeland\":\n",
    "        eval_fn = copeland_method\n",
    "    elif mode == \"bradley_terry\":\n",
    "        eval_fn = bradley_terry_method\n",
    "    else:\n",
    "        raise ValueError(f\"unknown mode {mode}\")\n",
    "\n",
    "    y_pred, scores = [], []\n",
    "    for X in X_test:\n",
    "        pairwise_results = extract_pairwise_results(zoo, X)\n",
    "        winners, s = eval_fn(methods, pairwise_results)\n",
    "        y_pred.append(winners)\n",
    "        scores.append(s)\n",
    "\n",
    "    if acc_mode == \"jac\":\n",
    "        res = []\n",
    "        for i in range(len(y_pred)):\n",
    "            res.append(jac_sim(y_test[i], y_pred[i]))\n",
    "        acc = np.mean(res)\n",
    "        print(f\"{mode} method accuracy: \", acc)\n",
    "    elif acc_mode == \"accuracy\":\n",
    "        correct = 0\n",
    "        for i in range(len(y_pred)):\n",
    "            if set(y_test[i]) == set(y_pred[i]):\n",
    "                correct += 1\n",
    "        acc = correct / len(y_test)\n",
    "        print(f\"{mode} method accuracy: \", acc)\n",
    "    elif acc_mode == \"all\":\n",
    "        res = []\n",
    "        for i in range(len(y_pred)):\n",
    "            res.append(jac_sim(y_test[i], y_pred[i]))\n",
    "        acc = np.mean(res)\n",
    "        print(f\"{mode} method accuracy: \", acc)\n",
    "\n",
    "        correct = 0\n",
    "        for i in range(len(y_pred)):\n",
    "            if set(y_test[i]) == set(y_pred[i]):\n",
    "                correct += 1\n",
    "        acc = correct / len(y_test)\n",
    "        print(f\"{mode} method accuracy: \", acc)\n",
    "    elif acc_mode == \"soft 0-1\":\n",
    "        correct = 0\n",
    "        for i, curr_y in enumerate(y_test):\n",
    "            curr_y = np.array(curr_y)\n",
    "            curr_pred =np.array(y_pred[i])\n",
    "            # # check if the positions of 1s in curr_pred are also 1s in curr_y\n",
    "            # pos = np.nonzero(curr_pred)[0]\n",
    "            # print(\"1\", curr_y)\n",
    "            # print(\"2\", curr_pred)\n",
    "            # print(\"3\", pos)\n",
    "            # (array([0, 1, 2, 3, 4]),)\n",
    "            # ['GroupDRO']\n",
    "            # ['GroupDRO' 'oversample' 'undersample']\n",
    "            # (array([0, 1, 2]),)\n",
    "            if np.all(np.isin(curr_pred, curr_y)):\n",
    "                correct += 1\n",
    "        acc = correct / y_test.shape[0]\n",
    "        print(f\"{mode} method accuracy: {acc}\")\n",
    "    else:\n",
    "        raise ValueError(f\"unknown acc_mode {acc_mode}\")\n",
    "\n",
    "    return y_pred, scores\n",
    "\n",
    "    # # correct = 0\n",
    "    # # for i in range(len(y_pred)):\n",
    "    # #     if y_test[i] in y_pred[i]:\n",
    "    # #         correct += 1\n",
    "    # # acc = correct / len(y_test)\n",
    "    # # print(f\"{mode} method accuracy: \", acc)\n",
    "    # res = []\n",
    "    # for i in range(len(y_pred)):\n",
    "    #     res.append(jac_sim(y_test[i], y_pred[i]))\n",
    "    # acc = np.mean(res)\n",
    "    # print(f\"{mode} method accuracy: \", acc)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "zoo, te_idx, X_test, methods, coco_X, coco_y = binary_classifiers(all_res_ori, tr_size=-1, mode=\"tie\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# y_test = all_df_rank[\"method\"].to_numpy()[te_idx]\n",
    "all_df_rank = all_res_ori[all_res_ori[\"rank\"]==1.0]\n",
    "winners = all_df_rank[\"winners\"].to_numpy()[te_idx]\n",
    "y_test = np.array([w.split(\"|\") for w in winners])\n",
    "\n",
    "# filter_idx = []\n",
    "# for i in range(len(y_test)):\n",
    "#     if len(y_test[i]) != 5:\n",
    "#         filter_idx.append(i)\n",
    "# y_test = y_test[filter_idx]\n",
    "# X_test = X_test[filter_idx]\n",
    "\n",
    "y_pred, scores = ranking_acc(methods, zoo, X_test, y_test, mode=\"copeland\", acc_mode=\"all\")\n",
    "y_pred, scores = ranking_acc(methods, zoo, X_test, y_test, mode=\"copeland\", acc_mode=\"soft 0-1\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "methods = [\"ERM\", \"GroupDRO\", \"remax-margin\", \"oversample\", \"undersample\"]\n",
    "test_df = all_df_rank.iloc[te_idx]\n",
    "eval_fn = copeland_method\n",
    "y_pred, scores = [], []\n",
    "for X in X_test:\n",
    "    pairwise_results = extract_pairwise_results(zoo, X)\n",
    "    winners, s = eval_fn(methods, pairwise_results)\n",
    "    y_pred.append(winners)\n",
    "    scores.append(s)\n",
    "\n",
    "test_df['winners_pred'] = y_pred\n",
    "fine_test_df = all_res_ori.merge(test_df, on=['n', 'sc', 'ci', 'ai', 'y_task', 'a_task'])\n",
    "\n",
    "def get_pred_wga_te_err(x):\n",
    "    pred_winners = x['winners_pred'].iloc[0]\n",
    "    if len(pred_winners) == 0:\n",
    "        # randomly select one\n",
    "        pred_winner = np.random.choice(methods)\n",
    "    else:\n",
    "        # random select from the predicted\n",
    "        pred_winner = np.random.choice(pred_winners)\n",
    "    return x[x['algorithm_x']==pred_winner]['wga_te_err_x'].iloc[0]\n",
    "pred_err = fine_test_df.groupby(['n', 'sc', 'ci', 'ai', 'y_task', 'a_task'])['wga_te_err_x', 'algorithm_x' ,'winners_pred'].apply(lambda x: get_pred_wga_te_err(x))\n",
    "pred_err.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# y_test = all_df_rank[\"method\"].to_numpy()[te_idx]\n",
    "winners = coco_res[\"winners\"].to_numpy()\n",
    "coco_y = np.array([w.split(\"|\") for w in winners])\n",
    "\n",
    "y_pred, scores = ranking_acc(methods, zoo, coco_X, coco_y, mode=\"copeland\", acc_mode=\"soft 0-1\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "methods = [\"ERM\", \"GroupDRO\", \"remax-margin\", \"oversample\", \"undersample\"]\n",
    "test_df = coco_res\n",
    "eval_fn = copeland_method\n",
    "y_pred, scores = [], []\n",
    "for X in coco_X:\n",
    "    pairwise_results = extract_pairwise_results(zoo, X)\n",
    "    winners, s = eval_fn(methods, pairwise_results)\n",
    "    y_pred.append(winners)\n",
    "    scores.append(s)\n",
    "\n",
    "test_df['winners_pred'] = y_pred\n",
    "fine_test_df = coco_res_ori.merge(test_df, on=['n', 'sc', 'ci', 'ai'])\n",
    "\n",
    "def get_pred_wga_te_err(x):\n",
    "    pred_winners = x['winners_pred'].iloc[0]\n",
    "    if len(pred_winners) == 0:\n",
    "        # randomly select one\n",
    "        pred_winner = np.random.choice(methods)\n",
    "    else:\n",
    "        # random select from the predicted\n",
    "        pred_winner = np.random.choice(pred_winners)\n",
    "    return x[x['algorithm_x']==pred_winner]['wga_te_err_x'].iloc[0]\n",
    "pred_err = fine_test_df.groupby(['n', 'sc', 'ci', 'ai'])['wga_te_err_x', 'algorithm_x' ,'winners_pred'].apply(lambda x: get_pred_wga_te_err(x))\n",
    "pred_err.mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### check results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    y_pred = trained_model(X_test)\n",
    "    y_pred = torch.sigmoid(y_pred)\n",
    "    y_pred = (y_pred > 0.5).float()\n",
    "    y_pred = y_pred.numpy()\n",
    "    y_test = y_test.numpy()\n",
    "# compute confusion matrix, and visualize it\n",
    "from sklearn.metrics import confusion_matrix\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "cm = confusion_matrix(y_test.argmax(axis=1), y_pred.argmax(axis=1))\n",
    "sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Attribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "\n",
    "model = \"resnet\"\n",
    "path = f\"exps/div_explore/celeba_v2/{model}_lp\"\n",
    "res_files = os.listdir(path)\n",
    "seed = 0\n",
    "\n",
    "all_res = []\n",
    "for file in res_files:\n",
    "    res_file = os.path.join(path, file)\n",
    "    curr_seed = int(file.split('seed')[-1][0])\n",
    "    if curr_seed == seed:\n",
    "        print(file)\n",
    "        res = pd.read_csv(res_file)\n",
    "        res[\"seed\"] = curr_seed\n",
    "        res[\"y_task\"] = int(file.split('y')[-1].split('_')[0])\n",
    "        res[\"a_task\"] = int(file.split('a')[-1].split('.')[0])\n",
    "        all_res.append(res)\n",
    "all_res = pd.concat(all_res, ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = \"resnet\"\n",
    "path = f\"exps/div_explore/celeba_v2/{model}_lp\"\n",
    "res_files = os.listdir(path)\n",
    "seed = 1\n",
    "\n",
    "all_res1 = []\n",
    "for file in res_files:\n",
    "    res_file = os.path.join(path, file)\n",
    "    curr_seed = int(file.split('seed')[-1][0])\n",
    "    if curr_seed == seed:\n",
    "        print(file)\n",
    "        res = pd.read_csv(res_file)\n",
    "        res[\"seed\"] = curr_seed\n",
    "        res[\"y_task\"] = int(file.split('y')[-1].split('_')[0])\n",
    "        res[\"a_task\"] = int(file.split('a')[-1].split('.')[0])\n",
    "        all_res1.append(res)\n",
    "all_res1 = pd.concat(all_res1, ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = \"resnet\"\n",
    "path = f\"exps/div_explore/celeba_v2/{model}_lp\"\n",
    "res_files = os.listdir(path)\n",
    "seed = 2\n",
    "\n",
    "all_res2 = []\n",
    "for file in res_files:\n",
    "    res_file = os.path.join(path, file)\n",
    "    curr_seed = int(file.split('seed')[-1][0])\n",
    "    if curr_seed == seed:\n",
    "        print(file)\n",
    "        res = pd.read_csv(res_file)\n",
    "        res[\"seed\"] = curr_seed\n",
    "        res[\"y_task\"] = int(file.split('y')[-1].split('_')[0])\n",
    "        res[\"a_task\"] = int(file.split('a')[-1].split('.')[0])\n",
    "        all_res2.append(res)\n",
    "all_res2 = pd.concat(all_res2, ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 193,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_res_merge = all_res.merge(all_res1, on=['n', 'sc', 'ci', 'ai', 'y_task', 'a_task', 'method'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 194,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_res_merge = all_res_merge.merge(all_res2, on=['n', 'sc', 'ci', 'ai', 'y_task', 'a_task', 'method'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 195,
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute variance of three columns\n",
    "all_res_merge['var_causal'] = all_res_merge[['wga_te_err_x', 'wga_te_err_y', 'wga_te_err']].var(axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_res_merge['var_causal'].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# resnet\n",
    "import seaborn as sns\n",
    "sns.histplot(all_res_merge['wga_te_err_x'] - all_res_merge['wga_te_err_y'], bins=50, kde=True)\n",
    "(all_res_merge['wga_te_err_x'] - all_res_merge['wga_te_err_y']).abs().mean()\n",
    "lp_hist = all_res_merge['wga_te_err_x'] - all_res_merge['wga_te_err_y']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# clip\n",
    "import seaborn as sns\n",
    "sns.histplot(all_res_merge['wga_te_err_x'] - all_res_merge['wga_te_err_y'], bins=50, kde=True)\n",
    "(all_res_merge['wga_te_err_x'] - all_res_merge['wga_te_err_y']).abs().mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 199,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import copy\n",
    "\n",
    "def load_data(y_a, n, seed):\n",
    "    hseed, dseed = 0, [0]\n",
    "    data = 'CelebA'\n",
    "    attr = {'ERM': 'Yes', 'GroupDRO': 'Yes', 'OverSample': 'Yes', 'UnderSample': 'Yes', 'ReWeightLogits': 'Yes'}\n",
    "    algorithms = attr.keys()\n",
    "\n",
    "    seed_diff = []\n",
    "    all_res = pd.DataFrame([])\n",
    "    for s in dseed:\n",
    "        for alg in algorithms:\n",
    "            res_dir = f'exps/div_explore/celeba_v2/celeba_y{y_a[0]}_a{y_a[1]}_ds0_n{n}/{alg.lower()}'\n",
    "            # find all the sub-directory under res_dir\n",
    "            dirs = os.listdir(res_dir)\n",
    "            for d in dirs:\n",
    "                if attr[alg] in d:\n",
    "                    res_file = os.path.join(res_dir, d)\n",
    "                    exps = os.listdir(res_file)[seed]\n",
    "                    # for exp in exps:\n",
    "                    #     res_file = os.path.join(res_dir, d, exp, 'results.json')\n",
    "                    #     print(res_file)\n",
    "                    #     if os.path.exists(res_file):\n",
    "                    #         final_res = get_trajectory_results(res_file, s)\n",
    "                    #         all_res = pd.concat([all_res, final_res])\n",
    "                    res_file = os.path.join(res_dir, d, exps, 'results.json')\n",
    "                    print(res_file)\n",
    "                    if os.path.exists(res_file):\n",
    "                        final_res = get_trajectory_results(res_file, s)\n",
    "                        all_res = pd.concat([all_res, final_res])\n",
    "\n",
    "    all_res = all_res[all_res['step']==1000]\n",
    "    all_res[\"wga_te_err\"] = 1 - all_res[\"metric\"]\n",
    "    all_res = all_res[[\"sc\", \"ci\", \"ai\", \"algorithm\", \"wga_te_err\", \"seed\"]]\n",
    "    all_res['n'] = n\n",
    "    all_res['n'] = all_res['n'].astype('str')\n",
    "\n",
    "    return all_res\n",
    "\n",
    "def get_trajectory_results(file_path, ds, metric='te_worst_acc'):\n",
    "    # List to hold all JSON objects\n",
    "    results = []\n",
    "\n",
    "    # Open the file and read it line by line or as a whole if the objects are not line-delimited\n",
    "    with open(file_path, 'r') as file:\n",
    "        file_contents = file.read()\n",
    "        # Attempt to split the file contents by a delimiter if they're not newline-delimited\n",
    "        # This delimiter needs to be defined based on your specific file structure\n",
    "        json_objects = file_contents.split('}\\n{')  # Example: split on possible JSON end/start\n",
    "\n",
    "        # Correct split parts and parse each as a JSON object\n",
    "        for i, obj in enumerate(json_objects):\n",
    "            try:\n",
    "                # Add missing curly braces if split was done in the middle of objects\n",
    "                if i != 0:\n",
    "                    obj = '{' + obj\n",
    "                if i != len(json_objects) - 1:\n",
    "                    obj += '}'\n",
    "                result = json.loads(obj)\n",
    "                # print(json.dumps(result, indent=4))\n",
    "                results.append(result)\n",
    "            except json.JSONDecodeError as e:\n",
    "                print(f\"Error decoding JSON object: {e}\")\n",
    "\n",
    "    results_df = pd.DataFrame(results)\n",
    "    trajectory = {'sc': [], 'ci': [], 'ai': [], 'algorithm': [], 'step': [], 'metric': [], 'seed': [], 'data_seed': []}\n",
    "\n",
    "    sc = float(results_df.iloc[-1]['args']['metadata'].split('sc')[-1].split('_')[0])\n",
    "    ci = float(results_df.iloc[-1]['args']['metadata'].split('ci')[-1].split('_')[0])\n",
    "    ai = float(results_df.iloc[-1]['args']['metadata'].split('ai')[-1][:4])\n",
    "    # sc = float(results_df.iloc[-1]['args']['cmnist_spur_prob'])\n",
    "    # ci = float(results_df.iloc[-1]['args']['cmnist_label_prob'])\n",
    "    # ai = float(results_df.iloc[-1]['args']['cmnist_attr_prob'])\n",
    "    alg = results_df.iloc[-1]['args']['algorithm']\n",
    "    for row in results_df.iterrows():\n",
    "        # extract the data\n",
    "        trajectory['sc'].append(sc)\n",
    "        trajectory['ci'].append(ci)\n",
    "        trajectory['ai'].append(ai)\n",
    "        trajectory['algorithm'].append(alg)\n",
    "        trajectory['step'].append(row[1]['step'])\n",
    "        trajectory['metric'].append(row[1][metric])\n",
    "        trajectory['seed'].append(row[1]['args']['seed'])\n",
    "        trajectory['data_seed'].append(ds)\n",
    "\n",
    "    # # smooth the 'metric'\n",
    "    # trajectory['metric'] = np.convolve(trajectory['metric'], np.ones(10)/10, mode='valid')\n",
    "\n",
    "    return pd.DataFrame(trajectory)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed=0\n",
    "all_res = []\n",
    "for i, y_a in enumerate([[2, 31], [21, 36], [8, 20], [25, 19]]):\n",
    "    for n in [200, 500, 1000, 2000, 5000, 10000]:\n",
    "        curr_res = load_data(y_a, n, seed)\n",
    "        curr_res[\"task_id\"] = i\n",
    "        all_res.append(curr_res)\n",
    "all_res = pd.concat(all_res, ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed=1\n",
    "all_res1 = []\n",
    "for i, y_a in enumerate([[2, 31], [21, 36], [8, 20], [25, 19]]):\n",
    "    for n in [200, 500, 1000, 2000, 5000, 10000]:\n",
    "        curr_res = load_data(y_a, n, seed)\n",
    "        curr_res[\"task_id\"] = i\n",
    "        all_res1.append(curr_res)\n",
    "all_res1 = pd.concat(all_res1, ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 203,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_res_merge = all_res.merge(all_res1, on=['n', 'sc', 'ci', 'ai', 'task_id', 'algorithm'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_res_merge"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# full resnet\n",
    "import seaborn as sns\n",
    "sns.histplot(all_res_merge['wga_te_err_x'] - all_res_merge['wga_te_err_y'], bins=50, kde=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "div_backup",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
