{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch import nn\n",
    "import copy\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Multi-classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_data(path, seeds=[0]):\n",
    "    all_df_seeds = []\n",
    "    for s in seeds:\n",
    "        if s > 1: folder = 'toy_v3'\n",
    "        else: folder = 'toy_v2'\n",
    "        # replace the path\n",
    "        path = path.replace('toy_v2', folder)\n",
    "        path = path.replace('seed0', f'seed{s}')\n",
    "\n",
    "        all_df = pd.read_csv(path)\n",
    "        all_df['seed'] = s\n",
    "        # all_df['n'] = all_df['n'].astype('str')\n",
    "        all_df_seeds.append(all_df)\n",
    "    all_df_seeds = pd.concat(all_df_seeds)\n",
    "    all_df = all_df_seeds.groupby(['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat', 'method']).mean().reset_index()\n",
    "    all_df['rank'] = all_df.groupby(['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat'])['wga_te_err'].rank(\"first\")\n",
    "    all_df['wga_te_err_var'] = all_df.groupby(['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat'])['wga_te_err'].transform('var')\n",
    "\n",
    "\n",
    "    def get_gt_rank(x, mode='tie', filter_thre=0.05):\n",
    "        # x is a dataframe with columns ['method', 'wga_te_err'], find the method(s) with smallest wga_te_err, with filter_thre tolerance for tie\n",
    "        min_err = x['wga_te_err'].min()\n",
    "        winners = x[x['wga_te_err'] <= min_err + filter_thre][\"method\"].to_list()\n",
    "        return '|'.join(winners)\n",
    "\n",
    "    # Apply the function to each group and reset the index to merge back with the original dataframe\n",
    "    winners_series = all_df.groupby(['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat'])[['method', 'wga_te_err']].apply(lambda x: get_gt_rank(x)).reset_index(name='winners')\n",
    "    # Merge the results back with the original dataframe\n",
    "    all_df = all_df.merge(winners_series, on=['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat'])\n",
    "\n",
    "    def get_conf(x, mode='tie', filter_thre=0.05):\n",
    "        # print(x)\n",
    "        min_err = x['wga_te_err'].min()\n",
    "        winners = x[x['wga_te_err'] <= min_err + filter_thre][\"method\"].to_list()\n",
    "        # get minimum of winner and maximum of loser\n",
    "        winner_max = x[x['method'].isin(winners)]['wga_te_err'].max()\n",
    "        loser_min = x[~x['method'].isin(winners)]['wga_te_err'].min()\n",
    "        return loser_min - winner_max\n",
    "    conf_series = all_df.groupby(['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat'])[['method', 'wga_te_err']].apply(lambda x: get_conf(x)).reset_index(name='conf')\n",
    "    all_df = all_df.merge(conf_series, on=['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat'])\n",
    "\n",
    "    # methods = [\"ERM\", \"GroupDRO\", \"remax-margin\", \"oversample\", \"undersample\"]\n",
    "    methods = ['ERM', 'GroupDRO', 'oversample', 'remax-margin', 'undersample']\n",
    "    all_df[\"multi_hot\"] = all_df[\"winners\"].map(lambda x: [1 if m in x.split(\"|\") else 0 for m in methods])\n",
    "\n",
    "    return all_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "seeds = [0,2,3]\n",
    "path = f\"//exps/div_explore/toy_v2/results_seed0_sc.csv\"\n",
    "sc_df = load_data(path, seeds)\n",
    "\n",
    "path = f\"//exps/div_explore/toy_v2/results_seed0_ci.csv\"\n",
    "ci_df = load_data(path, seeds)\n",
    "\n",
    "path = f\"//exps/div_explore/toy_v2/results_seed0_ai.csv\"\n",
    "ai_df = load_data(path, seeds)\n",
    "\n",
    "path = \"//exps/div_explore/toy_v2/results_seed0_3shifts.csv\"\n",
    "mix_df = load_data(path)\n",
    "\n",
    "all_df = pd.concat([sc_df, ci_df, ai_df, mix_df], ignore_index=True)\n",
    "# deduplicate all_df\n",
    "# all_df = all_df.groupby(['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat', 'method'])\n",
    "# all_df = all_df.groupby(['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat', 'method']).first().reset_index()\n",
    "# all_df = pd.concat([sc_df, ci_df, ai_df], ignore_index=True)\n",
    "# all_df = sc_df\n",
    "all_df_rank = all_df[all_df[\"rank\"]==1.0]\n",
    "\n",
    "# set all nan to 0\n",
    "all_df_rank['conf'].fillna(0, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_df_rank[\"winners\"].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_df_rank['conf'].hist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "\n",
    "# Define the MLP model using nn.Sequential\n",
    "class MLPTorch(nn.Module):\n",
    "    def __init__(self, input_size, hidden_layer_sizes, output_size):\n",
    "        super(MLPTorch, self).__init__()\n",
    "        layers = []\n",
    "        layer_sizes = [input_size] + list(hidden_layer_sizes)\n",
    "        for i in range(len(layer_sizes) - 1):\n",
    "            layers.append(nn.Linear(layer_sizes[i], layer_sizes[i+1]))\n",
    "            layers.append(nn.ReLU())\n",
    "        layers.append(nn.Linear(layer_sizes[-1], output_size))\n",
    "        self.network = nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.network(x)\n",
    "\n",
    "class FullModel(nn.Module):\n",
    "    def __init__(self, emb_mlp, mlp):\n",
    "        super(FullModel, self).__init__()\n",
    "        self.emb_mlp = emb_mlp  # Embedding model\n",
    "        self.mlp = mlp  # MLP model\n",
    "\n",
    "    def forward(self, x):\n",
    "        # TODO\n",
    "        x = self.emb_mlp(x)\n",
    "        return self.mlp(x)\n",
    "\n",
    "# Training function\n",
    "def train_model(model, dataloader, criterion, optimizer, num_epochs, patience, tol, X_val=None, y_val=None, verbose=True):\n",
    "    best_loss = float('inf')\n",
    "    patience_counter = 0\n",
    "\n",
    "    for epoch in range(num_epochs):\n",
    "        model.train()\n",
    "        running_loss = 0.0\n",
    "\n",
    "        for inputs, labels, weights in dataloader:\n",
    "            optimizer.zero_grad()\n",
    "            outputs = model(inputs)\n",
    "            # print(outputs.shape, labels.shape)\n",
    "            loss = criterion(outputs, labels, weights)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            running_loss += loss.item() * inputs.size(0)\n",
    "\n",
    "        epoch_loss = running_loss / len(dataloader.dataset)\n",
    "        if verbose:\n",
    "            print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}')\n",
    "\n",
    "        if epoch_loss < best_loss - tol:\n",
    "            best_loss = epoch_loss\n",
    "            patience_counter = 0\n",
    "        else:\n",
    "            patience_counter += 1\n",
    "\n",
    "        if patience_counter >= patience:\n",
    "            print(\"Early stopping triggered\")\n",
    "            break\n",
    "\n",
    "        if epoch % 5 == 0 and X_val is not None:\n",
    "            eval_model(model, X_val, y_val, criterion)\n",
    "\n",
    "    return model\n",
    "\n",
    "def eval_model(model, X, y, criterion, mode=\"0-1\", verbose=True):\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        y_pred = model(X)\n",
    "        # loss = criterion(y_pred, y)\n",
    "        # # print(f\"Eval loss: {loss.item()}\")\n",
    "\n",
    "        y_pred = torch.sigmoid(y_pred)\n",
    "        y_pred = (y_pred > 0.5).float()\n",
    "        if mode == \"0-1\":\n",
    "            acc = ((y_pred == y).float().sum(dim=1)==5).float().mean()\n",
    "            acc = acc.item()\n",
    "            if verbose:\n",
    "                print(f\"Eval accuracy: {acc}\")\n",
    "        elif mode == \"soft 0-1\":\n",
    "            correct = 0\n",
    "            for i, curr_y in enumerate(y):\n",
    "                curr_pred = y_pred[i]\n",
    "                # check if the positions of 1s in curr_pred are also 1s in curr_y\n",
    "                pos = torch.nonzero(curr_pred).squeeze()\n",
    "                if torch.all(curr_y[pos] == 1):\n",
    "                    correct += 1\n",
    "            acc = correct / y.shape[0]\n",
    "            if verbose:\n",
    "                print(f\"Eval accuracy: {acc}\")\n",
    "        elif mode == \"jaccard\":\n",
    "            # Jaccard similarity\n",
    "            intersection = torch.sum(y_pred * y, dim=1)\n",
    "            union = torch.sum(y_pred + y, dim=1)\n",
    "            acc = intersection / union\n",
    "            if verbose:\n",
    "                print(f\"Eval Jaccard: {acc.mean().item()}\")\n",
    "        # elif mode == \"err\":\n",
    "\n",
    "\n",
    "    return model, acc\n",
    "\n",
    "class WeightedBCELoss(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(WeightedBCELoss, self).__init__()\n",
    "        self.bce = nn.BCEWithLogitsLoss()\n",
    "\n",
    "    def forward(self, outputs, targets, weights):\n",
    "        loss = self.bce(outputs, targets)\n",
    "        # print(weights)\n",
    "        weighted_loss = loss * weights.unsqueeze(1)\n",
    "        return weighted_loss.mean()\n",
    "\n",
    "def sigmoid(x, k):\n",
    "    return 1 / (1 + np.exp(-k * x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "# ignore the warnings\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "# fix all the seeds in numpy and torch\n",
    "verbose=False\n",
    "seed_list = [2023, 2024, 2025]\n",
    "k_list = [0,1,2.5,5,10,15,20,50,100]\n",
    "for k in k_list:\n",
    "    hard_acc, soft_acc, err = [], [], []\n",
    "    for seed in seed_list:\n",
    "        np.random.seed(seed)\n",
    "        torch.manual_seed(seed)\n",
    "        all_df_rank['sigmoid_conf'] = sigmoid(all_df_rank['conf'], k)\n",
    "\n",
    "        feat_list = ['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat', 'multi_hot']\n",
    "        # feat_list = ['n', 'multi_hot']\n",
    "\n",
    "        all_df_rank_data = all_df_rank[feat_list]\n",
    "\n",
    "        all_X = all_df_rank_data.to_numpy()[:,:-1].astype('float')\n",
    "        all_y = np.array(list(all_df_rank_data.to_numpy()[:, -1]))\n",
    "\n",
    "        # all_X[:, [1,2,3,4,5]] = 1\n",
    "\n",
    "        # check if 'conf' exists\n",
    "        if 'sigmoid_conf' in all_df_rank.columns:\n",
    "            all_conf = all_df_rank['sigmoid_conf'].to_numpy() + 0.5\n",
    "        else:\n",
    "            # all ones\n",
    "            all_conf = np.ones_like(all_y)\n",
    "\n",
    "        from sklearn.preprocessing import StandardScaler\n",
    "        from sklearn.model_selection import train_test_split\n",
    "\n",
    "        X_train, X_test, y_train, y_test, conf_train, conf_test, tr_idx, te_idx = train_test_split(all_X, all_y, all_conf, range(len(all_X)), test_size=0.2, random_state=0)\n",
    "        X_val, y_val = None, None\n",
    "\n",
    "        scaler = StandardScaler()\n",
    "        X_train = scaler.fit_transform(X_train)\n",
    "        X_test = scaler.transform(X_test)\n",
    "\n",
    "        X_train, X_test = torch.tensor(X_train).float(), torch.tensor(X_test).float()\n",
    "        y_train, y_test = torch.tensor(y_train).float(), torch.tensor(y_test).float()\n",
    "        conf_train, conf_test = torch.tensor(conf_train).float(), torch.tensor(conf_test).float()\n",
    "\n",
    "        cols = list(range(X_train.shape[1]))\n",
    "        X_train, X_test = X_train[:,cols], X_test[:,cols]\n",
    "        # X_train = torch.ones_like(X_train)\n",
    "        # X_test = torch.ones_like(X_test)\n",
    "\n",
    "        # Hyperparameters and data\n",
    "        input_size = X_train.shape[1]\n",
    "        hidden_layer_sizes = (100, 100, 50, 50)\n",
    "        output_size = 5\n",
    "        num_epochs = 500\n",
    "        patience = 2000\n",
    "        tol = 1e-4\n",
    "        alpha = 0.0001\n",
    "        batch_size = 256\n",
    "\n",
    "        # Create dataloader\n",
    "        tr_dataset = TensorDataset(X_train, y_train, conf_train)\n",
    "        tr_dataloader = DataLoader(tr_dataset, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "        # Initialize model, criterion, and optimizer\n",
    "        model = MLPTorch(input_size, hidden_layer_sizes, output_size)\n",
    "\n",
    "        # criterion = nn.BCEWithLogitsLoss()\n",
    "        criterion = WeightedBCELoss()\n",
    "        optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=alpha)\n",
    "        # Train the model\n",
    "        trained_model = train_model(model, tr_dataloader, criterion, optimizer, num_epochs, patience, tol, X_val, y_val, verbose=verbose)\n",
    "\n",
    "        # Eval on train\n",
    "        _, curr_acc = eval_model(trained_model, X_train, y_train, criterion, verbose=verbose)\n",
    "        _, curr_hard_acc = eval_model(trained_model, X_test, y_test, criterion, verbose=verbose)\n",
    "        _, curr_soft_acc = eval_model(trained_model, X_test, y_test, criterion, mode=\"soft 0-1\", verbose=verbose)\n",
    "        hard_acc.append(curr_hard_acc)\n",
    "        soft_acc.append(curr_soft_acc)\n",
    "\n",
    "        # methods = [\"ERM\", \"GroupDRO\", \"remax-margin\", \"oversample\", \"undersample\"]\n",
    "        methods = ['ERM', 'GroupDRO', 'oversample', 'remax-margin', 'undersample']\n",
    "        test_df = all_df_rank.iloc[te_idx]\n",
    "        y_pred = trained_model(X_test)\n",
    "        y_pred = torch.sigmoid(y_pred)\n",
    "        y_pred = (y_pred > 0.5).float()\n",
    "        winners_pred = []\n",
    "        for y_ in y_pred:\n",
    "            winners_pred.append(np.array(methods)[np.where(y_==1)[0]])\n",
    "        test_df['winners_pred'] = winners_pred\n",
    "        fine_test_df = all_df.merge(test_df, on=['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat'])\n",
    "\n",
    "        def get_pred_wga_te_err(x):\n",
    "            pred_winners = x['winners_pred'].iloc[0]\n",
    "            if len(pred_winners) == 0:\n",
    "                # randomly select one\n",
    "                pred_winner = np.random.choice(methods)\n",
    "            else:\n",
    "                # random select from the predicted\n",
    "                pred_winner = np.random.choice(pred_winners)\n",
    "            return x[x['method_x']==pred_winner]['wga_te_err_x'].iloc[0]\n",
    "        pred_err = fine_test_df.groupby(['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat'])['wga_te_err_x', 'method_x' ,'winners_pred'].apply(lambda x: get_pred_wga_te_err(x))\n",
    "        err.append(pred_err.mean())\n",
    "\n",
    "    print(k, np.mean(hard_acc), np.std(hard_acc), np.mean(soft_acc), np.std(soft_acc), np.mean(err), np.std(err))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "all_df_rank = copy.deepcopy(all_df_rank)\n",
    "\n",
    "all_df_rank_data = all_df_rank[['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat', 'multi_hot']]\n",
    "\n",
    "all_X = all_df_rank_data.to_numpy()[:,:-1].astype('float')\n",
    "all_y = np.array(list(all_df_rank_data.to_numpy()[:, -1]))\n",
    "\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.neural_network import MLPClassifier\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "\n",
    "X_train, X_test, y_train, y_test, tr_idx, te_idx = train_test_split(all_X, all_y, range(len(all_X)), test_size=0.2, random_state=0)\n",
    "\n",
    "scaler = StandardScaler()\n",
    "X_train = scaler.fit_transform(X_train)\n",
    "X_test = scaler.transform(X_test)\n",
    "\n",
    "# ['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat']\n",
    "cols = list(range(X_train.shape[1]))\n",
    "# cols.remove(0)\n",
    "# cols = [0, 5]\n",
    "X_train, X_test = X_train[:,cols], X_test[:,cols]\n",
    "\n",
    "# clf = KNeighborsClassifier(n_neighbors=3).fit(X_train, y_train)\n",
    "# clf = LogisticRegression(random_state=0, max_iter=int(1e8), verbose=True, C=0.1).fit(X_train, y_train)\n",
    "clf = MLPClassifier(random_state=0, max_iter=10000, verbose=True, tol=1e-3, n_iter_no_change=2000, alpha=0.1, hidden_layer_sizes=(100, 100, 50, 50)).fit(X_train, y_train)\n",
    "\n",
    "# solver=\"sgd\", learning_rate_init=0.01\n",
    "print(clf.score(X_train, y_train))\n",
    "print(clf.score(X_test, y_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred = clf.predict(X_test)\n",
    "correct = 0\n",
    "for i, curr_y in enumerate(y_test):\n",
    "    curr_pred = y_pred[i]\n",
    "    # check if the positions of 1s in curr_pred are also 1s in curr_y\n",
    "    pos = np.nonzero(curr_pred)\n",
    "    if np.all(curr_y[pos] == 1):\n",
    "        correct += 1\n",
    "acc = correct / y_test.shape[0]\n",
    "print(f\"Eval accuracy: {acc}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "methods = [\"ERM\", \"GroupDRO\", \"remax-margin\", \"oversample\", \"undersample\"]\n",
    "test_df = all_df_rank.iloc[te_idx]\n",
    "y_pred = clf.predict(X_test)\n",
    "winners_pred = []\n",
    "for y_ in y_pred:\n",
    "    winners_pred.append(np.array(methods)[np.where(y_==1)[0]])\n",
    "test_df['winners_pred'] = winners_pred\n",
    "fine_test_df = all_df.merge(test_df, on=['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_pred_wga_te_err(x):\n",
    "    pred_winners = x['winners_pred'].iloc[0]\n",
    "    if len(pred_winners) == 0:\n",
    "        # randomly select one\n",
    "        pred_winner = np.random.choice(methods)\n",
    "    else:\n",
    "        # random select from the predicted\n",
    "        pred_winner = np.random.choice(pred_winners)\n",
    "    return x[x['method_x']==pred_winner]['wga_te_err_x'].iloc[0]\n",
    "    # err = 0.0\n",
    "    # for w in pred_winners:\n",
    "    #     err += x[x['method_x']==w]['wga_te_err_x'].iloc[0]\n",
    "    # err /= len(pred_winners)\n",
    "    # return err\n",
    "pred_err = fine_test_df.groupby(['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat'])['wga_te_err_x', 'method_x' ,'winners_pred'].apply(lambda x: get_pred_wga_te_err(x))\n",
    "print(pred_err.mean())\n",
    "print(test_df['wga_te_err'].mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_pred_wga_te_err(x):\n",
    "    pred_winners = x['winners_x'].iloc[0]\n",
    "    pred_winners = pred_winners.split(\"|\")\n",
    "    err = 0.0\n",
    "    for w in pred_winners:\n",
    "        err += x[x['method_x']==w]['wga_te_err_x'].iloc[0]\n",
    "    err /= len(pred_winners)\n",
    "    return err\n",
    "pred_err = fine_test_df.groupby(['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat'])['wga_te_err_x', 'method_x' ,'winners_x'].apply(lambda x: get_pred_wga_te_err(x))\n",
    "print(pred_err.mean())\n",
    "print(test_df['wga_te_err'].mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "methods = [\"ERM\", \"GroupDRO\", \"remax-margin\", \"oversample\", \"undersample\"]\n",
    "def jac_sim(list1, list2):\n",
    "        set1, set2 = set(list1), set(list2)\n",
    "        intersection = len(set1.intersection(set2))\n",
    "        union = len(set1.union(set2))\n",
    "        return intersection / union\n",
    "y_pred = clf.predict(X_test)\n",
    "\n",
    "winners_gt = []\n",
    "for y_ in y_test:\n",
    "    winners_gt.append(np.array(methods)[np.where(y_==1)[0]])\n",
    "\n",
    "winners_pred = []\n",
    "for y_ in y_pred:\n",
    "    winners_pred.append(np.array(methods)[np.where(y_==1)[0]])\n",
    "\n",
    "res = []\n",
    "for i in range(len(winners_pred)):\n",
    "    res.append(jac_sim(winners_gt[i], winners_pred[i]))\n",
    "acc = np.mean(res)\n",
    "print(f\"accuracy: \", acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred = clf.predict_proba(X_test)\n",
    "pred_rank = clf.predict_proba(X_test).argsort(axis=1)\n",
    "# te_idx and all_df_rank\n",
    "real_err = []\n",
    "for te in te_idx:\n",
    "    wga_te_err = all_df.iloc[te:te+5][\"wga_te_err\"].to_numpy()\n",
    "    real_err.append(wga_te_err)\n",
    "pearson = []\n",
    "for i, r in enumerate(real_err):\n",
    "    ranks = np.empty_like(pred_rank[i])\n",
    "    ranks[pred_rank[i]] = np.arange(5)\n",
    "    # print(y_pred[i], ranks, r)\n",
    "    curr_pearson = np.corrcoef(r, ranks)[0,1]\n",
    "    pearson.append(curr_pearson)\n",
    "pearson = np.array(pearson)\n",
    "pearson = pearson[~np.isnan(pearson)]\n",
    "np.mean(pearson)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Regression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 129,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# def load_data(path):\n",
    "#     all_df = pd.read_csv(path)\n",
    "#     all_df['n'] = all_df['n'].astype('str')\n",
    "\n",
    "#     return all_df\n",
    "\n",
    "def load_data(path, seeds=[0]):\n",
    "    all_df_seeds = []\n",
    "    for s in seeds:\n",
    "        if s > 1: folder = 'toy_v3'\n",
    "        else: folder = 'toy_v2'\n",
    "        # replace the path\n",
    "        path = path.replace('toy_v2', folder)\n",
    "        path = path.replace('seed0', f'seed{s}')\n",
    "\n",
    "        all_df = pd.read_csv(path)\n",
    "        all_df['seed'] = s\n",
    "        # all_df['n'] = all_df['n'].astype('str')\n",
    "        all_df_seeds.append(all_df)\n",
    "    all_df_seeds = pd.concat(all_df_seeds)\n",
    "    all_df = all_df_seeds.groupby(['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat', 'method']).mean().reset_index()\n",
    "\n",
    "    return all_df\n",
    "\n",
    "seeds = [0,2,3]\n",
    "path = f\"//exps/div_explore/toy_v2/results_seed0_sc.csv\"\n",
    "sc_df = load_data(path, seeds)\n",
    "\n",
    "path = f\"//exps/div_explore/toy_v2/results_seed0_ci.csv\"\n",
    "ci_df = load_data(path, seeds)\n",
    "\n",
    "path = f\"//exps/div_explore/toy_v2/results_seed0_ai.csv\"\n",
    "ai_df = load_data(path, seeds)\n",
    "\n",
    "path = \"//exps/div_explore/toy_v2/results_seed0_3shifts.csv\"\n",
    "mix_df = load_data(path)\n",
    "\n",
    "all_df = pd.concat([sc_df, ci_df, ai_df, mix_df], ignore_index=True)\n",
    "# all_df = all_df.groupby(['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat', 'method']).first().reset_index()\n",
    "all_df = all_df[['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat', 'method', 'wga_te_err']]\n",
    "\n",
    "\n",
    "methods = np.array([\"ERM\", \"GroupDRO\", \"remax-margin\", \"oversample\", \"undersample\"])\n",
    "all_df['method'] = all_df['method'].map(lambda x: np.where(methods==x)[0][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "from sklearn.model_selection import train_test_split\n",
    "tr_idx, te_idx = train_test_split(range(int(len(all_df)/5)), test_size=0.2, random_state=0)\n",
    "\n",
    "all_df_rank = copy.deepcopy(all_df)\n",
    "\n",
    "all_df_rank_data = all_df_rank[['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat', 'method', 'wga_te_err']]\n",
    "\n",
    "all_X = all_df_rank_data.to_numpy()[:,:-1].astype('float')\n",
    "all_y = np.array(list(all_df_rank_data.to_numpy()[:, -1]))\n",
    "\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.neural_network import MLPRegressor\n",
    "\n",
    "\n",
    "tr_idx = np.concatenate([np.arange(5*i, 5*i+5) for i in tr_idx])\n",
    "X_train = all_X[tr_idx]\n",
    "y_train = all_y[tr_idx]\n",
    "# the remaining ones are test\n",
    "te_idx = np.concatenate([np.arange(5*i, 5*i+5) for i in te_idx])\n",
    "X_test = all_X[te_idx]\n",
    "y_test = all_y[te_idx]\n",
    "\n",
    "scaler = StandardScaler()\n",
    "X_train = scaler.fit_transform(X_train)\n",
    "X_test = scaler.transform(X_test)\n",
    "\n",
    "cols = list(range(X_train.shape[1]))\n",
    "# cols.remove(0)\n",
    "X_train, X_test = X_train[:,cols], X_test[:,cols]\n",
    "\n",
    "clf = MLPRegressor(random_state=0, max_iter=10000, verbose=True, tol=1e-3, n_iter_no_change=100, alpha=0.1, hidden_layer_sizes=(100,100,50,50)).fit(X_train, y_train)\n",
    "\n",
    "y_pred = clf.predict(X_test)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 131,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_df = all_df.iloc[te_idx]\n",
    "test_df['pred'] = y_pred\n",
    "\n",
    "methods = np.array([\"ERM\", \"GroupDRO\", \"remax-margin\", \"oversample\", \"undersample\"])\n",
    "test_df['method'] = test_df['method'].map(lambda x: methods[x])\n",
    "\n",
    "def get_rank(x, col, mode='tie', filter_thre=0.05):\n",
    "    # x is a dataframe with columns ['method', 'wga_te_err'], find the method(s) with smallest wga_te_err, with filter_thre tolerance for tie\n",
    "    min_err = x[col].min()\n",
    "    winners = x[x[col] <= min_err + filter_thre][\"method\"].to_list()\n",
    "    return '|'.join(winners)\n",
    "\n",
    "# Apply the function to each group and reset the index to merge back with the original dataframe\n",
    "winners_series = test_df.groupby(['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat'])[['method', 'wga_te_err']].apply(lambda x: get_rank(x, 'wga_te_err')).reset_index(name='winners')\n",
    "# Merge the results back with the original dataframe\n",
    "test_df = test_df.merge(winners_series, on=['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat'])\n",
    "methods = [\"ERM\", \"GroupDRO\", \"remax-margin\", \"oversample\", \"undersample\"]\n",
    "test_df[\"multi_hot\"] = test_df[\"winners\"].map(lambda x: [1 if m in x.split(\"|\") else 0 for m in methods])\n",
    "\n",
    "\n",
    "# Apply the function to each group and reset the index to merge back with the original dataframe\n",
    "winners_series = test_df.groupby(['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat'])[['method', 'pred']].apply(lambda x: get_rank(x, 'pred')).reset_index(name='winners_pred')\n",
    "# Merge the results back with the original dataframe\n",
    "test_df = test_df.merge(winners_series, on=['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat'])\n",
    "methods = [\"ERM\", \"GroupDRO\", \"remax-margin\", \"oversample\", \"undersample\"]\n",
    "test_df[\"multi_hot_pred\"] = test_df[\"winners_pred\"].map(lambda x: [1 if m in x.split(\"|\") else 0 for m in methods])\n",
    "\n",
    "\n",
    "def get_all_rank(x, col, mode='tie', filter_thre=0.05):\n",
    "    # x is a dataframe with columns ['method', 'wga_te_err'], find the method(s) with smallest wga_te_err, with filter_thre tolerance for tie\n",
    "    try:\n",
    "        float_list = x[col].to_list()\n",
    "        if len(float_list) != 5:\n",
    "            raise ValueError(\"The input list must contain exactly 5 float numbers.\")\n",
    "\n",
    "        sorted_list = sorted(float_list)\n",
    "        ranks = [None] * 5\n",
    "        current_rank = 0\n",
    "        ranks[0] = current_rank\n",
    "\n",
    "        for i in range(1, 5):\n",
    "            if sorted_list[i] - sorted_list[i - 1] < 0.05:\n",
    "                ranks[i] = current_rank\n",
    "            else:\n",
    "                current_rank += 1\n",
    "                ranks[i] = current_rank\n",
    "\n",
    "        rank_map = {v: ranks[i] for i, v in enumerate(sorted_list)}\n",
    "        result = [rank_map[v] for v in float_list]\n",
    "    except:\n",
    "        result = [None] * 5\n",
    "    return result\n",
    "# Apply the function to each group and reset the index to merge back with the original dataframe\n",
    "ranking_series = test_df.groupby(['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat'])[['method', 'pred']].apply(lambda x: get_all_rank(x, 'pred')).reset_index(name='pred_rank')\n",
    "# Merge the results back with the original dataframe\n",
    "test_df = test_df.merge(ranking_series, on=['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred = test_df.iloc[::5, :]['multi_hot_pred'].to_numpy()\n",
    "y_test = test_df.iloc[::5, :]['multi_hot'].to_numpy()\n",
    "correct = 0\n",
    "for i, curr_y in enumerate(y_test):\n",
    "    curr_pred = np.array(y_pred[i])\n",
    "    curr_y = np.array(curr_y)\n",
    "    # check if the positions of 1s in curr_pred are also 1s in curr_y\n",
    "    pos = np.nonzero(curr_pred)\n",
    "    # print(curr_y)\n",
    "    if np.all(curr_y[pos] == 1):\n",
    "        correct += 1\n",
    "acc = correct / y_test.shape[0]\n",
    "print(f\"Eval accuracy: {acc}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_pred_wga_te_err(x):\n",
    "    pred_winners = x['winners_pred'].iloc[0]\n",
    "    pred_winners = pred_winners.split(\"|\")\n",
    "    pred_winner = np.random.choice(pred_winners)\n",
    "    return x[x['method']==pred_winner]['wga_te_err'].iloc[0]\n",
    "    # err = 0.0\n",
    "    # for w in pred_winners:\n",
    "    #     err += x[x['method']==w]['wga_te_err'].iloc[0]\n",
    "    # err /= len(pred_winners)\n",
    "    # return err\n",
    "pred_err = test_df.groupby(['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat'])['wga_te_err', 'method' ,'winners_pred'].apply(lambda x: get_pred_wga_te_err(x))\n",
    "print(pred_err.mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pearson = []\n",
    "for i in range(len(test_df)//5):\n",
    "    wga_te_err = test_df.iloc[5*i:5*i+5][\"wga_te_err\"].to_numpy()\n",
    "    pred_rank = test_df.iloc[5*i:5*i+5][\"pred\"].to_numpy()\n",
    "    curr_pearson = np.corrcoef(wga_te_err, pred_rank)[0,1]\n",
    "    pearson.append(curr_pearson)\n",
    "pearson = np.array(pearson)\n",
    "pearson = pearson[~np.isnan(pearson)]\n",
    "np.mean(pearson)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sample test_df every 5 rows\n",
    "test_df_dataset = test_df.iloc[::5, :]\n",
    "\n",
    "def jac_sim(list1, list2):\n",
    "        set1, set2 = set(list1), set(list2)\n",
    "        intersection = len(set1.intersection(set2))\n",
    "        union = len(set1.union(set2))\n",
    "        return intersection / union\n",
    "\n",
    "\n",
    "correct = 0\n",
    "jac = []\n",
    "for row in test_df_dataset.iterrows():\n",
    "    # print(row[1]['multi_hot'], row[1]['multi_hot_pred'])\n",
    "    if row[1]['multi_hot'] == row[1]['multi_hot_pred']:\n",
    "        correct += 1\n",
    "    jac.append(jac_sim(row[1]['winners'].split(\"|\"), row[1]['winners_pred'].split(\"|\")))\n",
    "print(\"jaccard:\", np.mean(jac))\n",
    "print(\"0-1 accuracy:\", correct/len(test_df_dataset))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Multiple binary classifiers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "more powerful, more flexible, more intuitive, more intepretable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 138,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def load_data(path, seeds=[0]):\n",
    "\n",
    "    all_df_seeds = []\n",
    "    for s in seeds:\n",
    "        if s > 1: folder = 'toy_v3'\n",
    "        else: folder = 'toy_v2'\n",
    "        # replace the path\n",
    "        path = path.replace('toy_v2', folder)\n",
    "        path = path.replace('seed0', f'seed{s}')\n",
    "\n",
    "    all_df = pd.read_csv(path)\n",
    "    all_df['n'] = all_df['n'].astype('str')\n",
    "    all_df['rank'] = all_df.groupby(['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat'])['wga_te_err'].rank(\"first\")\n",
    "    all_df['wga_te_err_var'] = all_df.groupby(['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat'])['wga_te_err'].transform('var')\n",
    "\n",
    "    def get_gt_rank(x, mode='tie', filter_thre=0.05):\n",
    "        # x is a dataframe with columns ['method', 'wga_te_err'], find the method(s) with smallest wga_te_err, with filter_thre tolerance for tie\n",
    "        min_err = x['wga_te_err'].min()\n",
    "        winners = x[x['wga_te_err'] <= min_err + filter_thre][\"method\"].to_list()\n",
    "        return '|'.join(winners)\n",
    "\n",
    "    # Apply the function to each group and reset the index to merge back with the original dataframe\n",
    "    winners_series = all_df.groupby(['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat'])[['method', 'wga_te_err']].apply(lambda x: get_gt_rank(x)).reset_index(name='winners')\n",
    "\n",
    "    # Merge the results back with the original dataframe\n",
    "    all_df = all_df.merge(winners_series, on=['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat'])\n",
    "\n",
    "    methods = [\"ERM\", \"GroupDRO\", \"remax-margin\", \"oversample\", \"undersample\"]\n",
    "    all_df[\"multi_hot\"] = all_df[\"winners\"].map(lambda x: [1 if m in x.split(\"|\") else 0 for m in methods])\n",
    "\n",
    "    return all_df\n",
    "\n",
    "seeds = [0,2,3]\n",
    "path = f\"//exps/div_explore/toy_v2/results_seed0_sc.csv\"\n",
    "sc_df = load_data(path, seeds)\n",
    "\n",
    "path = f\"//exps/div_explore/toy_v2/results_seed0_ci.csv\"\n",
    "ci_df = load_data(path, seeds)\n",
    "\n",
    "path = f\"//exps/div_explore/toy_v2/results_seed0_ai.csv\"\n",
    "ai_df = load_data(path, seeds)\n",
    "\n",
    "path = \"//exps/div_explore/toy_v2/results_seed0_3shifts.csv\"\n",
    "mix_df = load_data(path)\n",
    "\n",
    "all_df = pd.concat([sc_df, ci_df, ai_df, mix_df], ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 167,
   "metadata": {},
   "outputs": [],
   "source": [
    "# convert to multiple binary classification\n",
    "def prepare_data(n, tr_size):\n",
    "    from sklearn.model_selection import train_test_split\n",
    "    tr_idx, te_idx = train_test_split(range(n), test_size=0.2, random_state=0)\n",
    "\n",
    "    if tr_size > 0:\n",
    "        num = len(tr_idx)\n",
    "        te_size = (num - tr_size)/num\n",
    "        tr_idx, val_idx = train_test_split(tr_idx, test_size=te_size, random_state=0)\n",
    "        print(\"train size:\", len(tr_idx))\n",
    "\n",
    "    return tr_idx, te_idx\n",
    "\n",
    "def train_binary_classifier(all_df, tr_idx, te_idx, method1, method2, filter_thre=0.05, mode=\"filter\"):\n",
    "    # mode can be tie\n",
    "    import copy\n",
    "    import warnings\n",
    "    from sklearn.preprocessing import StandardScaler\n",
    "    from sklearn.neural_network import MLPClassifier\n",
    "\n",
    "    warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "    all_df_m1 = all_df[all_df[\"method\"]==method1]\n",
    "    all_df_m2 = all_df[all_df[\"method\"]==method2]\n",
    "    all_df_m1[\"loss_diff\"] = all_df_m1['wga_te_err'].to_numpy() - all_df_m2['wga_te_err'].to_numpy()\n",
    "\n",
    "    all_df = copy.deepcopy(all_df_m1)\n",
    "\n",
    "    if mode == \"filter\":\n",
    "        all_df['over_rank'] = all_df['loss_diff'] < 0\n",
    "    elif mode ==\"tie\":\n",
    "        def lossdiff2rank(x):\n",
    "            if x < -filter_thre:\n",
    "                return 1\n",
    "            elif x > filter_thre:\n",
    "                return 0\n",
    "            else:\n",
    "                return 2\n",
    "        all_df['over_rank'] = all_df['loss_diff'].map(lambda x: lossdiff2rank(x))\n",
    "    all_df['over_rank'] = all_df['over_rank'].astype('int')\n",
    "\n",
    "    all_df_tr = all_df.iloc[tr_idx]\n",
    "    all_df_te = all_df.iloc[te_idx]\n",
    "\n",
    "    if mode == \"filter\":\n",
    "        all_df_tr = all_df_tr[(all_df_tr['loss_diff'] < -filter_thre)|(all_df_tr['loss_diff'] > filter_thre)]\n",
    "        all_df_te = all_df_te[(all_df_te['loss_diff'] < -filter_thre)|(all_df_te['loss_diff'] > filter_thre)]\n",
    "\n",
    "    all_df_tr = all_df_tr[['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat', 'over_rank']]\n",
    "    all_df_te = all_df_te[['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat', 'over_rank']]\n",
    "\n",
    "    X_train = all_df_tr.to_numpy()[:,:-1].astype('float')\n",
    "    y_train = all_df_tr.to_numpy()[:,-1].astype('int')\n",
    "    X_test = all_df_te.to_numpy()[:,:-1].astype('float')\n",
    "    y_test = all_df_te.to_numpy()[:,-1].astype('int')\n",
    "\n",
    "    scaler = StandardScaler()\n",
    "    X_train = scaler.fit_transform(X_train)\n",
    "    X_test = scaler.transform(X_test)\n",
    "\n",
    "    # clf = MLPClassifier(random_state=1, max_iter=1000000, verbose=False, tol=5e-3, n_iter_no_change=10000, alpha=0.0001, hidden_layer_sizes=(100,10,)).fit(X_train, y_train)\n",
    "    clf = MLPClassifier(random_state=1, max_iter=1000000, verbose=False, tol=1e-3, n_iter_no_change=2000, alpha=0.01, hidden_layer_sizes=(100,10,)).fit(X_train, y_train)\n",
    "    train_acc = clf.score(X_train, y_train)\n",
    "    test_acc = clf.score(X_test, y_test)\n",
    "    print(method1, method2, \"train size: \", len(all_df_tr), \"train acc: \", train_acc, \"test acc: \", test_acc)\n",
    "    # compute test acc by class\n",
    "    for i in range(3):\n",
    "        te_idx = y_test == i\n",
    "        tr_idx = y_train == i\n",
    "        # print y distribution\n",
    "        print(\"      class\", i, \"train size: \", len(y_train[tr_idx]))\n",
    "        if len(y_test[te_idx]) == 0:\n",
    "            continue\n",
    "        print(\"      class\", i, \"test acc: \", clf.score(X_test[te_idx], y_test[te_idx]))\n",
    "    # print()\n",
    "\n",
    "    return clf, X_test\n",
    "\n",
    "def extract_pairwise_results(zoo, X_test):\n",
    "    res = []\n",
    "    # if 1d then expand to 2d\n",
    "    if len(X_test.shape) == 1:\n",
    "        X_test = X_test.reshape(1, -1)\n",
    "    for (m1, m2), clf in zoo.items():\n",
    "        y_pred = clf.predict(X_test)\n",
    "        # print(y_pred)\n",
    "        if y_pred == 1:\n",
    "            res.append((m1, m2, m1))\n",
    "        elif y_pred == 0:\n",
    "            res.append((m1, m2, m2))\n",
    "        else:\n",
    "            res.append((m1, m2, m1))\n",
    "            res.append((m1, m2, m2))\n",
    "    # print()\n",
    "    # print(res)\n",
    "    return res\n",
    "\n",
    "def copeland_method(candidate_names, pairwise_results):\n",
    "    # Initialize scores dictionary\n",
    "    scores = {name: 0 for name in candidate_names}\n",
    "\n",
    "    # Update scores based on pairwise results\n",
    "    for result in pairwise_results:\n",
    "        A, B, winner = result\n",
    "        if winner == A:\n",
    "            scores[A] += 1\n",
    "            scores[B] -= 1\n",
    "        elif winner == B:\n",
    "            scores[B] += 1\n",
    "            scores[A] -= 1\n",
    "    print(scores)\n",
    "    # # Generate the ranking\n",
    "    # ranking = sorted(candidate_names, key=lambda x: scores[x], reverse=True)\n",
    "    # get the ones with the highest score, note that some can have the same score, and we want all them\n",
    "    winners = [name for name in candidate_names if scores[name] == max(scores.values())]\n",
    "\n",
    "    return winners, scores\n",
    "\n",
    "def bradley_terry_method(candidate_names, pairwise_results):\n",
    "    from scipy.optimize import minimize\n",
    "\n",
    "    n = len(candidate_names)\n",
    "    candidate_index = {name: i for i, name in enumerate(candidate_names)}\n",
    "\n",
    "    # Initialize ability scores\n",
    "    abilities = np.zeros(n)\n",
    "\n",
    "    def log_likelihood(abilities):\n",
    "        ll = 0\n",
    "        for A, B, winner in pairwise_results:\n",
    "            i, j = candidate_index[A], candidate_index[B]\n",
    "            pi = 1 / (1 + np.exp(abilities[j] - abilities[i]))\n",
    "            if winner == A:\n",
    "                ll += np.log(pi)\n",
    "            elif winner == B:\n",
    "                ll += np.log(1 - pi)\n",
    "        return -ll\n",
    "\n",
    "    result = minimize(log_likelihood, abilities, method='BFGS')\n",
    "    abilities = result.x\n",
    "\n",
    "    # Generate the ranking\n",
    "    ranking = sorted(candidate_names, key=lambda x: abilities[candidate_index[x]], reverse=True)\n",
    "\n",
    "    return ranking[0] # , abilities\n",
    "\n",
    "def binary_classifiers(all_df, tr_size=-1, mode=\"filter\"):\n",
    "    import itertools\n",
    "    methods = np.array([\"ERM\", \"GroupDRO\", \"oversample\", \"undersample\", \"remax-margin\"])\n",
    "    combinations = np.array(list(itertools.combinations(methods, 2)))\n",
    "\n",
    "    num_exps = all_df[all_df['method']==methods[0]].shape[0]\n",
    "    tr_idx, te_idx = prepare_data(num_exps, tr_size)\n",
    "\n",
    "    zoo = {}\n",
    "    for m1, m2 in combinations:\n",
    "        zoo[(m1, m2)], X_test = train_binary_classifier(all_df, tr_idx, te_idx, m1, m2, filter_thre=0.05, mode=mode)\n",
    "    return zoo, te_idx, X_test, methods\n",
    "\n",
    "def ranking_acc(methods, zoo, X_test, y_test, mode=\"copeland\", acc_mode=\"jac\"):\n",
    "\n",
    "    def jac_sim(list1, list2):\n",
    "        set1, set2 = set(list1), set(list2)\n",
    "        intersection = len(set1.intersection(set2))\n",
    "        union = len(set1.union(set2))\n",
    "        return intersection / union\n",
    "\n",
    "    if mode == \"copeland\":\n",
    "        eval_fn = copeland_method\n",
    "    elif mode == \"bradley_terry\":\n",
    "        eval_fn = bradley_terry_method\n",
    "    else:\n",
    "        raise ValueError(f\"unknown mode {mode}\")\n",
    "\n",
    "    y_pred, scores = [], []\n",
    "    for X in X_test:\n",
    "        pairwise_results = extract_pairwise_results(zoo, X)\n",
    "        winners, s = eval_fn(methods, pairwise_results)\n",
    "        y_pred.append(winners)\n",
    "        scores.append(s)\n",
    "\n",
    "    if acc_mode == \"jac\":\n",
    "        res = []\n",
    "        for i in range(len(y_pred)):\n",
    "            res.append(jac_sim(y_test[i], y_pred[i]))\n",
    "        acc = np.mean(res)\n",
    "        print(f\"{mode} method accuracy: \", acc)\n",
    "    elif acc_mode == \"accuracy\":\n",
    "        correct = 0\n",
    "        for i in range(len(y_pred)):\n",
    "            if set(y_test[i]) == set(y_pred[i]):\n",
    "                correct += 1\n",
    "        acc = correct / len(y_test)\n",
    "        print(f\"{mode} method accuracy: \", acc)\n",
    "    elif acc_mode == \"all\":\n",
    "        res = []\n",
    "        for i in range(len(y_pred)):\n",
    "            res.append(jac_sim(y_test[i], y_pred[i]))\n",
    "        acc = np.mean(res)\n",
    "        print(f\"{mode} method accuracy: \", acc)\n",
    "\n",
    "        correct = 0\n",
    "        for i in range(len(y_pred)):\n",
    "            if set(y_test[i]) == set(y_pred[i]):\n",
    "                correct += 1\n",
    "        acc = correct / len(y_test)\n",
    "        print(f\"{mode} method accuracy: \", acc)\n",
    "    elif acc_mode == \"soft 0-1\":\n",
    "        correct = 0\n",
    "        for i, curr_y in enumerate(y_test):\n",
    "            curr_y = np.array(curr_y)\n",
    "            curr_pred =np.array(y_pred[i])\n",
    "            # # check if the positions of 1s in curr_pred are also 1s in curr_y\n",
    "            # pos = np.nonzero(curr_pred)[0]\n",
    "            # print(\"1\", curr_y)\n",
    "            # print(\"2\", curr_pred)\n",
    "            # print(\"3\", pos)\n",
    "            # (array([0, 1, 2, 3, 4]),)\n",
    "            # ['GroupDRO']\n",
    "            # ['GroupDRO' 'oversample' 'undersample']\n",
    "            # (array([0, 1, 2]),)\n",
    "            if np.all(np.isin(curr_pred, curr_y)):\n",
    "                correct += 1\n",
    "        acc = correct / y_test.shape[0]\n",
    "        print(f\"{mode} method accuracy: {acc}\")\n",
    "    else:\n",
    "        raise ValueError(f\"unknown acc_mode {acc_mode}\")\n",
    "\n",
    "    return y_pred, scores\n",
    "\n",
    "    # # correct = 0\n",
    "    # # for i in range(len(y_pred)):\n",
    "    # #     if y_test[i] in y_pred[i]:\n",
    "    # #         correct += 1\n",
    "    # # acc = correct / len(y_test)\n",
    "    # # print(f\"{mode} method accuracy: \", acc)\n",
    "    # res = []\n",
    "    # for i in range(len(y_pred)):\n",
    "    #     res.append(jac_sim(y_test[i], y_pred[i]))\n",
    "    # acc = np.mean(res)\n",
    "    # print(f\"{mode} method accuracy: \", acc)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "zoo, te_idx, X_test, methods = binary_classifiers(all_df, tr_size=-1, mode=\"tie\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "zoo, te_idx, X_test, methods = binary_classifiers(all_df, mode=\"tie\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# y_test = all_df_rank[\"method\"].to_numpy()[te_idx]\n",
    "all_df_rank = all_df[all_df[\"rank\"]==1.0]\n",
    "winners = all_df_rank[\"winners\"].to_numpy()[te_idx]\n",
    "y_test = np.array([w.split(\"|\") for w in winners])\n",
    "\n",
    "# filter_idx = []\n",
    "# for i in range(len(y_test)):\n",
    "#     if len(y_test[i]) != 5:\n",
    "#         filter_idx.append(i)\n",
    "# y_test = y_test[filter_idx]\n",
    "# X_test = X_test[filter_idx]\n",
    "\n",
    "y_pred, scores = ranking_acc(methods, zoo, X_test, y_test, mode=\"copeland\", acc_mode=\"all\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred, scores = ranking_acc(methods, zoo, X_test, y_test, mode=\"copeland\", acc_mode=\"soft 0-1\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "methods = [\"ERM\", \"GroupDRO\", \"remax-margin\", \"oversample\", \"undersample\"]\n",
    "test_df = all_df_rank.iloc[te_idx]\n",
    "eval_fn = copeland_method\n",
    "y_pred, scores = [], []\n",
    "for X in X_test:\n",
    "    pairwise_results = extract_pairwise_results(zoo, X)\n",
    "    winners, s = eval_fn(methods, pairwise_results)\n",
    "    y_pred.append(winners)\n",
    "    scores.append(s)\n",
    "\n",
    "test_df['winners_pred'] = y_pred\n",
    "fine_test_df = all_df.merge(test_df, on=['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat'])\n",
    "\n",
    "def get_pred_wga_te_err(x):\n",
    "    pred_winners = x['winners_pred'].iloc[0]\n",
    "    if len(pred_winners) == 0:\n",
    "        # randomly select one\n",
    "        pred_winner = np.random.choice(methods)\n",
    "    else:\n",
    "        # random select from the predicted\n",
    "        pred_winner = np.random.choice(pred_winners)\n",
    "    return x[x['method_x']==pred_winner]['wga_te_err_x'].iloc[0]\n",
    "pred_err = fine_test_df.groupby(['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat'])['wga_te_err_x', 'method_x' ,'winners_pred'].apply(lambda x: get_pred_wga_te_err(x))\n",
    "pred_err.mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Random Baseline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def load_data(path, seeds=[0]):\n",
    "    all_df_seeds = []\n",
    "    for s in seeds:\n",
    "        if s > 1: folder = 'toy_v3'\n",
    "        else: folder = 'toy_v2'\n",
    "        # replace the path\n",
    "        path = path.replace('toy_v2', folder)\n",
    "        path = path.replace('seed0', f'seed{s}')\n",
    "\n",
    "        all_df = pd.read_csv(path)\n",
    "        all_df['seed'] = s\n",
    "        # all_df['n'] = all_df['n'].astype('str')\n",
    "        all_df_seeds.append(all_df)\n",
    "    all_df_seeds = pd.concat(all_df_seeds)\n",
    "    all_df = all_df_seeds.groupby(['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat', 'method']).mean().reset_index()\n",
    "    all_df['rank'] = all_df.groupby(['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat'])['wga_te_err'].rank(\"first\")\n",
    "    all_df['wga_te_err_var'] = all_df.groupby(['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat'])['wga_te_err'].transform('var')\n",
    "\n",
    "\n",
    "    def get_gt_rank(x, mode='tie', filter_thre=0.05):\n",
    "        # x is a dataframe with columns ['method', 'wga_te_err'], find the method(s) with smallest wga_te_err, with filter_thre tolerance for tie\n",
    "        min_err = x['wga_te_err'].min()\n",
    "        winners = x[x['wga_te_err'] <= min_err + filter_thre][\"method\"].to_list()\n",
    "        return '|'.join(winners)\n",
    "\n",
    "    # Apply the function to each group and reset the index to merge back with the original dataframe\n",
    "    winners_series = all_df.groupby(['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat'])[['method', 'wga_te_err']].apply(lambda x: get_gt_rank(x)).reset_index(name='winners')\n",
    "\n",
    "    # Merge the results back with the original dataframe\n",
    "    all_df = all_df.merge(winners_series, on=['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat'])\n",
    "\n",
    "    # methods = [\"ERM\", \"GroupDRO\", \"remax-margin\", \"oversample\", \"undersample\"]\n",
    "    methods = ['ERM', 'GroupDRO', 'oversample', 'remax-margin', 'undersample']\n",
    "    all_df[\"multi_hot\"] = all_df[\"winners\"].map(lambda x: [1 if m in x.split(\"|\") else 0 for m in methods])\n",
    "\n",
    "    return all_df\n",
    "\n",
    "seeds = [0,2,3]\n",
    "path = f\"//exps/div_explore/toy_v2/results_seed0_sc.csv\"\n",
    "sc_df = load_data(path, seeds)\n",
    "\n",
    "path = f\"//exps/div_explore/toy_v2/results_seed0_ci.csv\"\n",
    "ci_df = load_data(path, seeds)\n",
    "\n",
    "path = f\"//exps/div_explore/toy_v2/results_seed0_ai.csv\"\n",
    "ai_df = load_data(path, seeds)\n",
    "\n",
    "path = \"//exps/div_explore/toy_v2/results_seed0_3shifts.csv\"\n",
    "mix_df = load_data(path)\n",
    "\n",
    "all_df = pd.concat([sc_df, ci_df, ai_df, mix_df], ignore_index=True)\n",
    "# all_df = all_df.groupby(['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat', 'method']).first().reset_index()\n",
    "# all_df = pd.concat([sc_df, ci_df, ai_df], ignore_index=True)\n",
    "# all_df = sc_df\n",
    "all_df_rank = all_df[all_df[\"rank\"]==1.0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_df_rank"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def jac_sim(list1, list2):\n",
    "    set1, set2 = set(list1), set(list2)\n",
    "    intersection = len(set1.intersection(set2))\n",
    "    union = len(set1.union(set2))\n",
    "    return intersection / union\n",
    "\n",
    "def baseline_pred(mode):\n",
    "    methods = np.array([\"ERM\", \"GroupDRO\", \"remax-margin\", \"oversample\", \"undersample\"])\n",
    "    if mode == \"real_random\":\n",
    "        num_winners = np.random.choice(5, 1, replace=False)[0] + 1\n",
    "        y_pred = np.random.choice(5, num_winners, replace=False)\n",
    "    elif mode == \"enhanced_random\":\n",
    "        # num_winners = sum(y)\n",
    "        y_pred = np.random.choice(5, 5, replace=False)\n",
    "    elif mode == \"global_rank_baseline\":\n",
    "        num_winners = np.random.choice(5, 1, replace=False)[0] + 1\n",
    "        y_pred = np.argsort(global_rank)[::-1][:num_winners]\n",
    "    elif mode == \"global_rank_baseline_random\":\n",
    "        num_winners = np.random.choice(5, 1, replace=False)[0] + 1\n",
    "        # print(num_winners)\n",
    "        y_pred = np.argsort(global_rank)[::-1][:num_winners]\n",
    "    else:\n",
    "        raise ValueError(f\"unknown mode {mode}\")\n",
    "    y_pred_str = methods[y_pred].tolist()\n",
    "    return y_pred, y_pred_str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seeds = [2023, 2024, 2025]\n",
    "\n",
    "mode = \"global_rank_baseline\"\n",
    "acc01, err, soft_acc01 = [], [], []\n",
    "for s in seeds:\n",
    "    from sklearn.preprocessing import StandardScaler\n",
    "    from sklearn.model_selection import train_test_split\n",
    "    train_df, test_df, tr_idx, te_idx = train_test_split(all_df_rank, range(len(all_df_rank)), test_size=0.2, random_state=0)\n",
    "    test_df = all_df_rank.iloc[te_idx]\n",
    "\n",
    "    gt = train_df[\"multi_hot\"].to_list()\n",
    "    global_rank = np.array(gt).sum(axis=0)\n",
    "    methods = [\"ERM\", \"GroupDRO\", \"remax-margin\", \"oversample\", \"undersample\"]\n",
    "    global_rank\n",
    "\n",
    "    # num = 0\n",
    "    jac, correct, soft_correct, pearson = [], [], [], []\n",
    "\n",
    "    y_preds = []\n",
    "    for y in test_df[\"multi_hot\"]:\n",
    "        # print(y)\n",
    "        # if sum(y) == 5:\n",
    "        #     continue\n",
    "        # else:\n",
    "        #     num+=1\n",
    "\n",
    "        y_pred, y_pred_str = baseline_pred(mode)\n",
    "        y = np.where(np.array(y)==1)[0]\n",
    "        # print(y_pred)\n",
    "        jac.append(jac_sim(y, y_pred))\n",
    "\n",
    "        if set(y) == set(y_pred):\n",
    "            correct.append(1)\n",
    "        else:\n",
    "            correct.append(0)\n",
    "\n",
    "        # check if the positions of 1s in curr_pred are also 1s in curr_y\n",
    "        curr_y = np.zeros(5)\n",
    "        curr_y[y] = 1\n",
    "        curr_pred = np.zeros(5)\n",
    "        curr_pred[y_pred] = 1\n",
    "        pos = np.nonzero(curr_pred)\n",
    "        if np.all(curr_y[pos] == 1):\n",
    "            soft_correct.append(1)\n",
    "        else:\n",
    "            soft_correct.append(0)\n",
    "\n",
    "        y_preds.append(curr_pred)\n",
    "    test_df[\"pred\"] = y_preds\n",
    "    # print(num)\n",
    "    # acc = np.mean(jac)\n",
    "    # print(f\"random baseline jac accuracy: \", acc)\n",
    "    acc = np.mean(correct)\n",
    "    # print(f\"random baseline 0-1 accuracy: \", acc)\n",
    "    soft_acc = np.mean(soft_correct)\n",
    "    acc01.append(acc)\n",
    "    soft_acc01.append(soft_acc)\n",
    "\n",
    "    fine_test_df = all_df.merge(test_df, on=['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat'])\n",
    "    def get_pred_wga_te_err(x, mode):\n",
    "        pred_winners = baseline_pred(mode)[1]\n",
    "        pred_winner = np.random.choice(pred_winners)\n",
    "        return x[x['method_x']==pred_winner]['wga_te_err_x'].iloc[0]\n",
    "        # err = 0.0\n",
    "        # for w in pred_winners:\n",
    "        #     err += x[x['method_x']==w]['wga_te_err_x'].iloc[0]\n",
    "        # err /= len(pred_winners)\n",
    "        # return err\n",
    "    pred_err = fine_test_df.groupby(['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat'])['wga_te_err_x', 'method_x'].apply(lambda x: get_pred_wga_te_err(x, mode))\n",
    "    # print(pred_err.mean())\n",
    "    # print(test_df['wga_te_err'].mean())\n",
    "    err.append(pred_err.mean())\n",
    "\n",
    "            # pearson = []\n",
    "        # for i in range(len(test_df)//5):\n",
    "        #     wga_te_err = test_df.iloc[5*i:5*i+5][\"wga_te_err\"].to_numpy()\n",
    "        #     pred_rank = test_df.iloc[5*i:5*i+5][\"pred\"].to_numpy()\n",
    "        #     curr_pearson = np.corrcoef(wga_te_err, pred_rank)[0,1]\n",
    "        #     pearson.append(curr_pearson)\n",
    "        # pearson = np.array(pearson)\n",
    "        # pearson = pearson[~np.isnan(pearson)]\n",
    "        # np.mean(pearson)\n",
    "    def get_pearson_corr(x):\n",
    "        wga_te_err = x['wga_te_err_x'].iloc[:5]\n",
    "        # pred_rank = x['pred'].iloc[0]\n",
    "        pred_rank = [4, 0, 3, 1, 2]\n",
    "        # print(wga_te_err)\n",
    "        # print(pred_rank)\n",
    "        curr_pearson = np.corrcoef(wga_te_err, pred_rank)[0,1]\n",
    "        return curr_pearson\n",
    "    pearson_corr = fine_test_df.groupby(['n', 'sc', 'ci', 'ai', 'var_causal', 'd_feat'])['wga_te_err_x', 'pred'].apply(lambda x: get_pearson_corr(x))\n",
    "    pearson.append(pearson_corr.mean())\n",
    "\n",
    "print(np.mean(acc01), np.std(acc01), np.mean(err), np.std(err), np.mean(soft_acc01), np.std(soft_acc01), np.mean(pearson), np.std(pearson))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "div_backup",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
