{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "import copy\n",
    "import torch\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_data(data, model, path, identifier):\n",
    "    data_path = os.path.join(path, f\"{model}_lp\")\n",
    "    res_files = os.listdir(data_path)\n",
    "    all_res = []\n",
    "    for file in res_files:\n",
    "        res_file = os.path.join(data_path, file)\n",
    "        # print(file)\n",
    "        # parse the seed from the file name\n",
    "        res = pd.read_csv(res_file)\n",
    "        res[\"seed\"] = int(file.split('seed')[-1][0])\n",
    "        if data == \"celeba\":\n",
    "            res[\"y_task\"] = int(file.split('y')[-1].split('_')[0])\n",
    "            res[\"a_task\"] = int(file.split('a')[-1].split('.')[0])\n",
    "        all_res.append(res)\n",
    "    all_res = pd.concat(all_res, ignore_index=True)\n",
    "    all_res.rename(columns={\"method\": \"algorithm\"}, inplace=True)\n",
    "\n",
    "    # average over seeds\n",
    "    all_res = all_res[identifier + [\"algorithm\", \"wga_te_err\", \"seed\"]]\n",
    "    all_res = all_res.groupby(identifier + [\"algorithm\"]).mean().reset_index()\n",
    "    all_res.drop(columns=[\"seed\"], inplace=True)\n",
    "\n",
    "    # get the best performing algorithm for each dataset, used for deduplicate\n",
    "    all_res['rank'] = all_res.groupby(identifier)['wga_te_err'].rank(\"first\")\n",
    "\n",
    "    # get the winners for each dataset\n",
    "    def get_gt_rank(x, filter_thre=0.05):\n",
    "        # print(x)\n",
    "        min_err = x['wga_te_err'].min()\n",
    "        winners = x[x['wga_te_err'] <= min_err + filter_thre][\"algorithm\"].to_list()\n",
    "        return '|'.join(winners)\n",
    "    # apply the function to each group and reset the index to merge back with the original dataframe\n",
    "    winners_series = all_res.groupby(identifier)[['algorithm', 'wga_te_err']].apply(lambda x: get_gt_rank(x)).reset_index(name='winners')\n",
    "    all_res = all_res.merge(winners_series, on=identifier)\n",
    "    # get the multi-hot encoding of the winners\n",
    "    all_res[\"multi_hot\"] = all_res[\"winners\"].map(lambda x: [1 if m in x.split(\"|\") else 0 for m in algorithms])\n",
    "\n",
    "    # get the confidence for each dataset\n",
    "    def get_conf(x, filter_thre=0.05):\n",
    "        min_err = x['wga_te_err'].min()\n",
    "        winners = x[x['wga_te_err'] <= min_err + filter_thre][\"algorithm\"].to_list()\n",
    "        # get minimum of winner and maximum of loser\n",
    "        winner_max = x[x['algorithm'].isin(winners)]['wga_te_err'].max()\n",
    "        loser_min = x[~x['algorithm'].isin(winners)]['wga_te_err'].min()\n",
    "        return loser_min - winner_max\n",
    "    conf_series = all_res.groupby(identifier)[['algorithm', 'wga_te_err']].apply(lambda x: get_conf(x)).reset_index(name='conf')\n",
    "    all_res = all_res.merge(conf_series, on=identifier)\n",
    "\n",
    "    # load extra data (static meta-features)\n",
    "    tr_df = pd.read_csv(os.path.join(path, f\"{model}_embeddings_tr/proc_emb.csv\"))\n",
    "    all_res = all_res.merge(tr_df, on=identifier)\n",
    "    if data == \"coco\":\n",
    "        # rename columns\n",
    "        all_res.rename(columns={\"c_dist\": \"c_intra\", \"a_dist\": \"a_intra\"}, inplace=True)\n",
    "\n",
    "    # retain the original results\n",
    "    all_res_ori = copy.deepcopy(all_res)\n",
    "    # deduplicate\n",
    "    all_res = all_res[all_res[\"rank\"]==1.0]\n",
    "    # set all nan to 0\n",
    "    all_res['conf'].fillna(0, inplace=True)\n",
    "\n",
    "    return all_res, all_res_ori\n",
    "\n",
    "\n",
    "def prepare_data(celeba_df, celeba_df_full, coco_df, coco_df_full, celeba_feats, coco_feats, seed):\n",
    "    # split the data and convert to trainable format\n",
    "    all_X = celeba_df[celeba_feats].to_numpy()[:,:-1].astype('float')\n",
    "    all_y = np.array(list(celeba_df[celeba_feats].to_numpy()[:,-1]))\n",
    "\n",
    "    from sklearn.preprocessing import StandardScaler\n",
    "    from sklearn.model_selection import train_test_split\n",
    "\n",
    "    X_train, X_test, y_train, y_test, tr_idx, te_idx = train_test_split(all_X, all_y, range(len(all_X)), test_size=0.2, random_state=seed)\n",
    "\n",
    "    scaler = StandardScaler()\n",
    "    X_train = scaler.fit_transform(X_train)\n",
    "    X_test = scaler.transform(X_test)\n",
    "\n",
    "    X_test_coco = coco_df[coco_feats].to_numpy()[:,:-1].astype('float')\n",
    "    y_test_coco = np.array(list(coco_df[coco_feats].to_numpy()[:,-1]))\n",
    "    X_test_coco = scaler.transform(X_test_coco)\n",
    "\n",
    "    # process celeba_df_full\n",
    "    tr_idx_full = np.concatenate([np.arange(5*i, 5*i+5) for i in tr_idx])\n",
    "    te_idx_full = np.concatenate([np.arange(5*i, 5*i+5) for i in te_idx])\n",
    "    all_X_full = celeba_df_full[celeba_feats[:-1]+[\"algorithm\"]]\n",
    "    all_X_full['algorithm'] = all_X_full['algorithm'].map(lambda x: algorithms.index(x))\n",
    "    all_X_full = all_X_full.to_numpy().astype('float')\n",
    "    all_X_alg = all_X_full[:, -1]\n",
    "    all_y_full = celeba_df_full['wga_te_err'].to_numpy()\n",
    "    X_train_full, y_train_full, X_test_full, y_test_full = all_X_full[tr_idx_full], all_y_full[tr_idx_full], all_X_full[te_idx_full], all_y_full[te_idx_full]\n",
    "    scaler = StandardScaler()\n",
    "    X_train_full = scaler.fit_transform(X_train_full)\n",
    "    X_test_full = scaler.transform(X_test_full)\n",
    "    X_train_alg = all_X_alg[tr_idx_full]\n",
    "    X_test_alg = all_X_alg[te_idx_full]\n",
    "\n",
    "    # process celeba_df_full\n",
    "    X_test_coco_full = coco_df_full[coco_feats[:-1]+[\"algorithm\"]]\n",
    "    X_test_coco_full['algorithm'] = X_test_coco_full['algorithm'].map(lambda x: algorithms.index(x)).to_numpy().astype('float')\n",
    "    y_test_coco_full = coco_df_full['wga_te_err'].to_numpy()\n",
    "    X_test_coco_full = scaler.transform(X_test_coco_full)\n",
    "\n",
    "    return X_train, X_test, y_train, y_test, X_train_full, y_train_full, X_test_full, y_test_full, tr_idx, te_idx, X_test_coco, y_test_coco, X_test_coco_full, y_test_coco_full, X_train_alg, X_test_alg, tr_idx_full, te_idx_full\n",
    "\n",
    "\n",
    "def eval_acc(y_test, y_pred, mode, verbose=True):\n",
    "    # assuming numpy arrays\n",
    "    if mode == \"0-1\":\n",
    "        # acc = ((y_pred == y_test).float().sum(dim=1)==5).float().mean()\n",
    "        acc = (y_pred == y_test).all(axis=1).mean()\n",
    "        if verbose:\n",
    "            print(f\"Eval {mode} accuracy: {acc}\")\n",
    "    elif mode == \"soft 0-1\":\n",
    "        correct = 0\n",
    "        for i, curr_y in enumerate(y_test):\n",
    "            curr_pred = y_pred[i]\n",
    "            # check if the positions of 1s in curr_pred are also 1s in curr_y\n",
    "            pos = np.where(curr_pred==1)[0]\n",
    "            if np.all(curr_y[pos] == 1):\n",
    "                correct += 1\n",
    "        acc = correct / y_test.shape[0]\n",
    "        if verbose:\n",
    "            print(f\"Eval {mode} accuracy: {acc}\")\n",
    "\n",
    "\n",
    "def eval_wga_err(test_df, y_pred, df_full, identifier):\n",
    "    # check if y_pred is already a list\n",
    "    if not isinstance(y_pred, list):\n",
    "        y_pred = y_pred.tolist()\n",
    "    test_df['pred'] = y_pred\n",
    "    winners_pred = []\n",
    "    for y_ in y_pred:\n",
    "        y_ = np.array(y_)\n",
    "        winners_pred.append(np.array(algorithms)[np.where(y_==1)[0]])\n",
    "    test_df['winners_pred'] = winners_pred\n",
    "    # need to transform the multi-hot vector to the algorithm name\n",
    "    fine_test_df = df_full.merge(test_df, on=identifier)\n",
    "    def get_pred_wga_te_err(x):\n",
    "        pred_winners = x['winners_pred'].iloc[0]\n",
    "        if len(pred_winners) == 0:\n",
    "            # randomly select one\n",
    "            pred_winner = np.random.choice(algorithms)\n",
    "        else:\n",
    "            pred_winner = np.random.choice(pred_winners)\n",
    "        return x[x['algorithm_x']==pred_winner]['wga_te_err_x'].iloc[0]\n",
    "    pred_err = fine_test_df.groupby(identifier)['wga_te_err_x', 'algorithm_x' ,'winners_pred'].apply(lambda x: get_pred_wga_te_err(x))\n",
    "    print(f\"Eval wg err: {pred_err.mean()}\")\n",
    "    return pred_err.mean()\n",
    "\n",
    "\n",
    "def load_ft_data(y_a, n):\n",
    "    hseed, dseed = 0, [0]\n",
    "    data = 'CelebA'\n",
    "    attr = {'ERM': 'Yes', 'GroupDRO': 'Yes', 'OverSample': 'Yes', 'UnderSample': 'Yes', 'ReWeightLogits': 'Yes'}\n",
    "    # algorithms = attr.keys()\n",
    "    algorithms = ['ERM', 'GroupDRO', 'OverSample', 'ReWeightLogits', 'UnderSample']\n",
    "\n",
    "    seed_diff = []\n",
    "    all_res = pd.DataFrame([])\n",
    "    for s in dseed:\n",
    "        for alg in algorithms:\n",
    "            res_dir = f'//exps/div_explore/celeba_v2/celeba_y{y_a[0]}_a{y_a[1]}_ds0_n{n}/{alg.lower()}'\n",
    "            # find all the sub-directory under res_dir\n",
    "            dirs = os.listdir(res_dir)\n",
    "            for d in dirs:\n",
    "                if attr[alg] in d:\n",
    "                    res_file = os.path.join(res_dir, d)\n",
    "                    exps = os.listdir(res_file)# [1]\n",
    "                    for exp in exps:\n",
    "                        res_file = os.path.join(res_dir, d, exp, 'results.json')\n",
    "                        print(res_file)\n",
    "                        if os.path.exists(res_file):\n",
    "                            final_res = get_trajectory_results(res_file, s)\n",
    "                            all_res = pd.concat([all_res, final_res])\n",
    "                    # res_file = os.path.join(res_dir, d, exps, 'results.json')\n",
    "                    # print(res_file)\n",
    "                    # if os.path.exists(res_file):\n",
    "                    #     final_res = get_trajectory_results(res_file, s)\n",
    "                    #     all_res = pd.concat([all_res, final_res])\n",
    "\n",
    "    all_res = all_res[all_res['step']==1000]\n",
    "    all_res[\"wga_te_err\"] = 1 - all_res[\"metric\"]\n",
    "    all_res = all_res[[\"sc\", \"ci\", \"ai\", \"algorithm\", \"wga_te_err\", \"seed\"]]\n",
    "    all_res['n'] = n\n",
    "    all_res['n'] = all_res['n'].astype('str')\n",
    "    # all_res_ori = copy.deepcopy(all_res)\n",
    "    # group by \"sc\", \"ci\", \"ai\", \"algorithm\", \"seed\"] then average over \"wga_te_err\"\n",
    "    all_res = all_res.groupby(['n', 'sc', 'ci', 'ai', 'algorithm']).mean().reset_index()\n",
    "    all_res.drop(columns=[\"seed\"], inplace=True)\n",
    "\n",
    "    all_res['rank'] = all_res.groupby(['n', 'sc', 'ci', 'ai'])['wga_te_err'].rank(\"first\")\n",
    "    # all_res_ori = copy.deepcopy(all_res)\n",
    "    # print(all_res.head(100))\n",
    "\n",
    "    def get_gt_rank(x, mode='tie', filter_thre=0.05):\n",
    "            # print(x)\n",
    "            min_err = x['wga_te_err'].min()\n",
    "            winners = x[x['wga_te_err'] <= min_err + filter_thre][\"algorithm\"].to_list()\n",
    "            return '|'.join(winners)\n",
    "\n",
    "    # Apply the function to each group and reset the index to merge back with the original dataframe\n",
    "    winners_series = all_res.groupby(['n', 'sc', 'ci', 'ai'])[['algorithm', 'wga_te_err']].apply(lambda x: get_gt_rank(x)).reset_index(name='winners')\n",
    "    all_res = all_res.merge(winners_series, on=['n', 'sc', 'ci', 'ai'])\n",
    "\n",
    "    all_res[\"multi_hot\"] = all_res[\"winners\"].map(lambda x: [1 if m in x.split(\"|\") else 0 for m in algorithms])\n",
    "\n",
    "    all_res_ori = copy.deepcopy(all_res)\n",
    "    # deduplicate\n",
    "    all_res = all_res[all_res[\"rank\"]==1.0]\n",
    "\n",
    "    return all_res, all_res_ori\n",
    "\n",
    "def get_trajectory_results(file_path, ds, metric='te_worst_acc'):\n",
    "    # List to hold all JSON objects\n",
    "    results = []\n",
    "\n",
    "    # Open the file and read it line by line or as a whole if the objects are not line-delimited\n",
    "    with open(file_path, 'r') as file:\n",
    "        file_contents = file.read()\n",
    "        # Attempt to split the file contents by a delimiter if they're not newline-delimited\n",
    "        # This delimiter needs to be defined based on your specific file structure\n",
    "        json_objects = file_contents.split('}\\n{')  # Example: split on possible JSON end/start\n",
    "\n",
    "        # Correct split parts and parse each as a JSON object\n",
    "        for i, obj in enumerate(json_objects):\n",
    "            try:\n",
    "                # Add missing curly braces if split was done in the middle of objects\n",
    "                if i != 0:\n",
    "                    obj = '{' + obj\n",
    "                if i != len(json_objects) - 1:\n",
    "                    obj += '}'\n",
    "                result = json.loads(obj)\n",
    "                # print(json.dumps(result, indent=4))\n",
    "                results.append(result)\n",
    "            except json.JSONDecodeError as e:\n",
    "                print(f\"Error decoding JSON object: {e}\")\n",
    "\n",
    "    results_df = pd.DataFrame(results)\n",
    "    trajectory = {'sc': [], 'ci': [], 'ai': [], 'algorithm': [], 'step': [], 'metric': [], 'seed': [], 'data_seed': []}\n",
    "\n",
    "    sc = float(results_df.iloc[-1]['args']['metadata'].split('sc')[-1].split('_')[0])\n",
    "    ci = float(results_df.iloc[-1]['args']['metadata'].split('ci')[-1].split('_')[0])\n",
    "    ai = float(results_df.iloc[-1]['args']['metadata'].split('ai')[-1][:4])\n",
    "    # sc = float(results_df.iloc[-1]['args']['cmnist_spur_prob'])\n",
    "    # ci = float(results_df.iloc[-1]['args']['cmnist_label_prob'])\n",
    "    # ai = float(results_df.iloc[-1]['args']['cmnist_attr_prob'])\n",
    "    alg = results_df.iloc[-1]['args']['algorithm']\n",
    "    for row in results_df.iterrows():\n",
    "        # extract the data\n",
    "        trajectory['sc'].append(sc)\n",
    "        trajectory['ci'].append(ci)\n",
    "        trajectory['ai'].append(ai)\n",
    "        trajectory['algorithm'].append(alg)\n",
    "        trajectory['step'].append(row[1]['step'])\n",
    "        trajectory['metric'].append(row[1][metric])\n",
    "        trajectory['seed'].append(row[1]['args']['seed'])\n",
    "        trajectory['data_seed'].append(ds)\n",
    "\n",
    "    # # smooth the 'metric'\n",
    "    # trajectory['metric'] = np.convolve(trajectory['metric'], np.ones(10)/10, mode='valid')\n",
    "\n",
    "    return pd.DataFrame(trajectory)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "seed = 0\n",
    "\n",
    "model = \"resnet\"\n",
    "algorithms = ['ERM', 'GroupDRO', 'oversample', 'remax-margin', 'undersample']\n",
    "celeba_path = f\"//exps/div_explore/celeba_v2/\"\n",
    "coco_path = f\"//exps/div_explore/coco_v2/\"\n",
    "celeba_identifier = ['n', 'sc', 'ci', 'ai', 'y_task', 'a_task']\n",
    "coco_identifier = ['n', 'sc', 'ci', 'ai']\n",
    "\n",
    "celeba_df, celeba_df_full = load_data(\"celeba\", model, celeba_path, celeba_identifier)\n",
    "coco_df, coco_df_full = load_data(\"coco\", model, coco_path, coco_identifier)\n",
    "\n",
    "\n",
    "celeba_df_ft_full = []\n",
    "for i, y_a in enumerate([[2, 31], [21, 36], [8, 20], [25, 19]]):\n",
    "    for n in [200, 500, 1000, 2000, 5000, 10000]:\n",
    "        _, curr_res = load_ft_data(y_a, n)\n",
    "        curr_res[\"y_task\"] = y_a[0]\n",
    "        curr_res[\"a_task\"] = y_a[1]\n",
    "        celeba_df_ft_full.append(curr_res)\n",
    "celeba_df_ft_full = pd.concat(celeba_df_ft_full, ignore_index=True)\n",
    "celeba_df_ft_full['n'] = celeba_df_ft_full['n'].astype('int64')\n",
    "# change algorithm name\n",
    "alg_dict = {'ERM': 'ERM', 'GroupDRO': 'GroupDRO', 'OverSample': 'oversample', 'ReWeightLogits': 'remax-margin', 'UnderSample': 'undersample'}\n",
    "celeba_df_ft_full['algorithm'] = celeba_df_ft_full['algorithm'].map(alg_dict)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# merge the celeba_df_full and celeba_df_ft_full\n",
    "# celeba_df_full1 = pd.concat([celeba_df_full, celeba_df_ft_full], axis=1, ignore_index=True)\n",
    "celeba_df_full1 = celeba_df_full.merge(celeba_df_ft_full, on=celeba_identifier+['algorithm'], suffixes=('', '_ft'))\n",
    "celeba_df_full1.drop(columns=['wga_te_err', 'rank', 'winners', 'multi_hot'], inplace=True)\n",
    "# remove the ft suffix\n",
    "celeba_df_full1.columns = celeba_df_full1.columns.str.replace('_ft', '')\n",
    "\n",
    "winner_list = celeba_df_full1['winners'].tolist()\n",
    "new_winner_list = []\n",
    "for w in winner_list:\n",
    "    ws = w.split('|')\n",
    "    new_ws = [alg_dict[wi] for wi in ws]\n",
    "    new_winner_list.append('|'.join(new_ws))\n",
    "celeba_df_full1['winners'] = new_winner_list\n",
    "\n",
    "celeba_df1 = celeba_df_full1[celeba_df_full1['rank']==1.0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "celeba_df_full1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# other variables\n",
    "celeba_feats = ['n', 'sc', 'ci', 'ai', 'c_intra', 'a_intra', 'multi_hot']\n",
    "# celeba_feat = ['n', 'sc_', 'ci_', 'ai_', 'c_intra', 'a_intra_i', 'multi_hot']\n",
    "coco_feats = ['n', 'sc', 'ci', 'ai', 'c_intra', 'a_intra', 'multi_hot']\n",
    "# sanity check\n",
    "print(model)\n",
    "print(celeba_df1.shape, celeba_df_full1.shape)\n",
    "print(coco_df.shape, coco_df_full.shape)\n",
    "print(celeba_df.columns)\n",
    "print(coco_df.columns)\n",
    "print(celeba_df1['wga_te_err'].mean(), celeba_df_full1['wga_te_err'].mean())\n",
    "print(coco_df['wga_te_err'].mean(), coco_df_full['wga_te_err'].mean())\n",
    "\n",
    "(\n",
    "    X_train, X_test, y_train, y_test,\n",
    "    X_train_full, y_train_full, X_test_full, y_test_full,\n",
    "    tr_idx, te_idx,\n",
    "    X_test_coco, y_test_coco,\n",
    "    X_test_coco_full, y_test_coco_full,\n",
    "    X_train_alg, X_test_alg,\n",
    "    tr_idx_full, te_idx_full\n",
    ") = prepare_data(celeba_df1, celeba_df_full1, coco_df, coco_df_full, celeba_feats, coco_feats, seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### E2E"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "\n",
    "# Define the MLP model using nn.Sequential\n",
    "class MLPTorch(nn.Module):\n",
    "    def __init__(self, input_size, hidden_layer_sizes, output_size):\n",
    "        super(MLPTorch, self).__init__()\n",
    "        layers = []\n",
    "        layer_sizes = [input_size] + list(hidden_layer_sizes)\n",
    "        for i in range(len(layer_sizes) - 1):\n",
    "            layers.append(nn.Linear(layer_sizes[i], layer_sizes[i+1]))\n",
    "            layers.append(nn.ReLU())\n",
    "        layers.append(nn.Linear(layer_sizes[-1], output_size))\n",
    "        self.network = nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.network(x)\n",
    "\n",
    "# Training function\n",
    "def train_model(model, dataloader, criterion, optimizer, num_epochs, patience, tol, verbose=True):\n",
    "    best_loss = float('inf')\n",
    "    patience_counter = 0\n",
    "\n",
    "    for epoch in range(num_epochs):\n",
    "        model.train()\n",
    "        running_loss = 0.0\n",
    "\n",
    "        for inputs, labels, weights in dataloader:\n",
    "            optimizer.zero_grad()\n",
    "            outputs = model(inputs)\n",
    "            # print(outputs.shape, labels.shape)\n",
    "            loss = criterion(outputs, labels, weights)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            running_loss += loss.item() * inputs.size(0)\n",
    "\n",
    "        epoch_loss = running_loss / len(dataloader.dataset)\n",
    "        if verbose:\n",
    "            print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}')\n",
    "\n",
    "        if epoch_loss < best_loss - tol:\n",
    "            best_loss = epoch_loss\n",
    "            patience_counter = 0\n",
    "        else:\n",
    "            patience_counter += 1\n",
    "\n",
    "        if patience_counter >= patience:\n",
    "            print(\"Early stopping triggered\")\n",
    "            break\n",
    "\n",
    "    return model\n",
    "\n",
    "class WeightedBCELoss(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(WeightedBCELoss, self).__init__()\n",
    "        self.bce = nn.BCEWithLogitsLoss()\n",
    "\n",
    "    def forward(self, outputs, targets, weights):\n",
    "        loss = self.bce(outputs, targets)\n",
    "        # print(weights)\n",
    "        weighted_loss = loss * weights.unsqueeze(1)\n",
    "        return weighted_loss.mean()\n",
    "\n",
    "def sigmoid(x, k):\n",
    "    return 1 / (1 + np.exp(-k * x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_e2e(X_train, y_train, conf_train=None, verbose=True, naive_feat=False):\n",
    "    # Hyperparameters and data\n",
    "    input_size = X_train.shape[1]\n",
    "    hidden_layer_sizes = (100, 100, 50, 50)\n",
    "    output_size = len(algorithms)\n",
    "    num_epochs = 500\n",
    "    patience = 2000\n",
    "    tol = 1e-4\n",
    "    alpha = 0.0001\n",
    "    batch_size = 256\n",
    "\n",
    "    if conf_train is None:\n",
    "        conf_train = np.ones(y_train.shape[0])\n",
    "\n",
    "    # convert to torch tensors\n",
    "    X_train = torch.tensor(X_train).float()\n",
    "    y_train = torch.tensor(y_train).float()\n",
    "    conf_train = torch.tensor(conf_train).float()\n",
    "\n",
    "    if naive_feat:\n",
    "        X_train[:,1:] = 0\n",
    "\n",
    "    # Create dataloader\n",
    "    tr_dataset = TensorDataset(X_train, y_train, conf_train)\n",
    "    tr_dataloader = DataLoader(tr_dataset, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "    # Initialize model, criterion, and optimizer\n",
    "    model = MLPTorch(input_size, hidden_layer_sizes, output_size)\n",
    "\n",
    "    # criterion = nn.BCEWithLogitsLoss()\n",
    "    criterion = WeightedBCELoss()\n",
    "    optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=alpha)\n",
    "    # Train the model\n",
    "    trained_model = train_model(model, tr_dataloader, criterion, optimizer, num_epochs, patience, tol, verbose=verbose)\n",
    "\n",
    "    return trained_model\n",
    "\n",
    "def predict_e2e(model, X, naive_feat=False):\n",
    "    # check if X is tensor otherwise convert\n",
    "    if not torch.is_tensor(X):\n",
    "        X = torch.tensor(X).float()\n",
    "\n",
    "    if naive_feat:\n",
    "        X[:, 1:] = 0\n",
    "\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        y_pred = model(X)\n",
    "        # loss = criterion(y_pred, y)\n",
    "        # # print(f\"Eval loss: {loss.item()}\")\n",
    "\n",
    "        y_pred = torch.sigmoid(y_pred)\n",
    "        y_pred = (y_pred > 0.5).int()\n",
    "        y_pred = y_pred.numpy()\n",
    "    return y_pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trained_model = run_e2e(X_train, y_train, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred = predict_e2e(trained_model, X_test)\n",
    "_ = eval_acc(y_test, y_pred, mode='0-1', verbose=True)\n",
    "_ = eval_acc(y_test, y_pred, mode='soft 0-1', verbose=True)\n",
    "_ = eval_wga_err(celeba_df1.iloc[te_idx], y_pred, celeba_df_full1, celeba_identifier)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Regression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_regression(X_train, y_train, verbose=True, naive_feat=False):\n",
    "    from sklearn.neural_network import MLPRegressor\n",
    "\n",
    "    if naive_feat:\n",
    "        X_train[:, 1:] = 0\n",
    "\n",
    "    # trained_model = MLPRegressor(random_state=0, max_iter=500, verbose=verbose, tol=1e-3, n_iter_no_change=100, alpha=0.1, hidden_layer_sizes=(100,100,50,50)).fit(X_train, y_train)\n",
    "    trained_model = MLPRegressor(random_state=3, max_iter=500, verbose=verbose, alpha=0.1, hidden_layer_sizes=(100,100,50,50)).fit(X_train, y_train)\n",
    "\n",
    "    return trained_model\n",
    "\n",
    "def predict_regression(model, X, naive_feat=False):\n",
    "\n",
    "    def get_rank(x):\n",
    "        min_err = x.min()\n",
    "        return (x <= min_err + 0.05).astype(int)\n",
    "\n",
    "    if naive_feat:\n",
    "        X_train[:, 1:] = 0\n",
    "\n",
    "    y_pred = model.predict(X)\n",
    "    y_pred_agg = []\n",
    "    assert len(y_pred) % 5 == 0\n",
    "    for i in range(len(y_pred)//5):\n",
    "        curr_pred = y_pred[5*i:5*i+5]\n",
    "        y_pred_agg.append(get_rank(curr_pred))\n",
    "    return y_pred_agg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trained_model = run_regression(X_train_full, y_train_full, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred = predict_regression(trained_model, X_test_full, naive_feat=False)\n",
    "_ = eval_acc(y_test, y_pred, mode='0-1', verbose=True)\n",
    "_ = eval_acc(y_test, y_pred, mode='soft 0-1', verbose=True)\n",
    "_ = eval_wga_err(celeba_df1.iloc[te_idx], y_pred, celeba_df_full1, celeba_identifier)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Ensemble"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_binary_classifier(X_train, y_train, X_train_alg, method1, method2, filter_thre=0.05, verbose=True, naive_feat=False):\n",
    "    from sklearn.neural_network import MLPClassifier\n",
    "\n",
    "    m1_idx = X_train_alg == algorithms.index(method1)\n",
    "    m2_idx = X_train_alg == algorithms.index(method2)\n",
    "\n",
    "    def lossdiff2rank(x):\n",
    "        if x < -filter_thre:\n",
    "            return 1\n",
    "        elif x > filter_thre:\n",
    "            return 0\n",
    "        else:\n",
    "            return 2\n",
    "    y_diff = y_train[m1_idx] - y_train[m2_idx]\n",
    "    # apply lossdiff2rank\n",
    "    y_diff = np.array([lossdiff2rank(x) for x in y_diff])\n",
    "\n",
    "    X_train = X_train[:, :-1][m1_idx]\n",
    "    y_train = y_diff\n",
    "\n",
    "    if naive_feat:\n",
    "        X_train[:, 1:] = 0\n",
    "\n",
    "    # clf = MLPClassifier(random_state=1, max_iter=1000000, verbose=False, tol=1e-3, n_iter_no_change=2000, alpha=0.01, hidden_layer_sizes=(100,10,)).fit(X_train, y_train)\n",
    "    clf = MLPClassifier(random_state=2, max_iter=500, verbose=False, tol=1e-3, n_iter_no_change=2000, alpha=0.1, hidden_layer_sizes=(100,100,50,50)).fit(X_train, y_train)\n",
    "    train_acc = clf.score(X_train, y_train)\n",
    "\n",
    "    if verbose:\n",
    "        print(method1, method2, \"train size: \", len(X_train), \"train acc: \", train_acc)\n",
    "\n",
    "    return clf\n",
    "\n",
    "def run_ensemble(X_train, y_train, X_train_alg, naive_feat=False):\n",
    "    import itertools\n",
    "    combinations = np.array(list(itertools.combinations(algorithms, 2)))\n",
    "\n",
    "    trained_model = {}\n",
    "    for m1, m2 in combinations:\n",
    "        trained_model[(m1, m2)] = train_binary_classifier(X_train, y_train, X_train_alg, m1, m2, filter_thre=0.05, naive_feat=naive_feat)\n",
    "    return trained_model\n",
    "\n",
    "def predict_ensemble(model, X, naive_feat=False):\n",
    "    def copeland_method(candidate_names, pairwise_results):\n",
    "        # Initialize scores dictionary\n",
    "        scores = {name: 0 for name in candidate_names}\n",
    "        # Update scores based on pairwise results\n",
    "        for result in pairwise_results:\n",
    "            A, B, winner = result\n",
    "            if winner == A:\n",
    "                scores[A] += 1\n",
    "                scores[B] -= 1\n",
    "            elif winner == B:\n",
    "                scores[B] += 1\n",
    "                scores[A] -= 1\n",
    "        winners = [name for name in candidate_names if scores[name] == max(scores.values())]\n",
    "        return winners, scores\n",
    "\n",
    "    def extract_pairwise_results(zoo, X_test):\n",
    "        res = []\n",
    "        # if 1d then expand to 2d\n",
    "        if len(X_test.shape) == 1:\n",
    "            X_test = X_test.reshape(1, -1)\n",
    "        for (m1, m2), clf in zoo.items():\n",
    "            y_pred = clf.predict(X_test)\n",
    "            if y_pred == 1:\n",
    "                res.append((m1, m2, m1))\n",
    "            elif y_pred == 0:\n",
    "                res.append((m1, m2, m2))\n",
    "            else:\n",
    "                res.append((m1, m2, m1))\n",
    "                res.append((m1, m2, m2))\n",
    "        return res\n",
    "    winners = []\n",
    "\n",
    "    if naive_feat:\n",
    "        X[:, 1:] = 0\n",
    "\n",
    "    for curr_x in X:\n",
    "        pair_res = extract_pairwise_results(model, curr_x)\n",
    "        winner, _ = copeland_method(algorithms, pair_res)\n",
    "        winners.append(winner)\n",
    "    # convert winners to multi-hot\n",
    "    y_pred = np.zeros((X.shape[0], len(algorithms)))\n",
    "    for i, w in enumerate(winners):\n",
    "        inds = [algorithms.index(m) for m in w]\n",
    "        y_pred[i, inds] = 1\n",
    "    return y_pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trained_model = run_ensemble(X_train_full, y_train_full, X_train_alg, naive_feat=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred = predict_ensemble(trained_model, X_test, naive_feat=False)\n",
    "_ = eval_acc(y_test, y_pred, mode='0-1', verbose=True)\n",
    "_ = eval_acc(y_test, y_pred, mode='soft 0-1', verbose=True)\n",
    "_ = eval_wga_err(celeba_df1.iloc[te_idx], y_pred, celeba_df_full1, celeba_identifier)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "div_backup",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
