{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 120,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "import copy\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 121,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_data(data, model, path, identifier):\n",
    "    data_path = os.path.join(path, f\"{model}_lp\")\n",
    "    res_files = os.listdir(data_path)\n",
    "    all_res = []\n",
    "    for file in res_files:\n",
    "        res_file = os.path.join(data_path, file)\n",
    "        # print(file)\n",
    "        # parse the seed from the file name\n",
    "        res = pd.read_csv(res_file)\n",
    "        res[\"seed\"] = int(file.split('seed')[-1][0])\n",
    "        if data == \"celeba\":\n",
    "            res[\"y_task\"] = int(file.split('y')[-1].split('_')[0])\n",
    "            res[\"a_task\"] = int(file.split('a')[-1].split('.')[0])\n",
    "        all_res.append(res)\n",
    "    all_res = pd.concat(all_res, ignore_index=True)\n",
    "    all_res.rename(columns={\"method\": \"algorithm\"}, inplace=True)\n",
    "\n",
    "    # average over seeds\n",
    "    all_res = all_res[identifier + [\"algorithm\", \"wga_te_err\", \"seed\"]]\n",
    "    all_res = all_res.groupby(identifier + [\"algorithm\"]).mean().reset_index()\n",
    "    all_res.drop(columns=[\"seed\"], inplace=True)\n",
    "\n",
    "    # get the best performing algorithm for each dataset, used for deduplicate\n",
    "    all_res['rank'] = all_res.groupby(identifier)['wga_te_err'].rank(\"first\")\n",
    "\n",
    "    # get the winners for each dataset\n",
    "    def get_gt_rank(x, filter_thre=0.05):\n",
    "        # print(x)\n",
    "        min_err = x['wga_te_err'].min()\n",
    "        winners = x[x['wga_te_err'] <= min_err + filter_thre][\"algorithm\"].to_list()\n",
    "        return '|'.join(winners)\n",
    "    # apply the function to each group and reset the index to merge back with the original dataframe\n",
    "    winners_series = all_res.groupby(identifier)[['algorithm', 'wga_te_err']].apply(lambda x: get_gt_rank(x)).reset_index(name='winners')\n",
    "    all_res = all_res.merge(winners_series, on=identifier)\n",
    "    # get the multi-hot encoding of the winners\n",
    "    all_res[\"multi_hot\"] = all_res[\"winners\"].map(lambda x: [1 if m in x.split(\"|\") else 0 for m in algorithms])\n",
    "\n",
    "    # get the confidence for each dataset\n",
    "    def get_conf(x, filter_thre=0.05):\n",
    "        min_err = x['wga_te_err'].min()\n",
    "        winners = x[x['wga_te_err'] <= min_err + filter_thre][\"algorithm\"].to_list()\n",
    "        # get minimum of winner and maximum of loser\n",
    "        winner_max = x[x['algorithm'].isin(winners)]['wga_te_err'].max()\n",
    "        loser_min = x[~x['algorithm'].isin(winners)]['wga_te_err'].min()\n",
    "        return loser_min - winner_max\n",
    "    conf_series = all_res.groupby(identifier)[['algorithm', 'wga_te_err']].apply(lambda x: get_conf(x)).reset_index(name='conf')\n",
    "    all_res = all_res.merge(conf_series, on=identifier)\n",
    "\n",
    "    # load extra data (static meta-features)\n",
    "    tr_df = pd.read_csv(os.path.join(path, f\"{model}_embeddings_tr/proc_emb.csv\"))\n",
    "    all_res = all_res.merge(tr_df, on=identifier)\n",
    "    if data == \"coco\":\n",
    "        # rename columns\n",
    "        all_res.rename(columns={\"c_dist\": \"c_intra\", \"a_dist\": \"a_intra\"}, inplace=True)\n",
    "\n",
    "    # retain the original results\n",
    "    all_res_ori = copy.deepcopy(all_res)\n",
    "    # deduplicate\n",
    "    all_res = all_res[all_res[\"rank\"]==1.0]\n",
    "    # set all nan to 0\n",
    "    all_res['conf'].fillna(0, inplace=True)\n",
    "\n",
    "    return all_res, all_res_ori\n",
    "\n",
    "\n",
    "def prepare_data(celeba_df, celeba_df_full, coco_df, coco_df_full, celeba_feats, coco_feats, seed):\n",
    "    # split the data and convert to trainable format\n",
    "    all_X = celeba_df[celeba_feats].to_numpy()[:,:-1].astype('float')\n",
    "    all_y = np.array(list(celeba_df[celeba_feats].to_numpy()[:,-1]))\n",
    "\n",
    "    from sklearn.preprocessing import StandardScaler\n",
    "    from sklearn.model_selection import train_test_split\n",
    "\n",
    "    X_train, X_test, y_train, y_test, tr_idx, te_idx = train_test_split(all_X, all_y, range(len(all_X)), test_size=0.2, random_state=seed)\n",
    "\n",
    "    scaler = StandardScaler()\n",
    "    X_train = scaler.fit_transform(X_train)\n",
    "    X_test = scaler.transform(X_test)\n",
    "\n",
    "    X_test_coco = coco_df[coco_feats].to_numpy()[:,:-1].astype('float')\n",
    "    y_test_coco = np.array(list(coco_df[coco_feats].to_numpy()[:,-1]))\n",
    "    X_test_coco = scaler.transform(X_test_coco)\n",
    "\n",
    "    # process celeba_df_full\n",
    "    tr_idx_full = np.concatenate([np.arange(5*i, 5*i+5) for i in tr_idx])\n",
    "    te_idx_full = np.concatenate([np.arange(5*i, 5*i+5) for i in te_idx])\n",
    "    all_X_full = celeba_df_full[celeba_feats[:-1]+[\"algorithm\"]]\n",
    "    all_X_full['algorithm'] = all_X_full['algorithm'].map(lambda x: algorithms.index(x))\n",
    "    all_X_full = all_X_full.to_numpy().astype('float')\n",
    "    all_X_alg = all_X_full[:, -1]\n",
    "    all_y_full = celeba_df_full['wga_te_err'].to_numpy()\n",
    "    X_train_full, y_train_full, X_test_full, y_test_full = all_X_full[tr_idx_full], all_y_full[tr_idx_full], all_X_full[te_idx_full], all_y_full[te_idx_full]\n",
    "    scaler = StandardScaler()\n",
    "    X_train_full = scaler.fit_transform(X_train_full)\n",
    "    X_test_full = scaler.transform(X_test_full)\n",
    "    X_train_alg = all_X_alg[tr_idx_full]\n",
    "    X_test_alg = all_X_alg[te_idx_full]\n",
    "\n",
    "    # process celeba_df_full\n",
    "    X_test_coco_full = coco_df_full[coco_feats[:-1]+[\"algorithm\"]]\n",
    "    X_test_coco_full['algorithm'] = X_test_coco_full['algorithm'].map(lambda x: algorithms.index(x)).to_numpy().astype('float')\n",
    "    y_test_coco_full = coco_df_full['wga_te_err'].to_numpy()\n",
    "    X_test_coco_full = scaler.transform(X_test_coco_full)\n",
    "\n",
    "    return X_train, X_test, y_train, y_test, X_train_full, y_train_full, X_test_full, y_test_full, tr_idx, te_idx, X_test_coco, y_test_coco, X_test_coco_full, y_test_coco_full, X_train_alg, X_test_alg, tr_idx_full, te_idx_full\n",
    "\n",
    "\n",
    "def eval_acc(y_test, y_pred, mode, verbose=True):\n",
    "    # assuming numpy arrays\n",
    "    if mode == \"0-1\":\n",
    "        # acc = ((y_pred == y_test).float().sum(dim=1)==5).float().mean()\n",
    "        acc = (y_pred == y_test).all(axis=1).mean()\n",
    "        if verbose:\n",
    "            print(f\"Eval {mode} accuracy: {acc}\")\n",
    "    elif mode == \"soft 0-1\":\n",
    "        correct = 0\n",
    "        for i, curr_y in enumerate(y_test):\n",
    "            curr_pred = y_pred[i]\n",
    "            # check if the positions of 1s in curr_pred are also 1s in curr_y\n",
    "            pos = np.where(curr_pred==1)[0]\n",
    "            if np.all(curr_y[pos] == 1):\n",
    "                correct += 1\n",
    "        acc = correct / y_test.shape[0]\n",
    "        if verbose:\n",
    "            print(f\"Eval {mode} accuracy: {acc}\")\n",
    "\n",
    "\n",
    "def eval_wga_err(test_df, y_pred, df_full, identifier):\n",
    "    # check if y_pred is already a list\n",
    "    if not isinstance(y_pred, list):\n",
    "        y_pred = y_pred.tolist()\n",
    "    test_df['pred'] = y_pred\n",
    "    winners_pred = []\n",
    "    for y_ in y_pred:\n",
    "        y_ = np.array(y_)\n",
    "        winners_pred.append(np.array(algorithms)[np.where(y_==1)[0]])\n",
    "    test_df['winners_pred'] = winners_pred\n",
    "    # need to transform the multi-hot vector to the algorithm name\n",
    "    fine_test_df = df_full.merge(test_df, on=identifier)\n",
    "    def get_pred_wga_te_err(x):\n",
    "        pred_winners = x['winners_pred'].iloc[0]\n",
    "        if len(pred_winners) == 0:\n",
    "            # randomly select one\n",
    "            pred_winner = np.random.choice(algorithms)\n",
    "        else:\n",
    "            pred_winner = np.random.choice(pred_winners)\n",
    "        return x[x['algorithm_x']==pred_winner]['wga_te_err_x'].iloc[0]\n",
    "    pred_err = fine_test_df.groupby(identifier)['wga_te_err_x', 'algorithm_x' ,'winners_pred'].apply(lambda x: get_pred_wga_te_err(x))\n",
    "    print(f\"Eval wg err: {pred_err.mean()}\")\n",
    "    return pred_err.mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### E2E"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 139,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "\n",
    "# Define the MLP model using nn.Sequential\n",
    "class MLPTorch(nn.Module):\n",
    "    def __init__(self, input_size, hidden_layer_sizes, output_size):\n",
    "        super(MLPTorch, self).__init__()\n",
    "        layers = []\n",
    "        layer_sizes = [input_size] + list(hidden_layer_sizes)\n",
    "        for i in range(len(layer_sizes) - 1):\n",
    "            layers.append(nn.Linear(layer_sizes[i], layer_sizes[i+1]))\n",
    "            layers.append(nn.ReLU())\n",
    "        layers.append(nn.Linear(layer_sizes[-1], output_size))\n",
    "        self.network = nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.network(x)\n",
    "\n",
    "# Define the linear model\n",
    "class LinearTorch(nn.Module):\n",
    "    def __init__(self, input_size, output_size):\n",
    "        super(LinearTorch, self).__init__()\n",
    "        self.linear = nn.Linear(input_size, output_size)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.linear(x)\n",
    "\n",
    "# Define the KNN model using the sklearn implementation\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "class KNN:\n",
    "    def __init__(self, n_neighbors):\n",
    "        self.model = KNeighborsClassifier(n_neighbors=n_neighbors)\n",
    "\n",
    "    def fit(self, X, y):\n",
    "        self.model.fit(X, y)\n",
    "\n",
    "    def predict(self, X):\n",
    "        return self.model.predict(X)\n",
    "\n",
    "# Training function\n",
    "def train_model(model, dataloader, criterion, optimizer, num_epochs, patience, tol, verbose=True):\n",
    "    best_loss = float('inf')\n",
    "    patience_counter = 0\n",
    "\n",
    "    for epoch in range(num_epochs):\n",
    "        model.train()\n",
    "        running_loss = 0.0\n",
    "\n",
    "        for inputs, labels, weights in dataloader:\n",
    "            optimizer.zero_grad()\n",
    "            outputs = model(inputs)\n",
    "            # print(outputs.shape, labels.shape)\n",
    "            loss = criterion(outputs, labels, weights)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            running_loss += loss.item() * inputs.size(0)\n",
    "\n",
    "        epoch_loss = running_loss / len(dataloader.dataset)\n",
    "        if verbose:\n",
    "            print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}')\n",
    "\n",
    "        if epoch_loss < best_loss - tol:\n",
    "            best_loss = epoch_loss\n",
    "            patience_counter = 0\n",
    "        else:\n",
    "            patience_counter += 1\n",
    "\n",
    "        if patience_counter >= patience:\n",
    "            print(\"Early stopping triggered\")\n",
    "            break\n",
    "\n",
    "    return model\n",
    "\n",
    "class WeightedBCELoss(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(WeightedBCELoss, self).__init__()\n",
    "        self.bce = nn.BCEWithLogitsLoss()\n",
    "\n",
    "    def forward(self, outputs, targets, weights):\n",
    "        loss = self.bce(outputs, targets)\n",
    "        # print(weights)\n",
    "        weighted_loss = loss * weights.unsqueeze(1)\n",
    "        return weighted_loss.mean()\n",
    "\n",
    "def sigmoid(x, k):\n",
    "    return 1 / (1 + np.exp(-k * x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 464,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_e2e(X_train, y_train, conf_train=None, verbose=True, naive_feat=False):\n",
    "    # Hyperparameters and data\n",
    "    input_size = X_train.shape[1]\n",
    "    # hidden_layer_sizes = (100, 100, 50, 50)\n",
    "    hidden_layer_sizes = (100, 100, 50, 50)\n",
    "    output_size = len(algorithms)\n",
    "    num_epochs = 500\n",
    "    patience = 2000\n",
    "    tol = 1e-4\n",
    "    # alpha = 0.0001\n",
    "    alpha=0.0001\n",
    "    # batch_size = 256\n",
    "    batch_size = 64\n",
    "\n",
    "    if conf_train is None:\n",
    "        conf_train = np.ones(y_train.shape[0])\n",
    "\n",
    "    # convert to torch tensors\n",
    "    X_train = torch.tensor(X_train).float()\n",
    "    y_train = torch.tensor(y_train).float()\n",
    "    conf_train = torch.tensor(conf_train).float()\n",
    "\n",
    "    if naive_feat:\n",
    "        X_train[:,1:] = 0\n",
    "\n",
    "    # Create dataloader\n",
    "    tr_dataset = TensorDataset(X_train, y_train, conf_train)\n",
    "    tr_dataloader = DataLoader(tr_dataset, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "    # Initialize model, criterion, and optimizer\n",
    "    model = MLPTorch(input_size, hidden_layer_sizes, output_size)\n",
    "    # model = LinearTorch(input_size, output_size)\n",
    "\n",
    "    # criterion = nn.BCEWithLogitsLoss()\n",
    "    criterion = WeightedBCELoss()\n",
    "    optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=alpha)\n",
    "    # optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=alpha)\n",
    "    # Train the model\n",
    "    trained_model = train_model(model, tr_dataloader, criterion, optimizer, num_epochs, patience, tol, verbose=verbose)\n",
    "\n",
    "    # # use the knn model\n",
    "    # k = 5\n",
    "    # trained_model = KNN(k)\n",
    "    # trained_model.fit(X_train, y_train)\n",
    "\n",
    "    return trained_model\n",
    "\n",
    "def predict_e2e(model, X, naive_feat=False):\n",
    "    # check if X is tensor otherwise convert\n",
    "    if not torch.is_tensor(X):\n",
    "        X = torch.tensor(X).float()\n",
    "\n",
    "    if naive_feat:\n",
    "        X[:, 1:] = 0\n",
    "\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        y_pred = model(X)\n",
    "        # loss = criterion(y_pred, y)\n",
    "        # # print(f\"Eval loss: {loss.item()}\")\n",
    "\n",
    "        y_pred = torch.sigmoid(y_pred)\n",
    "        y_pred = (y_pred > 0.5).int()\n",
    "        y_pred = y_pred.numpy()\n",
    "\n",
    "    # # use the knn model\n",
    "    # y_pred = model.predict(X)\n",
    "    return y_pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "seed = 0\n",
    "\n",
    "model = \"clip\"\n",
    "algorithms = ['ERM', 'GroupDRO', 'oversample', 'remax-margin', 'undersample']\n",
    "celeba_path = f\"//exps/div_explore/celeba_v2/\"\n",
    "coco_path = f\"//exps/div_explore/coco_v2/\"\n",
    "celeba_identifier = ['n', 'sc', 'ci', 'ai', 'y_task', 'a_task']\n",
    "coco_identifier = ['n', 'sc', 'ci', 'ai']\n",
    "\n",
    "celeba_df, celeba_df_full = load_data(\"celeba\", model, celeba_path, celeba_identifier)\n",
    "coco_df, coco_df_full = load_data(\"coco\", model, coco_path, coco_identifier)\n",
    "\n",
    "# other variables\n",
    "celeba_feats = ['n', 'sc', 'ci', 'ai', 'c_intra', 'a_intra', 'multi_hot']\n",
    "# celeba_feats = ['n', 'sc_', 'ci_', 'ai_', 'c_intra', 'a_intra_i', 'multi_hot']\n",
    "coco_feats = ['n', 'sc', 'ci', 'ai', 'c_intra', 'a_intra', 'multi_hot']\n",
    "# coco_feats = ['n', 'sc_', 'ci_', 'ai_', 'c_intra', 'a_intra_i', 'multi_hot']\n",
    "\n",
    "(\n",
    "    X_train, X_test, y_train, y_test,\n",
    "    X_train_full, y_train_full, X_test_full, y_test_full,\n",
    "    tr_idx, te_idx,\n",
    "    X_test_coco, y_test_coco,\n",
    "    X_test_coco_full, y_test_coco_full,\n",
    "    X_train_alg, X_test_alg,\n",
    "    tr_idx_full, te_idx_full\n",
    ") = prepare_data(celeba_df, celeba_df_full, coco_df, coco_df_full, celeba_feats, coco_feats, seed)\n",
    "\n",
    "# sanity check\n",
    "print(model)\n",
    "print(celeba_df.shape, celeba_df_full.shape)\n",
    "print(coco_df.shape, coco_df_full.shape)\n",
    "print(celeba_df.columns)\n",
    "print(coco_df.columns)\n",
    "print(celeba_df['wga_te_err'].mean(), celeba_df_full['wga_te_err'].mean())\n",
    "print(coco_df['wga_te_err'].mean(), coco_df_full['wga_te_err'].mean())\n",
    "print(celeba_df.iloc[te_idx]['wga_te_err'].mean())\n",
    "print()\n",
    "celeba_df_full_test = celeba_df_full.iloc[te_idx_full]\n",
    "print(celeba_df_full_test[celeba_df_full_test['algorithm']=='ERM']['wga_te_err'].mean())\n",
    "print(celeba_df_full_test[celeba_df_full_test['algorithm']=='GroupDRO']['wga_te_err'].mean())\n",
    "print(celeba_df_full_test[celeba_df_full_test['algorithm']=='oversample']['wga_te_err'].mean())\n",
    "print(celeba_df_full_test[celeba_df_full_test['algorithm']=='undersample']['wga_te_err'].mean())\n",
    "print(celeba_df_full_test[celeba_df_full_test['algorithm']=='remax-margin']['wga_te_err'].mean())\n",
    "print()\n",
    "print(coco_df_full[coco_df_full['algorithm']=='ERM']['wga_te_err'].mean())\n",
    "print(coco_df_full[coco_df_full['algorithm']=='GroupDRO']['wga_te_err'].mean())\n",
    "print(coco_df_full[coco_df_full['algorithm']=='oversample']['wga_te_err'].mean())\n",
    "print(coco_df_full[coco_df_full['algorithm']=='undersample']['wga_te_err'].mean())\n",
    "print(coco_df_full[coco_df_full['algorithm']=='remax-margin']['wga_te_err'].mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trained_model = run_e2e(X_train, y_train, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred = predict_e2e(trained_model, X_test)\n",
    "_ = eval_acc(y_test, y_pred, mode='0-1', verbose=True)\n",
    "_ = eval_acc(y_test, y_pred, mode='soft 0-1', verbose=True)\n",
    "_ = eval_wga_err(celeba_df.iloc[te_idx], y_pred, celeba_df_full, celeba_identifier)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred_coco = predict_e2e(trained_model, X_test_coco)\n",
    "_ = eval_acc(y_test_coco, y_pred_coco, mode='0-1', verbose=True)\n",
    "_ = eval_acc(y_test_coco, y_pred_coco, mode='soft 0-1', verbose=True)\n",
    "_ = eval_wga_err(coco_df, y_pred_coco, coco_df_full, coco_identifier)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# map y_pred back to algorithms\n",
    "y_pred_alg = []\n",
    "for y_ in y_pred:\n",
    "    y_ = np.array(y_)\n",
    "    y_pred_alg.append(np.array(algorithms)[np.where(y_==1)[0]])\n",
    "# show the counts for each algorithm\n",
    "y_pred_alg = np.array(y_pred_alg)\n",
    "counts = [0,0,0,0,0]\n",
    "for alg in y_pred_alg:\n",
    "    for i, a in enumerate(algorithms):\n",
    "        if a in alg:\n",
    "            counts[i] += 1\n",
    "counts, algorithms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred_alg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trained_model = run_e2e(X_train, y_train, verbose=True, naive_feat=True)\n",
    "# y_pred = predict_e2e(trained_model, X_test, naive_feat=True)\n",
    "# _ = eval_acc(y_test, y_pred, mode='0-1', verbose=True)\n",
    "# _ = eval_acc(y_test, y_pred, mode='soft 0-1', verbose=True)\n",
    "# _ = eval_wga_err(celeba_df.iloc[te_idx], y_pred, celeba_df_full, celeba_identifier)\n",
    "# print()\n",
    "y_pred_coco = predict_e2e(trained_model, X_test_coco, naive_feat=False)\n",
    "_ = eval_acc(y_test_coco, y_pred_coco, mode='0-1', verbose=True)\n",
    "_ = eval_acc(y_test_coco, y_pred_coco, mode='soft 0-1', verbose=True)\n",
    "_ = eval_wga_err(coco_df, y_pred_coco, coco_df_full, coco_identifier)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Regression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_regression(X_train, y_train, verbose=True, naive_feat=False):\n",
    "    from sklearn.neural_network import MLPRegressor\n",
    "\n",
    "    if naive_feat:\n",
    "        X_train[:, 1:] = 0\n",
    "\n",
    "    trained_model = MLPRegressor(random_state=0, max_iter=10000, verbose=verbose, tol=1e-3, n_iter_no_change=100, alpha=0.1, hidden_layer_sizes=(100,100,50,50)).fit(X_train, y_train)\n",
    "\n",
    "    return trained_model\n",
    "\n",
    "def predict_regression(model, X, naive_feat=False):\n",
    "\n",
    "    def get_rank(x):\n",
    "        min_err = x.min()\n",
    "        return (x <= min_err + 0.05).astype(int)\n",
    "\n",
    "    if naive_feat:\n",
    "        X_train[:, 1:] = 0\n",
    "\n",
    "    y_pred = model.predict(X)\n",
    "    y_pred_agg = []\n",
    "    assert len(y_pred) % 5 == 0\n",
    "    for i in range(len(y_pred)//5):\n",
    "        curr_pred = y_pred[5*i:5*i+5]\n",
    "        y_pred_agg.append(get_rank(curr_pred))\n",
    "    return y_pred_agg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "seed = 0\n",
    "\n",
    "model = \"resnet\"\n",
    "algorithms = ['ERM', 'GroupDRO', 'oversample', 'remax-margin', 'undersample']\n",
    "celeba_path = f\"//exps/div_explore/celeba_v2/\"\n",
    "coco_path = f\"//exps/div_explore/coco_v2/\"\n",
    "celeba_identifier = ['n', 'sc', 'ci', 'ai', 'y_task', 'a_task']\n",
    "coco_identifier = ['n', 'sc', 'ci', 'ai']\n",
    "\n",
    "celeba_df, celeba_df_full = load_data(\"celeba\", model, celeba_path, celeba_identifier)\n",
    "coco_df, coco_df_full = load_data(\"coco\", model, coco_path, coco_identifier)\n",
    "\n",
    "# other variables\n",
    "# celeba_feats = ['n', 'sc', 'ci', 'ai', 'c_intra', 'a_intra', 'multi_hot']\n",
    "celeba_feats = ['n', 'sc_', 'ci_', 'ai_', 'c_intra', 'a_intra_i', 'multi_hot']\n",
    "# coco_feats = ['n', 'sc', 'ci', 'ai', 'c_intra', 'a_intra', 'multi_hot']\n",
    "coco_feats = ['n', 'sc_', 'ci_', 'ai_', 'c_intra', 'a_intra_i', 'multi_hot']\n",
    "\n",
    "# sanity check\n",
    "print(model)\n",
    "print(celeba_df.shape, celeba_df_full.shape)\n",
    "print(coco_df.shape, coco_df_full.shape)\n",
    "print(celeba_df.columns)\n",
    "print(coco_df.columns)\n",
    "print(celeba_df['wga_te_err'].mean(), celeba_df_full['wga_te_err'].mean())\n",
    "print(coco_df['wga_te_err'].mean(), coco_df_full['wga_te_err'].mean())\n",
    "\n",
    "(\n",
    "    X_train, X_test, y_train, y_test,\n",
    "    X_train_full, y_train_full, X_test_full, y_test_full,\n",
    "    tr_idx, te_idx,\n",
    "    X_test_coco, y_test_coco,\n",
    "    X_test_coco_full, y_test_coco_full,\n",
    "    X_train_alg, X_test_alg,\n",
    "    tr_idx_full, te_idx_full\n",
    ") = prepare_data(celeba_df, celeba_df_full, coco_df, coco_df_full, celeba_feats, coco_feats, seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trained_model = run_regression(X_train_full, y_train_full, verbose=True, naive_feat=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred = predict_regression(trained_model, X_test_full, naive_feat=False)\n",
    "_ = eval_acc(y_test, y_pred, mode='0-1', verbose=True)\n",
    "_ = eval_acc(y_test, y_pred, mode='soft 0-1', verbose=True)\n",
    "_ = eval_wga_err(celeba_df.iloc[te_idx], y_pred, celeba_df_full, celeba_identifier)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred_coco = predict_regression(trained_model, X_test_coco_full, naive_feat=False)\n",
    "_ = eval_acc(y_test_coco, y_pred_coco, mode='0-1', verbose=True)\n",
    "_ = eval_acc(y_test_coco, y_pred_coco, mode='soft 0-1', verbose=True)\n",
    "_ = eval_wga_err(coco_df, y_pred_coco, coco_df_full, coco_identifier)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Ensemble"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 162,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_binary_classifier(X_train, y_train, X_train_alg, method1, method2, filter_thre=0.05, verbose=True, naive_feat=False):\n",
    "    from sklearn.neural_network import MLPClassifier\n",
    "\n",
    "    m1_idx = X_train_alg == algorithms.index(method1)\n",
    "    m2_idx = X_train_alg == algorithms.index(method2)\n",
    "\n",
    "    def lossdiff2rank(x):\n",
    "        if x < -filter_thre:\n",
    "            return 1\n",
    "        elif x > filter_thre:\n",
    "            return 0\n",
    "        else:\n",
    "            return 2\n",
    "    y_diff = y_train[m1_idx] - y_train[m2_idx]\n",
    "    # apply lossdiff2rank\n",
    "    y_diff = np.array([lossdiff2rank(x) for x in y_diff])\n",
    "\n",
    "    X_train = X_train[:, :-1][m1_idx]\n",
    "    y_train = y_diff\n",
    "\n",
    "    if naive_feat:\n",
    "        X_train[:, 1:] = 0\n",
    "\n",
    "    # clf = MLPClassifier(random_state=1, max_iter=1000000, verbose=False, tol=1e-3, n_iter_no_change=2000, alpha=0.01, hidden_layer_sizes=(100,10,)).fit(X_train, y_train)\n",
    "    # clf = MLPClassifier(random_state=1, max_iter=1000000, verbose=False, tol=1e-3, n_iter_no_change=2000, alpha=0.1, hidden_layer_sizes=(100,100,50,50)).fit(X_train, y_train)\n",
    "    clf = MLPClassifier(random_state=1, max_iter=100000, verbose=False, tol=1e-3, n_iter_no_change=2000, alpha=0.01, hidden_layer_sizes=(100,50)).fit(X_train, y_train)\n",
    "    train_acc = clf.score(X_train, y_train)\n",
    "\n",
    "    if verbose:\n",
    "        print(method1, method2, \"train size: \", len(X_train), \"train acc: \", train_acc)\n",
    "\n",
    "    return clf\n",
    "\n",
    "def run_ensemble(X_train, y_train, X_test_full, y_test_full, X_train_alg, naive_feat=False):\n",
    "    import itertools\n",
    "    combinations = np.array(list(itertools.combinations(algorithms, 2)))\n",
    "\n",
    "    trained_model = {}\n",
    "    for m1, m2 in combinations:\n",
    "        trained_model[(m1, m2)] = train_binary_classifier(X_train, y_train, X_train_alg, m1, m2, filter_thre=0.05, naive_feat=naive_feat)\n",
    "    return trained_model\n",
    "\n",
    "def predict_ensemble(model, X, naive_feat=False):\n",
    "    def copeland_method(candidate_names, pairwise_results):\n",
    "        # Initialize scores dictionary\n",
    "        scores = {name: 0 for name in candidate_names}\n",
    "        # Update scores based on pairwise results\n",
    "        for result in pairwise_results:\n",
    "            A, B, winner = result\n",
    "            if winner == A:\n",
    "                scores[A] += 1\n",
    "                scores[B] -= 1\n",
    "            elif winner == B:\n",
    "                scores[B] += 1\n",
    "                scores[A] -= 1\n",
    "        winners = [name for name in candidate_names if scores[name] == max(scores.values())]\n",
    "        return winners, scores\n",
    "\n",
    "    def extract_pairwise_results(zoo, X_test):\n",
    "        res = []\n",
    "        # if 1d then expand to 2d\n",
    "        if len(X_test.shape) == 1:\n",
    "            X_test = X_test.reshape(1, -1)\n",
    "        for (m1, m2), clf in zoo.items():\n",
    "            y_pred = clf.predict(X_test)\n",
    "            if y_pred == 1:\n",
    "                res.append((m1, m2, m1))\n",
    "            elif y_pred == 0:\n",
    "                res.append((m1, m2, m2))\n",
    "            else:\n",
    "                res.append((m1, m2, m1))\n",
    "                res.append((m1, m2, m2))\n",
    "        return res\n",
    "    winners = []\n",
    "\n",
    "    if naive_feat:\n",
    "        X[:, 1:] = 0\n",
    "\n",
    "    for curr_x in X:\n",
    "        pair_res = extract_pairwise_results(model, curr_x)\n",
    "        winner, _ = copeland_method(algorithms, pair_res)\n",
    "        winners.append(winner)\n",
    "    # convert winners to multi-hot\n",
    "    y_pred = np.zeros((X.shape[0], len(algorithms)))\n",
    "    for i, w in enumerate(winners):\n",
    "        inds = [algorithms.index(m) for m in w]\n",
    "        y_pred[i, inds] = 1\n",
    "    return y_pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "seed = 0\n",
    "\n",
    "model = \"clip\"\n",
    "algorithms = ['ERM', 'GroupDRO', 'oversample', 'remax-margin', 'undersample']\n",
    "celeba_path = f\"//exps/div_explore/celeba_v2/\"\n",
    "coco_path = f\"//exps/div_explore/coco_v2/\"\n",
    "celeba_identifier = ['n', 'sc', 'ci', 'ai', 'y_task', 'a_task']\n",
    "coco_identifier = ['n', 'sc', 'ci', 'ai']\n",
    "\n",
    "celeba_df, celeba_df_full = load_data(\"celeba\", model, celeba_path, celeba_identifier)\n",
    "coco_df, coco_df_full = load_data(\"coco\", model, coco_path, coco_identifier)\n",
    "\n",
    "# other variables\n",
    "celeba_feats = ['n', 'sc', 'ci', 'ai', 'c_intra', 'a_intra', 'multi_hot']\n",
    "# celeba_feats = ['n', 'sc_', 'ci_', 'ai_', 'c_intra', 'a_intra_i', 'multi_hot']\n",
    "coco_feats = ['n', 'sc', 'ci', 'ai', 'c_intra', 'a_intra', 'multi_hot']\n",
    "# coco_feats = ['n', 'sc_', 'ci_', 'ai_', 'c_intra', 'a_intra_i', 'multi_hot']\n",
    "\n",
    "# sanity check\n",
    "print(model)\n",
    "print(celeba_df.shape, celeba_df_full.shape)\n",
    "print(coco_df.shape, coco_df_full.shape)\n",
    "print(celeba_df.columns)\n",
    "print(coco_df.columns)\n",
    "print(celeba_df['wga_te_err'].mean(), celeba_df_full['wga_te_err'].mean())\n",
    "print(coco_df['wga_te_err'].mean(), coco_df_full['wga_te_err'].mean())\n",
    "\n",
    "(\n",
    "    X_train, X_test, y_train, y_test,\n",
    "    X_train_full, y_train_full, X_test_full, y_test_full,\n",
    "    tr_idx, te_idx,\n",
    "    X_test_coco, y_test_coco,\n",
    "    X_test_coco_full, y_test_coco_full,\n",
    "    X_train_alg, X_test_alg,\n",
    "    tr_idx_full, te_idx_full\n",
    ") = prepare_data(celeba_df, celeba_df_full, coco_df, coco_df_full, celeba_feats, coco_feats, seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trained_model = run_ensemble(X_train_full, y_train_full, X_test_full, y_test_full, X_train_alg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred = predict_ensemble(trained_model, X_test, naive_feat=False)\n",
    "_ = eval_acc(y_test, y_pred, mode='0-1', verbose=True)\n",
    "_ = eval_acc(y_test, y_pred, mode='soft 0-1', verbose=True)\n",
    "_ = eval_wga_err(celeba_df.iloc[te_idx], y_pred, celeba_df_full, celeba_identifier)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred_coco = predict_ensemble(trained_model, X_test_coco, naive_feat=False)\n",
    "_ = eval_acc(y_test_coco, y_pred_coco, mode='0-1', verbose=True)\n",
    "_ = eval_acc(y_test_coco, y_pred_coco, mode='soft 0-1', verbose=True)\n",
    "_ = eval_wga_err(coco_df, y_pred_coco, coco_df_full, coco_identifier)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# map y_pred back to algorithms\n",
    "y_test_alg = []\n",
    "for y_ in y_test:\n",
    "    y_ = np.array(y_)\n",
    "    y_test_alg.append(np.array(algorithms)[np.where(y_==1)[0]])\n",
    "# show the counts for each algorithm\n",
    "y_test_alg = np.array(y_test_alg)\n",
    "counts = [0,0,0,0,0]\n",
    "for alg in y_test_alg:\n",
    "    for i, a in enumerate(algorithms):\n",
    "        if a in alg:\n",
    "            counts[i] += 1\n",
    "counts, algorithms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# map y_pred back to algorithms\n",
    "y_pred_alg = []\n",
    "for y_ in y_pred:\n",
    "    y_ = np.array(y_)\n",
    "    y_pred_alg.append(np.array(algorithms)[np.where(y_==1)[0]])\n",
    "# show the counts for each algorithm\n",
    "y_pred_alg = np.array(y_pred_alg)\n",
    "counts = [0,0,0,0,0]\n",
    "for alg in y_pred_alg:\n",
    "    for i, a in enumerate(algorithms):\n",
    "        if a in alg:\n",
    "            counts[i] += 1\n",
    "counts, algorithms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "\n",
    "def plot_single_bar(numbers, suffix):\n",
    "    # Calculate the sum of the numbers\n",
    "    sns.set(style=\"white\")\n",
    "    total = sum(numbers)\n",
    "\n",
    "    # Calculate the ratios\n",
    "    ratios = [num / total for num in numbers]\n",
    "\n",
    "    # Create a DataFrame for seaborn (optional, for labeling and potential further usage)\n",
    "    df = pd.DataFrame({\n",
    "        'ratio': ratios,\n",
    "        'category': ['A', 'B', 'C', 'D', 'E']  # You can name these categories as you like\n",
    "    })\n",
    "\n",
    "    # Initialize the figure and apply a seaborn theme\n",
    "    sns.set(style=\"whitegrid\")\n",
    "    # set color palette\n",
    "    # sns.set_palette(\"RdPu\")\n",
    "    # sns.set_palette(\"RdYlBu\")\n",
    "    sns.set_palette(\"Set2\")\n",
    "    plt.figure(figsize=(2, 0.3))\n",
    "\n",
    "    # Plot each segment of the bar using plt.barh\n",
    "    left = 0  # Starting position of each bar\n",
    "    for i in range(len(df)):\n",
    "        plt.barh(y=0, width=df['ratio'][i], left=left, color=f'C{i}', edgecolor=\"white\", alpha=0.7)\n",
    "        left += df['ratio'][i]  # Update the starting position for the next bar\n",
    "\n",
    "    # Set limits for the x-axis to ensure it goes from 0 to 1\n",
    "    plt.xlim(0, 1)\n",
    "    plt.yticks([])  # Remove y-ticks since there's only one bar\n",
    "    plt.xticks([])  # Remove x-ticks\n",
    "\n",
    "    # # Optionally, add labels to each segment\n",
    "    # left = 0\n",
    "    # for i in range(len(df)):\n",
    "    #     plt.text(left + df['ratio'][i] / 2, 0, f'{df[\"ratio\"][i]:.2f}', ha='center', va='center', color='white')\n",
    "    #     left += df['ratio'][i]\n",
    "    # remove the axis\n",
    "    plt.axis('off')\n",
    "    # plot with high dpi\n",
    "    plt.savefig(f\"./figure/alg_sel_{suffix}.pdf\", dpi=300, bbox_inches='tight', pad_inches=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# List of your 5 numbers\n",
    "# numbers = [37, 85, 96, 194, 82]\n",
    "# numbers = [53, 163, 90, 225, 167]\n",
    "numbers = [100,0,0,0,0]\n",
    "algorithms = ['ERM', 'GroupDRO', 'remax-margin', 'undersample', 'oversample']\n",
    "\n",
    "# plot_single_bar([37, 85, 96, 194, 82], \"ours\")\n",
    "# plot_single_bar([62, 128, 132, 228, 130], \"gt\")\n",
    "# plot_single_bar([100,0,0,0,0], \"erm\")\n",
    "# plot_single_bar([0,100,0,0,0], \"groupdro\")\n",
    "# plot_single_bar([0,0,100,0,0], \"remax-margin\")\n",
    "# plot_single_bar([0,0,0,100,0], \"undersample\")\n",
    "# plot_single_bar([0,0,0,0,100], \"oversample\")\n",
    "\n",
    "\n",
    "# plot_single_bar([38, 137, 57, 195, 122], \"ours_res\")\n",
    "# plot_single_bar([53, 163, 90, 225, 167], \"gt_res\")\n",
    "plot_single_bar([20,20,20,20,20], \"random\")\n",
    "plot_single_bar([263, 707, 388, 870, 709], \"gb\") # ['ERM', 'GroupDRO', 'oversample', 'remax-margin', 'undersample']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# List of your 5 numbers\n",
    "# numbers = [37, 85, 96, 194, 82]\n",
    "# numbers = [53, 163, 90, 225, 167]\n",
    "numbers = [100,0,0,0,0]\n",
    "algorithms = ['ERM', 'GroupDRO', 'remax-margin', 'undersample', 'oversample']\n",
    "\n",
    "plot_single_bar([37, 85, 96, 194, 82], \"ours\")\n",
    "plot_single_bar([62, 128, 132, 228, 130], \"gt\")\n",
    "plot_single_bar([100,0,0,0,0], \"erm\")\n",
    "plot_single_bar([0,100,0,0,0], \"groupdro\")\n",
    "plot_single_bar([0,0,100,0,0], \"remax-margin\")\n",
    "plot_single_bar([0,0,0,100,0], \"undersample\")\n",
    "plot_single_bar([0,0,0,0,100], \"oversample\")\n",
    "\n",
    "\n",
    "plot_single_bar([38, 137, 57, 195, 122], \"ours_res\")\n",
    "plot_single_bar([53, 163, 90, 225, 167], \"gt_res\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Baselines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "seed = 0\n",
    "\n",
    "model = \"resnet\"\n",
    "algorithms = ['ERM', 'GroupDRO', 'oversample', 'remax-margin', 'undersample']\n",
    "celeba_path = f\"//exps/div_explore/celeba_v2/\"\n",
    "coco_path = f\"//exps/div_explore/coco_v2/\"\n",
    "celeba_identifier = ['n', 'sc', 'ci', 'ai', 'y_task', 'a_task']\n",
    "coco_identifier = ['n', 'sc', 'ci', 'ai']\n",
    "\n",
    "celeba_df, celeba_df_full = load_data(\"celeba\", model, celeba_path, celeba_identifier)\n",
    "coco_df, coco_df_full = load_data(\"coco\", model, coco_path, coco_identifier)\n",
    "\n",
    "# other variables\n",
    "celeba_feats = ['n', 'sc', 'ci', 'ai', 'c_intra', 'a_intra', 'multi_hot']\n",
    "# celeba_feat = ['n', 'sc_', 'ci_', 'ai_', 'c_intra', 'a_intra_i', 'multi_hot']\n",
    "coco_feats = ['n', 'sc', 'ci', 'ai', 'c_intra', 'a_intra', 'multi_hot']\n",
    "\n",
    "# sanity check\n",
    "print(model)\n",
    "print(celeba_df.shape, celeba_df_full.shape)\n",
    "print(coco_df.shape, coco_df_full.shape)\n",
    "print(celeba_df.columns)\n",
    "print(coco_df.columns)\n",
    "print(celeba_df['wga_te_err'].mean(), celeba_df_full['wga_te_err'].mean())\n",
    "print(coco_df['wga_te_err'].mean(), coco_df_full['wga_te_err'].mean())\n",
    "\n",
    "(\n",
    "    X_train, X_test, y_train, y_test,\n",
    "    X_train_full, y_train_full, X_test_full, y_test_full,\n",
    "    tr_idx, te_idx,\n",
    "    X_test_coco, y_test_coco,\n",
    "    X_test_coco_full, y_test_coco_full,\n",
    "    X_train_alg, X_test_alg,\n",
    "    tr_idx_full, te_idx_full\n",
    ") = prepare_data(celeba_df, celeba_df_full, coco_df, coco_df_full, celeba_feats, coco_feats, seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "mode = 'global_best'\n",
    "celeba_train = celeba_df.iloc[tr_idx]\n",
    "global_rank = np.array(celeba_train['multi_hot'].tolist()).sum(axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "global_rank"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict_baseline(mode, X):\n",
    "    y_preds = []\n",
    "    for i in range(len(X)):\n",
    "        if mode == \"random\":\n",
    "            num_winners = np.random.choice(5, 1, replace=False)[0] + 1\n",
    "            y_p = np.random.choice(5, num_winners, replace=False)\n",
    "        elif mode == \"global_best\":\n",
    "            num_winners = np.random.choice(5, 1, replace=False)[0] + 1\n",
    "            y_p = np.argsort(global_rank)[::-1][:num_winners]\n",
    "        else:\n",
    "            raise ValueError(f\"unknown mode {mode}\")\n",
    "        # y_pred_str = algorithms[y_pred].tolist()\n",
    "        # convert y_pred to multi-hot\n",
    "        y_p = [1 if i in y_p else 0 for i in range(5)]\n",
    "        y_preds.append(y_p)\n",
    "    return np.array(y_preds) # , y_pred_str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred = predict_baseline(mode, X_test)\n",
    "_ = eval_acc(y_test, y_pred, mode='0-1', verbose=True)\n",
    "_ = eval_acc(y_test, y_pred, mode='soft 0-1', verbose=True)\n",
    "_ = eval_wga_err(celeba_df.iloc[te_idx], y_pred, celeba_df_full, celeba_identifier)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred_coco = predict_baseline(mode, X_test_coco)\n",
    "_ = eval_acc(y_test_coco, y_pred_coco, mode='0-1', verbose=True)\n",
    "_ = eval_acc(y_test_coco, y_pred_coco, mode='soft 0-1', verbose=True)\n",
    "_ = eval_wga_err(coco_df, y_pred_coco, coco_df_full, coco_identifier)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "div_backup",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
