\documentclass[accepted]{uai2023} 
%
\usepackage[american]{babel}
\usepackage{enumitem}
\usepackage{tikz}
\usepackage[most]{tcolorbox}
\usepackage{bbm}
\usepackage{latexsym, epsfig, amssymb, amsmath, amsthm, graphicx, mathrsfs,mathtools}
\usepackage{enumerate}
\usepackage{romannum}% for approach #1 and #2
\usepackage{multirow}
\usepackage{multicol}
\usepackage{enumitem}
\usepackage{epstopdf}
\usepackage{caption}
\usepackage{braket}
\usepackage{subcaption}
\usepackage{adjustbox}
\usepackage{multirow}
\usepackage{multicol}
\usepackage{algorithm}
\usepackage{algpseudocode}
\usepackage{enumitem}
\usepackage{epstopdf}
\usepackage{tikz}
\usepackage[toc,page,header]{appendix}
\usepackage{minitoc}
\usepackage{comment}
\usepackage{xspace}
\usepackage{booktabs}
\usepackage{tablefootnote}
\usepackage{tabularray}
\usepackage{setspace}
\usepackage{empheq}
\usepackage{tabularray}
\usepackage{makecell}
\usepackage{multirow}
\usepackage{siunitx}
\usepackage{hyperref}
\hypersetup{
    colorlinks,
    citecolor=black,
    filecolor=black,
    linkcolor=black,
    urlcolor=black
}
%
\usepackage[toc,page,header]{appendix}
\usepackage{minitoc}
%
%%%%%%%%%%% theorem %%%%%%%%%%%%%%%%%%%%%%%%%%
\newtheorem{thm}{Theorem}
\newtheorem{prop}[thm]{Proposition}
\newtheorem{defn}[thm]{Definition}
\newtheorem{cor}[thm]{Corollary}
\newtheorem{exm}{Example}
\newtheorem{pre}[thm]{Proceduer}
\newtheorem{rem}[thm]{Remark}
\newtheorem{que}[thm]{Question}
\newtheorem{lem}[thm]{Lemma}
\newtheorem{assumption}{Assumption}
\renewcommand*{\theassumption}{\Alph{assumption}}
%%%%%%%%%%%%%%%%%%%%%%davoud's notations%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
\newcommand{\mc}{\mathcal}
\newcommand{\m}[1]{{\bf{#1}}}
\newcommand{\wh}[1]{{\widehat{#1}}}
\newcommand{\ti}[1]{{\tilde{#1}}}
\renewcommand{\mc}[1]{\ensuremath{\mathcal{#1}}} 	
\newcommand{\g}[1]{\mbox{\boldmath $#1$}}
\newcommand{\mb}[1]{{\mathbb{#1}}}
%%%%%%%%%%%%%%%%%%%%%%%%%%%%
\renewcommand{\qedsymbol}{$\blacksquare$}
\newcommand{\norm}[1]{\left\lVert#1\right\rVert}
\DeclareMathOperator*{\argmax}{arg\,max}
\DeclareMathOperator*{\argmin}{arg\,min}
\renewcommand{\theenumi}{\Alph{enumi}}
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% In your camera-ready you should use the 'accepted' parameter. This shows the authors and how an accepted paper will look like. The footer is 'Acccepted for X'. In the final version, the proceedings chairs will add the page numbers for PMLR and the final footer will be 'Proceedings of X'.
%
%% There is a class option to choose the math font
% \documentclass[mathfont=ptmx]{uai2022} % ptmx math instead of Computer
                                         % Modern (has noticable issues)
% \documentclass[mathfont=newtx]{uai2022} % newtx fonts (improves upon
                                          % ptmx; less tested, no support)
% NOTE: Only keep *one* line above as appropriate, as it will be replaced
%       automatically for papers to be published. Do not make any other
%       change above this note for an accepted version.

%% Choose your variant of English; be consistent
% \usepackage[american]{babel}
% % \usepackage[british]{babel}

% %% Some suggested packages, as needed:
\usepackage{natbib} % has a nice set of citation styles and commands
\bibliographystyle{plainnat}
\renewcommand{\bibsection}{\subsubsection*{References}}
%% Self-defined macros
\newcommand{\swap}[3][-]{#3#1#2} % just an example

\title{Fairness-Aware Class Imbalanced Learning on Multiple Subgroups
}
%
% Important:  case of equal contributions, we strongly recommend to NOT show it in this part of the paper, but rather describe it in the appropriate section at the end of the paper "Author Contribution", where you have more space to describe how each author contributed.
%
% Add authors
% Remember to use the order convention "First/Given name" "Last/Family name", e.g. John Smith, Hanako Yamada, Marco Rossi, Wei Zhang
\author[]{\href{mailto:<tarzanaq@upenn.edu>}{Davoud Ataee Tarzanagh}{}}
\author[]{\href{mailto:<bojian.hou@pennmedicine.upenn.edu>}{Bojian Hou}{}}
\author[]{\href{mailto:<boningt@seas.upenn.edu>}{Boning Tong}{}}
\author[]{\href{mailto:<qlong@pennmedicine.upenn.edu>}{Qi Long}{}}
\author[]{\href{mailto:<li.shen@pennmedicine.upenn.edu>}{Li Shen}{}}
% Add affiliations after the authors
\affil[]{ University of Pennsylvania}
\begin{document}
%
\maketitle 
\begin{abstract}
We present a novel Bayesian-based optimization framework that addresses the challenge of generalization in overparameterized models when dealing with imbalanced subgroups and limited samples per subgroup. Our proposed tri-level optimization framework utilizes \textit{local} predictors, which are trained on a small amount of data, as well as a fair and class-balanced predictor at the middle and lower levels. To effectively overcome saddle points for minority classes, our lower-level formulation incorporates sharpness-aware minimization. Meanwhile, at the upper level, the framework dynamically adjusts the loss function based on validation loss, ensuring a close alignment between the \textit{global} predictor and local predictors. Theoretical analysis demonstrates the framework's ability to enhance classification and fairness generalization, potentially resulting in improvements in the generalization bound. Empirical results validate the superior performance of our tri-level framework compared to existing state-of-the-art approaches. The source code can be found at \url{https://github.com/PennShenLab/FACIMS}.
\end{abstract}

\section{Introduction}\label{sec:intro}
Machine learning has achieved exceptional performance through overparameterization and advanced techniques. This progress is supported by high-quality datasets with sufficient samples for each data class and subgroup. However, real-world datasets frequently exhibit imbalances of different types and magnitudes, reflecting the significance and diversity of the underlying domains \citep{barocas2017fairness}. Two common imbalances are observed in label-imbalanced and group-sensitive classification scenarios.

Label-imbalanced classification (LIC) suffers from a significant discrepancy in the number of examples across classes, requiring the use of balanced accuracy as a more suitable metric than conventional misclassification error. To improve model performance and balanced accuracy, various methods have been developed, including \citep{buda2018systematic}  and loss re-weighting \citep{he2009learning}. Weighted cross-entropy (wCE) loss, a classical approach, amplifies the contribution of minority examples in proportion to the imbalance level. However, wCE may not effectively handle the imbalance in overparameterized models \citep{cao2019learning}, which can result in poor generalization. Recent studies propose alternative loss functions, such as logit-adjusted loss \citep{menon2020long,cao2019learning}, class-dependent temperature loss \citep{ye2020identifying}, and vector-scaling loss \citep{kini2021label}, aiming to address the challenges associated with overparameterization. Nonetheless, there is still a risk of overfitting on minority class samples despite these advancements \citep{rangwani2022escaping}.

In group-sensitive classification (GSC), the goal is to ensure fairness concerning protected attributes like gender or race, addressing the issue of \textit{stereotyping} where certain target labels are more frequently associated with specific groups \citep{mehrabi2021survey}. For instance, the occupation of "nurse" being commonly associated with females. While there is no universal fairness metric \citep{kleinberg2016inherent}, one suggestion is \textit{group sufficiency}, which aims to maintain identical conditional expectations of the ground-truth label $(\mathbb{E}[Y | f(X), A])$ across different subgroups $A\in\mathcal{A}$ given the predictor's output $f(X)$. However, in overparameterized models with \textit{limited} samples per subgroup, this control of group sufficiency may not always hold, despite its effectiveness under certain assumptions in unconstrained learning \citep{liu2019implicit, shui2022learning}.



\begin{figure*}
\centering
\includegraphics[width=0.9\textwidth]{uai_plot.d1.png}
\caption{An illustration of the FACIMS model defined in \eqref{eqn:sharp:FACIMS}. $f^a$ and $f^b$ maximize the margin for minority classes for groups $a$ and $b$ in \eqref{eqn:in:sharp:FACIMS}. In the upper level problem~\eqref{eqn:out:sharp:FACIMS}, FACIMS finds $\mathbf{Z}\in\mathcal{Z}$ to achieve a small balanced accuracy while minimizing the discrepancy between $(\mathbf{Z}^{a,\star},\mathbf{Z}^{b,\star})$. The approximation term $\text{KL}(\mathbf{Z}^{a,\star}| \mathbf{Z})$ is based on the distribution family $\mathcal{Z}$ (orange region). If the predefined $\mathcal{Z}$ has good expressive power, the approximation is treated as a small constant.
}
\label{fig:one}
\end{figure*}
Given the aforementioned challenges regarding the performance of LIC and GSC in overparameterized models, we pose the following question:
\vspace{-10pt}
\begin{quote}
\hspace{-20pt}
\textbf{Q:} How can a classifier be designed to effectively generalize on imbalanced subgroups with limited samples?
%\vspace{-2pt}
\end{quote}
To address  \textbf{Q}, we establish a link between LIC and GSC and propose a novel Bayesian framework that maintains informative predictions for imbalanced data while minimizing generalization error. Our contributions can be summarized as follows.
\\
$\bullet$ We design a Bayesian-based \textit{tri-level} optimization framework called Fairness-Aware Class Imbalanced Learning on Multiple Subgroups (FACIMS). In FACIMS, \textit{local} predictors are learned using a small amount of training data and a fair, class-balanced predictor. The lower-level formulation utilizes the sharpness-aware minimization \citep{foret2020sharpness} to encourage convergence to a flat minimum and effectively avoid saddle points for minority classes.  The upper-level problem dynamically adjusts the loss function by monitoring the validation loss, following a similar approach to \citep{li2021autobalance}, and updates the \textit{global} predictor to align with all subgroup-specific predictors.
\\
$\bullet$ We establish the ${\cal O}(1/\sqrt{T})$ convergence rate of our proposed three-level optimization framework, corresponding to a ${\cal O}(\epsilon^{-2})$ sample complexity with a fixed number of samples used per iteration.
\\
$\bullet$ We quantify the generalization performance of the models trained using our proposed tri-level FACIMS approach. The generalization bound analysis demonstrates that our method can achieve superior generalization performance compared to bilevel variants, such as \citep{rangwani2022escaping}, for fair learning on multiple subgroups.
\\
$\bullet$ We conduct experiments on synthetic and real-world datasets to evaluate the balanced accuracy, demographic parity, equalized odds, and group sufficiency. The results showcase the effectiveness of our proposed method.

\section{Preliminaries}\label{sec:preliminary} 
We consider a joint random variable $(X, Y, A)$ that follows an underlying distribution $\mathcal{P}(X, Y, A)$, where $X \in \mathcal{X}$ represents the input,  $Y \in \mathcal{Y} = \{1, \ldots, K\}$ represents the label,  $A \in \mathcal{A}$ is a scalar discrete random variable that denotes the sensitive attribute or subgroup index. For instance, $A$ could represent gender or race. Throughout, $\mathbb{E}[Y|X]$ denotes the conditional expectation of $Y$, which can be seen as a function of $X$. $\mathbb{E}_{A, X}[\cdot]$ represents the expectation over the marginal distribution of $\mathcal{P}(A, X)$.

Suppose we have a dataset $\mc{S}=(\m{x}_i,y_i)_{i=1}^n$ sampled i.i.d.~from a distribution $\mc{P}$ with input space $\mc{X}$ and $K$ classes. Let  $f:\mc{X}\rightarrow\mb{R}^K$ be a model that outputs a distribution over classes and let $h_f(\m{x})=\arg\max_{i\in[K]}f(\m{x})$. The standard classification error is denoted by $\text{ACC}=\mb{P}_{\mc{S}}[y\neq \hat{y}_f(\m{x})]$. For a loss function $\ell(y,\hat{y})$, we similarly denote %\vsp\vsp
\begin{subequations}
\begin{eqnarray}
&\text{Population risk:}~\mc{L}(f; \mc{P}):=\mb{E}_{\mc{P}}[\ell(y,\hat{y}_f(\m{x}))],~~~~~~\\
&\text{Empirical risk:}~\mc{L}(f;\mc{S}):=\frac{1}{n}\sum_{i=1}^n\ell(y_i,\hat{y}_f(\m{x}_i)).  
\end{eqnarray}
\end{subequations}
We denote the frequency of the $k$'th class via $\pi_k=\mb{P}_{(\m{x},y)\sim\mc{P}}(y=k)$. Label/class-imbalance occurs when the class frequencies differ substantially, i.e.,~$\max_{i\in[K]} \pi_i >> \min_{i\in[K]} \pi_i$.  We define
\begin{subequations}
\begin{eqnarray}
&\text{Class-conditional risk:}~f_k:=\mb{E}_{\mc{P}_k}[\ell(y,\hat{y}_f(\m{x}))],\\
&\text{Balanced risk:}~~\text{BACC}(f):=\frac{1}{K}\sum_{k=1}^K f_k.
\end{eqnarray}
\end{subequations}
\subsection{ Parametric Losses} 
We review some of the SOTA re-weighting methods for training on imbalanced data with distribution shifts. 

Label-Distribution-Aware Margin (LDAM)  \citep{cao2019learning} determines optimal margins for each class by minimizing errors using a generalization bound. It utilizes $\Delta_j$ as the margin for each class, defined as follows:
% We review some of the state-of-the-art reweighting methods used for training on imbalanced data with distribution shifts. LDAM \citep{cao2019learning} gives optimal margins for each class based on reducing the error through a generalization bound. It uses $\Delta_j$ as the margin for each class as follows
\begin{equation}
    \begin{split}
& \ell_{\Delta}(f ; \m{x},y) = -\log \frac{e^{f(\m{x})_y - \Delta_y}}{e^{f(\m{x})_y - \Delta_y} + \sum_{j \neq y} e^{f(\m{x})_j}}, \qquad\\
& \qquad \text{where} \; \; \Delta_j= \frac{C}{n_j^{1/4}} \; \text{for} \; j \in \{1, \dots, K\}.
    \label{eq:ldam_loss}
    \end{split}
    \tag{LDAM}
\end{equation}
LDAM prioritizes classes with low sample sizes ($n_j$) over those with high frequencies. Deferred Re-Weighting (DRW)~\citep{cao2019learning}  involves training the model with an average loss until a certain epoch, then applying weights proportional to the inverse of the sample size to the loss term for each class. The loss function for DRW is as follows:
\begin{equation}\label{eqn:drw}
    \begin{split}
&\ell_{u}(f; \m{x},y) = -u_{y}\log  \frac{e^{f(\m{x})_y}}{\sum_{j = 1}^{K}   e^{f(\m{x})_y}},   \qquad \qquad\\
&  \qquad \text{where} \; u_{j} = \frac{1}{1 + (n_j - 1)\mathbbm{1}_{\text{epoch} \geq K} }.
    \end{split}
    \tag{DRW}
\end{equation}
This way of re-weighting has been shown to be effective for improving generalization performance when combined with various losses. 

Vector Scaling (VS)~\citep{kini2021label} loss is a recently proposed loss function that unifies the idea of multiplicative shift  \citep{ye2020identifying}, additive shift \citep{menon2020long}, and loss re-weighting. It has the following form:
\begin{equation}\label{eq:vs_loss}
\ell(f, \m{v}; \m{x},y)= -u_{y}\log  \frac{e^{\gamma_y f(\m{x})_y + \Delta_y}}{\sum_{j = 1}^{k}e^{\gamma_j f(\m{x})_j + \Delta_j}} . \tag{VS}
\end{equation}
Here, $\m{v}:=[\m{v}_1, \ldots, \m{v}_K]$, where $\m{v}_j:=(u_j,\Delta_j,\gamma_j)$ are some hyperparameters. 


Throughout, our main focus is on  VS loss, but our framework can also accommodate other loss functions.
%, followed by optimization techniques related to our work.
%\subsection{Long-Tailed Learning}
 %Re-sampling \cite{buda2018systematic} 
%  and Re-weighting \cite{5128907} are the most commonly used methods to train on class-imbalanced datasets. Oversampling the minority classes \cite{chawla2002smote} and undersampling the majority classes \cite{buda2018systematic} are two approaches to re-sampling. Oversampling leads to overfitting on the tail classes, and undersampling discards a large amount of data, which inevitably results in poor generalization. 
%  ~\citet{Kang2020Decoupling} proposed to decouple representation learning and classifier training to improve performance with the same. Mixup Shifted Label-Aware Smoothing model (MiSLAS) \cite{zhong2021improving} aims to improve the calibration of models trained on long-tailed datasets by mixup and label-aware smoothing and thereby improve performance. RIDE \cite{wang2021longtailed} and TADE \cite{zhang2021test} are ensemble-based methods that achieve state-of-the-art on the long-tailed visual recognition.  \citet{Samuel_2021_ICCV} introduces a new loss, DRO-LT, based on distributionally robust optimization for learning balanced feature representations. % We explore the problem of training class-imbalanced datasets through the lens of optimization and loss landscape. We will now describe some representative recent effective methods in detail, which we will use as baselines. Additional discussion on long-tailed learning methods is present in App. \ref{app:related_work}. 
\subsection{Fairness Notions} Next, we discuss fairness notions and their gaps.
\begin{defn}\label{def:fair:notions}
% A score function~$f$ maps the random variable $X$ to a real number. We say that
% \\
% $\bullet$ \textnormal{\textbf{Group Sufficiency (GS)}}.  $f$ is
% \emph{sufficient} with respect to attribute~$A$ if: $\mb{E}[Y\mid f(X)]=\mb{E}[Y\mid f(X), A]$.
% \\
% $\bullet$ \textnormal{\textbf{Demographic Parity (DP)}}. $f$ satisfies the demographic parity with respect to $A$ if: 
% $\mb{E}[f(X)] = \mb{E}[f(X)|A]$.
% \\
% $\bullet$ \textnormal{\textbf{Equalized Odds (EO)}}. $f$ satisfies the equalized odds with respect to $A$ if:
% $\mb{E}[f(X)|Y] = \mb{E}[f(X)|Y,A]$.
Let $f$ be a score function that maps the random variable $X$ to a real number.
\vspace{.15cm}
\\
 $\bullet$ 
\textnormal{\textbf{Group Sufficiency (GS)}}: We say that $f$ is sufficient with respect to attribute $A$ if $\mathbb{E}[Y|f(X)] = \mathbb{E}[Y|f(X), A]$.
\vspace{.15cm}
\\
 $\bullet$ \textnormal{ \textbf{Demographic Parity (DP)}}: $f$ satisfies demographic parity with respect to $A$ if $\mathbb{E}[f(X)] = \mathbb{E}[f(X)|A]$. 
 \vspace{.15cm}
 \\
$\bullet$  \textnormal{ \textbf{Equalized Odds (EO)}}: $f$ satisfies equalized odds with respect to $A$ if $\mathbb{E}[f(X)|Y] = \mathbb{E}[f(X)|Y,A]$. 
\end{defn}
GS means that the score function $f$ captures all the information about the label $Y$ that is relevant for prediction, regardless of the attribute $A$.  DP ensures that the expected score $f(X)$ remains constant, regardless of the attribute $A$. This principle guarantees that the distribution of scores remains unaffected by the sensitive attribute, thereby promoting fairness in the decision-making process. EO dictates that the expected score $f(X)$ remains consistent across all combinations of labels $Y$ and attributes $A$. It ensures that individuals sharing the same label but differing attributes are treated equally in terms of their predicted scores, irrespective of the sensitive attribute.

%DP means that the expected score $f(X)$ is the same regardless of the attribute $A$. ensures that the distribution of scores is independent of the sensitive attribute, promoting fairness in the decision-making process. EO means that the expected score $f(X)$ is the same for each combination of labels $Y$ and attributes $A$. It ensures that individuals with the same label but different attributes are treated equally in terms of their predicted scores, regardless of the sensitive attribute.
% The impossibility theorem of fairness states that outside of special cases, one cannot exactly and simultaneously satisfy all common and intuitive definitions of fairness. Specifically,  \citep{barocas-hardt-narayanan,chouldechova2017fair} show that if $A \not\!\perp\!\!\!\perp Y$, group sufficiency, and demographic parity could not be simultaneously achieved. Further, \citep{barocas-hardt-narayanan,pleiss2017fairness} reveal that if $\mc{P}(X,Y,A)>0$ and $A \not\!\perp\!\!\!\perp Y$, group sufficiency and demographic can not both hold. 
The impossibility theorem of fairness asserts that, in general cases, it is impossible to simultaneously achieve all common and intuitive definitions of fairness. Notably, \citep{barocas-hardt-narayanan, chouldechova2017fair} demonstrate that if $A \not\perp Y$, it is not feasible to achieve both group sufficiency and demographic parity. Moreover, \citep{barocas-hardt-narayanan, pleiss2017fairness} reveal that when $\mc{P}(X,Y,A)>0$ and $A \not\perp Y$, it is not possible for both group sufficiency and demographic parity to hold simultaneously.

Definition~\ref{def:fair:notions} leads to a notion of the \emph{group sufficiency gap}, \emph{demographic parity gap}, and \emph{equalized odds gap} defined, respectively, as:
\begin{subequations}\label{eqn:gaps:def}
\begin{align}
%\hspace{-.8cm}
\textbf{SGap}_f(A) &= \mb{E}[\left|\mb{E}[Y\mid f(X)]-\mb{E}[Y\mid f(X), A]\right|],\\ 
\textbf{PGap}_f(A) &= \mb{E}[\mb{E}[f(X)] - \mb{E}[f(X)|A]], \\
\textbf{OGap}_f(A) &= \mb{E}[\mb{E}[f(X)|Y] - \mb{E}[f(X)|Y,A]].
\end{align}    
\end{subequations}
$\textbf{SGap}_f$ measures the extent of group sufficiency violation, induced by the predictor $f$, which is taken by the expectation over $(X,A)$. Hence, $\textbf{SGap}_f = 0$ suggests that $f$ satisfies group sufficiency and vice versa. For completeness, we also discuss computing these gaps in Appendix. %Appendix~\ref{app:fair:notion}. 


To conclude this section, we provide Group $A$-Bayes predictor and an upper bound for $\textbf{SGap}_f$ from \citep{shui2022learning}. These findings serve as the foundation for our Bayesian-based tri-level optimization framework.
  
\begin{defn}[$A$-group Bayes predictor]\label{def:g:bays}
The $A$-group Bayes predictor $f^{A,\texttt{Bayes}}$ associated with distribution $\mc{P}(X, Y, A)$ is defined as: $f^{A,\texttt{Bayes}}(X) = \mb{E}[Y|X,A]$.
\end{defn}
%
The following Theorem provides the upper bound of group sufficiency gap w.r.t. any predictor $f$:
%
\begin{thm}\label{thm:uff:upp} If $A$ takes finite value, i.e.$|\mc{A}|<+\infty$ and follows uniform distribution with $p(A=a)=1/|\mc{A}|$, then
\begin{equation}\label{eqn:b:one}
 \textbf{SGap}_f(A)\leq \frac{4}{|\mc{A}|}\sum_{a \in\mc{A}}\mb{E}_{X}\left[|f-f^{A,\texttt{Bayes}}|\big|A=a\right].   
\end{equation}
\end{thm}
Hence, $\textbf{SGap}f(A)$ depends on the discrepancy between the predictor $f$ and the $A$-group Bayes predictor $f^{A,\texttt{Bayes}}$. In other words, when considering different subgroups $A=a$, the optimal predictor $f$ should closely align with all the group Bayes predictors $f^{A=a,\texttt{Bayes}}$, for all $a\in\mc{A}$.
%
\section{Proposed Framework}

% \begin{figure*}[t]
% \begin{center}
% \begin{tcolorbox}[enhanced, width=8.5cm, title= {\small \textsc{FACIMS}},colframe=green!3!black,colback=green!3!white,colbacktitle=orange!5!yellow!10!white,
% fonttitle=\bfseries,coltitle=black,attach boxed title to top center=
% {yshift=-0.25mm-\tcboxedtitleheight/2,yshifttext=2mm-\tcboxedtitleheight/2},
% boxed title style={boxrule=0.2mm,
% frame code={ \path[tcb fill frame] ([xshift=-4mm]frame.west)
% -- (frame.north west) -- (frame.north east) -- ([xshift=4mm]frame.east)
% -- (frame.south east) -- (frame.south west) -- cycle; },
% interior code={ \path[tcb fill interior] ([xshift=-2mm]interior.west)
% -- (interior.north west) -- (interior.north east)
% -- ([xshift=2mm]interior.east) -- (interior.south east) -- (interior.south west)
% -- cycle;} }]
% \begin{subequations}\label{eqn:sharp:FACIMS}
% %\leqnomode
% \vspace{-0.3cm}
% \begin{align}%\label{FACIMS}
% \nonumber
% \hspace{-.9cm}&\min_{\m{Z} \sim\mc{Z},\m{v}}~%\textcolor{blue}{\max_{||\epsilon||_2 \leq \alpha_{\rm up}}}
% %\textcolor{red}{\max_{\|\epsilon\| \leq \beta_{\rm out} }} 
% \sum_{a \in \mc{A}}  \alpha^{\rm up}  \textnormal{KL}(\m{Z}^{a,*}\|\m{Z})\\
% &+  \mb{E}_{\m{w}\sim \m{Z}}  \mc{L}_{\rm bal} (\tilde{f}_{\m{w}}; \mc{V}^a), \label{eqn:out:sharp:FACIMS}\\
% \nonumber
% &{\rm {\rm s.t.}}~\m{Z}^{a,*}= \arg\min_{\m{Z}^a\in\mc{Z}}~{\max_{\|\boldsymbol{\epsilon}^a\| \leq \beta^a }}~ \alpha^{\rm low}\ \text{KL}(\m{Z}^a\|\m{Z}) \\
% &+ \mb{E}_{\m{w}^a\sim \m{Z}^a} \mc{L}_{\rm vs}(\tilde{f}_{\m{w}^a +\boldsymbol{\epsilon}^a}, \m{v}; \mc{T}^a),
% %\arg\min_{\cdot^a}~\mathcal{L}(\m{Z}_{a}; \mc{T}^a)\\
% ~~\forall a \in \mc{A}.
% \label{eqn:in:sharp:FACIMS}
% \end{align}
% \end{subequations}
% \end{tcolorbox}
% \end{center}
% \end{figure*}
In this section, we present the formulation of FACIMS, which is a framework designed to promote both classification accuracy and fairness through a randomized algorithm. FACIMS achieves this by learning a predictive distribution $\m{Z}$, which assigns higher scores to predictors that are favorable based on the available data. In the context of the Bayes framework, the predictor is sampled from the posterior distribution, represented as $\tilde{f}\sim \m{Z}$. During the inference stage, the predictor's output is computed as the expectation of the learned predictive distribution $\m{Z}$:~$f(X)= \mb{E}_{\tilde{f}\sim \m{Z}} \tilde{f}(X)$.
%


In practical scenarios, it is infeasible to optimize over the entire space of possible distributions. Therefore, we constrain the predictive distribution $\mc{Z}$ to a specific distribution family $\m{Z} \in \mc{Z}$, such as the Gaussian distribution. Additionally, we denote $\m{Z}^{a,\star}\in\mc{Z}$ as the optimal prediction distribution with respect to the subgroup $A=a$ within the family $\mc{Z}$:
$$ 
\m{Z}^{a,*}=\arg\min_{\m{Z}^a\in\mc{Z}}~~\mb{E}_{{\tilde{f}}^a\sim \m{Z}^a} \mc{L}(\tilde{f}^a,  \m{v}; \mc{S}^a).$$ 
In general, $\m{Z}^{a, \star} \neq \tilde{f}^{a,\texttt{Bayes}}$, since the distribution family $\mc{Z}$ is only the subset of all possible distributions. 
\begin{cor}[\cite{shui2022learning}]
The group sufficiency gap in a randomized algorithm w.r.t. the learned predictive distribution $\m{Z}$ is bounded as follows:
\begin{align}\label{eqn:gap:sec:b}
\textbf{SGap}_{f} & \leq \mc{O}\left(\texttt{Optim} + \texttt{Approx}\right), 
% \\
% \textnormal{where}~~\texttt{Optim} &:=\frac{1}{|\mc{A}|}\sum_{a} \sqrt{\text{KL}(\m{Z}^{a,\star}\|\m{Z})},\\
%  \texttt{Approx}&:=\frac{1}{|\mc{A}|}\sum_{a}\sqrt{{\text{KL}(\m{Z}^{a,\star}\|\mc{P}(Y|X,A=a))}}.
\end{align}
where 
%\begin{subequations}
\begin{align*}
   \texttt{Optim} &:=\frac{1}{|\mc{A}|}\sum_{a} \sqrt{\text{KL}(\m{Z}^{a,\star}\|\m{Z})},\\
 \texttt{Approx}&:=\frac{1}{|\mc{A}|}\sum_{a}\sqrt{{\text{KL}(\m{Z}^{a,\star}\|\mc{P}(Y|X,A=a))}}.
\end{align*}
\end{cor}
% From \eqref{eqn:gap:sec:b}, minimizing the  \texttt{Optim} term implies that the learned distribution $\m{Z}$ will be both fair and informative for the prediction. On the other hand, \texttt{Approx} term is the KL divergence between the optimal distribution $\m{Z}^{a,\star}$ and $\mc{P}$. The \texttt{Approx} term will be small if, for example, the distribution family $\mathcal{Z}$ has a rich expressive power such as a deep neural network; see Figure \ref{fig:one} for an illustration.

Minimizing the \texttt{Optim} term ensures that the learned distribution $\m{Z}$ is both fair and informative for making predictions. On the other hand, the \texttt{Approx} term represents the KL divergence between the optimal distribution $\m{Z}^{a,\star}$ and $\mc{P}$. If the distribution family $\mathcal{Z}$ has a rich expressive power, like that of a deep neural network, the \texttt{Approx} term will be small. See Figure \ref{fig:one} for a visual representation.


Now, we provide a framework for fairness-aware class imbalanced learning on multiple sub-groups with potentially improved generalization bound and $\textbf{SGap}_{f}$. We begin with formulating the loss function design as a bilevel optimization over hyperparameters $\m{v}$ and a distribution $\mc{Z}$. Assume each group $a \in \mc{A}$ has a fine-tuning training set $\mc{T}^a= \cup^{n}_{i=1}\{(x_i^a,  y_i^a)\}$ and a separate validation set $\mc{V}^a=\cup^{n}_{i=1}\{(x_i^a,y_i^a)\}$, where data are independently and identically distributed (i.i.d.) and drawn from the per-task data distribution $\mathcal{P}^a$. Following \citep{li2021autobalance}, define the \emph{empirical risk} and the balanced \emph{empirical risk}  over a finite-sample dataset $\mathcal{S}$ as $\mathcal{L}_{\rm vs}(\tilde{f}, \m{v};\mathcal{S}) := 1/n\sum_{i=1}^{n}\ell(\tilde{f},\m{v}; \m{x}_i, y_i)$ and $\mathcal{L}_{\rm bal}(\tilde{f};\mathcal{S}) := 1/n \sum_{i=1}^{n} \ell (\tilde{f}, \bar{\m{v}}; \m{x}_i, y_i)$. Here, $\bar{\m{v}}$ can be manually adjusted using \eqref{eqn:drw}, \eqref{eq:ldam_loss}, and \eqref{eq:vs_loss}.
%and \eqref{eq:ldam_loss}.
%\\
% Consider a per datum loss $l:\cdot \times \m{Z} \times \mathcal{X} \times \mathcal{Y} \rightarrow \mathbb{R}_{+}$;
%and the  \emph{balanced risk} over as $\mathcal{L}_{\rm bal} (\cdot;{\mathcal{P}})=\mathbb{E}_{(x, y)\sim \mathcal{P}}[l(\cdot, x, y)]$. 
% For a particular group $m$, they become  $\mathcal{L} (\cdot;\mathcal{S}_m)$ or $\mathcal{L} (\cdot;\mathcal{S}^a')$ and $\mathcal{L} (\cdot;{\mathcal{P}_m})$.
% Let $\mc{S}^a=\{(x^{a}_i,y^{a}_i)\}_{i=1}^{m}$ as the observed data w.r.t. subgroups $A=a$, which are i.i.d. samplings from the underlying  distribution $\mc{D}(x,y|A=a)$. %We also denote the {empirical} loss w.r.t $A=a$ as: $\mc{L}^{\mc{S}}^a(f)$
% \begin{center}
% \begin{tcolorbox}[enhanced, width=8.2cm, title= {\small \textsc{FACIMS}},colframe=green!3!black,colback=green!3!white,colbacktitle=orange!5!yellow!10!white,
% fonttitle=\bfseries,coltitle=black,attach boxed title to top center=
% {yshift=-0.25mm-\tcboxedtitleheight/2,yshifttext=2mm-\tcboxedtitleheight/2},
% boxed title style={boxrule=0.2mm,
% frame code={ \path[tcb fill frame] ([xshift=-4mm]frame.west)
% -- (frame.north west) -- (frame.north east) -- ([xshift=4mm]frame.east)
% -- (frame.south east) -- (frame.south west) -- cycle; },
% interior code={ \path[tcb fill interior] ([xshift=-2mm]interior.west)
% -- (interior.north west) -- (interior.north east)
% -- ([xshift=2mm]interior.east) -- (interior.south east) -- (interior.south west)
% -- cycle;} }]
%\\
% \end{tcolorbox}
% \end{center}
 %Let $\m{Z}$ stand for a \emph{prior} information for minimizing the loss. 
% \begin{center}
% \begin{tcolorbox}[enhanced, width=9cm, title= {\small \textsc{FedBLO}},colframe=green!3!black,colback=green!3!white,colbacktitle=orange!5!yellow!10!white,
% fonttitle=\bfseries,coltitle=black,attach boxed title to top center=
% {yshift=-0.25mm-\tcboxedtitleheight/2,yshifttext=2mm-\tcboxedtitleheight/2},
% boxed title style={boxrule=0.2mm,
% frame code={ \path[tcb fill frame] ([xshift=-4mm]frame.west)
% -- (frame.north west) -- (frame.north east) -- ([xshift=4mm]frame.east)
% -- (frame.south east) -- (frame.south west) -- cycle; },
% interior code={ \path[tcb fill interior] ([xshift=-2mm]interior.west)
% -- (interior.north west) -- (interior.north east)
% -- ([xshift=2mm]interior.east) -- (interior.south east) -- (interior.south west)
% -- cycle;} }]
% \begin{subequations}\label{eq:map:bilevel}
% \begin{align}
% \hspace{-.7cm}\m{q}^m(\m{x},\m{Z})&=
% \begin{bmatrix}
%  \nabla_\m{w} g^m(\m{x}, \m{w})\\
%     \nabla_{\m{w}}^2 g^m(\m{x}, \m{w}) \m{v} - \nabla_{\m{x}} f^m(\m{x},\m{w})
% \end{bmatrix}    
% \\
% \hspace{-.7cm}\m{h}^m(\m{x},\m{Z}) &= \nabla_\m{x} f^m(\m{x},\m{w}) -\nabla_{\m{xw}}^2 g^m(\m{x},\m{w})\m{v}.    
% \end{align}
% \end{subequations}
% \end{tcolorbox}
% \end{center}
% \begin{tcolorbox}[enhanced, width=7cm, title= {\small \textsc{FedBLO}},colframe=green!3!black,colback=green!3!white,colbacktitle=orange!5!yellow!10!white,
% fonttitle=\bfseries,coltitle=black,attach boxed title to top center=
% {yshift=-0.25mm-\tcboxedtitleheight/2,yshifttext=2mm-\tcboxedtitleheight/2},
% boxed title style={boxrule=0.2mm,
% frame code={ \path[tcb fill frame] ([xshift=-4mm]frame.west)
% -- (frame.north west) -- (frame.north east) -- ([xshift=4mm]frame.east)
% -- (frame.south east) -- (frame.south west) -- cycle; },
% interior code={ \path[tcb fill interior] ([xshift=-2mm]interior.west)
% -- (interior.north west) -- (interior.north east)
% -- ([xshift=2mm]interior.east) -- (interior.south east) -- (interior.south west)
% -- cycle;} }]
% %\leqnomode
% \vspace{-0.3cm}
% \begin{align}\label{SAM_MAML}
% \hspace{-2cm}\tag{\bf P}\hspace{-1cm}& ~~~~~~~ \min_{\cdot}~~~\, \text{\sf (upper)}\hspace{-2cm}\\
% &{\rm {\rm s. t.}}~~\cdot_{m}^{*}(\cdot) = \arg \min_{\cdot_m}  \textcolor{red}{\max_{||\epsilon_m||_2 \leq \beta_{\rm low}}} \!\mathcal{L} (\cdot_{m}+\textcolor{red}{\epsilon_m}; \mathcal{S}^a)\nonumber\\
% &~~~~~~~~~~~~~~~~ ~+ \frac{\|\cdot_m-\cdot\|^2}{2\beta_{\rm low}},~  {m}=1,...,M.~~~~~~\text{\sf (lower)} \nonumber
% \end{align}
% \end{tcolorbox}
% \begin{subequations}
% \begin{align}
% &\min_{\m{Z}\in\mathcal{Z}} \frac{1}{|\mc{A}|}\sum_{a}\text{KL}(\overline{\m{Z}}_a^{\star}\|\m{Z}) \qquad \qquad  \qquad  \\
%      \nonumber
%  \text{s.t.}~~~~~~\overline{\m{Z}}_a^{\star} &= \text{argmin}_{\m{Z}_a\in\mc{Z}}~\mb{E}_{{f}^a\sim \m{Z}_a} \mc{L}^{\mc{S}}_a(f_a)~~~~~ \\
%      &+ \alpha^{\rm low} \text{KL}(\m{Z}_a\|\m{Z})\}, ~~\forall a\in\mc{A} 
%      \label{eqn:in:fims}
% \end{align}
% \end{subequations}
%

Let  $\m{Z}$ stand for both \textit{fair} and \textit{informative} prediction. Building on \citep{kini2021label,li2021autobalance,shui2022learning}, we design the following objective:  
\begin{subequations}\label{eqn:FACIMS}
%\leqnomode
\begin{align}%\label{FACIMS}
&\min_{\m{Z} \sim\mc{Z},\m{v}}~\sum_{a \in \mc{A}}  \alpha^{\rm up} \textnormal{KL}(\m{Z}^{a,*}\|\m{Z}) + \mb{E}_{{\tilde{f}} \sim \m{Z}}  \mathcal{L}_{\rm bal} (\tilde{f}; \mc{V}^a), 
\label{eqn:out:FACIMS}\\
&{\rm {\rm s.t.}}~~\m{Z}^{a,*} =  \arg\min_{\m{Z}^a\in\mc{Z}}~
\alpha^{\rm low}\textnormal{KL}(\m{Z}^a\|\m{Z}) \label{eqn:in:FACIMS}\\
\nonumber
& \qquad \qquad  + \mb{E}_{{\tilde{f}}^a\sim \m{Z}^a} \mc{L}_{\rm vs}(\tilde{f}^a, \m{v}; \mc{T}^a),~~\forall a \in \mc{A}.
\end{align}
\end{subequations}
Here, the lower-level problem \eqref{eqn:in:FACIMS} includes a regularization term $\textnormal{KL}(\m{Z}^a|\m{Z})$ as an informative prior for learning local predictor $\m{Z}^{a,\star}$ with a fixed predictive distribution $\m{Z}$. This optimization reduces the upper bound of the group-specific generalization error.

In the upper-level problem \eqref{eqn:out:FACIMS}, we update $\m{Z}$ by minimizing the average KL-divergence between different $\m{Z}^{a,\star}$, controlling the upper bound of $\textbf{SGap}_f$ according to \eqref{eqn:gap:sec:b}, as well as the balanced empirical risk. However, directly minimizing \eqref{eqn:in:FACIMS} in a single-level approach does not work well in our setting due to the limited number of samples in each subgroup. This leads to overfitting and large generalization error for each subgroup. To address this, we consider additional assumptions, such as the similarity in the data generation distribution $\mc{P}[Y|X,A]$ for each subgroup. With these assumptions, we can learn shared and fair models that are informative and sufficient for a large number of subgroups.
% Here, the lower problem~\eqref{eqn:in:FACIMS} minimize the empirical term $\m{Z}^{a,\star}$ for each $a\in\mc{A}$. The loss in lower-level adds a regularization term $\textnormal{KL}(\m{Z}^a\|\m{Z})$ as an \textit{informative prior} in learning $\m{Z}^{a,\star}$, given a fixed predictive-distribution $\m{Z}$. Note that optimizing the lower-level loss is to minimize the upper bound of the group-specific generalization error. In the upper-level problem \eqref{eqn:out:FACIMS}, $\m{Z}$ is updated by minimizing the average KL-divergence between different $\m{Z}^{a,\star}$ that controls the upper bound of $\textbf{SGap}_f$ according to \eqref{eqn:gap:sec:b} as well as the balanced empirical risk. We note that an alternative approach is to minimize  \eqref{eqn:in:FACIMS}. Then, $\m{Z}$ can be updated through minimizing the average KL-divergence: $\sum_{a \in \mc{A}}\textnormal{KL}(\m{Z}^{a,*}\|\m{Z})$ from learned $\m{Z}^{a,\star}$. However, this single-level idea generally \emph{does not work} in our setting, because each subgroup contains \textbf{limited} number of samples. Therefore, a straightforward minimization leads to overfitting for each subgroup and the generalization error will be very large. On the other hand, without any assumptions, it is impossible to learn a fair and informative predictor from limited samples for each subgroup. As a result, additional assumptions will be considered, such as the data generation distribution of each subgroup  $\mc{P}[Y|X,A]$ being similar. Then we could learn shared and fair (in terms of group sufficiency) models from a large number of subgroups.
%see Figure~\ref{fig:gen:bound}. 
%
%\subsection{$\m{Z}$ as an informative prior}\label{sec:bilevel}
% We have demonstrated that $\m{Z}$ can achieve both fair and informative prediction. Therefore, we regard $\m{Z}$ as a \emph{prior} information for minimizing the loss, yielding a bilevel objective.
% \begin{align*}
%      & \min_{\m{Z}\in\mathcal{\m{Z}}} \frac{1}{|\mc{A}|}\sum_{a}\text{KL}(\overline{\m{Z}}_a^{\star}\|\m{Z}) \tag*{(\textcolor{cyan}{Upper-level})}\\
%      & \text{s.t.}~~\overline{\m{Z}}_a^{\star} = \text{argmin}_{Q_a\in\mathcal{\m{Z}}} \{\E_{\tilde{f}_a\sim Q_a} ~\hat{\mc{L}}(\tilde{f}_a) + \lambda \text{KL}(Q_a\|\m{Z})\}, \forall a\in\mc{A} \tag*{(\textcolor{brown}{Lower-level})}
% \end{align*}
% Where $\lambda>0$ is the hyper-parameter. The proposed loss is a typical bilevel optimization. \textcolor{brown}{(1)} In the lower-level, we aim to learn $\overline{\m{Z}}_a^{\star}$ for each $a\in\mathcal{A}$. Different from Eq.~(\ref{eq:erm}), the loss in lower-level adds a regularization term $\text{KL}(Q_a\|\m{Z})$ as an informative prior in learning $\overline{\m{Z}}_a^{\star}$, given a fixed predictive-distribution $\m{Z}$. Moreover, Theorem \ref{thm:pac_bayes} formally justified that optimizing the lower-level loss is to minimize the upper bound of the generalization error. \textcolor{cyan}{(2)} In the upper-level, $\m{Z}$ is updated through minimizing the average KL divergence between different $\overline{\m{Z}}_a^{\star}$, which controls the upper bound of $\textbf{SGap}_f$.
%We formulate the loss function design as a bilevel optimization over {hyperparameters} $\m{Z}$ and a hypothesis set $\mc{F}$. We split the dataset $\mc{S}$ into training $\mc{S}^t$ and validation $\mc{S}^v$ sets with $n^t$ and $n^v$ examples respectively. Let $\mc{E}^t$ be the desired test-error objective. When $\mc{E}^t$ is not differentiable, we use a weighted cross-entropy loss function $\ell(y,\hat{y})$ chosen to be consistent with $\mc{E}^t$. 
% For instance, $\mc{E}^t$ could be a superposition of standard and balanced classification errors, i.e. $\Eout=(1-\la) {\cal{E}}+\la{\cal{E}}_{\text{bal}}$. Then, we simply choose $\lout=(1-\la) \CE+\CE_{\text{bal}}$. 
%We also note that if the randomized prediction is applied to the group-specific $\tilde{f}^a$, i.e., $\alpha^{\rm low} > 0$  and $\alpha^{\rm up}=0$, and the parametric VS loss is replaced with standard CE, then {FACIMS} reduces to FAMS framework \citep{shui2022learning} for group-sensetive classification on multiple subgroups. 
%$\bullet$ \textbf{FACIMS-II}: randomized prediction is applied to only the meta-update step, i.e., $\alpha_{\rm low} = 0$ and $\alpha_{\rm up}>0$.
%{\bf (c) FACIMS: SAM is applied to both fine-tuning and meta-update steps, i.e., $\alpha_{\rm up}$, $\alpha_{\rm low} > 0$.
%\end{rem}
\begin{figure*}[t]
\begin{center}
\begin{tcolorbox}[enhanced, width=16cm, title= {\small \textsc{FACIMS}},colframe=green!3!black,colback=green!3!white,colbacktitle=orange!5!yellow!10!white,
fonttitle=\bfseries,coltitle=black,attach boxed title to top center=
{yshift=-0.25mm-\tcboxedtitleheight/2,yshifttext=2mm-\tcboxedtitleheight/2},
boxed title style={boxrule=0.2mm,
frame code={ \path[tcb fill frame] ([xshift=-4mm]frame.west)
-- (frame.north west) -- (frame.north east) -- ([xshift=4mm]frame.east)
-- (frame.south east) -- (frame.south west) -- cycle; },
interior code={ \path[tcb fill interior] ([xshift=-2mm]interior.west)
-- (interior.north west) -- (interior.north east)
-- ([xshift=2mm]interior.east) -- (interior.south east) -- (interior.south west)
-- cycle;} }]
\begin{subequations}\label{eqn:sharp:FACIMS}\vspace{-0.3cm}
\begin{align} 
&\min_{\m{v}, \m{Z} \sim\mc{Z}}~~~\sum_{a \in \mc{A}}  \alpha^{\rm up}  \textnormal{KL}\left(\m{Z}^{a,*}\|\m{Z}\right)+  \mb{E}_{\m{w}\sim \m{Z}}~\mc{L}_{\rm bal} (\tilde{f}_{\m{w}}; \mc{V}^a), \label{eqn:out:sharp:FACIMS}\\
&~~{\rm{\rm s.t.}}~~~\m{Z}^{a,*} \in \argmin_{\m{Z}^a\in\mc{Z}}~{\max_{\|\boldsymbol{\epsilon}^a\| \leq \beta^a }}~ \alpha^{\rm low}\ \text{KL}(\m{Z}^a\|\m{Z}) + \mb{E}_{\m{w}^a\sim \m{Z}^a}~\mc{L}_{\rm vs}(\tilde{f}_{\m{w}^a +\boldsymbol{\epsilon}^a}, \m{v}; \mc{T}^a), ~~~~\forall a \in \mc{A}.
\label{eqn:in:sharp:FACIMS}
\end{align}
\end{subequations}
\end{tcolorbox}
\end{center}
%\vspace{-.5cm}
\end{figure*}
\subsection{Parametric Models and FACIMS}



\begin{algorithm*}[!t]
\caption{ An Alternating Stochastic Gradient Method for FACIMS
}
\begin{algorithmic}[1] 
\State{\bfseries Input:} VS loss hyperparameters $\m{v}_0$; distribution  parameters $(\g{\theta}_0,\boldsymbol\sigma_0)$ and  $(\g{\theta}_0^a,\boldsymbol\sigma_0^a)$ for all $a \in \mc{A}$; regularization parameters $(\alpha^{\rm low}, \alpha^{\rm up})$; sharpness parameters $\{\beta^a\}_a$; and stepsizes  $ (\gamma^{\rm up},\gamma^{\rm low})$.
\For{$t =0 \ldots T-1$}
\State Sample dataset $S^a_t=(\mc{T}^a_t, \mc{V}^a_t)$, where $a\in \mc{A}^{\prime}\subseteq \mc{A}$.
\For{ $a \in \mc{A}$ }
\State Update $\g{\epsilon}^a$ \textcolor{brown}{in the lower level}:
\begin{equation}
  \g{\epsilon}^a_{t+1} \longleftarrow	 {\argmax_{\|\boldsymbol{\epsilon}^a\| \leq \beta^a }}~ \alpha^{\rm low}\ \text{KL}(\m{Z}^a_t\|\m{Z}_t) + \mb{E}_{\m{w}^a\sim \m{Z}^a_t} \mc{L}_{\rm vs}(\tilde{f}_{\m{w}^a +\boldsymbol{\epsilon}^a}, \m{v}_t; \mc{T}^a_t).  
\end{equation}
\State Update $\m{Z}^a=\mc{N} (\g{\theta}^{a},\boldsymbol\sigma^{a})$ using SGD (with step size $\gamma^{\rm low}$)  \textcolor{purple}{in the middle level}:
\begin{align}
\m{Z}^a_{t+1} \longleftarrow	 \argmin_{\m{Z}^a\in\mc{Z}}~~ \alpha^{\rm low}\ \text{KL}(\m{Z}^a\|\m{Z}_t) + \mb{E}_{\m{w}^a\sim \m{Z}^a} \mc{L}_{\rm vs}(\tilde{f}_{\m{w}^a +\boldsymbol{\epsilon}^a_{t+1}}, \m{v}_t; \mc{T}^a_t).  
\end{align}
\EndFor        
\State Update  $\m{Z}=\mc{N} (\g{\theta},\boldsymbol\sigma)$ and $\m{v}$ using SGD (with step size $\gamma^{\rm up}$) \textcolor{cyan}{in the upper level}:
\begin{equation}
({\m{Z}_{t+1},\m{v}_{t+1}})\longleftarrow	\argmin_{\m{Z} \sim\mc{Z},\m{v}}~
\sum_{a \in \mc{A}}  \alpha^{\rm up}  \textnormal{KL}(\m{Z}^{a}_{t+1}\|\m{Z})+  \mb{E}_{\m{w}\sim \m{Z}}  \mc{L}_{\rm bal} (\tilde{f}_{\m{w}}; \mc{V}^a_t).    
\end{equation} 
        \EndFor
\State\textbf{Return:}~~ $(\m{v}_T,\g{\theta}_T,\boldsymbol\sigma_T)$.
        \end{algorithmic}
        \label{alg:FACIMS}
\end{algorithm*}



In this section, we propose a practical learning algorithm applicable to various differentiable and parametric models, including neural networks. 

We utilize the Isotropic Gaussian distribution as $\mc{Z}$ to learn global informative $\m{Z}$ with parameters $(\g{\theta},\g{\sigma})$. For each subgroup $A=a$, we also learn group-specific parameters $(\g{\theta}^{a},\g{\sigma}^a)$ for $\m{Z}^{a}$ in $\mc{Z}$. The Isotropic Gaussian distribution is selected for computational efficiency in optimization, but other differentiable distributions can also be used for parameter density functions. 

Given a training set, we learn $\tilde{f}_{\m{w}}:\mc{X}\mapsto\mc{Y}$ parameterized by $\m{w}\in\mb{R}^d$. Then $\tilde{f}_{\m{w}} \sim \m{Z}$ is equivalent to sampling the model parameter $\bf{w}$ from the predictive-distribution $\m{Z}$. Hence, learning the distribution $\m{Z}$ is equivalent to learning parameter $(\g{\theta},\boldsymbol\sigma)$.  Note that for each subgroup  $A=a$, $\tilde{f}_{\m{w}}^a \sim {\m{Z}}^{a}$ can be modeled similarly. Both procedures can be formulated as follows: 
\begin{align*}
&\mathbf{w} \sim \mc{N}(\g{\theta},\boldsymbol\sigma) = \prod_{i=1}^{d}\mc{N}(\g{\theta}[i],\boldsymbol\sigma[i]), \\
&\mathbf{w}^a \sim \mc{N}(\g{\theta}^a,\boldsymbol\sigma^{a}) = \prod_{i=1}^{d} \mc{N}(\g{\theta}^{a}[i],\boldsymbol\sigma^{a}[i]),~~\forall a \in \mc{A}.
\end{align*}
To enhance the convergence to a flat minimum and effectively avoid saddle points for minority classes, we integrate the sharpness-aware minimization (SAM) algorithm \citep{foret2020sharpness} into \eqref{eqn:in:FACIMS}. SAM is a recently introduced technique that improves generalization performance by jointly minimizing the loss value and the loss sharpness, leveraging the geometry of the loss landscape. Given a perturbation parameter $\beta >0$ and the empirical risk $\mathcal{L}(\tilde{f}_{\m{w}}; \mc{S})$, the goal of training is to choose $\m{w}$ having low population loss $\mathcal{L} (\tilde{f}_{\m{w}};{\mathcal{P}})$. SAM achieves this via the problem
\begin{align}\label{sam}
\min_{\tilde{f}_{\m{w}}}~~~\max_{\|\epsilon\|\leq \beta}  \mathcal{L}(\tilde{f}_{\m{w} + \g{\epsilon}};\mathcal{S}).
\tag{SAM}
\end{align}
Given $\m{w}$, the maximization in \eqref{sam} seeks to find the weight perturbation $\g{\epsilon}$ in the Euclidean ball that maximizes the empirical loss. If we define the \textit{sharpness} as
\begin{equation*}
    \max_{\|\epsilon\| \leq \beta}~ \left[\mathcal{L}(\tilde{f}_{\m{w} + \g{\epsilon}};\mathcal{S})-\mathcal{L}(\tilde{f}_{\m{w}};\mathcal{S})\right]
\end{equation*}
then \eqref{sam} essentially minimizes the sum 
of the sharpness and the empirical loss $\mathcal{L}(\tilde{f}_{\m{w}};\mathcal{S})$. 

We incorporate \eqref{sam} into \eqref{eqn:in:FACIMS} and propose \eqref{eqn:sharp:FACIMS} by introducing a set of positive constants $\{\beta^a\}_{a \in \mc{A}}$. The FACIMS framework, combined with SAM, promotes convergence to a flat minimum and aids in escaping saddle points for minority classes \citep{rangwani2022escaping}. We empirically demonstrate the superiority of integrating SAM into FACIMS over popular baselines and provide theoretical evidence suggesting improved generalization bounds. Despite the tri-level problem formulation in \eqref{eqn:sharp:FACIMS}, our algorithm design efficiently approximates the maximization step, making the computational cost comparable to that of \eqref{eqn:FACIMS}.

Based on the analysis and \eqref{eqn:sharp:FACIMS}, we provide an alternating optimization algorithm for solving \eqref{eqn:sharp:FACIMS} in Algorithm~\ref{alg:FACIMS}. Line 3 provides a \textit{partial group setting}, i.e., for many subgroups, we can randomly sample a subset $\mc{A}^{'}$ such that $|\mc{A}^{'}| << |\mc{A}|$ for memory saving. %For further details, we refer to Appendix~\ref{app:alg:details}.   




\section{Theoretical Analysis of FACIMS}\label{sec:gen}


Next, we analyze the performance of the FACIMS method. %in terms of the convergence rate and the generalization error. 

%\subsection{Optimization analysis}
% For simplicity, we set $\beta^a =\beta$ for all $a \in\mc{A}$. We  replace $\mc{L}_{\rm bal}$ and $\mc{L}_{\rm vs}$ with $\mc{L}$. We also use $\widetilde{\nabla}\mathcal{L}$ to denote the stochastic gradients of $\mc{L}$. Define $\g{\Theta}:=( \g{\theta},\g{\sigma}, \m{v})$ and let $F$ denote the objective of \eqref{eqn:out:sharp:FACIMS}.
To simplify, we assume $\beta^a =\beta$ for all $a \in\mc{A}$ and combine $\mc{L}_{\rm bal}$ and $\mc{L}_{\rm vs}$ into a single notation $\mc{L}$. We represent the stochastic gradients of $\mc{L}$ as $\widetilde{\nabla}\mathcal{L}$. Let $\g{\Theta}:=( \g{\theta},\g{\sigma}, \m{v}) \in \mb{R}^p $ and $F$ be the objective of \eqref{eqn:out:sharp:FACIMS}.

\begin{assumption}[Lipschitz continuity]\label{li-lip}
Assume that $\mathcal{L}(\cdot;\mc{V}^a),\nabla\mathcal{L}(\cdot;\mc{T}^a),\nabla\mathcal{L}(\cdot;\mc{V}^a),\nabla^2\mathcal{L}(\cdot;\mc{T}^a)$, $\forall a \in \mc{A}$ are Lipschitz continuous with constant $\ell_0,\ell_1,\ell_1,\ell_2$.
\end{assumption}

\begin{table*}[t!]
\setlength{\tabcolsep}{3pt}
\centering
\small
\caption {\label{tab:datainfo} Statistical summary of the datasets including class and sensitive feature information.}
\begin{tabular}{c|c|c|c|c|c|c} 
%\sett
\hline
\textbf{Dataset} & \textbf{\#Instance} & \textbf{\#Features} & \textbf{Class}  & \textbf{Class Distr.} & \textbf{\makecell{Sensitive \\ Feature}} & \textbf{Sensitive Feature Distr.} \\
\hline
\makecell{Alzheimer's \\ Disease} & 5137 & 17 & AD / MCI & 21\% / 79\% & Race & 93.75\% / 3.20\% / 1.88\% / 1.17\% \\ 
\hline
\makecell {Credit \\ Card} & 30,000 & 22 & \makecell{Credible / \\  Not Credible} & 22\% / 77\%  & \makecell{Education \\ Level}  & \makecell{46.77\% / 35.28\% / 16.39\% / \\0.93\% / 0.41\% / 0.17\% / 0.05\%}  \\
\hline
\makecell{Drug \\ Consumption} & 1885 & 9 & \makecell{Never used / \\ Not used in the past year / \\ Used in the past year / \\ Used in the past day} & \makecell{ 1.81\% / \\  5.41\% / \\ 65.98\% / \\  26.80\%} & \makecell{Education \\ Level} & \makecell{6.74\% / 6.90\% / 26.86\% / \\ 14.28\% / 25.48\% / 15.02\% / 4.72\% } \\
\hline
\end{tabular}
\end{table*}


\begin{assumption}[Stochastic derivatives]\label{variance}
Assume that $\widetilde\nabla\mathcal{L}(\cdot;\mc{T}^a),\widetilde\nabla^2\mathcal{L}(\cdot;\mc{T}^a),\widetilde\nabla\mathcal{L}(\cdot;\mc{V}^a)$ are unbiased estimator of $\nabla\mathcal{L}(\cdot;\mc{T}^a),\nabla^2\mathcal{L}(\cdot;\mc{T}^a),\nabla\mathcal{L}(\cdot;\mc{V}^a)$ respectively and their variances are bounded by $\sigma^2$.
\end{assumption}
Assumptions \ref{li-lip}--\ref{variance} also appear similarly in the convergence analysis of and bilevel optimization \citep{chen2021closing,tarzanagh2022fednest,abbas2022sharp}.
With the above assumptions, we get the following theorem. The proof is deferred in Appendix.
\begin{thm}\label{thm:con:FACIMS}
Under Assumptions~\ref{li-lip}--\ref{variance}, and choosing stepsizes  $\gamma^{\rm up}=\gamma^{\rm low}=\mc{O}(1/\sqrt{T})$ and sharpness parameter $\beta={\cal O}(1)$,  with some proper constants, we can get that the iterates $\{(\g{\theta}_t,\boldsymbol\sigma_t, \m{v}_t)\}_t$ generated by Algorithm~\ref{alg:FACIMS} satisfy
\begin{equation}
\frac{1}{T}\sum_{t=1}^T\mathbb{E}\left[\left\|\nabla F( \g{\theta}_t,\g{\sigma}_t, \m{v}_t)\right\|^2\right]={\cal O}\left(\frac{1}{\sqrt{T}}\right).
\end{equation}
%where $F(\boldsymbol \g{\theta},\boldsymbol\sigma, \m{v})$ is the objective of  \eqref{eqn:in:sharp:FACIMS}.
\end{thm}

Theorem~\ref{thm:con:FACIMS} implies that under some standard assumption,  Algorithm~\ref{alg:FACIMS} can find $\epsilon$ stationary points for FACIMS objective \eqref{eqn:sharp:FACIMS} with ${\cal O}(\epsilon^{-2})$ iterations and ${\cal O}(\epsilon^{-2})$ samples.

% To analyze the generalization error of FACIMS, we make similar assumptions  
% Recall the \emph{population loss} $\mathcal{L} (\cdot;{\mathcal{P}})=\mathbb{E}_{(x, y)\sim \mathcal{P}}[l(\cdot, x, y)] $. Denote the stationary point obtained by  Sharp-FACIMS algorithm as $\hat{\cdot}$. 
%
%

% \begin{defn}[\cite{hardt2016train}]
% An  algorithm $A$ is $\gamma$-uniformly stable if for all data sets $S, S^{\prime} \in Z^{n}$ such that $S$ and $S^{\prime}$ differ in at most one example, we have
% \begin{equation}
%   \sup_S \left|\mathbb{E}_{S}\left[l\left(A(S) ; x,y\right)-l\left(A\left(S^{\prime}\right) ; x,y\right)\right]\right| \leq \gamma
% \end{equation}
% where $A(S)$ and $A(S^{\prime})$ are the outputs of the algorithm $A$ given datasets $S$ and $S^{\prime}$.
% \end{defn}
% \TC{@Lisha, please define $f(\cdot, z)$?}


 

Next, we establish the generalization performance. 
\begin{thm}%[FACIMS]
\label{thm:out:gen:bound}
Assume the function $\mathcal{L}(\cdot)$ is bounded for any $\mathcal{S}$. %Suppose $\m{Z}^a\in\mathcal{Z}$ is any learned distribution from dataset $\m{S}^a$ and $\m{Z}\in\mathcal{Q}$ is any distribution. 
Let $F(\cdot ; \mathcal{P})=\mathbb{E}_{\mathcal{S} \sim \mathcal{P}}\left[F(\cdot ; \mathcal{S})\right]$. %where $F$ is the (upper level) objective of  \eqref{eqn:in:sharp:FACIMS}
Assume  $F(\hat{\g{\theta}} ,\hat{\g{\sigma}}, \hat{\m{v}} ; \mathcal{P}) \leq \mathbb{E}_{\epsilon \sim \mc{N}\left(0, \beta^{2} \mb{I}\right)}\left[F(\hat{\g{\theta}} +\epsilon,\hat{\g{\sigma}} +\epsilon, \hat{\m{v}}; \mathcal{P})\right]$ at the stationary point of \eqref{eqn:sharp:FACIMS} denoted by $(\hat{\g{\theta}} ,\hat{\g{\sigma}}, \hat{\m{v}})$. Then, with probability $1-\delta$ over the choice of the training set $\mathcal{S} \sim \mathcal{P}$, with $|\mathcal{S}| = n|\mc{A}|$, we have
\begin{align}\label{eqn:gen:bnd}
\nonumber 
& \quad F(\hat{\g{\Theta}}; \mathcal{P})  \leq \max_{\|\epsilon\| \leq \beta }\mathbb{E}_{\epsilon \sim \mc{N}\left(0, \beta^{2} \mb{I}\right)}\left[F( \hat{\g{\theta}} +\epsilon, \hat{\g{\sigma}} +\epsilon, \hat{\m{v}}; \mathcal{S})\right] \\
& \leq \mc{O} \left( \left(\frac{p \ln A_{\beta} +\ln \frac{1}{\delta} + \ln (n|\mc{A}|)}{n|\mc{A}|}\right)^{\frac{1}{2}}\right).
% \sqrt{\frac{p \ln \Big(1+\frac{\|\hat{\g{\Theta}}\|_{2}^{2}}{ \beta^{2}} \Big(1+\sqrt{\frac{\ln (n|\mc{A}|)}{p}}\Big)^{2} \Big)}{4n|\mc{A}|}}\\
% & + 5\sqrt{\frac{\ln \frac{1}{\delta} + \ln (n|\mc{A}|)}{4n|\mc{A}|}}.
\end{align}
% \begin{align}\label{eqn:gen:bnd}
% & \leq  \sqrt{\frac{p \ln A_{\beta} +\ln \frac{1}{\delta} + \ln (n|\mc{A}|)}{n|\mc{A}|}}% & + 5\sqrt{\frac{\ln \frac{1}{\delta} + \ln (n|\mc{A}|)}{4n|\mc{A}|}}.
% \end{align}
Here,  $A_{\beta}:=1+\frac{\|\hat{\g{\Theta}}\|_{2}^{2}}{ \beta^{2}} \Big(1+\sqrt{\frac{\ln (n|\mc{A}|)}{p}}\Big)^{2}$.% and $\varrho_{\textnormal{SGD}}$ denotes the uniform stability of SGD \citep{hardt2016train}. 
% \begin{align}\label{eqn:gen:bnd}
% \nonumber 
% & \leq  \mc{O} (\frac{p \ln (1+\frac{\|\hat{\g{\Theta}}\|_{2}^{2}}{ \beta^{2}}) + \ln \frac{1}{\delta} + \ln (n|\mc{A}|)}{n|\mc{A}|}%\\
% %& %+ \sqrt{\frac{\ln \frac{1}{\delta} + \ln (n|\mc{A}|)}{n|\mc{A}|}}.
% \end{align}

\end{thm}

Theorem~\ref{thm:out:gen:bound} shows that the difference between the population loss and the empirical loss of FACIMS is bounded by %the stability of the lower-level update $\gamma_A$ and another
$\tilde{\mathcal{O}}(p/n |\mc{A}|)$. Note that the bound in \eqref{eqn:gen:bnd} is a function of $\beta$. Hence, for a choice of $\beta \rightarrow 0$,  the bound \eqref{eqn:gen:bnd} is not optimal. This suggests that tri-level FACIMS can have better generalization performance than that from bilevel variants such as \citep{li2021autobalance,shui2022learning}.

\begin{table*}[t]
\setlength{\tabcolsep}{6.4pt}
\small
\centering
\caption{\label{tab:results}Numerical results (mean ± standard deviation ) for 5 repeats of different methods on Alzheimer’s disease (AD) and Credit Card (CC) datasets regarding six measurements. Time is in the format of ``hours:minutes:seconds''. FACIMS-\Romannum{1} means FACIMS ($\beta=0$), and FACIMS-\Romannum{2}  means FACIMS ($\beta=0$, $v=\bar{v}$). ``$\uparrow$'' indicates the larger the better while ``$\downarrow$'' indicates the smaller the better. The best one in each column is bold.} 
\begin{tabular}{c|c|c|c|c|c|c|c|c}
\hline 
\textbf{Data} & \textbf{Method}                                             & \textbf{\makecell{Balanced \\ Accuracy $\uparrow$}} & \textbf{\makecell{Demographic \\ Parity $\downarrow$}} & \textbf{\makecell{Equalized \\ Odds $\downarrow$}} & \textbf{\makecell {Sufficiency \\ Gap $\downarrow$}} & \textbf{Recall 0 $\uparrow$} & \textbf{Recall 1 $\uparrow$} & \textbf{Time $\downarrow$} \\
\hline
\multirow{8}{*}{AD} & EIIL                                                        & .8639±.0199              & .0764±.0176               & .1015±.0529           & .1193±.0206            & .9288±.0119     & .7991±.0409     & 0:03:32       \\
\cline{2-9}
& FSCS                                                        & .8498±.0485              & .0711±.0287               & .1650±.1008           & .1254±.0528            & .9504±.0426     & .7493±.1018     & 0:08:05       \\
\cline{2-9}
& FAMS                                                        & .8369±.0136              & \textbf{.0431±.0210}               & .1444±.0435           & .1328±.0273            & .7624±.0077     & .9114±.0096     & 0:09:51       \\
\cline{2-9}
& ERM                                                         & .8687±.0136              & .0550±.0196               & .1143±.0390           & .1701±.0387            & \textbf{.9883±.0053}     & .7491±.0430     & \textbf{0:00:51}       \\
\cline{2-9}
& BERM                                                & .8886±.0042              & .0869±.0204               & .0813±.0129           & .1456±.0330            & .9854±.0043     & .7918±.0520     & 0:02:24       \\
\cline{2-9}
& FACIMS-\Romannum{2}   & .8839±.0079              & .0747±.0182               & .0868±.0130           & .1167±.0139            & .8456±.0148     & \textbf{.9222±.0043}     & 0:09:58       \\
\cline{2-9}
& FACIMS-\Romannum{1}                              & .8887±.0066              & .0893±.0080               & \textbf{.0450±.0049}           & .1059±.0060            & .8780±.0104     & .8994±.0148     & 0:13:38       \\
\cline{2-9}
& FACIMS                                                      & \textbf{.8897±.0098}              & .0765±.0208               & .0616±.0142           & \textbf{.1052±.0197}            & .8832±.0072     & .8962±.0054     & 0:15:26       \\
\hline
\hline
\multirow{8}{*}{CC} & EIIL                                                        & .6357±.0267              & .0834±.0200               & .1723±.0515           & .1266±.023             & .7897±.0176     & .4817±.0448     & 0:03:30       \\
\cline{2-9}
& FSCS                                                        & .5976±.0277              & .0850±.0137               & .2000±.0456           & .2007±.0039            & .8953±.0130     & .3000±.0685     & 0:42:10       \\
\cline{2-9}
& FAMS                                                        & .6542±.0098              & .0746±.0066               & .1859±.0368           & .1352±.0106            & .8194±.0374     & .4890±.0270     & 0:10:21       \\
\cline{2-9}
& ERM                                                         & .6104±.0111              & .0599±.0173               & \textbf{.1577±.0175}           & .2760±.0710            & \textbf{.9919±.0233}     & .2289±.0820     & \textbf{0:02:07}       \\
\cline{2-9}
& BERM                                                & .6570±.0106              & .1060±.0125               & .1631±.0304           & .2315±.0623            & .8717±.0191     & .4423±.0146     & 0:02:09       \\
\cline{2-9}
& FACIMS-\Romannum{2}  & .6446±.0163              & .0707±.0073               & .1973±.0358           & .1340±.0147            & .8002±.0374     & .4890±.0270     & 0:10:03       \\
\cline{2-9}
& FACIMS-\Romannum{1}                             & .6768±.0040              & .0750±.0105               & .1951±.0524           & .1396±.0081            & .8081±.0114     & .5455±.0098     & 0:14:07       \\
\cline{2-9}
& FACIMS                                                      & \textbf{.6799±.0374}              & \textbf{.0593±.0070}               & .1567±.0230           & \textbf{.1264±.0145}            & .8136±.0054     & \textbf{.5462±.0017}     & 0:14:18       \\
\hline
\end{tabular}
\end{table*}

\section{Experiments}\label{sec:exprim} 
% \hline
% Toxicity            & 448,000    & 768        & Toxic / Not Toxic       & 8\% / 92\%         & Race              & \begin{tabular}[c]{@{}c@{}}92.56\% / 2.05\% \\/~4.50\% / 0.68\% / 0.21\%\end{tabular}   \\
% \begin{itemize}
% \item [1] {\color{teal} \textbf{Sharpness }:}
% \item[2] {\color{teal} \textbf{Limited Samples}: Amazon review dataset, each client with only limited reviews. We assume the the relationship between sentiment and review text: will be similar across clients. For instance, if X = "good product," Y will be positive by 5 star or 4 star depending on the client.}
%     \item[3] {\color{teal} \textbf{Scalability with subgroup numbers:} How to use partial clients $\mc{A}' \subset \mc{A}$?  In the Amazon-review experiment, we can select $|\mc{A}|=200,\ldots 400$. We can further evaluate our method for different client numbers (appendix). The results can indicate that our method is consistently superior in different group numbers. 
%  subgroup predictors in the lower-level optimization. To address this, we could randomly sample a subset of subgroup 
%  at each training epoch and solve the bi-level objective.
%     \item[4] \textbf{Multiple protected attributes.} We are currently focusing on a single protected attribute such as gender. The algorithm could be practically extended to multi-protected attributes. Consider gender and race, then we can create subgroups solely on gender and race (i.e., group 1 = male, group 2 = female, group 3 = black, group 4 = white, with some overlaps between the subgroups).}
% \end{itemize}
\subsection{Experimental Setup} \label{sec5.1}
%
\paragraph{Datasets.} We applied our model to the Alzheimer’s disease (AD), credit card and drug consumption datasets, and the data information is summarized in Table \ref{tab:datainfo}.

\textbf{Alzheimer's Disease dataset}\footnote{\url{http://adni.loni.usc.edu}} were obtained from the Alzheimer's Disease Neuroimaging Initiative (ADNI) database \citep{weiner2017recent,shen2014bior}.
We included 5137 instances, including 4080 mild cognitive impairment (MCI, a prodromal stage of AD) and 1057 AD instances, to conduct the binary classification. Moreover, we chose race as the sensitive feature and divided the participants into four subgroups, where white subjects exceeding 90\%. Our features included 17 AD-related biomarkers, including cognitive scores, volumes of brain regions extracted from the magnetic resonance imaging (MRI) scans, amyloid and tau measurements from positron emission tomography (PET) scans and cerebrospinal fluid (CSF), and risk factors like APOE4 carriers and age.


\textbf{Credit Card dataset}\footnote{\url{https://archive.ics.uci.edu/ml/datasets/credit+approval}} %\citep{yeh2009comparisons} were downloaded from the UCI Machine Learning Repository \citep{Dua:2019}
contains 22 attributes like clients' basic information, history of payments, and bill statement amount to classify whether the clients are credible or not. We included 30000 instances with 6636 credible and 23365 not credible clients. We chose the education level as the sensitive feature where we observed more clients who graduated from university than other six levels.


\textbf{Drug Consumption dataset}\footnote{\url{https://archive.ics.uci.edu/dataset/373/drug+consumption+quantified}}  contains demographic information such as age, gender, and education level, as well as measures of personality traits thought to influence drug use for 1885 respondents. The task is to predict alcohol use with K = 4 categories (never used, not used in the past year, used in the past year, and used in the past day) for multi-class outcomes. The sensitive feature is education level (Left school before or at 16, Left school at 17-18, Some college, Certificate diploma, University degree, Masters, Doctorate ). The data information is summarized in Table 1 below. As can be seen, the class distribution shows that the dataset suffers from heavy label imbalance.



\paragraph{Baselines}
To validate the effectiveness of our method, FACIMS, we compare it with seven baseline methods.
\begin{itemize}[noitemsep,topsep=0pt,parsep=0pt,partopsep=0pt]
    \item EIIL~\citep{creager2021environment}: An Invariant Risk Minimization (IRM) based approach that promotes group sufficiency.
    \item FSCS~\citep{lee2021fair}: An approach that adopts the conditional mutual information constraint to improve group sufficiency.
    \item FAMS~\citep{shui2022learning}: A bilevel framework that considers maintaining both the accuracy and group sufficiency gap for multiple subgroups.
    \item ERM: Empirical Risk Minimization using a four-layer fully connected neural network trained with cross-entropy loss.
    \item BERM: ERM with a balanced cross-entropy loss, incorporating class proportions as weights similar to~\citep{cao2019learning}.
      \item FACIMS ($\beta=0$, $\m{v}=\bar{\m{v}}$): Our method without the lower level. Besides, in the upper level, we manually adjust the logits using the proportion of class~\citep{menon2020long,kini2021label} instead of learning the hyperparameter for logits adjustment.
    \item FACIMS ($\beta=0$): Our method without the lower level which aims to flatten the sharp landscape of the objective in the middle level.
\end{itemize}
%In order to make the comparisons fair, we use the same backbone model for all the methods, namely the four-layer fully connected neural network. 
We set $\alpha^{\rm up}$ and $\alpha^{\rm low}$ to be 0.7. We use the grid of [0.1, 0.01, 0.001] to search the learning rate for global model and local models and report the results over five independent repeats.
% \begin{table*}
% \setlength{\tabcolsep}{4pt}
% \small
% \centering
% \caption{Classification results (Mean ± Standard Deviation ) for 5 repeats of different methods on Credit dataset.}
% \begin{tabular}{c|c|c|c|c|c|c|c}
% \hline
% \textbf{Method}                                             & \textbf{\makecell{Balanced \\ Accuracy}} & \textbf{\makecell{Demographic \\ Parity}} & \textbf{\makecell{Equalized \\ Odds}} & \textbf{\makecell{Sufficiency \\ Gap}} & \textbf{Recall 0} & \textbf{Recall 1} & \textbf{Time} \\
% \hline
% EIIL                                                        & 0.6357±0.0267              & 0.0834±0.0200               & 0.1723±0.0515           & 0.1266±0.023             & 0.7897±0.0176     & 0.4817±0.0448     & 0:03:30       \\
% \hline
% FSCS                                                        & 0.5976±0.0277              & 0.0850±0.0137               & 0.2000±0.0456           & 0.2007±0.0039            & 0.8953±0.0130     & 0.3000±0.0685     & 0:42:10       \\
% \hline
% FAMS                                                        & 0.6542±0.0098              & 0.0746±0.0066               & 0.1859±0.0368           & 0.1352±0.0106            & 0.8194±0.0374     & 0.4890±0.0270     & 0:10:21       \\
% \hline
% ERM                                                         & 0.6104±0.0111              & 0.0599±0.0173               & 0.1577±0.0175           & 0.2760±0.0710            & 0.9919±0.0233     & 0.2289±0.0820     & 0:02:07       \\
% \hline
% BERM                                                & 0.6570±0.0106              & 0.1060±0.0125               & 0.1631±0.0304           & 0.2315±0.0623            & 0.8717±0.0191     & 0.4423±0.0146     & 0:02:09       \\
% \hline
% FACIMS-\Romannum{2}  & 0.6446±0.0163              & 0.0707±0.0073               & 0.1973±0.0358           & 0.1340±0.0147            & 0.8002±0.0374     & 0.4890±0.0270     & 0:10:03       \\
% \hline
% FACIMS-\Romannum{1}                             & 0.6768±0.0040              & 0.0750±0.0105               & 0.1951±0.0524           & 0.1396±0.0081            & 0.8081±0.0114     & 0.5455±0.0098     & 0:14:07       \\
% \hline
% FACIMS                                                      & 0.6799±0.0374              & 0.0593±0.0070               & 0.1567±0.0230           & 0.1264±0.0145            & 0.8136±0.0054     & 0.5462±0.0017     & 0:14:18       \\
% \hline
% \end{tabular}
% \end{table*}

% \begin{table*}[htb]
% %\resizebox{\linewidth}{!}{
% \setlength{\tabcolsep}{4pt}
% \small
% \centering
% \caption{\label{tab:results}Classification results (Mean ± Standard Deviation ) for 5 repeats of different methods and datasets. The best one in each column for each dataset is bold.}
% \begin{tabular}{c|l|c|c|c|c}
% \hline
% Dataset                  & \makecell[c]{Method}           & Balanced Accuracy & Demographic Parity & Equalized Odds & Group Sufficiency \\
% \hline
% \multirow{6}{*}{\makecell{Alzheimer's \\Disease}} & FAMS             & 0.8369±0.0136     & \textbf{0.0431±0.0210}      & 0.1444±0.0435  & 0.1328±0.0273   \\
%                          & ERM         & 0.8687±0.0136     & 0.0550±0.0196       & 0.1143±0.0390  & 0.1701±0.0387   \\
%                          & BERM     & 0.8886±0.0042     & 0.0869±0.0204      & 0.0813±0.0129  & 0.1456±0.0330   \\
%                          & FACIMS ($\beta=0$, $\m{v}=\bar{\m{v}}$)   & 0.8839±0.0079     & 0.0747±0.0182      & 0.0868±0.0130  & 0.1167±0.0139   \\
%                          & FACIMS ($\beta=0$) & 0.8887±0.0066     & 0.0893±0.0080      & \textbf{0.0450±0.0049}  & 0.1059±0.0060   \\
%                          & FACIMS    & \textbf{0.8897±0.0098}     & 0.0765±0.0208      & 0.0616±0.0142  & \textbf{0.1052±0.0197}   \\
%                          \hline
% \multirow{6}{*}{Credit Card}          & FAMS             & 0.6542±0.0098     & 0.0746±0.0066      & 0.1859±0.0368  & 0.1352±0.0106   \\
%                          & ERM         & 0.6104±0.0111     & 0.0599±0.0173      & 0.1577±0.0175  & 0.2760±0.0710   \\
%                          & BERM     & 0.6570±0.0106     & 0.1060±0.0125      & 0.1631±0.0304  & 0.2315±0.0623   \\
%                          & FACIMS ($\beta=0$, $\m{v}=\bar{\m{v}}$)   & 0.6446±0.0163     & 0.0707±0.0073      & 0.1973±0.0358  & 0.1340±0.0147   \\
%                          & FACIMS ($\beta=0$) & 0.6768±0.0040     & 0.0750±0.0105      & 0.1951±0.0524  & 0.1396±0.0081   \\
%                          & FACIMS     & \textbf{0.6799±0.0374}     & \textbf{0.0593±0.0070}      & \textbf{0.1567±0.0230}  & \textbf{0.1264±0.0145}  \\
%                          \hline
% \end{tabular}
% %}
% \end{table*}

\subsection{Experimental Results} \label{sec5.2}

\begin{figure*}[t]
  \centering
  \begin{subfigure}[b]{0.33\textwidth}
    \includegraphics[width=\textwidth]{AD_figure_6models.png}
    \subcaption{Alzheimer's Disease}
    \label{fig:subfiga}
  \end{subfigure}
  \begin{subfigure}[b]{0.33\textwidth}
    \includegraphics[width=\textwidth]{Credit_figure_6models.png}
    \subcaption{Credit Card}
    \label{fig:subfigb}
  \end{subfigure}
  \begin{subfigure}[b]{0.33\textwidth}
    \includegraphics[width=\textwidth]{Drug_figure_6models.png}
    \subcaption{Drug Consumption}
    \label{fig:subfigc}
  \end{subfigure}
  \caption{\label{fig:results} 
Boxplot comparing balanced accuracy and group sufficiency gap for three real datasets with 5 repeats. The mean is represented by the middle of each box, while the box width represents twice the standard deviation. Better performance is indicated by boxes located towards the bottom right (higher balanced accuracy and lower group sufficiency). Two FACIMS variants are excluded for clarity, with complete results available in the appendix.}
\end{figure*}

In this section, we analyze Alzheimer's disease and credit card datasets. The numerical results of the multi-class dataset drug consumption are included in the appendix due to page limits.

\paragraph{Balanced Accuracy and Sufficiency Gap} We primarily focus on balanced accuracy and group sufficiency gap as our main goals. Table~\ref{tab:results} shows that on the Alzheimer's disease dataset, our method FACIMS outperforms EIIL, FSCS, FAMS, and ERM in terms of balanced accuracy, with improvements of 2.6\%, 4.0\%, 5.3\%, and 2.1\% respectively. While BERM addresses the class imbalance issue and demonstrates a significant improvement over ERM by nearly 2\%, our method still achieves a higher balanced accuracy than BERM. Our method significantly improves the group sufficiency gap by 6.5\% and 4.0\% respectively, compared to ERM and BERM which do not address this issue. Although EIIL, FSCS, and FAMS specifically target the group sufficiency problem and achieve lower sufficiency gaps than ERM and BERM, our method still outperforms these three baseline methods by improving the sufficiency gap by 1.4\%, 2.0\%, and 2.8\% respectively.

Removing the lower level ($\beta=0$) leads to a slight decrease in balanced accuracy and group sufficiency gap as the objective landscape is not flattened in the middle level. Additionally, manually adjusting the logits instead of learning the hyperparameters (as in \citep{menon2020long}) further decreases the balanced accuracy and group sufficiency gap. However, our bilevel structure for addressing fairness ensures that the group sufficiency gap remains good despite these drops.

On credit card dataset, we have similar results. As for balanced accuracy, our method FACIMS improves the performance by 4.4\%, 8.2\%, 2.6\%, 7.0\%, and 2.3\% comapred to EIIL, FSCS, FAMS, ERM and BERM. When it comes to the group sufficiency gap, the performance of our method is improved by 0.2\%, 7.4\%, 0.8\%, 15\%, and 11\% comapred to the same baseline methods as metioned above. The performances of FACIMS ($\beta=0$) and FACIMS ($\beta=0,\textbf{v}=\bar{\textbf{v}}$) drop slightly regarding both measurements.




To provide a more intuitive visualization of the results, we present boxplots in Figure~\ref{fig:results}. Each axis represents a measurement, where the mean value is represented by the middle of the box and the box width corresponds to twice the length of the standard deviation. The model's performance is reflected by the position of the box, with improved performance observed towards the bottom right corner, indicating higher balanced accuracy and lower group sufficiency gap. For clarity, we have excluded two variants of our method from Figure~\ref{fig:results}. The complete figures can be found in the appendix. Figure~\ref{fig:results} highlights that our method is positioned towards the bottom right corner, indicating improved performance compared to other methods.

\paragraph{Results on Other Metrics} In addition to the result analysis on balanced accuracy and group sufficiency, we also report demographic parity, equalized odds, recall, and time in Table~\ref{tab:results}. The results show that our method achieves competitive results despite not outperforming all baselines in terms of demographic parity and equalized odds gaps. We emphasize that our method primarily addresses the group sufficiency gap for fairness, and it is challenging to optimize all three fairness measurements simultaneously, as discussed in Section~\ref{sec:preliminary}. When assessing a classifier's performance, it is important to achieve a high recall for each class. However, the average recall across all classes determines the balanced accuracy, highlighting the need for a balanced recall quantity across all classes. Our approach and its variations demonstrate a more balanced recall for each class, as illustrated in Table~\ref{tab:results}.

Comparing the time aspect, despite employing a complex tri-level optimization framework for training our model, the total runtime is not significantly longer than other fairness baselines. Indeed, utilizing differentiable bilevel methods in the large hyperparameter search provides substantial cost reduction and speedup compared to traditional approaches like grid search or random search. For instance, the first variant of our method, FACIMS ($\beta=0,v=\bar{v}$), runs in approximately 13 minutes. However, employing grid search or random search to tune the parametric loss would require significantly more time. For example, if we perform a search with five different settings to enhance the accuracy of FACIMS ($\beta=0,v=\bar{v}$), the total time would be $13 \min \times 5 = 65 \min$, which is around four times longer than our differentiable tri-level FACIMS approach.


\paragraph{Influence of $\alpha^{\rm low}$}
In the middle level, the parameter $\alpha^{\rm low}$ determines the attention given to $\textnormal{KL}(\m{Z}^a|\m{Z})$. A higher value of $\alpha^{\rm low}$ brings the local model closer to the global model, leading to improved group sufficiency gap but potentially worse balanced accuracy. We experimented with four different values of $\alpha^{\rm low}$: 0.01, 0.1, 0.2, and 1. Figure~\ref{fig:gen:bound} illustrates the Accuracy-$\textbf{SGap}_{f}$ curve under varying $\alpha^{\rm low}$ on the Alzheimer's disease dataset. The figure demonstrates a clear trend: as $\alpha^{\rm low}$ increases, both the balanced accuracy and group sufficiency gap decrease, aligning with our expectations. This analysis provides insight into how the KL divergence in the middle level influences the group sufficiency gap and balanced accuracy, enhancing our understanding of the framework's mechanism.
%
\begin{figure}
\includegraphics[width=0.43\textwidth]{lambda.png}
\caption{Accuracy-$\textbf{SGap}_{f}$ curve under different $\alpha^{\rm low}$ in Alzheimer's disease dataset.}
\label{fig:gen:bound}
%\vspace{-0.3cm}
\end{figure}
%
\section{Related Work}
\subsection{Long-tailed Learning} Re-sampling \citep{buda2018systematic} and Re-weighting \citep{he2009learning} are commonly used methods for training on imbalanced datasets. Recent studies focus on optimizing loss landscapes for class-imbalanced datasets \citep{khan2017cost, cao2019learning, menon2020long, ye2020identifying, li2021autobalance, kini2021label, behnia2023implicit, thrampoulidis2022imbalance}. Our work is related to the long-tail learning literature \citep{cao2019learning, menon2020long, ye2020identifying, kini2021label}, where authors propose refined class-balanced loss functions that better adapt to training data. These include the logit-adjusted loss \citep{menon2020long, cao2019learning}, the class-dependent temperature loss \citep{ye2020identifying}, and the VS loss \citep{kini2021label}, which unifies the concepts of multiplicative shift, additive shift, and loss re-weighting.

\subsection{Nested Optimziation} Nested optimization involves solving hierarchical problems with multiple levels of optimization \citep{colson2007overview,tarzanagh2022fednest,chen2021closing,ji2021bilevel,tarzanagh2022online}. Min-max nested optimization is commonly used to learn fair representations in the context of demographic parity or equalized odds \citep{zemel2013learning,song2019learning,zhao2019conditional}. Bi-level optimization and meta-learning algorithms have also been explored in the context of fair learning and classification \citep{shui2022fair,abbas2022sharp}. Recent advancements in differentiable algorithms have led to faster bilevel algorithms for learning hyperparameters and classification \citep{li2021autobalance,lorraine2020optimizing,tarzanagh2022fednest,chen2021closing,ji2021bilevel}. Building on \citep{li2021autobalance,abbas2022sharp}, we propose a theoretically justified \textit{tri-level} optimization perspective to control the group sufficiency gap and improve generalization performance across multiple subgroups with limited samples.

\subsection{Fairness} Group-sensitive learning aims to ensure fairness in the presence of under-represented groups~\citep{lin2023evaluate,zafar2017fairness,tarzanagh2021fair,chierichetti2017fair}. Our work mainly focuses on the fair notion of group sufficiency. This notion has recently been studied in the health of populations \citep{doi:10.1126/science.aax2342} and crime prediction \citep{chouldechova2017fair,pleiss2017fairness}. \citet{pmlr-v97-liu19f} show that under some assumptions, the group sufficiency can be controlled in unconstraint learning. On the other hand,  \citet{doi:10.1126/science.aax2342,Shui2022FairRL,koh2021wilds} claim that this conclusion may not always hold in the overparameterized models with limited samples per group.  \citet{subramanian2021fairness} provided a  method for fair and class-imbalanced learning.\citet{lee2021fair} proposed a bilevel objective approach to achieve fairness in predictive models across all groups. In contrast, our tri-level algorithm incorporates a Bayesian framework for imbalanced learning, considering both class imbalance and subgroup distribution within each class, while also employing a nested optimization akin to SAM to overcome saddle points for minority classes.

\section{Conclusions}
We studied fairness-aware class imbalanced learning on multiple subgroups (FACIMS) using a Bayesian-based optimization framework. Through extensive empirical and theoretical analysis, we demonstrated that FACIMS enhances the generalization performance of overparameterized models when dealing with limited samples per subgroup.
%
\section*{Acknowledgements}
%
This work was supported in part by the NIH grants U01 AG066833, RF1 AG063481, U01 AG068057, R01 LM013463 P30 AG073105, and U01 CA274576, and the NSF grant IIS 1837964.
%
Data used in this study were obtained from the Alzheimer's Disease Neuroimaging Initiative database (\url{adni.loni.usc.edu}), which was funded by NIH U01 AG024904. The authors Davoud Ataee Tarzanagh, Bojian Hou and Boning Tong  have contributed equally to this paper.

\bibliography{tarzanagh_561}



\end{document}
