{
  "Selected_candidate": {
    "pr_number": 2368,
    "pr_title": "Fix pairgrid off-diagonal plots with non-string column names",
    "pr_body": "Fixes #2307\r\n\r\nWorking reprex from original issue:\r\n\r\n![image](https://user-images.githubusercontent.com/315810/100743239-108c0200-33aa-11eb-8885-9c61d89b3acc.png)\r\n",
    "issue_id": 2307,
    "issue_title": "map_* methods of PairGrid broken in 0.11.0",
    "issue_body": "Hi,\r\n\r\nI discovered that the map_* methods of PairGrid seem to be broken in version 0.11.0 for user defined functions. See reproducible example below with a corrfunc defined to plot the pearson correlation value on the lower plots. The function doesn't seem to get evaluated in version 0.11.0. When I pip install seaborn==0.10.1, I get the desired result. Plots from both cases also attached.\r\n\r\n```import numpy as np\r\nfrom scipy import stats\r\nimport pandas as pd\r\nimport seaborn as sns\r\nimport matplotlib.pyplot as plt\r\nsns.set(style=\"white\")\r\n\r\nmean = np.zeros(3)\r\ncov = np.random.uniform(.2, .4, (3, 3))\r\ncov += cov.T\r\ncov[np.diag_indices(3)] = 1\r\ndata = np.random.multivariate_normal(mean, cov, 100)\r\ndf = pd.DataFrame(data, columns=[\"X\", \"Y\", \"Z\"])\r\n\r\ndef corrfunc(x, y,**kws):\r\n    r, _ = stats.pearsonr(x, y)\r\n    ax = plt.gca()\r\n    ax.annotate(\"r = {:.2f}\".format(r),\r\n                xy=(.1, .9), xycoords=ax.transAxes)\r\n\r\ng = sns.PairGrid(df, palette=[\"red\"])\r\ng.map_upper(plt.scatter, s=10)\r\ng.map_diag(sns.distplot, kde=False)\r\ng.map_lower(sns.kdeplot, cmap=\"Blues_d\")\r\ng.map_lower(corrfunc)\r\nplt.show()\r\n\r\n```\r\n\r\n![seaborn-0-11-0](https://user-images.githubusercontent.com/3239171/94969718-380d3e00-04d1-11eb-821b-9aad80ec696e.png)\r\n\r\n![seaborn-0-10-1](https://user-images.githubusercontent.com/3239171/94969722-3a6f9800-04d1-11eb-8ef1-861d9beb1f26.png)\r\n\r\n\r\n\r\n",
    "issue_closed_at": "2020-12-01T20:48:45Z",
    "base_commit": "2717408b564994002fe08f72ba2dd7e1acf359b6",
    "changes": [
      {
        "file": "seaborn/axisgrid.py",
        "type": "function",
        "name": "__init__",
        "class_name": "JointGrid",
        "code": "def __init__(\n        self, *,\n        x=None, y=None,\n        data=None,\n        height=6, ratio=5, space=.2,\n        dropna=False, xlim=None, ylim=None, size=None, marginal_ticks=False,\n        hue=None, palette=None, hue_order=None, hue_norm=None,\n    ):\n        # Handle deprecations\n        if size is not None:\n            height = size\n            msg = (\"The `size` parameter has been renamed to `height`; \"\n                   \"please update your code.\")\n            warnings.warn(msg, UserWarning)\n\n        # Set up the subplot grid\n        f = plt.figure(figsize=(height, height))\n        gs = plt.GridSpec(ratio + 1, ratio + 1)\n\n        ax_joint = f.add_subplot(gs[1:, :-1])\n        ax_marg_x = f.add_subplot(gs[0, :-1], sharex=ax_joint)\n        ax_marg_y = f.add_subplot(gs[1:, -1], sharey=ax_joint)\n\n        self.fig = f\n        self.ax_joint = ax_joint\n        self.ax_marg_x = ax_marg_x\n        self.ax_marg_y = ax_marg_y\n\n        # Turn off tick visibility for the measure axis on the marginal plots\n        plt.setp(ax_marg_x.get_xticklabels(), visible=False)\n        plt.setp(ax_marg_y.get_yticklabels(), visible=False)\n        plt.setp(ax_marg_x.get_xticklabels(minor=True), visible=False)\n        plt.setp(ax_marg_y.get_yticklabels(minor=True), visible=False)\n\n        # Turn off the ticks on the density axis for the marginal plots\n        if not marginal_ticks:\n            plt.setp(ax_marg_x.yaxis.get_majorticklines(), visible=False)\n            plt.setp(ax_marg_x.yaxis.get_minorticklines(), visible=False)\n            plt.setp(ax_marg_y.xaxis.get_majorticklines(), visible=False)\n            plt.setp(ax_marg_y.xaxis.get_minorticklines(), visible=False)\n            plt.setp(ax_marg_x.get_yticklabels(), visible=False)\n            plt.setp(ax_marg_y.get_xticklabels(), visible=False)\n            plt.setp(ax_marg_x.get_yticklabels(minor=True), visible=False)\n            plt.setp(ax_marg_y.get_xticklabels(minor=True), visible=False)\n            ax_marg_x.yaxis.grid(False)\n            ax_marg_y.xaxis.grid(False)\n\n        # Process the input variables\n        p = VectorPlotter(data=data, variables=dict(x=x, y=y, hue=hue))\n        plot_data = p.plot_data.loc[:, p.plot_data.notna().any()]\n\n        # Possibly drop NA\n        if dropna:\n            plot_data = plot_data.dropna()\n\n        def get_var(var):\n            vector = plot_data.get(var, None)\n            if vector is not None:\n                vector = vector.rename(p.variables.get(var, None))\n            return vector\n\n        self.x = get_var(\"x\")\n        self.y = get_var(\"y\")\n        self.hue = get_var(\"hue\")\n\n        for axis in \"xy\":\n            name = p.variables.get(axis, None)\n            if name is not None:\n                getattr(ax_joint, f\"set_{axis}label\")(name)\n\n        if xlim is not None:\n            ax_joint.set_xlim(xlim)\n        if ylim is not None:\n            ax_joint.set_ylim(ylim)\n\n        # Store the semantic mapping parameters for axes-level functions\n        self._hue_params = dict(palette=palette, hue_order=hue_order, hue_norm=hue_norm)\n\n        # Make the grid look nice\n        utils.despine(f)\n        if not marginal_ticks:\n            utils.despine(ax=ax_marg_x, left=True)\n            utils.despine(ax=ax_marg_y, bottom=True)\n        for axes in [ax_marg_x, ax_marg_y]:\n            for axis in [axes.xaxis, axes.yaxis]:\n                axis.label.set_visible(False)\n        f.tight_layout()\n        f.subplots_adjust(hspace=space, wspace=space)"
      },
      {
        "file": "seaborn/axisgrid.py",
        "type": "function",
        "name": "_map_diag_iter_hue",
        "class_name": "PairGrid",
        "code": "def _map_diag_iter_hue(self, func, **kwargs):\n        \"\"\"Put marginal plot on each diagonal axes, iterating over hue.\"\"\"\n        # Plot on each of the diagonal axes\n        fixed_color = kwargs.pop(\"color\", None)\n\n        for var, ax in zip(self.diag_vars, self.diag_axes):\n            hue_grouped = self.data[var].groupby(self.hue_vals)\n\n            plt.sca(ax)\n\n            for k, label_k in enumerate(self.hue_names):\n\n                # Attempt to get data for this level, allowing for empty\n                try:\n                    data_k = hue_grouped.get_group(label_k)\n                except KeyError:\n                    data_k = pd.Series([], dtype=float)\n\n                if fixed_color is None:\n                    color = self.palette[k]\n                else:\n                    color = fixed_color\n\n                if self._dropna:\n                    data_k = utils.remove_na(data_k)\n\n                if str(func.__module__).startswith(\"seaborn\"):\n                    func(x=data_k, label=label_k, color=color, **kwargs)\n                else:\n                    func(data_k, label=label_k, color=color, **kwargs)\n\n            self._clean_axis(ax)\n\n        self._add_axis_labels()\n\n        return self"
      },
      {
        "file": "seaborn/axisgrid.py",
        "type": "function",
        "name": "_plot_bivariate_iter_hue",
        "class_name": "PairGrid",
        "code": "def _plot_bivariate_iter_hue(self, x_var, y_var, ax, func, **kwargs):\n        \"\"\"Draw a bivariate plot while iterating over hue subsets.\"\"\"\n        plt.sca(ax)\n        if x_var == y_var:\n            axes_vars = [x_var]\n        else:\n            axes_vars = [x_var, y_var]\n\n        hue_grouped = self.data.groupby(self.hue_vals)\n        for k, label_k in enumerate(self.hue_names):\n\n            kws = kwargs.copy()\n\n            # Attempt to get data for this level, allowing for empty\n            try:\n                data_k = hue_grouped.get_group(label_k)\n            except KeyError:\n                data_k = pd.DataFrame(columns=axes_vars,\n                                      dtype=float)\n\n            if self._dropna:\n                data_k = data_k[axes_vars].dropna()\n\n            x = data_k[x_var]\n            y = data_k[y_var]\n\n            for kw, val_list in self.hue_kws.items():\n                kws[kw] = val_list[k]\n            kws.setdefault(\"color\", self.palette[k])\n            if self._hue_var is not None:\n                kws[\"label\"] = label_k\n\n            if str(func.__module__).startswith(\"seaborn\"):\n                func(x=x, y=y, **kws)\n            else:\n                func(x, y, **kws)\n\n        self._update_legend_data(ax)\n        self._clean_axis(ax)"
      }
    ]
  },
  "Justification": "Candidate C is the most relevant to the CURRENT bug because it directly involves the `pairplot` function and shares the same file (`seaborn/axisgrid.py`). The bug report identifies an issue with handling data structures in the `PairGrid`, which is a close structural match to how `pairplot` operates. The errors in handling column names (e.g., non-string or MultiIndex) are likely to provide insights into the KeyError experienced in the CURRENT bug, particularly as both include issues related to DataFrame handling. This strong structural similarity, module relevance, and shared context of fixing plotting methods bolster its helpfulness for debugging the CURRENT issue.",
  "instance_id": "mwaskom__seaborn-3407",
  "repo": "mwaskom/seaborn",
  "created_at": "2023-06-27T23:17:29Z",
  "problem_statement": "pairplot raises KeyError with MultiIndex DataFrame\nWhen trying to pairplot a MultiIndex DataFrame, `pairplot` raises a `KeyError`:\r\n\r\nMRE:\r\n\r\n```python\r\nimport numpy as np\r\nimport pandas as pd\r\nimport seaborn as sns\r\n\r\n\r\ndata = {\r\n    (\"A\", \"1\"): np.random.rand(100),\r\n    (\"A\", \"2\"): np.random.rand(100),\r\n    (\"B\", \"1\"): np.random.rand(100),\r\n    (\"B\", \"2\"): np.random.rand(100),\r\n}\r\ndf = pd.DataFrame(data)\r\nsns.pairplot(df)\r\n```\r\n\r\nOutput:\r\n\r\n```\r\n[c:\\Users\\KLuu\\anaconda3\\lib\\site-packages\\seaborn\\axisgrid.py](file:///C:/Users/KLuu/anaconda3/lib/site-packages/seaborn/axisgrid.py) in pairplot(data, hue, hue_order, palette, vars, x_vars, y_vars, kind, diag_kind, markers, height, aspect, corner, dropna, plot_kws, diag_kws, grid_kws, size)\r\n   2142     diag_kws.setdefault(\"legend\", False)\r\n   2143     if diag_kind == \"hist\":\r\n-> 2144         grid.map_diag(histplot, **diag_kws)\r\n   2145     elif diag_kind == \"kde\":\r\n   2146         diag_kws.setdefault(\"fill\", True)\r\n\r\n[c:\\Users\\KLuu\\anaconda3\\lib\\site-packages\\seaborn\\axisgrid.py](file:///C:/Users/KLuu/anaconda3/lib/site-packages/seaborn/axisgrid.py) in map_diag(self, func, **kwargs)\r\n   1488                 plt.sca(ax)\r\n   1489 \r\n-> 1490             vector = self.data[var]\r\n   1491             if self._hue_var is not None:\r\n   1492                 hue = self.data[self._hue_var]\r\n\r\n[c:\\Users\\KLuu\\anaconda3\\lib\\site-packages\\pandas\\core\\frame.py](file:///C:/Users/KLuu/anaconda3/lib/site-packages/pandas/core/frame.py) in __getitem__(self, key)\r\n   3765             if is_iterator(key):\r\n   3766                 key = list(key)\r\n-> 3767             indexer = self.columns._get_indexer_strict(key, \"columns\")[1]\r\n   3768 \r\n   3769         # take() does not accept boolean indexers\r\n\r\n[c:\\Users\\KLuu\\anaconda3\\lib\\site-packages\\pandas\\core\\indexes\\multi.py](file:///C:/Users/KLuu/anaconda3/lib/site-packages/pandas/core/indexes/multi.py) in _get_indexer_strict(self, key, axis_name)\r\n   2534             indexer = self._get_indexer_level_0(keyarr)\r\n   2535 \r\n-> 2536             self._raise_if_missing(key, indexer, axis_name)\r\n   2537             return self[indexer], indexer\r\n   2538 \r\n\r\n[c:\\Users\\KLuu\\anaconda3\\lib\\site-packages\\pandas\\core\\indexes\\multi.py](file:///C:/Users/KLuu/anaconda3/lib/site-packages/pandas/core/indexes/multi.py) in _raise_if_missing(self, key, indexer, axis_name)\r\n   2552                 cmask = check == -1\r\n   2553                 if cmask.any():\r\n-> 2554                     raise KeyError(f\"{keyarr[cmask]} not in index\")\r\n   2555                 # We get here when levels still contain values which are not\r\n   2556                 # actually in Index anymore\r\n\r\nKeyError: \"['1'] not in index\"\r\n```\r\n\r\nA workaround is to \"flatten\" the columns:\r\n\r\n```python\r\ndf.columns = [\"\".join(column) for column in df.columns]\r\n```\n",
  "patch": "diff --git a/seaborn/axisgrid.py b/seaborn/axisgrid.py\n--- a/seaborn/axisgrid.py\n+++ b/seaborn/axisgrid.py\n@@ -1472,8 +1472,8 @@ def map_diag(self, func, **kwargs):\n                 for ax in diag_axes[1:]:\n                     share_axis(diag_axes[0], ax, \"y\")\n \n-            self.diag_vars = np.array(diag_vars, np.object_)\n-            self.diag_axes = np.array(diag_axes, np.object_)\n+            self.diag_vars = diag_vars\n+            self.diag_axes = diag_axes\n \n         if \"hue\" not in signature(func).parameters:\n             return self._map_diag_iter_hue(func, **kwargs)\n"
}