\documentclass{midl} % Include author names
% \documentclass[anon]{midl} % Anonymized submission
% \usepackage[utf8]{inputenc} % allow utf-8 input
% \usepackage[T1]{fontenc}    % use 8-bit T1 fonts
\usepackage{hyperref}       % hyperlinks
\usepackage{url}            % simple URL typesetting
\usepackage{booktabs}       % professional-quality tables
\usepackage{nicefrac}       % compact symbols for 1/2, etc.
\usepackage{microtype}      % microtypography

\usepackage{amsmath}
\usepackage{amsfonts}
\usepackage{amssymb}
\usepackage{graphicx}
% \usepackage{subfig}
\usepackage{float}

\usepackage{gensymb}

% Optional math commands from https://github.com/goodfeli/dlbook_notation.
% \input{math_commands.tex}
%%%%% NEW MATH DEFINITIONS %%%%%

% \usepackage{amsmath,amsfonts,bm}

% Mark sections of captions for referring to divisions of figures
\newcommand{\figleft}{{\em (Left)}}
\newcommand{\figcenter}{{\em (Center)}}
\newcommand{\figright}{{\em (Right)}}
\newcommand{\figtop}{{\em (Top)}}
\newcommand{\figbottom}{{\em (Bottom)}}
\newcommand{\captiona}{{\em (a)}}
\newcommand{\captionb}{{\em (b)}}
\newcommand{\captionc}{{\em (c)}}
\newcommand{\captiond}{{\em (d)}}

% Highlight a newly defined term
\newcommand{\newterm}[1]{{\bf #1}}


% Figure reference, lower-case.
\def\figref#1{figure~\ref{#1}}
% Figure reference, capital. For start of sentence
\def\Figref#1{Figure~\ref{#1}}
\def\twofigref#1#2{figures \ref{#1} and \ref{#2}}
\def\quadfigref#1#2#3#4{figures \ref{#1}, \ref{#2}, \ref{#3} and \ref{#4}}
% Section reference, lower-case.
\def\secref#1{section~\ref{#1}}
% Section reference, capital.
\def\Secref#1{Section~\ref{#1}}
% Reference to two sections.
\def\twosecrefs#1#2{sections \ref{#1} and \ref{#2}}
% Reference to three sections.
\def\secrefs#1#2#3{sections \ref{#1}, \ref{#2} and \ref{#3}}
% Reference to an equation, lower-case.
\def\eqref#1{equation~\ref{#1}}
% Reference to an equation, upper case
\def\Eqref#1{Equation~\ref{#1}}
% A raw reference to an equation---avoid using if possible
\def\plaineqref#1{\ref{#1}}
% Reference to a chapter, lower-case.
\def\chapref#1{chapter~\ref{#1}}
% Reference to an equation, upper case.
\def\Chapref#1{Chapter~\ref{#1}}
% Reference to a range of chapters
\def\rangechapref#1#2{chapters\ref{#1}--\ref{#2}}
% Reference to an algorithm, lower-case.
\def\algref#1{algorithm~\ref{#1}}
% Reference to an algorithm, upper case.
\def\Algref#1{Algorithm~\ref{#1}}
\def\twoalgref#1#2{algorithms \ref{#1} and \ref{#2}}
\def\Twoalgref#1#2{Algorithms \ref{#1} and \ref{#2}}
% Reference to a part, lower case
\def\partref#1{part~\ref{#1}}
% Reference to a part, upper case
\def\Partref#1{Part~\ref{#1}}
\def\twopartref#1#2{parts \ref{#1} and \ref{#2}}

\def\ceil#1{\lceil #1 \rceil}
\def\floor#1{\lfloor #1 \rfloor}
\def\1{\bm{1}}
\newcommand{\train}{\mathcal{D}}
\newcommand{\valid}{\mathcal{D_{\mathrm{valid}}}}
\newcommand{\test}{\mathcal{D_{\mathrm{test}}}}

\def\eps{{\epsilon}}


% Random variables
\def\reta{{\textnormal{$\eta$}}}
\def\ra{{\textnormal{a}}}
\def\rb{{\textnormal{b}}}
\def\rc{{\textnormal{c}}}
\def\rd{{\textnormal{d}}}
\def\re{{\textnormal{e}}}
\def\rf{{\textnormal{f}}}
\def\rg{{\textnormal{g}}}
\def\rh{{\textnormal{h}}}
\def\ri{{\textnormal{i}}}
\def\rj{{\textnormal{j}}}
\def\rk{{\textnormal{k}}}
\def\rl{{\textnormal{l}}}
% rm is already a command, just don't name any random variables m
\def\rn{{\textnormal{n}}}
\def\ro{{\textnormal{o}}}
\def\rp{{\textnormal{p}}}
\def\rq{{\textnormal{q}}}
\def\rr{{\textnormal{r}}}
\def\rs{{\textnormal{s}}}
\def\rt{{\textnormal{t}}}
\def\ru{{\textnormal{u}}}
\def\rv{{\textnormal{v}}}
\def\rw{{\textnormal{w}}}
\def\rx{{\textnormal{x}}}
\def\ry{{\textnormal{y}}}
\def\rz{{\textnormal{z}}}

% Random vectors
\def\rvepsilon{{\mathbf{\epsilon}}}
\def\rvtheta{{\mathbf{\theta}}}
\def\rva{{\mathbf{a}}}
\def\rvb{{\mathbf{b}}}
\def\rvc{{\mathbf{c}}}
\def\rvd{{\mathbf{d}}}
\def\rve{{\mathbf{e}}}
\def\rvf{{\mathbf{f}}}
\def\rvg{{\mathbf{g}}}
\def\rvh{{\mathbf{h}}}
\def\rvu{{\mathbf{i}}}
\def\rvj{{\mathbf{j}}}
\def\rvk{{\mathbf{k}}}
\def\rvl{{\mathbf{l}}}
\def\rvm{{\mathbf{m}}}
\def\rvn{{\mathbf{n}}}
\def\rvo{{\mathbf{o}}}
\def\rvp{{\mathbf{p}}}
\def\rvq{{\mathbf{q}}}
\def\rvr{{\mathbf{r}}}
\def\rvs{{\mathbf{s}}}
\def\rvt{{\mathbf{t}}}
\def\rvu{{\mathbf{u}}}
\def\rvv{{\mathbf{v}}}
\def\rvw{{\mathbf{w}}}
\def\rvx{{\mathbf{x}}}
\def\rvy{{\mathbf{y}}}
\def\rvz{{\mathbf{z}}}

% Elements of random vectors
\def\erva{{\textnormal{a}}}
\def\ervb{{\textnormal{b}}}
\def\ervc{{\textnormal{c}}}
\def\ervd{{\textnormal{d}}}
\def\erve{{\textnormal{e}}}
\def\ervf{{\textnormal{f}}}
\def\ervg{{\textnormal{g}}}
\def\ervh{{\textnormal{h}}}
\def\ervi{{\textnormal{i}}}
\def\ervj{{\textnormal{j}}}
\def\ervk{{\textnormal{k}}}
\def\ervl{{\textnormal{l}}}
\def\ervm{{\textnormal{m}}}
\def\ervn{{\textnormal{n}}}
\def\ervo{{\textnormal{o}}}
\def\ervp{{\textnormal{p}}}
\def\ervq{{\textnormal{q}}}
\def\ervr{{\textnormal{r}}}
\def\ervs{{\textnormal{s}}}
\def\ervt{{\textnormal{t}}}
\def\ervu{{\textnormal{u}}}
\def\ervv{{\textnormal{v}}}
\def\ervw{{\textnormal{w}}}
\def\ervx{{\textnormal{x}}}
\def\ervy{{\textnormal{y}}}
\def\ervz{{\textnormal{z}}}

% Random matrices
\def\rmA{{\mathbf{A}}}
\def\rmB{{\mathbf{B}}}
\def\rmC{{\mathbf{C}}}
\def\rmD{{\mathbf{D}}}
\def\rmE{{\mathbf{E}}}
\def\rmF{{\mathbf{F}}}
\def\rmG{{\mathbf{G}}}
\def\rmH{{\mathbf{H}}}
\def\rmI{{\mathbf{I}}}
\def\rmJ{{\mathbf{J}}}
\def\rmK{{\mathbf{K}}}
\def\rmL{{\mathbf{L}}}
\def\rmM{{\mathbf{M}}}
\def\rmN{{\mathbf{N}}}
\def\rmO{{\mathbf{O}}}
\def\rmP{{\mathbf{P}}}
\def\rmQ{{\mathbf{Q}}}
\def\rmR{{\mathbf{R}}}
\def\rmS{{\mathbf{S}}}
\def\rmT{{\mathbf{T}}}
\def\rmU{{\mathbf{U}}}
\def\rmV{{\mathbf{V}}}
\def\rmW{{\mathbf{W}}}
\def\rmX{{\mathbf{X}}}
\def\rmY{{\mathbf{Y}}}
\def\rmZ{{\mathbf{Z}}}

% Elements of random matrices
\def\ermA{{\textnormal{A}}}
\def\ermB{{\textnormal{B}}}
\def\ermC{{\textnormal{C}}}
\def\ermD{{\textnormal{D}}}
\def\ermE{{\textnormal{E}}}
\def\ermF{{\textnormal{F}}}
\def\ermG{{\textnormal{G}}}
\def\ermH{{\textnormal{H}}}
\def\ermI{{\textnormal{I}}}
\def\ermJ{{\textnormal{J}}}
\def\ermK{{\textnormal{K}}}
\def\ermL{{\textnormal{L}}}
\def\ermM{{\textnormal{M}}}
\def\ermN{{\textnormal{N}}}
\def\ermO{{\textnormal{O}}}
\def\ermP{{\textnormal{P}}}
\def\ermQ{{\textnormal{Q}}}
\def\ermR{{\textnormal{R}}}
\def\ermS{{\textnormal{S}}}
\def\ermT{{\textnormal{T}}}
\def\ermU{{\textnormal{U}}}
\def\ermV{{\textnormal{V}}}
\def\ermW{{\textnormal{W}}}
\def\ermX{{\textnormal{X}}}
\def\ermY{{\textnormal{Y}}}
\def\ermZ{{\textnormal{Z}}}

% Vectors
\def\vzero{{\bm{0}}}
\def\vone{{\bm{1}}}
\def\vmu{{\bm{\mu}}}
\def\vtheta{{\bm{\theta}}}
\def\va{{\bm{a}}}
\def\vb{{\bm{b}}}
\def\vc{{\bm{c}}}
\def\vd{{\bm{d}}}
\def\ve{{\bm{e}}}
\def\vf{{\bm{f}}}
\def\vg{{\bm{g}}}
\def\vh{{\bm{h}}}
\def\vi{{\bm{i}}}
\def\vj{{\bm{j}}}
\def\vk{{\bm{k}}}
\def\vl{{\bm{l}}}
\def\vm{{\bm{m}}}
\def\vn{{\bm{n}}}
\def\vo{{\bm{o}}}
\def\vp{{\bm{p}}}
\def\vq{{\bm{q}}}
\def\vr{{\bm{r}}}
\def\vs{{\bm{s}}}
\def\vt{{\bm{t}}}
\def\vu{{\bm{u}}}
\def\vv{{\bm{v}}}
\def\vw{{\bm{w}}}
\def\vx{{\bm{x}}}
\def\vy{{\bm{y}}}
\def\vz{{\bm{z}}}

% Elements of vectors
\def\evalpha{{\alpha}}
\def\evbeta{{\beta}}
\def\evepsilon{{\epsilon}}
\def\evlambda{{\lambda}}
\def\evomega{{\omega}}
\def\evmu{{\mu}}
\def\evpsi{{\psi}}
\def\evsigma{{\sigma}}
\def\evtheta{{\theta}}
\def\eva{{a}}
\def\evb{{b}}
\def\evc{{c}}
\def\evd{{d}}
\def\eve{{e}}
\def\evf{{f}}
\def\evg{{g}}
\def\evh{{h}}
\def\evi{{i}}
\def\evj{{j}}
\def\evk{{k}}
\def\evl{{l}}
\def\evm{{m}}
\def\evn{{n}}
\def\evo{{o}}
\def\evp{{p}}
\def\evq{{q}}
\def\evr{{r}}
\def\evs{{s}}
\def\evt{{t}}
\def\evu{{u}}
\def\evv{{v}}
\def\evw{{w}}
\def\evx{{x}}
\def\evy{{y}}
\def\evz{{z}}

% Matrix
\def\mA{{\bm{A}}}
\def\mB{{\bm{B}}}
\def\mC{{\bm{C}}}
\def\mD{{\bm{D}}}
\def\mE{{\bm{E}}}
\def\mF{{\bm{F}}}
\def\mG{{\bm{G}}}
\def\mH{{\bm{H}}}
\def\mI{{\bm{I}}}
\def\mJ{{\bm{J}}}
\def\mK{{\bm{K}}}
\def\mL{{\bm{L}}}
\def\mM{{\bm{M}}}
\def\mN{{\bm{N}}}
\def\mO{{\bm{O}}}
\def\mP{{\bm{P}}}
\def\mQ{{\bm{Q}}}
\def\mR{{\bm{R}}}
\def\mS{{\bm{S}}}
\def\mT{{\bm{T}}}
\def\mU{{\bm{U}}}
\def\mV{{\bm{V}}}
\def\mW{{\bm{W}}}
\def\mX{{\bm{X}}}
\def\mY{{\bm{Y}}}
\def\mZ{{\bm{Z}}}
\def\mBeta{{\bm{\beta}}}
\def\mPhi{{\bm{\Phi}}}
\def\mLambda{{\bm{\Lambda}}}
\def\mSigma{{\bm{\Sigma}}}

% Tensor
\DeclareMathAlphabet{\mathsfit}{\encodingdefault}{\sfdefault}{m}{sl}
\SetMathAlphabet{\mathsfit}{bold}{\encodingdefault}{\sfdefault}{bx}{n}
\newcommand{\tens}[1]{\bm{\mathsfit{#1}}}
\def\tA{{\tens{A}}}
\def\tB{{\tens{B}}}
\def\tC{{\tens{C}}}
\def\tD{{\tens{D}}}
\def\tE{{\tens{E}}}
\def\tF{{\tens{F}}}
\def\tG{{\tens{G}}}
\def\tH{{\tens{H}}}
\def\tI{{\tens{I}}}
\def\tJ{{\tens{J}}}
\def\tK{{\tens{K}}}
\def\tL{{\tens{L}}}
\def\tM{{\tens{M}}}
\def\tN{{\tens{N}}}
\def\tO{{\tens{O}}}
\def\tP{{\tens{P}}}
\def\tQ{{\tens{Q}}}
\def\tR{{\tens{R}}}
\def\tS{{\tens{S}}}
\def\tT{{\tens{T}}}
\def\tU{{\tens{U}}}
\def\tV{{\tens{V}}}
\def\tW{{\tens{W}}}
\def\tX{{\tens{X}}}
\def\tY{{\tens{Y}}}
\def\tZ{{\tens{Z}}}


% Graph
\def\gA{{\mathcal{A}}}
\def\gB{{\mathcal{B}}}
\def\gC{{\mathcal{C}}}
\def\gD{{\mathcal{D}}}
\def\gE{{\mathcal{E}}}
\def\gF{{\mathcal{F}}}
\def\gG{{\mathcal{G}}}
\def\gH{{\mathcal{H}}}
\def\gI{{\mathcal{I}}}
\def\gJ{{\mathcal{J}}}
\def\gK{{\mathcal{K}}}
\def\gL{{\mathcal{L}}}
\def\gM{{\mathcal{M}}}
\def\gN{{\mathcal{N}}}
\def\gO{{\mathcal{O}}}
\def\gP{{\mathcal{P}}}
\def\gQ{{\mathcal{Q}}}
\def\gR{{\mathcal{R}}}
\def\gS{{\mathcal{S}}}
\def\gT{{\mathcal{T}}}
\def\gU{{\mathcal{U}}}
\def\gV{{\mathcal{V}}}
\def\gW{{\mathcal{W}}}
\def\gX{{\mathcal{X}}}
\def\gY{{\mathcal{Y}}}
\def\gZ{{\mathcal{Z}}}

% Sets
\def\sA{{\mathbb{A}}}
\def\sB{{\mathbb{B}}}
\def\sC{{\mathbb{C}}}
\def\sD{{\mathbb{D}}}
% Don't use a set called E, because this would be the same as our symbol
% for expectation.
\def\sF{{\mathbb{F}}}
\def\sG{{\mathbb{G}}}
\def\sH{{\mathbb{H}}}
\def\sI{{\mathbb{I}}}
\def\sJ{{\mathbb{J}}}
\def\sK{{\mathbb{K}}}
\def\sL{{\mathbb{L}}}
\def\sM{{\mathbb{M}}}
\def\sN{{\mathbb{N}}}
\def\sO{{\mathbb{O}}}
\def\sP{{\mathbb{P}}}
\def\sQ{{\mathbb{Q}}}
\def\sR{{\mathbb{R}}}
\def\sS{{\mathbb{S}}}
\def\sT{{\mathbb{T}}}
\def\sU{{\mathbb{U}}}
\def\sV{{\mathbb{V}}}
\def\sW{{\mathbb{W}}}
\def\sX{{\mathbb{X}}}
\def\sY{{\mathbb{Y}}}
\def\sZ{{\mathbb{Z}}}

% Entries of a matrix
\def\emLambda{{\Lambda}}
\def\emA{{A}}
\def\emB{{B}}
\def\emC{{C}}
\def\emD{{D}}
\def\emE{{E}}
\def\emF{{F}}
\def\emG{{G}}
\def\emH{{H}}
\def\emI{{I}}
\def\emJ{{J}}
\def\emK{{K}}
\def\emL{{L}}
\def\emM{{M}}
\def\emN{{N}}
\def\emO{{O}}
\def\emP{{P}}
\def\emQ{{Q}}
\def\emR{{R}}
\def\emS{{S}}
\def\emT{{T}}
\def\emU{{U}}
\def\emV{{V}}
\def\emW{{W}}
\def\emX{{X}}
\def\emY{{Y}}
\def\emZ{{Z}}
\def\emSigma{{\Sigma}}

% entries of a tensor
% Same font as tensor, without \bm wrapper
\newcommand{\etens}[1]{\mathsfit{#1}}
\def\etLambda{{\etens{\Lambda}}}
\def\etA{{\etens{A}}}
\def\etB{{\etens{B}}}
\def\etC{{\etens{C}}}
\def\etD{{\etens{D}}}
\def\etE{{\etens{E}}}
\def\etF{{\etens{F}}}
\def\etG{{\etens{G}}}
\def\etH{{\etens{H}}}
\def\etI{{\etens{I}}}
\def\etJ{{\etens{J}}}
\def\etK{{\etens{K}}}
\def\etL{{\etens{L}}}
\def\etM{{\etens{M}}}
\def\etN{{\etens{N}}}
\def\etO{{\etens{O}}}
\def\etP{{\etens{P}}}
\def\etQ{{\etens{Q}}}
\def\etR{{\etens{R}}}
\def\etS{{\etens{S}}}
\def\etT{{\etens{T}}}
\def\etU{{\etens{U}}}
\def\etV{{\etens{V}}}
\def\etW{{\etens{W}}}
\def\etX{{\etens{X}}}
\def\etY{{\etens{Y}}}
\def\etZ{{\etens{Z}}}

% The true underlying data generating distribution
\newcommand{\pdata}{p_{\rm{data}}}
% The empirical distribution defined by the training set
\newcommand{\ptrain}{\hat{p}_{\rm{data}}}
\newcommand{\Ptrain}{\hat{P}_{\rm{data}}}
% The model distribution
\newcommand{\pmodel}{p_{\rm{model}}}
\newcommand{\Pmodel}{P_{\rm{model}}}
\newcommand{\ptildemodel}{\tilde{p}_{\rm{model}}}
% Stochastic autoencoder distributions
\newcommand{\pencode}{p_{\rm{encoder}}}
\newcommand{\pdecode}{p_{\rm{decoder}}}
\newcommand{\precons}{p_{\rm{reconstruct}}}

\newcommand{\laplace}{\mathrm{Laplace}} % Laplace distribution

\newcommand{\E}{\mathbb{E}}
\newcommand{\Ls}{\mathcal{L}}
\newcommand{\R}{\mathbb{R}}
\newcommand{\emp}{\tilde{p}}
\newcommand{\lr}{\alpha}
\newcommand{\reg}{\lambda}
\newcommand{\rect}{\mathrm{rectifier}}
\newcommand{\softmax}{\mathrm{softmax}}
\newcommand{\sigmoid}{\sigma}
\newcommand{\softplus}{\zeta}
\newcommand{\KL}{D_{\mathrm{KL}}}
\newcommand{\Var}{\mathrm{Var}}
\newcommand{\standarderror}{\mathrm{SE}}
\newcommand{\Cov}{\mathrm{Cov}}
% Wolfram Mathworld says $L^2$ is for function spaces and $\ell^2$ is for vectors
% But then they seem to use $L^2$ for vectors throughout the site, and so does
% wikipedia.
\newcommand{\normlzero}{L^0}
\newcommand{\normlone}{L^1}
\newcommand{\normltwo}{L^2}
\newcommand{\normlp}{L^p}
\newcommand{\normmax}{L^\infty}

\newcommand{\parents}{Pa} % See usage in notation.tex. Chosen to match Daphne's book.

\DeclareMathOperator*{\argmax}{arg\,max}
\DeclareMathOperator*{\argmin}{arg\,min}

\DeclareMathOperator{\sign}{sign}
\DeclareMathOperator{\Tr}{Tr}
\let\ab\allowbreak




\usepackage{tikz}
\usetikzlibrary{shapes,arrows,fit}
\usetikzlibrary{positioning, shapes.geometric}
\tikzstyle{node} = [circle, minimum size = 10mm, thick, draw =black!80]
\tikzstyle{nodeobserved} = [circle, minimum size = 10mm, thick, draw =black!80, fill=gray!30]
\tikzstyle{semi1} = [circle, minimum size=10mm, draw =black!80, thick]
\tikzstyle{semi2} = [semicircle, fill=gray!30, minimum size=5mm, rotate=90]
\tikzstyle{box} = [rectangle, draw =black!0]
\tikzstyle{arrow} = [thick,->,>=stealth]
\tikzstyle{arrow2} = [dashed,->,>=stealth]
% The following packages will be automatically loaded:
% jmlr, amsmath, amssymb, natbib, graphicx, url, algorithm2e
% ifoddpage, relsize and probably more
% make sure they are installed with your latex distribution

\usepackage{mwe} % to get dummy images
\jmlrvolume{-- Under Review}
\jmlryear{2020}
\jmlrworkshop{Full Paper -- MIDL 2020 submission}
\editors{Under Review for MIDL 2020}

\title[DIVA: Domain Invariant Variational Autoencoders]{DIVA: Domain Invariant Variational Autoencoders}

 % Use \Name{Author Name} to specify the name.
 % If the surname contains spaces, enclose the surname
 % in braces, e.g. \Name{John {Smith Jones}} similarly
 % if the name has a "von" part, e.g \Name{Jane {de Winter}}.
 % If the first letter in the forenames is a diacritic
 % enclose the diacritic in braces, e.g. \Name{{\'E}louise Smith}

 % Two authors with the same address
 % \midlauthor{\Name{Author Name1} \Email{abc@sample.edu}\and
 %  \Name{Author Name2} \Email{xyz@sample.edu}\\
 %  \addr Address}

 % Three or more authors with the same address:
 % \midlauthor{\Name{Author Name1} \Email{an1@sample.edu}\\
 %  \Name{Author Name2} \Email{an2@sample.edu}\\
 %  \Name{Author Name3} \Email{an3@sample.edu}\\
 %  \addr Address}


% Authors with different addresses:
% \midlauthor{\Name{Author Name1} \Email{abc@sample.edu}\\
% \addr Address 1
% \AND
% \Name{Author Name2} \Email{xyz@sample.edu}\\
% \addr Address 2
% }

%\footnotetext[1]{Contributed equally}

% More complicate cases, e.g. with dual affiliations and joint authorship
\midlauthor{\Name{Maximilian Ilse\nametag{$^{1}$}} \Email{m.ilse@uva.nl}\\
\Name{Jakub M. Tomczak \nametag{$^{1}$}} \Email{j.m.tomczak@uva.nl}\\
\Name{Christos Louizos\nametag{$^{1}$}} \Email{c.louizos@uva.nl}\\
\Name{Max Welling\nametag{$^{1}$}} \Email{m.welling@uva.nl}\\
\addr $^{1}$ University of Amsterdam \\
}

% \author{Maximilian Ilse$^{\dagger}$, Jakub M. Tomczak$^{\dagger}$, Christos Louizos$^{\dagger,\ddagger}$ \& Max Welling$^{\dagger, c}$\\
% $^\dagger$Amsterdam Machine Learning Lab, University of Amsterdam\\
% $^\ddagger$TNO, Intelligent Imaging\\
% $^c$CIFAR\\
% \texttt{\{m.ilse,j.m.tomczak,c.louizos,m.welling\}@uva.nl} \\
% }

\begin{document}

\maketitle

\begin{abstract}
  We consider the problem of domain generalization, namely, how to learn representations given data from a set of domains that generalize to data from a previously unseen domain. We propose the Domain Invariant Variational Autoencoder (DIVA), a generative model that tackles this problem by learning three independent latent subspaces, one for the domain, one for the class, and one for any residual variations. We highlight that due to the generative nature of our model we can also incorporate unlabeled data from known or previously unseen domains. To the best of our knowledge this has not been done before in a domain generalization setting. This property is highly desirable in fields like medical imaging where labeled data is scarce. We experimentally evaluate our model on the rotated MNIST benchmark and a malaria cell images dataset where we show that (i) the learned subspaces are indeed complementary to each other, (ii) we improve upon recent works on this task and (iii) incorporating unlabelled data can boost the performance even further.
\end{abstract}

\begin{keywords}
generative model, representation learning, invariance, domain generalization
\end{keywords}

% --SECTION--
\section{Introduction}
% \input{introduction.tex}
Deep neural networks (DNNs) led to major breakthroughs in a variety of areas like computer vision and natural language processing. Despite their big success, recent research shows that DNNs learn the bias present in the training data. As a result they are not invariant to cues that are irrelevant to the actual task \cite{azulay_why_2018}. This leads to a dramatic performance decrease when tested on data from a different distribution with a different bias.

In domain generalization the goal is to learn representations from a set of similar distributions, here called domains, that can be transferred to a previously unseen domain during test time. A common motivating application, where domain generalization is crucial, is medical imaging \cite{blanchard_generalizing_2011, muandet_domain_2013}. For instance, in digital histopathology a typical task is the classification of benign and malignant tissue. However, the preparation of a histopathology image includes the staining and scanning of tissue which can greatly vary between hospitals. Moreover, a sample from a patient could be preserved in different conditions \cite{ciompi_importance_2017}. As a result, each patient's data could be treated as a separate domain \cite{lafarge_domain-adversarial_2017}. Another problem commonly encountered in medical imaging is class label scarcity. Annotating medical images is an extremely time consuming task that requires expert knowledge. However, obtaining domain labels is surprisingly cheap, since hospitals generally store information about the patient (e.g., age and sex) and the medical equipment (e.g., manufacturer and settings). Therefore, we are interested in extending the domain generalization framework to be able to deal with additional unlabeled data, as we hypothesize that it can help to improve performance. %. %We hypothesize that additional unlabeled data can lead to better domain generalization.

In this paper, we propose to tackle domain generalization via a new deep generative model that we refer to as the Domain Invariant Variational Autoencoder (DIVA). We extend the variational autoencoder (VAE) framework \cite{kingma_auto-encoding_2013,rezende2014stochastic} by introducing independent latent representations for a domain label $d$, a class label $y$ and any residual variations in the input $\rvx$. Such partitioning of the latent space will encourage and guide the model to disentangle these sources of variation. Finally, by virtue of having a generative model we can naturally handle the semi-supervised scenario, similarly to \citet{kingma_semi-supervised_2014}. We evaluate our model on a version of the MNIST dataset where each domain corresponds to a specific rotation angle of the digits, as well as on a Malaria Cell Images dataset where each domain corresponds to a different patient. An implementation of DIVA can be found under \url{https://github.com/AMLab-Amsterdam/DIVA}.

% --SECTION--
\section{DIVA: Domain Invariant VAE}
% \input{method.tex}


% \begin{itemize}
%     \item A domain is defined as a joint distribution $p(x,y)$ on $\mathcal{X}\times\mathcal{Y}$, where $\mathcal{X}$ denotes an input space and $\mathcal{Y}$ an output space\\
%     \item The training set consists of samples $\mathcal{S}$ taken from $N$ domains, $\mathcal{S} = \{S^{(d=i)}\}^N_{i=1}$\\
%     \item The $i$th domain $p^{(d=i)}(x,y)$ is represented by $n_i$ samples, $S^{(d=i)} = \{(x^{(d=i)}_k, y^{(d=i)}_k)\}^{n_i}_{k=1}$\\
%     \item We further assume that $p^{(d=i)}(x,y)$ $\neq$ $p^{(d=j)}(x,y)$, therefore, the samples in $\mathcal{S}$ are non-i.i.d.\\
%     \item During test time we are presented with samples $S^{(d=N+1)}$ from a previously unseen domain $p^{(d=N+1)}(x,y)$
% \end{itemize}




\label{sec:method}
Assuming a perfectly disentangled latent space \cite{higgins_towards_2018}, we hypothesize that there exists a latent subspace that is invariant to changes in $d$, i.e., it is domain invariant. We propose a generative model with three independent sources of variation; $\rvz_d$, which is domain specific, $\rvz_y$, which is class specific and finally $\rvz_x$, which captures any residual variations left in $\rvx$. While $\rvz_x$ keeps an independent Gaussian prior $p(\rvz_x)$, $\rvz_d$ and $\rvz_y$ have conditional priors $p_{\theta_d}(\rvz_d|d)$, $p_{\theta_y}(\rvz_y|y)$ with learnable parameters $\theta_d, \theta_y$. This will encourage information about the domain $d$ and label $y$ to be encoded into $\rvz_d$ and $\rvz_y$, respectively. Furthermore, as $\rvz_d$ and $\rvz_y$ are marginally independent by construction, we argue that the model will learn representations $\rvz_y$ that are invariant with respect to the domain $d$. All three of these latent variables are then used by a single decoder $p_\theta(\rvx|\rvz_d, \rvz_x, \rvz_y)$ for the reconstruction of $\rvx$.

Since we are interested in using neural networks to parameterize $p_\theta(\rvx|\rvz_d, \rvz_x, \rvz_y)$, exact inference will be intractable. For this reason, we perform amortized variational inference with an inference network \cite{kingma_auto-encoding_2013,rezende2014stochastic}, i.e., we employ a VAE-type framework. We introduce three separate encoders $q_{\phi_d}(\rvz_d|\rvx)$, $q_{\phi_x}(\rvz_x|\rvx)$ and $q_{\phi_y}(\rvz_y|\rvx)$ that serve as variational posteriors over the latent variables. Notice that we do not share their parameters as we empirically found that sharing parameters leads to a decreased generalization performance. For the prior and variational posterior distributions over the latent variables $\rvz_x, \rvz_d, \rvz_y$ we assume fully factorized Gaussians with parameters given as a function of their input. We coin the term Domain Invariant VAE (DIVA) for our overall model, which is depicted in Figure \ref{fig:graph_model} in the Appendix.

Given a specific dataset, all of the aforementioned parameters can be optimized by maximizing the following variational lower bound per input $\rvx$:
% \begin{align}
% \log p(x|y,d) \geq \mathcal{L}_s(d, x, y)
% \end{align}
\begin{align}
\mathcal{L}_s(d, \rvx, y) &= \mathbb{E}_{q_{\phi_d}(\rvz_d | \rvx)q_{\phi_x}(\rvz_x|\rvx),q_{\phi_y}(\rvz_y|\rvx)} \left[ \log p_\theta(\rvx|\rvz_d, \rvz_x, \rvz_y) \right] -\beta KL\left(q_{\phi_d}(\rvz_d|\rvx)||p_{\theta_d}(\rvz_d|d)\right) \nonumber \\
&- \beta KL\left(q_{\phi_x}(\rvz_x|\rvx)||p(\rvz_x)\right) 
-\beta KL\left(q_{\phi_y}(\rvz_y|\rvx)||p_{\theta_y}(\rvz_y|y)\right).
\end{align}
Notice that we have introduced a weigting term, $\beta$. This is motivated by the $\beta$-VAE~\cite{higgins_-vae:_2017} and serves as a constraint that controls the capacity of the latent spaces of DIVA. Larger values of $\beta$ limit the capacity of each $\rvz$ and in the ideal case each dimension of $\rvz$ captures one of the conditionally independent factors in $\rvx$.

% \subsection{$\beta$-VAE}
% The $\beta$-VAE \citep{higgins_-vae:_2017} is a modification of the VAE where a constraint is introduced that controls the capacity of the latent space of a VAE. The objective of the $\beta$-VAE is given by
% \begin{align}
%     \mathcal{L}(\theta, \phi, \beta, x, z) = \mathbb{E}_{q_\phi(z|x)} \lbrack \log p_\theta(x|z) \rbrack - \beta KL(q_\phi(z|x)||p(z)) .
% \end{align}
% Changing the value of $\beta$ results in different learned representations $z$. Larger values of $\beta$ limits the capacity of $z$ and in the ideal case each dimension of $z$ captures one of the conditionally independent factors in data. The $\beta$-VAE frameworks offers a trade-off between the information preservation, i.e., how well one can reconstruct $x$ from $z$, and the capacity, i.e., how well $z$ compresses information about $x$ \citep{chen_isolating_2018}.

% Assuming a perfectly disentangled latent space, we hypothesize that there exists a subspace that is invariant to changes in the $d$, i.e., that is domain invariant. But instead of separating the latent dimensions after training we want to guide the disentanglement towards a domain invariant subspace during training.

To further encourage separation of $\rvz_d$ and $\rvz_y$ into domain and class specific information respectively, we add two auxiliary objectives. During training $\rvz_d$ is used to predict the domain $d$ and $\rvz_y$ is used to predict the class $y$ for a given input $\rvx$:
\begin{align}
\mathcal{F}_{\text{DIVA}}(d, \rvx, y) := \mathcal{L}_s(d, \rvx, y) 
+ \alpha_d\mathbb{E}_{q_{\phi_d}(\rvz_d|\rvx)}\left[\log q_{\omega_d}(d|\rvz_d)\right]
+ \alpha_y\mathbb{E}_{q_{\phi_y}(\rvz_y|\rvx)}\left[\log q_{\omega_y}(y|\rvz_y)\right], \label{eq:aux}
\end{align}
where $\alpha_d$, $\alpha_y$ are weighting terms for each of these auxiliary objectives. Since our main goal is a domain invariant classifier, during inference we only use the encoder $q_{\phi_y}(\rvz_y|\rvx)$ and the auxiliary classifier $q_{\omega_y}(y|\rvz_y)$. For the prediction of the class $y$ for a new input $x$ we use the mean of $\rvz_y$. Consequently, we regard the variational lower bound $\mathcal{L}_s(d, \rvx, y)$ as a regularizer. Therefore, evaluating the marginal likelihood $p(\rvx)$ of DIVA is outside the scope of this paper.

% \subsection{Guided disentanglement}
\citet{locatello_challenging_2018} and \citet{dai_diagnosing_2019} claim that learning a disentangled representation, i.e., $q_{\phi}(\rvz) = \prod_i q_\phi(z_i)$, in an unsupervised fashion is impossible for arbitrary generative models. Inductive biases, e.g., some form of supervision or constraints on the latent space, are necessary to find a specific set of solutions that matches the true generative model. Consequently, DIVA is using domain labels $d$ and class labels $y$ in addition to input data $\rvx$ during training. Recent work by \citet{khemakhem_variational_2019} shows that conditional priors, like $p_{\theta_d}(z_d|d)$ and $p_{\theta_y}(z_y|y)$ in DIVA, lead to identifiability guarantees in VAEs. Furthermore, we enforce the factorization of the marginal distribution of $\rvz$ in the following form: $q_{\phi}(\rvz) = q_{\phi_d}(\rvz_d)q_{\phi_x}(\rvz_x)q_{\phi_y}(\rvz_y)$, which prevents the impossibility described in \citet{locatello_challenging_2018}. We argue that the strong inductive biases in DIVA make it possible to learn disentangled representations that match the ground truth factors of interest, namely, the domain factors $\rvz_d$ and class factors $\rvz_y$. To highlight the importance of a partitioned latent space we compare DIVA to a VAE with a single latent space, the results of this comparison can be found in the Appendix.

\subsection{Semi-supervised DIVA}
\label{sec:semisuper_diva}
In \citet{kingma_semi-supervised_2014} an extension to the VAE framework was introduced that allows to use labeled as well as unlabeled data during training. While \citet{kingma_semi-supervised_2014} introduced a two step procedure, \citet{louizos_variational_2015} presented a way of optimizing the decoder of the VAE and the auxiliary classifier jointly. We use the latter approach to learn from supervised data $\{(d_n, \rvx_n, y_n)\}$ as well as from unsupervised data $\{(d_m,\rvx_m)\}$. Analogically to \cite{louizos_variational_2015}, we use $q_{\omega_y}(y|\rvz_y)$ to impute $y$:
% \begin{align}
% \log p(x|y,d) \geq \mathcal{L}_u(d, x)
% \end{align}
\begin{align}
\mathcal{L}_u(d, \rvx) &= \mathbb{E}_{q_{\phi_d}(\rvz_d|\rvx)q_{\phi_x}(\rvz_x|\rvx)q_{\phi_y}(\rvz_y|\rvx)} \lbrack \log p_\theta(\rvx|\rvz_d, \rvz_x, \rvz_y) \rbrack 
-\beta KL(q_{\phi_d}(\rvz_d|\rvx)||p_{\theta_d}(\rvz_d|d)) \nonumber\\ &-\beta KL(q_{\phi_x}(\rvz_x|\rvx)||p(\rvz_x))
+\beta \mathbb{E}_{q_{\phi_y}(\rvz_y|\rvx)q_{\omega_y}(y|\rvz_y)}\lbrack \log p_{\theta_y}(\rvz_y|y) - \log q_{\phi_y}(\rvz_y|\rvx) \rbrack \nonumber\\ &+ \mathbb{E}_{q_{\phi_y}(\rvz_y|\rvx)q_{\omega_y}(y|\rvz_y)}\lbrack \log p(y) - \log q_{\omega_y}(y|\rvz_y) \rbrack,
\end{align}
where we use Monte Carlo sampling with the reparametrization trick~\cite{kingma_auto-encoding_2013} for the continuous latents $\rvz_d, \rvz_x, \rvz_y$ and explicitly marginalize over the discrete variable $y$.
% In order to prevent noisy gradients in the early stages of training, we replace the Monte Carlo sample $y \sim q_{\omega_y}(y|z_y)$ by marginalization over $y$ as proposed in \citep{kingma_semi-supervised_2014} that yields:
% \begin{align}
% &\mathbb{E}_{q_{\phi_y}(z_y|x)q_{\omega_y}(y|z_y)}\lbrack \log p_{\theta_y}(z_y|y) - \log q_{\phi_y}(z_y|x) \rbrack 
% \nonumber\\
% &= \mathbb{E}_{q_{\phi_y}(z_y|x)}\Big\lbrack \sum_y q_{\omega_y}(y|z_y) (\log p_{\theta_y}(z_y|y) - \log q_{\phi_y}(z_y|x)) \Big\rbrack , \\
% & \mathbb{E}_{q_{\phi_y}(z_y|x)q_{\omega_y}(y|z_y)}\lbrack \log p(y) - \log q_{\omega_y}(y|z_y) \rbrack \nonumber\\
% &= \mathbb{E}_{q_{\phi_y}(z_y|x)}\Big\lbrack \sum_y q_{\omega_y}(y|z_y) (\log p(y) - \log q_{\omega_y}(y|z_y)) \Big\rbrack .
% \end{align}
The final objective combines the supervised and unsupervised variational lower bound as well as the two auxiliary losses. Assuming $N$ labeled and $M$ unlabeled examples, we obtain the following objective:
\begin{align}
\mathcal{F}_{\text{SS-DIVA}} & = \sum_{n=1}^N \mathcal{F}_{\text{DIVA}}(d_n, \rvx_n, y_n) + \sum_{m=1}^M \mathcal{L}_u(d_m, \rvx_m)
+ \alpha_d \mathbb{E}_{q_{\phi_d}(\rvz_d|\rvx_m)} \lbrack \log q_{\omega_d}(d_m|\rvz_d) \rbrack.
\end{align}
% \subsection{DIVA: Distributions}
% According to the stochastic dependencies used in DIVA, we propose to use the following distributions:

% \begin{minipage}{.5\linewidth}
% \begin{align}
% p_\theta(x_{i}|z_d, z_x, z_y) &= Cat(x_{i}|\pi_{i}(z_d, z_x, z_y)) \\
% p_{\theta_d}(z_d|d) &= \mathcal{N}(z_d|\mu_d(d), \sigma_d(d))\\
% p(z_x) &= \mathcal{N}(z_x|0, I)\\
% p_{\theta_d}(z_y|y) &= \mathcal{N}(z_y|\mu_y(y), \sigma_y(y))\\
% p(d) &= Cat(d|\pi_{d})\\
% p(y) &=  Cat(y|\pi_{y})
% \end{align}
% \end{minipage}
% \begin{minipage}{.5\linewidth}
% \begin{align}
% q_{\phi_d}(z_d|x) &= \mathcal{N}(z_d|\mu_{q,d}(x), \sigma_{q,d}(x))\\
% q_{\phi_x}(z_x|x) &= \mathcal{N}(z_x|\mu_{q,x}(x), \sigma_{q,x}(x))\\
% q_{\phi_y}(z_y|x) &= \mathcal{N}(z_y|\mu_{q,y}(x), \sigma_{q,y}(x))\\
% q_{\omega_d}(d_j|z_d) &= Cat(d|\pi_{d,j}(z_d))\\
% q_{\omega_y}(y_k|z_y) &= Cat(y|\pi_{y,k}(z_y))
% \end{align}
% \end{minipage}

%\subsection{During test time}
% Since our main goal is a domain invariant classifier, during the test time we are left with the encoder $q_{\phi_y}(x|z_y)$ and the auxiliary classifier $q_{\omega_y}(y|z_y)$. For predicting the class $y$ of a new input $x$ we only use the mean of $z_y$.% The resulting model has a simple architecture with 7 convolutions and 2 fully connected layers.

% --SECTION--
\section{Related Work}
% \input{related_work.tex}
The majority of proposed deep learning methods for domain generalization fall into one of two categories: 1) Learning a single domain invariant representation, e.g., using adversarial methods \cite{carlucci_agnostic_2018, ghifary_domain_2015, li_domain_2018, li_learning_2017, motiian_unified_2017, shankar_generalizing_2018, wang_learning_2019}. While DIVA falls under this category there is a key difference: we do not explicitly regularize $\rvz_y$ using $d$. Instead we learn complementary representations $\rvz_d$, $\rvz_x$ and $\rvz_y$ utilizing a generative architecture. 2) Ensembling models, each trained on an individual domain from the training set \cite{ding_deep_2018, mancini_best_2018}. The size of models in this category scales linearly with the amount of training domains. This leads to slow inference if the number of training domains is large. However, the size of DIVA is independent of the number of training domains. In addition, during inference time we only use the mean of the encoder $q_{\phi_y}(\rvz_y|\rvx)$ and the auxiliary classifier $q_{\omega_y}(y|\rvz_y)$.

Concurrently to DIVA, \citet{cai_learning_2019} developed a framework to learn latent disentangled semantic representations (DSR) for domain adaptation. DSR assumes that the data generation process is exclusively controlled by the domain $d$ and class $y$. As a result, DSR is lacking a third latent space $z_x$. We designed DIVA assuming that not all variations in $x$ can be explained by the domain $d$ and the class $y$. Therefore we introduce $z_x$ in order to capture these residual variations. Furthermore, while DSR uses gradient reversal layers, we directly parameterize the ground truth generative model. As a result, the priors in DIVA are conditional which is necessary for guaranteed disentanglement as recent research has shown \cite{khemakhem_variational_2019}. More related work is published under the name of multiple source domain adaptation \cite{zhao_adversarial_2018}.

An area that is closely related to domain generalization is that of the statistical parity in fairness. The goal of fair classification is to learn a meaningful representation that at the same time cannot be used to associate a data sample to a certain group \cite{pmlr-v28-zemel13}. The major difference to domain generalization is the intention behind that goal, e.g., to protect groups of individuals versus being robust to variations in the input. Consequently, DIVA is closely related to the fair VAE \cite{louizos_variational_2015}. In contrast to the fair VAE, which is using a hierarchical latent space, DIVA is using a partitioned latent space. Moreover, the fair VAE requires the domain label during inference while DIVA alleviates this issue by learning the classifier without $d$. Similar to DIVA, there is an increasing number of methods showing the benefits of using latent subspaces in generative models \cite{siddharth_learning_2017, klys_learning_2018, jacobsen_excessive_2018, bouchacourt_multi-level_nodate, atanov_semi-conditional_2019, antoran_disentangling_2019}.

We derived DIVA by following the VAE framework,  where the generative process is the starting point.  A conditional version of the variational information bottleneck (CVIB) was proposed by \citet{moyer_invariant_2018} that likewise leads to an objective consisting of a reconstruction loss. However, CVIB suffers from the same limitation as the fair VAE: that the domain must be known during inference, hence, we excluded it from our experiments.

% --SECTION--
\section{Experiments}
% \input{experiments.tex}
% Qualitative and quantitative evaluation of the model rotated MNIST is used. First we show that the latent space of the model is indeed disentangled by visualizing. Conditional generation. Then we compare our method to other and show promising results. We evaluate how our model improves when unlabeled data is added. We compare two scenarios: . In both scenarios we show performance gains. Lastly, we the effect of varying of $beta_d$, $beta_x$ and $beta_y$ on disentanglement. Architecture, hyperparameters and training schedule are the same throughout the paper if not stated differently, see appendix for details.

% We evaluate the performance of DIVA on two datasets: Rotated MNIST \citep{ghifary_domain_2015} and Malaria Cell Images \citep{rajaraman_pre-trained_2018}. In both cases we first investigate if DIVA is able to successfully learn disentangled representations. Furthermore, we compare DIVA to other methods in a supervised and semi-supervised setting. While for the rotated MNIST dataset DIVA's graphical model is matching the ground truth generative model, the Malaria Cell Images dataset poses a more challenging and realistic scenario, where the ground truth generative model is unknown.

\subsection{Rotated MNIST}
\label{sec:rot_mnist}
The construction of the Rotated MNIST dataset follows \citet{ghifary_domain_2015}. We sample 100 images from each of the 10 classes from the original MNIST training dataset. This set of images is denoted $\mathcal{M}_{0\degree}$. To create five additional domains the images in $\mathcal{M}_{0\degree}$ are rotated by 15, 30, 45, 60 and 75 degrees. In order to evaluate their domain generalization abilities, models are trained on five domains and tested on the remaining 6th domain, e.g., train on $\mathcal{M}_{0\degree}$, $\mathcal{M}_{15\degree}$, $\mathcal{M}_{30\degree}$, $\mathcal{M}_{45\degree}$ and $\mathcal{M}_{60\degree}$, test on $\mathcal{M}_{75\degree}$.
The evaluation metric is the classification accuracy on the test domain. All experiments are repeated 10 times. Detailed information about hyperparameters, architecture and training schedule can be found in the Appendix.




% Figure \ref{fig:train_zy} shows 5000 embeddings $z_y$ encoded by $q_{\phi_y}(z_y|x)$. In Figure \ref{fig:train_zy} (left) the colors indicate the 10 different classes of the MNIST dataset. We observe 10 well separated clusters, each corresponding to one of the 10 classes. In stark contrast, Figure \ref{fig:train_zy} (right), where the colors indicate the five different training domains, shows no such clustering. It appears that $z_y$ is indeed capturing all necessary information for predicting the class $y$ while containing very little information about the domain $d$.

% \begin{figure}[h]
% \begin{minipage}{.5\textwidth}
% \centering
% \includegraphics[scale=0.37]{mnist_plots/train_zy_class.png}
% \end{minipage}
% \begin{minipage}{.5\textwidth}
% \centering
% \includegraphics[scale=0.37]{mnist_plots/train_zy_domain.png}
% \end{minipage}
% \caption{Left: 5000 two-dimensional embeddings $z_y$ encoded by $q_{\phi_y}(z_y|x)$. The color of each point indicates the associated class. Both dimensions of $z_y$ are used to encode the label. Right: Plot of the same embeddings as seen on the left. This time the color indicates the associated domain. No apparent clustering is visible.}
% \label{fig:train_zy}
% \end{figure}

% In Figure \ref{fig:train_zd} we visualize the two-dimensional latent space for $z_d$. Each of the 5000 training images is encoded by $q_{\phi_d}(z_d|x)$. Figure \ref{fig:train_zd} (left) shows very little clustering according to the class label $y$. However, Figure \ref{fig:train_zd} (right), where each color represents a different domain, shows strong clustering. Each cluster corresponds to one of the five domains.



% \begin{figure}[h]
% \begin{minipage}{.5\textwidth}
% \centering
% \includegraphics[scale=0.37]{mnist_plots/train_zd_class.png}
% \end{minipage}
% \begin{minipage}{.5\textwidth}
% \centering
% \includegraphics[scale=0.37]{mnist_plots/train_zd_domain.png}
% \end{minipage}
% \caption{Left: 5000 two-dimensional embeddings $z_d$ encoded by $q_{\phi_d}(z_d|x)$. The color of each point indicates the associated class. No apparent clustering is visible. Right: Plot of the same embeddings as seen on the left. This time the color indicates the associated domain. The plot shows five distinct clusters, where each cluster corresponds to a single domain.}
% \label{fig:train_zd}
% \end{figure}

% In contrast to $z_y$ and $z_d$, $z_x$ is only used for reconstructing $x$ and has an independent Gaussian prior. In Figure \ref{fig:train_zx} (left) we can see that there is a certain amount of clustering. We find that narrow numbers with thin lines, e.g., '1' and '7', cluster in the bottom left half of the plot. Whereas round numbers with thicker lines, e.g.,  '0' and '6' appear to cluster in the top right half. We conclude that $z_x$ models the remaining variations that are not captured by $z_d$ and $z_y$. In Figure \ref{fig:train_zx} (right) we do not notice any apparent clustering.

% \begin{figure}[h]
% \begin{minipage}{.5\textwidth}
% \centering
% \includegraphics[scale=0.385]{mnist_plots/train_zx_class.png}
% \end{minipage}
% \begin{minipage}{.5\textwidth}
% \centering
% \includegraphics[scale=0.385]{mnist_plots/train_zx_domain.png}
% \end{minipage}
% \caption{Left: 5000 two-dimensional embeddings $z_x$ encoded by $q_{\phi_x}(z_x|x)$. The color of each point indicates the associated class. We observe a high correlation between the line thickness of each MNIST digit and $z_x$[0]. Whereas, $z_x[1]$ is highly correlated with the width of each MNIST image. For that reason we observe a clustering of embeddings with class '1' at the lower half of the plot. Right: Plot of the same embeddings as seen on the left. This time the color indicates the associated domain. No apparent clustering is visible.}
% \label{fig:train_zx}
% \end{figure}
First of all, we visualize the three latent spaces $\rvz_d$, $\rvz_x$ and $\rvz_y$, to see if DIVA is able to successfully disentangle them. In addition, we want to verify whether DIVA utilizes $\rvz_x$ in a meaningful way, since it is not directly connected to any downstream task. We restrict the size of each latent space $\rvz_d$, $\rvz_x$ and $\rvz_y$ to 2 dimensions. Therefore, we can plot the latent subspaces directly without applying dimensionality reduction, see Figure \ref{fig:embeddings_mnist} in the Appendix. Here, we trained
DIVA on 5000 images from five domains: $\mathcal{M}_{0\degree}$, $\mathcal{M}_{15\degree}$, $\mathcal{M}_{30\degree}$, $\mathcal{M}_{45\degree}$ and $\mathcal{M}_{60\degree}$.
From these initial qualitative results we conclude that DIVA is disentangling the information contained in $\rvx$ as intended, as $\rvz_d$ is only containing information about $d$ and $\rvz_y$ only information about $y$. In the case of the Rotated MNIST dataset $\rvz_x$ captures any residual variation that is not explained by the domain $d$ or the class $y$. In addition, we are able to generate conditional reconstructions as well as entirely new samples with DIVA. We provide these in the Appendix. %Furthermore, we are able to generate entirely new samples using DIVA, see Appendix.

\label{sec:comparison}
We compare DIVA against the well known domain adversarial neural networks (DA) \cite{ganin_domain-adversarial_2015} as well as three recently proposed methods: LG \cite{shankar_generalizing_2018}, HEX \cite{wang_learning_2019} and ADV \cite{wang_learning_2019}. For the first half of Table \ref{tab:comparison} (until the vertical line) we only use labeled data. The first column indicates the rotation angle of the test domain. We report test accuracy on $y$ for all methods. For DIVA we report the mean and standard error for 10 repetitions. DIVA achieves the highest accuracy across all test domains and the highest average test accuracy among all proposed methods.

\setlength{\tabcolsep}{0.2cm}
\begin{table}[t]
\caption{Comparison with other state-of-the-art domain generalization methods. Methods in the first half of the table (until the vertical line) use only labeled data. The second half of the table shows results of DIVA when trained semi-supervised (+ X times the amount of unlabeled data). We report the average and standard error of the classification accuracy.}
\small
\begin{tabular}{cccccc|cccc}
Test &DA &LG &HEX &ADV & DIVA & DIVA(+1) & DIVA(+3) & DIVA(+5) & DIVA(+9)\\
\hline
$\mathcal{M}_{0\degree}$  &86.7 &89.7 &90.1 &89.9 &\bf{93.5} $\pm$ 0.3 &93.8 $\pm$ 0.4 & 93.9 $\pm$ 0.5 & 93.2 $\pm$ 0.5 & 93.0 $\pm$ 0.4\\
$\mathcal{M}_{15\degree}$ &98.0 &97.8 &98.9 &98.6 &\bf{99.3} $\pm$ 0.1 &99.4 $\pm$ 0.1 & 99.5 $\pm$ 0.1 & 99.5 $\pm$ 0.1 & 99.6 $\pm$ 0.1\\
$\mathcal{M}_{30\degree}$ &97.8 &98.0 &98.9 &98.8 &\bf{99.1} $\pm$ 0.1 &99.3 $\pm$ 0.1 & 99.3 $\pm$ 0.1 & 99.3 $\pm$ 0.1 & 99.3 $\pm$ 0.1\\
$\mathcal{M}_{45\degree}$ &97.4 &97.1 &98.8 &98.7 &\bf{99.2} $\pm$ 0.1 &99.0 $\pm$ 0.2 & 99.2 $\pm$ 0.1 & 99.3 $\pm$ 0.1 & 99.3 $\pm$ 0.1\\
$\mathcal{M}_{60\degree}$ &96.9 &96.6 &98.3 &98.6 &\bf{99.3} $\pm$ 0.1 &99.4 $\pm$ 0.1 & 99.4 $\pm$ 0.1 & 99.4 $\pm$ 0.1 & 99.2 $\pm$ 0.2\\
$\mathcal{M}_{75\degree}$ &89.1 &92.1 &90.0 &90.4 &\bf{93.0} $\pm$ 0.4 &93.8 $\pm$ 0.4 & 93.8 $\pm$ 0.2 & 93.5 $\pm$ 0.4 & 93.2 $\pm$ 0.3\\
\hline
Avg &94.3 &95.3 &95.8 & 95.2 &\bf{97.2} $\pm$ 1.3 &97.5 $\pm$ 1.1 & 97.5 $\pm$ 1.2 & 97.4 $\pm$ 1.3 & 97.3 $\pm$ 1.3\\
\end{tabular}
\label{tab:comparison}
\end{table}

The second half of Table \ref{tab:comparison} showcases the ability of DIVA to use unlabeled data. For this experiment we add: The same amount (+1) of unlabeled data as well as three (+3), five (+5) and nine (+9) times the amount of unlabeled data to our training set. We first add the unlabeled data to $\mathcal{M}_{0\degree}$ and create the data for the other domains. In Table \ref{tab:comparison} we can clearly see a performance increase when unlabeled data is added to the training set. When the amount of unlabeled data is much larger than the amount of labeled data the balancing of loss terms become increasingly more challenging which can lead to a declining performance of DIVA, as seen in the last two columns of Table \ref{tab:comparison}.

\label{sec:mnist_semi}
In the experiment described above each training domain consists of labeled and unlabeled examples. Now we investigate a more challenging scenario: We add an additional domain to our training set that consists of only unlabeled examples. Coming back to our introductory example of medical imaging, here we would add unlabeled data from a new patient or new hospital to the training set.

In the following, we are looking at two different experimental setups, in both cases $\mathcal{M}_{75\degree}$ is the test domain:
For the first experiment we choose the domains $\mathcal{M}_{0\degree}$, $\mathcal{M}_{15\degree}$, $\mathcal{M}_{45\degree}$ and $\mathcal{M}_{60\degree}$ to be part of the labeled training set. In addition, unlabeled data from $\mathcal{M}_{30\degree}$ is used. We find that even in the case where the additional domain is dissimilar to the test domain DIVA is able to slightly improve, see Table \ref{tab:additional_domains_main}. For the second experiment we choose the domains $\mathcal{M}_{0\degree}$, $\mathcal{M}_{15\degree}$, $\mathcal{M}_{30\degree}$ and $\mathcal{M}_{45\degree}$ to be part of the labeled training set. In addition, unlabeled data from $\mathcal{M}_{60\degree}$ is used. When comparing with the results in Table \ref{tab:comparison} we notice a drop in accuracy of about 20$\%$ for DIVA trained with only labeled data. However, when trained with unlabeled data from $\mathcal{M}_{60\degree}$ we see an improvement of about 7$\%$, see Table \ref{tab:additional_domains_main}. The comparison shows that DIVA can successfully learn from samples of a domain without any labels.
\begin{table}[h]
\small
\caption{Comparison of DIVA trained supervised to DIVA trained semi-supervised with additional unlabeled data from $\mathcal{M}_{30\degree}$ and $\mathcal{M}_{60\degree}$. We report the average and standard error of the classification accuracy on $\mathcal{M}_{75\degree}$.}
\begin{center}
% \vspace{-0.5cm}
\begin{tabular}{ccc}
Unsupervised domain & DIVA supervised & DIVA semi-supervised\\
\hline
$\mathcal{M}_{30\degree}$ & 93.1 $\pm$ 0.5   &  93.3 $\pm$ 0.4 \\
$\mathcal{M}_{60\degree}$ & 73.8 $\pm$ 0.8 & 80.6 $\pm$ 1.1
\end{tabular}
\end{center}
\label{tab:additional_domains_main}
\end{table}
% \begin{table}[h]
% \caption{We compare DIVA trained with only labeled data to DIVA trained with additional unlabeled data from $\mathcal{M}_{30\degree}$. We report the average and standard error of the classification accuracy on $\mathcal{M}_{75\degree}$.}
% \begin{center}
% \begin{tabular}{ccc}
% Test &Only labeled data &Additional unlabeled\\
% \hline
% $\mathcal{M}_{75\degree}$ &93.1 $\pm$ 0.2  & 93.3 $\pm$ 0.1\\
% \end{tabular}
% \end{center}
% \label{tab:dissimilar}
% \end{table}


% \vspace{-0.5cm}
\subsection{Malaria Cell Images}
The majority of medical imaging datasets consist of images from a multitude of patients. In a domain generalization setting each patient is viewed as an individual domain. While we focus on \emph{patients as domains} in this paper, this type of reasoning can be extended to, e.g., \emph{hospitals as domains}. 
We, among others \cite{rajaraman_pre-trained_2018, lafarge_domain-adversarial_2017}, argue that machine learning algorithms trained with medical imaging datasets should be evaluated on a subset of hold-out patients. This presents a more realistic scenario since the algorithm is tested on images from a previously unseen domain.
In the following, we use a Malaria Cell Images dataset \cite{rajaraman_pre-trained_2018} as an example of a dataset consisting of samples from multiple patients. The images in this dataset were collected and photographed at Chittagong Medical College Hospital, Bangladesh. It consists of 27558 single red blood cell images taken from 150 infected and 50 healthy patients. The images were manually annotated by a human expert. A cell has the label $y=1$ if it shows the parasite and the label $y=0$ if not. To facilitate the counting of parasitized and uninfected cells, the cells were stained using Giemsa stain which turns the parasites inside the cell pink. In addition, the staining process leads to a variety of colors of the cell itself. While the color of the cell is relatively constant for a single patient, it can vary greatly between patients, see the first row in Figure \ref{fig:recon_malaria}. This variability in appearance of the cells can be easily ignored by a human observer, however, machine learning models can fail to generalize across patients. In our experiments we will use the patient ID as the domain label $d$. We argue that for this specific dataset the patient ID is a good proxy of appearance variability. In addition, there is no extra cost for obtaining the patient ID for each cell.

Subsequently, we use a subset of the Malaria Cell Images dataset that consists of the 10 patients with the highest number of cells. The amount of cells per patient varies between 400 and 700 and there are 5922 cell images in total. The choice of this subset is motivated by the similiar amount of cells as well as the similar marginal label distributions per patient, the latter being a necessary condition for successful domain generalization \cite{zhao_learning_2019}. Furthermore we rescale all images to 64 $\times$ 64 pixels. To artificially expand the size of the training dataset we use data augmentation in the form of vertical flips, horizontal flips and random rotations.

% : If the marginal label distributions of the train domains and test domain are significantly different among each other, learning an invariant distribution with a low training classification error will only increase the test error

% \subsubsection{Qualitative disentanglement}
We investigate the three latent subspaces $\rvz_d$, $\rvz_x$ and $\rvz_y$ to see if DIVA is able to successfully disentangle them. In addition, we want to see if DIVA utilizes $\rvz_x$ in a meaningful way, since it is not directly connected to any downstream task. Figure \ref{fig:recon_malaria} shows the reconstructions of $\rvx$ using all three latent subspaces as well as reconstructions of $\rvx$ using only a single latent subspace at a time. First, we find that DIVA is able to reconstruct the original cell images using all three subspaces (Figure \ref{fig:recon_malaria}, second row). Second, we find that the three latent subspaces are indeed disentangled: $\rvz_d$ is containing the color of the cell (Figure \ref{fig:recon_malaria}, third row), $\rvz_x$ the shape of the cell (Figure \ref{fig:recon_malaria}, fourth row) and $\rvz_y$ the location of the parasite (Figure \ref{fig:recon_malaria}, fifth row). The holes in the reconstructions using only $\rvz_x$ indicate that there is no probability mass in $\rvz_d$ and $\rvz_y$ at 0, similar to Figure \ref{fig:embeddings_mnist}. From the reconstructions in Figure \ref{fig:recon_malaria} we conclude that DIVA is able to learn disentangled representations that match the ground truth factors of interest, here, the appearance of the cell and the presence of the parasite. In addition to these qualitative results, we show that a classifier for $y$ trained on $\rvz_d$ or $\rvz_x$ performs worse than a classifier that would always predict the majority class, the results can be found in the Appendix.
\begin{figure}[h]
\begin{minipage}{.77\textwidth}
 \includegraphics[scale=0.36]{reconstructions_10_in_one_plot_beta_1_64_50000.png} 
% \hspace*{0.8em} 0 \hspace{1.15em} 1 \hspace{1.2em} 0 \hspace{1.15em} 1 \hspace{1.2em} 0 \hspace{1.15em} 1 \hspace{1.2em} 1 \hspace{1.2em} 0 \hspace{1.15em} 1 \hspace{1.2em} 1 \hspace{1.2em} 1 \hspace{1.2em} 1 \hspace{1.2em} 1 \hspace{1.2em} 0
 \end{minipage}
 \begin{minipage}{.22\textwidth}
\vspace{0.85cm}$\rvx$ \vspace{0.35cm}\\
$\rvx \sim p_\theta(\rvx|\rvz_d, \rvz_x, \rvz_y)$\vspace{0.35cm}\\
$\rvx \sim p_\theta(\rvx|\rvz_d, 0, 0)$\vspace{0.35cm}\\
$\rvx \sim p_\theta(\rvx|0, \rvz_x, 0)$\vspace{0.35cm}\\
$\rvx \sim p_\theta(\rvx|0, 0, \rvz_y)$\vspace{0.35cm}\\

 \end{minipage}
\caption{Reconstructions of $\rvx$ using all three latent subspaces as well as reconstructions of $\rvx$ using only a single latent subspace at a time.}
\label{fig:recon_malaria}
\end{figure}

% \subsubsection{Semi-supervised case}
To further evaluate domain generalization abilities, models are trained on nine domains(patient IDs) and tested on the remaining 10th domain.  We choose ROC AUC on the holdout test domain as the evaluation metric, since the two classes are highly imbalanced.  All experiments are repeated five times. We compare DIVA with a ResNet-like \cite{he_deep_2015} baseline and DA. During inference all three models have the same architecture, seven ResNet blocks followed by two linear layers. Detailed information about hyperparameters, architecture and training schedule can be found in the Appendix. We find that the results are not equally distributed across all test domains. In five cases DIVA is able to significantly improve upon the baseline model and DA. However, averaged over all domains none of the three methods performs significantly better than the others, see Table \ref{tab:malaria_super_1} and \ref{tab:malaria_super_2} in the Appendix.

\label{sec:malaria_semi}
As described in Section \ref{sec:semisuper_diva} we are interested in learning from domains with no class labels, since such an approach can drastically lower the amount of labeled data needed to learn a domain invariant representation, i.e., a model that generalizes well across patients. For the semi-supervised experiments we choose domain C116P77 to be the test domain since its cells show a unique dark pink stain. Furthermore, unlabeled data from  domain C59P20 is used since it is visually the closest to domain C116P77, see Appendix. The evaluation metric on the hold out test domain is ROC AUC again. In Table \ref{tab:malaria_semi} we compare the baseline model, DA and DIVA trained with labeled data from domain C59P20, unlabeled data from domain C59P20 and no data from domain C59P20. We argue that the improvement of DIVA over DA arises from the way the additional unlabeled data is utilized. In case of DA the unlabeled data ($d$, $\rvx$) is only used to train the domain classifier and the feature extractor in an adversarial manner.  In Section \ref{sec:semisuper_diva} we show that due to DIVA's generative nature $q_{\phi_y}(\rvz_y|\rvx)$, $p_{\theta_y}(\rvz_y|y)$ can be updated using $q_{\omega_y}(y|\rvz_y)$ to marginalize over $y$ for an unlabeled sample $\rvx$. In addition, the unlabeled data ($d$, $\rvx$) is used to update $q_{\phi_d}(\rvz_d|\rvx)$, $p_{\theta_d}(\rvz_d|d)$, $q_{\omega_d}(d|\rvz_d)$, $q_{\phi_x}(\rvz_x|\rvx)$ and $p_{\theta}(\rvx|\rvz_d, \rvz_x, \rvz_y)$ in the same way as in the supervised case.

% \vspace{-1cm}
\begin{table}[h]
\caption{Results of the semi-supervised experiments for domain C116P77. Comparison of baseline method, DA and DIVA trained with labeled data from domain C59P20, unlabeled data from domain C59P20 and no data from domain C59P20. We report the average and standard error of ROC AUC.}
\centering
\small
\begin{tabular}{cccc}
Training data & Baseline & DA & DIVA\\
\hline
Labeled data from C59P20 & 90.6 $\pm$ 0.7 & 90.6 $\pm$ 1.7  & \bf{93.3} $\pm$ 0.4\\
Unlabeled data from C59P20 & - & 72.05 $\pm$ 2.2 & \bf{79.4} $\pm$ 2.8\\
No data from C59P20 & 70.0 $\pm$ 2.6 & 69.2 $\pm$ 1.9 & 71.9 $\pm$ 2.7
\end{tabular}
\label{tab:malaria_semi}
\end{table}

% As described in Section \ref{sec:mnist_semi} we are interested in learning from domains with no class labels, since such an approach can drastically lower the amount of labeled data needed to learn a domain invariant representation, i.e., a model that generalizes well across patients. For the semi-supervised experiments we randomly select one, two, three, four and five domains from the nine training domains. The samples from the selected domains are subsequently treated as unlabeled. Samples from supervised domains are presented as ($d$, $x$, $y$) whereas samples from unsupervised domains are presented as ($d$, $x$). Note that in the case of five unsupervised domains more than 50$\%$ of the training dataset is unlabeled. In all five experiments we choose the test domain to be domain '1', since all three methods show very similar results on this domain across all three metrics in a supervised setting, see Figure \ref{fig:malaria_super}. All experiments are repeated five times. Each time another subset of domains is selected, which results in the large standard errors seen in Figure \ref{fig:malaria_semi}. The architecture and hyperparameters for all models are the same as for the experiments in Section \ref{sec:supervised_malaria}.
% In Figure \ref{fig:malaria_semi} we see a steady decline in performance of the baseline model across all three metrics. However, we find that DIVA and DA benefit from jointly learning from labeled and unlabeled data. As a result DIVA and DA show better performance especially for experiment with three, four and five unsupervised domains compared to the baseline model. Furthermore, we argue that the improvement of DIVA over DA arises from the way the additional unlabeled data is utilized. In case of DA the unlabeled data ($d$, $x$) is only used to train the domain classifier and the feature extractor in an adversarial manner.  In Section \ref{sec:semisuper_diva} we show that due to DIVA's generative nature $q_{\phi_y}(z_y|x)$, $p_{\theta_y}(z_y|y)$ can be updated using $q_{\omega_y}(y|z_y)$ to marginalize over $y$ for an unlabeled sample $x$. In addition, the unlabeled data ($d$, $x$) is used to update $q_{\phi_d}(z_d|x)$, $p_{\theta_d}(z_d|d)$, $q_{\omega_d}(d|z_d)$, $q_{\phi_x}(z_x|x)$ and $p_{\theta}(x|z_d, z_x, z_y)$ in the same way as in the supervised case.

% \begin{figure}[h]
% \begin{minipage}{.33\textwidth}
% \includegraphics[scale=0.3]{malaria_plots/unsupervised_acc.png}
% \end{minipage}
% \begin{minipage}{.33\textwidth}
% \includegraphics[scale=0.3]{malaria_plots/unsupervised_f1.png}
% \end{minipage}
% \begin{minipage}{.33\textwidth}
% \includegraphics[scale=0.3]{malaria_plots/unsupervised_auc.png}
% \end{minipage}
% \caption{Results of the semi-supervised experiments, where we gradually reduce the number of supervised domains. We report the average and standard error for all three metrics.}
% \label{fig:malaria_semi}
% \end{figure}

% --SECTION--
\section{Conclusion}
% \input{conclusion.tex}
We have proposed DIVA as a generative model with three latent subspaces. We evaluated DIVA on Rotated MNIST and a Malaria Cell Images dataset. In both cases DIVA is able to learn disentangled representations that match the ground truth factors of interest, represented by the class $y$ and the domain $d$. By learning representations $\rvz_y$ that are invariant with respect to the domain $d$ DIVA is able to improve upon other methods on both datasets. Furthermore, we show that we can boost DIVA's performance by incorporating unlabeled samples, even from entirely new domains for which no labeled examples are available. This property is highly desirable in fields like medical imaging where the labeling process is very time consuming and costly.It appears that there is a key difference between interpolation and extrapolation, a distinction currently not made by the domain generalization community: If we assume that the domains lie in intervals like [$0\degree$,$15\degree$, $30\degree$] or [’red’, ’orange’, ’yellow’] then the performance for the domains in the center of the interval, e.g., $15\degree$ and ’orange’, seems to be better than for the domains on the ends of the interval. We argue that DIVA can make use of unlabeled data from a domain that is close to the test domain to improve its extrapolation performance.

 

% Acknowledgments---Will not appear in anonymized version
\midlacknowledgments{The authors thank Patrick Forr\'{e}, Rianne van den Berg, Marco Federici, Daniel Worrall and Bas Veeling for the helpful discussions and comments.\\\\
Maximilian Ilse was funded by the Nederlandse Organisatie voor Wetenschappelijk Onderzoek (Grant DLMedIa: Deep Learning for Medical Image Analysis).}


\bibliography{ilse20}

\appendix
\section{}
% \input{appendix.tex}
\label{sec:appendix}
\subsection{Graphical model of DIVA}
\begin{figure}[H]
    \centering
    % \input{tikz_plots/gen_and_inf.tex}
    \begin{tikzpicture}
\node [nodeobserved] (x) {$\rvx$};
\node [node, above = of x](zx) {$\rvz_x$};
\node [node, left = of zx](zd) {$\rvz_d$};
\node [node, right = of zx] (zy) {$\rvz_y$};
\node [semi2, right = of x, xshift=-4.5mm, yshift=-3.0mm] (y_2) {};
\node [semi1, right = of x] (y_1) {$y$};
\node [nodeobserved, left = of x] (d) {$d$};
% \node [box, below = of x] (text1) {generative: $p(x|z_d, z_x, z_y)p(z_d|d)p(z_x)p(z_y|y)p(d)p(y)$};

\draw [arrow] (d) -- (zd);
\draw [arrow] (y_1) -- (zy);
\draw [arrow] (zd) -- (x);
\draw [arrow] (zx) -- (x);
\draw [arrow] (zy) -- (x);
\end{tikzpicture}
\hspace{2cm}
\begin{tikzpicture}
\node [nodeobserved] (x) {$\rvx$};
\node [node, above = of x](zx) {$\rvz_x$};
\node [node, left = of zx](zd) {$\rvz_d$};
\node [node, right = of zx] (zy) {$\rvz_y$};
\node [semi2, right = of x, xshift=-4.5mm, yshift=-3.0mm] (y_2) {};
\node [semi1, right = of x] (y_1) {$y$};
\node [nodeobserved, left = of x] (d) {$d$};
% \node [box, below = of x] (text1) {inference: $q(z_d|x)q(z_x|x)q(z_y|x)$, aux: $q(d|z_d)$ and $q(y|z_y)$};

\draw [arrow2] (zd) -- (d);
\draw [arrow2] (zy) -- (y_1);
\draw [arrow] (x) -- (zd);
\draw [arrow] (x) -- (zx);
\draw [arrow] (x) -- (zy);
\end{tikzpicture}
    \label{fig:my_label}
    \caption{Left: Generative model. According to the graphical model we obtain $p(d, \rvx, y, \rvz_d, \rvz_x, \rvz_y) = p_\theta(\rvx| \rvz_d, \rvz_x, \rvz_y)p_{\theta_d}(\rvz_d|d)p(\rvz_x)p_{\theta_y}(\rvz_y|y)p(d)p(y)$. Right: Inference model. We propose to factorize the variational posterior as $q_{\phi_d}(\rvz_d|\rvx)q_{\phi_x}(\rvz_x|\rvx)q_{\phi_y}(\rvz_y|\rvx)$. Dashed arrows represent the two auxiliary classifiers $q_{\omega_d}(d|\rvz_d)$ and $q_{\omega_y}(y|\rvz_y)$.}
\label{fig:graph_model}
\end{figure}

\subsection{Definition of Domain Generalization}
% \subsection{Domain generalization}
We follow the domain generalization definitions used in \citet{muandet_domain_2013}. A domain is defined as a joint distribution $p(\rvx,y)$ on $\mathcal{X}\times\mathcal{Y}$, where $\mathcal{X}$ denotes the input space and $\mathcal{Y}$ denotes the output space. Let $\mathfrak{P}_{\mathcal{X}\times\mathcal{Y}}$ be the set of all domains. The training set consists of samples $\mathcal{S}$ taken from $N$ domains, $\mathcal{S} = \{S^{(d=i)}\}^N_{i=1}$. Here, the $i$th domain $p^{(d=i)}(\rvx,y)$ is represented by $n_i$ samples, $S^{(d=i)} = \{(\rvx^{(d=i)}_k, y^{(d=i)}_k)\}^{n_i}_{k=1}$. Each of the $N$ distributions $p^{(d=1)}(\rvx,y), \ldots , p^{(d=i)}(\rvx,y), \ldots , p^{(d=N)}(\rvx,y)$ are sampled from  $\mathfrak{P}_{\mathcal{X}\times\mathcal{Y}}$. We further assume that $p^{(d=i)}(\rvx,y)$ $\neq$ $p^{(d=j)}(\rvx,y)$, therefore, the samples in $\mathcal{S}$ are non-i.i.d.
During test time we are presented with samples $S^{(d=N+1)}$ from a previously unseen domain $p^{(d=N+1)}(\rvx,y)$. We are interested in learning representations that generalize from $p^{(d=1)}(\rvx,y), \ldots , p^{(d=N)}(\rvx,y)$ to this new domain. Training data are given as tuples $(d,\rvx,y)$ in the case of supervised data or as $(d,\rvx)$ in the case of unsupervised data.

\subsection{Rotated MNIST}
\subsubsection{Training procedure}
All DIVA models are trained for 500 epochs. The training is terminated if the training loss for $y$ has not improved for 100 epochs. As proposed in \citet{burgess_understanding_2018}, we linearly increase $\beta$ from 0.0 to 1.0 during the first 100 epochs of training. We set $\alpha_d$ = 2000. As seen in \citet{maaloe_biva:_2019}, we adjust $\alpha_y$ according to the ratio of labeled (N) and unlabeled data (M),
\begin{align}
    \alpha_y = \gamma \dfrac{M + N}{N},
\label{eq:semi_diva}
\end{align}
where we set $\gamma = 3500$. Last, $\rvz_d$, $\rvz_x$ and $\rvz_y$ each have 64 latent dimensions. All hyperparameters were determined by training DIVA on $\mathcal{M}_{0\degree}$, $\mathcal{M}_{15\degree}$, $\mathcal{M}_{30\degree}$, $\mathcal{M}_{45\degree}$ and testing on $\mathcal{M}_{60\degree}$. We searched over the following parameters: $\alpha_d$, $\alpha_d \in \{1500, 2000, 2500, 3000, 3500, 4000\}$; $\mathrm{dim}(\rvz_d)$ = $\mathrm{dim}(\rvz_x)$ = $\mathrm{dim}(\rvz_y)$ and $\mathrm{dim}(\rvz_x) \in \{16, 32, 64\}$; $\beta_{max} \in \{1, 5, 10\}$.

All models were trained using ADAM \cite{kingma_adam:_2014} (with default settings), a pixel-wise cross entropy loss and a batch size of 100. 

\subsubsection{Architectures}
To enable a fair experiment, the encoder $q_{\phi_y}(\rvz_y|\rvx)$ and auxiliary classifier $q_{\omega_y}(y|\rvz_y)$ form a DNN with the same number of layers and weights as described in \citet{wang_learning_2019}.

\begin{table}[H]
\caption{Architecture for  $p_\theta(\rvx|\rvz_d, \rvz_x, \rvz_y)$. The parameter for Linear is output features. The parameters for ConvTranspose2d are output channels and kernel size. The parameter for Upsample is the upsampling factor. The parameters for Conv2d are output channels and kernel size.}
\begin{center}
\begin{tabular}{c|c}
block & details\\
\hline
1 & Linear(1024), BatchNorm1d, ReLU\\
2 & Upsample(2)\\
3 & ConvTranspose2d(128, 5), BatchNorm2d, ReLU\\
4 & Upsample(2)\\
5 & ConvTranspose2d(256, 5), BatchNorm2d, ReLU\\
6 & Conv2d(256, 1)
\end{tabular}
\end{center}
\end{table}

\begin{table}[H]
\caption{Architecture for $p_{\theta_d}(\rvz_d|d)$ and $p_{\theta_y}(\rvz_y|y)$. Each network has two heads one for the mean and one for the scale. The parameter for Linear is output features.}
\begin{center}
\begin{tabular}{c|c}
block & details\\
\hline
1 & Linear(64), BatchNorm1d, ReLU\\
2.1 & Linear(64)\\
2.2 & Linear(64), Softplus\\
\end{tabular}
\end{center}
\end{table}

\begin{table}[H]
\caption{Architecture for $q_{\phi_d}(\rvz_d|\rvx)$, $q_{\phi_x}(\rvz_x|\rvx)$ and $q_{\phi_y}(\rvz_y|\rvx)$. Each network has two heads one for the mean one and for the scale. The parameters for Conv2d are output channels and kernel size. The parameters for  MaxPool2d are kernel size and stride. The parameter for Linear is output features.}
\begin{center}
\begin{tabular}{c|c}
block & details\\
\hline
1 & Conv2d(32, 5), BatchNorm2d, ReLU\\
2 & MaxPool2d(2, 2)\\
3 & Conv2d(64, 5), BatchNorm2d, ReLU\\
4 & MaxPool2d(2, 2)\\
5.1 & Linear(64)\\
5.2 & Linear(64), Softplus\\
\end{tabular}
\end{center}
\end{table}

\begin{table}[H]
\caption{Architecture for $q_{\omega_d}(d|\rvz_d)$ and $q_{\omega_y}(y|\rvz_y)$. The parameter for Linear is output features.}
\begin{center}
\begin{tabular}{c|c}
block & details\\
\hline
1 & ReLU, Linear(5 for $q_{\omega_d}(d|\rvz_d)$/10 for $q_{\omega_y}(y|\rvz_y)$), Softmax\\
\end{tabular}
\end{center}
\end{table}

% \subsubsection*{Quantitative disentanglement}
% While Section \ref{sec:quali_disentangle} contains only qualitative results, we now quantify the disentanglement and its influence on the classification accuracy. A good indicator for disentanglement is to test how predictive $z_d$ is of $y$ and $z_y$ is of $d$. First, we train DIVA using the training data from $\mathcal{M}_{0\degree}$, $\mathcal{M}_{15\degree}$, $\mathcal{M}_{30\degree}$, $\mathcal{M}_{45\degree}$ and $\mathcal{M}_{60\degree}$. After training DIVA, we use $q_{\phi_d}(z_d|x)$ and $q_{\phi_y}(z_y|x)$ to generate embeddings $z_d$ and $z_y$ for every $x$. Finally we train a two layer MLP for 500 epochs for the following 4 cases: predicting $y$ from $z_d$,  predicting $d$ from $z_d$,  predicting $y$ from $z_y$ and predicting $d$ from $z_y$.

% We repeat this experiment 3 times for $\beta_d$ = $\beta_x$ = $\beta_y$ = 1, $\beta_d$ = $\beta_x$ = $\beta_y$ = 5 and $\beta_d$ = $\beta_x$ = $\beta_y$ = 10. The first thing we notice in Table \ref{tab:different_betas} is that the two layer MLP is able to overfit perfectly when predicting $d$ from $z_d$ and $y$ from $z_y$ for all choices of $\beta_d$, $\beta_x$ and $\beta_y$. When we try to predict the class label $y$ from $z_d$ the accuracy is capped at around 26$\%$, while the random chance accuracy for this task is 10$\%$. In the other case, where we try to predict the domain label $d$ from $z_y$ the accuracy is capped at around 41$\%$. The random chance accuracy for this task is 20$\%$. In both cases we observe that the classification accuracy decreases if we increase $\beta_d$, $\beta_x$ and $\beta_y$. This trend implies that for higher $\beta$s $z_d$ becomes less predictive of $y$ and  $z_y$ becomes less predictive of $d$, i.e. $z_y$ and $z_d$ are better disentangled. This finding aligns with observations made in \citep{higgins_-vae:_2017}.

% \begin{table}[h]
% \caption{Classification results for embeddings of the training set for different values of $\beta_d$, $\beta_x$ and $\beta_y$.}
% \begin{center}
% \begin{tabular}{cccc}
% Task &$\beta = 1 $ &$\beta = 5 $ &$\beta = 10 $\\
% \hline
% $y$ from $z_d$ & 26.6 $\pm$ 0.1 & 26.5 $\pm$ 0.1 &26.0 $\pm$ >0.1\\
% $d$ from $z_d$ & 100.0 $\pm$ 0.0  & 100.0 $\pm$ 0.0 & 100.0 $\pm$ 0.0 \\
% $y$ from $z_y$ & 100.0 $\pm$ 0.0  & 100.0 $\pm$ 0.0 & 100.0 $\pm$ 0.0 \\
% $d$ from $z_y$ & 41.4 $\pm$ 0.1  & 41.0 $\pm$ 0.1 & 40.8 $\pm$ >0.1 \\
% \end{tabular}
% \end{center}
% \label{tab:different_betas}
% \end{table}

% In addition, we look at the classification performance on the test domain $\mathcal{M}_{75\degree}$. In Table \ref{tab:different_betas_test} we notice that test accuracy decreases when we increase $\beta_d$, $\beta_x$ and $\beta_y$. By comparing this observation with the results in Table \ref{tab:different_betas}, we conclude that a better disentangled latent space results in worse test accuracy on the rotated MNIST dataset. This finding aligns well with observation made in \citep{wang_learning_2019}, that some information about the domain seems to be useful for predicting the class in case of the rotated MNIST dataset.
% \begin{table}[h]
% \caption{Classification results on the test domain $\mathcal{M}_{75\degree}$ for different values of $\beta_d$, $\beta_x$ and $\beta_y$.}
% \begin{center}
% \begin{tabular}{cccc}
% Test &$\beta = 1 $ &$\beta = 5 $ &$\beta = 10 $\\
% \hline
% $\mathcal{M}_{75\degree}$ &90.9 $\pm$ 0.2  & 89.4 $\pm$ 0.2 &88.9 $\pm$ 0.2 \\
% \end{tabular}
% \end{center}
% \label{tab:different_betas_test}
% \end{table}
\newpage
\subsubsection{Qualitative Disentanglement: Training domains}
Here, we trained
DIVA on 5000 images from five domains: $\mathcal{M}_{0\degree}$, $\mathcal{M}_{15\degree}$, $\mathcal{M}_{30\degree}$, $\mathcal{M}_{45\degree}$ and $\mathcal{M}_{60\degree}$.

\label{sec:quali_disentangle}

\begin{figure}[H]
\begin{minipage}{.32\textwidth}
\includegraphics[scale=0.25]{train_zd_domain.png}
\end{minipage}
\begin{minipage}{.32\textwidth}
\includegraphics[scale=0.25]{train_zx_domain.png}
\end{minipage}
\begin{minipage}{.32\textwidth}
\includegraphics[scale=0.25]{train_zy_domain.png}
\end{minipage}\\
\begin{minipage}{.32\textwidth}
\includegraphics[scale=0.25]{train_zd_class.png}
\end{minipage}
\begin{minipage}{.32\textwidth}
\includegraphics[scale=0.25]{train_zx_class.png}
\end{minipage}
\begin{minipage}{.32\textwidth}
\includegraphics[scale=0.25]{train_zy_class.png}
\end{minipage}
\caption{2D embeddings of all three latent subspaces. In the top row embeddings are colored according to their domain, in the bottom row they are colored according to their class. First column: $\rvz_d$ encoded by $q_{\phi_d}(\rvz_d|\rvx)$. The top plot shows five distinct clusters, where each cluster corresponds to a single domain. In the bottom plot no clustering is visible. Second column: $\rvz_x$ encoded by $q_{\phi_x}(\rvz_x|\rvx)$. We observe a correlation between the rotation angle of each MNIST digit and $\rvz_x$[0] in the top plot. Upon visual inspection of the original inputs $\rvx$, we find a correlation between the line thickness digit and $\rvz_x$[0] as well as a correlation between the digit width and $\rvz_x$[1] in the bottom plot. As a result, we observe a clustering of embeddings with class '1' at the lower left part of the plot. Third column: $\rvz_y$ encoded by $q_{\phi_y}(\rvz_y|\rvx)$. In the top plot no clustering is visible. The bottom plot shows ten distinct clusters, where each cluster corresponds to a class. A plot of the 2D embeddings for the test domain $\mathcal{M}_{75\degree}$ can be found in the Appendix.}
\label{fig:embeddings_mnist}
\end{figure}

\subsubsection{Qualitative Disentanglement: Test domain}
In this section, we visualize the $\rvz_d$ and $\rvz_y$ for data points $x$ from the test domain $\mathcal{M}_{75\degree}$ for the model trained in Section \ref{sec:quali_disentangle}. Figure \ref{fig:test_zy} shows 1000 embeddings $\rvz_y$ encoded by 
$q_{\phi_y}(\rvz_y|\rvx)$. Figure \ref{fig:test_zd} shows 1000 embeddings $\rvz_d$ encoded by $q_{\phi_d}(\rvz_d|\rvx)$.

\begin{figure}[H]
\centering
\includegraphics[scale=0.5]{test_zd_class_with_train_test_75.png}
% \vskip -3mm
\caption{1000 two-dimensional embeddings $\rvz_d$ encoded by $q_{\phi_d}(\rvz_d|\rvx)$ for $\rvx$ from the test domain $\mathcal{M}_{75\degree}$. The color of each point indicates the associated domain.}
\label{fig:test_zd}
\end{figure}

\begin{figure}[H]
\centering
\includegraphics[scale=0.5]{test_zy_class.png}
% \vskip -3mm
\caption{1000 two-dimensional embeddings $\rvz_y$ encoded by $q_{\phi_y}(\rvz_y|\rvx)$ for $\rvx$ from the test domain $\mathcal{M}_{75\degree}$. The color of each point indicates the associated class.}
\label{fig:test_zy}
\end{figure}

\subsubsection{Samples}
We present samples from  DIVA by sampling $\rvz_d$, $\rvz_x$ and $\rvz_y$ from their priors and then decoding them. Generated examples on the Rotated MNIST data are given in the Figure \ref{fig:samples_diva}. DIVA allows to generate images that are almost indistinguishable from real datapoints.
\begin{figure}[H]
\centering
\includegraphics[scale=0.8]{samples.png}
% \caption{64 rotated MNIST samples created by DIVA.}
\caption{Samples from DIVA trained on Rotated MNIST.}
\label{fig:samples_diva}
\end{figure}

\subsubsection{Conditional generation}
\label{sec:conditional_generation}
Yet another way to gain insight into the disentanglement abilities of DIVA is conditional generation. We first train DIVA with $\beta$ = 10 using $\mathcal{M}_{0\degree}$, $\mathcal{M}_{15\degree}$, $\mathcal{M}_{30\degree}$, $\mathcal{M}_{45\degree}$ and $\mathcal{M}_{60\degree}$ as training domains. After training we perform two experiments. In the first one we are fixing the class and varying the domain. In the second experiment we are fixing the domain and varying the class.

\paragraph{Change of class}
The first row of Figure \ref{fig:cond_gen} (left) shows the input images $x$ for DIVA. First, we generate embeddings $\rvz_d$, $\rvz_x$ and $\rvz_y$ for each $\rvx$ using $q_{\phi_d}(\rvz_d|\rvx)$, $q_{\phi_x}(\rvz_x|\rvx)$ and $q_{\phi_y}(\rvz_y|\rvx)$. Second, we replace $\rvz_y$ with a sample $\rvz_y'$ from the conditional prior $p_{\theta_y}(\rvz_y|y)$. Last, we generate new images from $\rvz_d$, $\rvz_x$ and $\rvz_y'$ using the trained encoder $p_\theta(\rvx|\rvz_d, \rvz_x, \rvz_y)$. In Figure \ref{fig:cond_gen} (left) rows 2 to 11 correspond to the classes '0' to '9'. We observe that the rotation angle (encoded in $\rvz_d$) and the line thickness (encoded in $\rvz_x$) are well preserved, while the class of the image is changing as intended.

\begin{figure}[H]
\begin{minipage}{.5\textwidth}
\centering
\includegraphics[scale=0.8]{reconstruction_label.png}
\end{minipage}
\begin{minipage}{.5\textwidth}
\centering
\includegraphics[scale=0.8]{reconstruction_domain.png}
\end{minipage}
\caption{Reconstructions. Left: First row is input, row 2 to 11 correspond to labels '0' to '9'. Right: First row is input, row 2 to 6
 correspond to domains 0, 15, 30, 45, 60.}
\label{fig:cond_gen}
\end{figure}

\paragraph{Change of domain}
We repeat the experiment from above but this time we keep $\rvz_x$ and $\rvz_y$ fixed while changing the domain. After generating embeddings $\rvz_d$, $\rvz_x$ and $\rvz_y$ for each $\rvx$ in the first row of Figure \ref{fig:cond_gen} (right), we replace $\rvz_d$ with a sample $\rvz_d'$ from the conditional prior $p_{\theta_d}(\rvz_d|d)$. Finally, we generate new images from $\rvz_d'$, $\rvz_x$ and $\rvz_y$ using the trained encoder $p_\theta(\rvx|\rvz_d, \rvz_x, \rvz_y)$. In Figure \ref{fig:cond_gen} (right) rows 2 to 6 correspond to the domains $\mathcal{M}_{0\degree}$ to $\mathcal{M}_{60\degree}$. Again, DIVA shows the desired behaviour: While the rotation angle is changing the class and style of the original image is maintained.

\subsubsection{Ablation study: Partitioned latent space}
We compare DIVA to a VAE with a single latent space, a standard Gaussian prior and two auxillary tasks. The resulting graphical model is shown in Figure \ref{fig:graph_model_ablation}. The results in Table \ref{tab:comparison_diva_vae} clearly show the benefits of having a partitioned latent space $\rvz$.

\begin{figure}[H]
    \centering
    % \input{tikz_plots/ablation.tex}
    \begin{tikzpicture}
\node [nodeobserved] (x) {$\rvx$};
\node [node, above = of x](z) {$\rvz$};
% \node [box, below = of x] (text1) {generative: $p(x|z_d, z_x, z_y)p(z_d|d)p(z_x)p(z_y|y)p(d)p(y)$};
\draw [arrow] (z) -- (x);
\end{tikzpicture}
\hspace{2cm}
\begin{tikzpicture}
\node [nodeobserved] (x) {$\rvx$};
\node [node, above = of x](z) {$\rvz$};
\node [nodeobserved, right = of x] (y) {$y$};
\node [nodeobserved, left = of x] (d) {$d$};
% \node [box, below = of x] (text1) {inference: $q(z_d|x)q(z_x|x)q(z_y|x)$, aux: $q(d|z_d)$ and $q(y|z_y)$};

\draw [arrow2] (z) -- (d);
\draw [arrow2] (z) -- (y);
\draw [arrow] (x) -- (z);
\end{tikzpicture}
    \caption{Left: Generative model. According to the graphical model we obtain $p(\rvx, \rvz) = p_\theta(\rvx|\rvz)p(\rvz)$. Right: Inference model. We propose $q_{\phi}(\rvz|\rvx)$ as the variational posterior. Dashed arrows represent the two auxiliary classifiers $q_{\omega_d}(d|\rvz)$ and $q_{\omega_y}(y|\rvz)$.}
\label{fig:graph_model_ablation}
\end{figure}

The objective is given by,
\begin{align}
\mathcal{F}_{\text{VAE}}(d, \rvx, y) &:= \mathbb{E}_{q_{\phi}(\rvz | \rvx)} \left[ \log p_\theta(\rvx|\rvz) \right] -\beta KL\left(q_{\phi}(\rvz|\rvx)||p(\rvz)\right)\nonumber\\ &+ \alpha_d\mathbb{E}_{q_{\phi}(\rvz|\rvx)}\left[\log q_{\omega_d}(d|\rvz)\right]
+ \alpha_y\mathbb{E}_{q_{\phi}(\rvz|\rvx)}\left[\log q_{\omega_y}(y|\rvz)\right].
\end{align}

\begin{table}[H]
\caption{Comparison of DIVA with a VAE with a single latent space, a standard Gaussian prior and two auxillary tasks on Rotated MNIST. We report the average and standard error of the classification accuracy.}
\centering
\begin{tabular}{ccc}
Test & VAE & DIVA\\
\hline
$\mathcal{M}_{0\degree}$  & 88.4 $\pm$ 0.5 &\bf{93.5} $\pm$ 0.3\\
$\mathcal{M}_{15\degree}$ & 98.3 $\pm$ 0.1 &\bf{99.3} $\pm$ 0.1\\
$\mathcal{M}_{30\degree}$ & 97.4 $\pm$ 0.2 &\bf{99.1} $\pm$ 0.1\\
$\mathcal{M}_{45\degree}$ & 97.4 $\pm$ 0.4 &\bf{99.2} $\pm$ 0.1\\
$\mathcal{M}_{60\degree}$ & 97.9 $\pm$ 0.2 &\bf{99.3} $\pm$ 0.1\\
$\mathcal{M}_{75\degree}$ & 84.0 $\pm$ 0.3 &\bf{93.0} $\pm$ 0.4\\
\hline
Avg & 93.9 $\pm$ 0.1 &\bf{97.2} $\pm$ 1.3
\end{tabular}
\label{tab:comparison_diva_vae}
\end{table}

\subsubsection{Ablation study: DIVA without $z_d$ and $z_x$}
We compare DIVA as proposed in Section \ref{sec:method} to two ablated versions of DIVA:
\begin{enumerate}
    \item DIVA without $z_d$: The domain label $d$ is not used during training. Therefore, there exist no latent space $z_d$, no encoder $q_{\phi_d}(z_d|x)$, no prior $p_{\theta_d}(z_d|d)$ and no classifier $q_{\omega_d}(d|z_d)$. The decoder becomes $p_\theta(x|z_x, z_y)$.
    \item DIVA without $z_x$: There exist no latent space $z_x$, no encoder $q_{\phi_x}(z_x|x)$ and no prior $p(z_x)$. The decoder becomes $p_\theta(x|z_d, z_y)$.
\end{enumerate}

In Table \ref{tab:comparison_ablation}, we compare DIVA and the two ablated versions on the Rotated MNIST dataset. Surprisingly, for Rotated MNIST we could not find a significant difference in performance between DIVA and DIVA without $z_d$, as seen in the third column. However, not having $z_d$ drastically reduces the interpretability of our model, since without $z_d$ we cannot find the variations in x that are explained by the domain $d$. E.g. in Appendix \ref{sec:conditional_generation}, we show that we can generate samples conditioned on the domain label that give us a clear idea of the meaning of $d$. Furthermore, as seen in Figure \ref{fig:recon_malaria}, we see that the patient ID is highly correlated with the color of the stain. While the cell shape in not correlated with $d$ or $y$ and therefore is captured by $z_x$. Without $z_d$ we are unable to gain such (especially from a medical perspective) important insights.
 In the fourth column, we see that for $\mathcal{M}_{0\degree}$ and $\mathcal{M}_{75\degree}$ DIVA with $z_x$ performs significantly better than without. We argue that if $z_x$ does not exist, $z_d$ and $z_y$ will capture the residual variations in x that are not explained by $d$ or $y$. We believe this makes it harder to predict $y$ using $z_y$ and $d$ using $z_d$.

\begin{table}[H]
\centering
\caption{Results of ablation study.}
\begin{tabular}{cccc}
Test &DIVA &DIVA without $z_d$ & DIVA without $z_x$\\
\hline
$\mathcal{M}_{0\degree}$ &93.5 $\pm$ 0.3 & 93.4 $\pm$ 0.5  & 92.7 $\pm$ 0.5  \\
$\mathcal{M}_{15\degree}$ &99.3 $\pm$ 0.1 & 99.3 $\pm$ 0.1  & 99.4  $\pm$ 0.1 \\
$\mathcal{M}_{30\degree}$ &99.1 $\pm$ 0.1 & 98.9 $\pm$ 0.1  & 99.2  $\pm$ 0.1  \\
$\mathcal{M}_{45\degree}$ &99.2 $\pm$ 0.1 & 99.1 $\pm$ 0.1  & 99.1 $\pm$ 0.1  \\
$\mathcal{M}_{60\degree}$ &99.3 $\pm$ 0.1 & 99.1 $\pm$ 0.1  & 99.4 $\pm$ 0.1  \\
$\mathcal{M}_{75\degree}$ &93.0 $\pm$ 0.4 & 92.8 $\pm$ 0.4  & 92.4 $\pm$ 0.4  \\
\hline
Avg &97.2 $\pm$ 1.3 & 97.1 $\pm$ 1.3 & 97.1 $\pm$ 1.5 \\
\end{tabular}
\label{tab:comparison_ablation}
\end{table}

\subsection{Malaria Cell Images}
\subsubsection{Example cells from each domain}
\begin{figure}[H]
\centering
\subfigure[C116P77]{
  \includegraphics[width=27mm]{reconstruction_cell_train_top10_0.png}
}
\subfigure[C132P93]{
  \includegraphics[width=27mm]{reconstruction_cell_train_top10_1.png}
}
\subfigure[C137P98]{
  \includegraphics[width=27mm]{reconstruction_cell_train_top10_2.png}
}
\subfigure[C180P141]{
  \includegraphics[width=27mm]{reconstruction_cell_train_top10_3.png}
}
\subfigure[C182P143]{
  \includegraphics[width=27mm]{reconstruction_cell_train_top10_4.png}
}

\subfigure[C184P145]{
  \includegraphics[width=27mm]{reconstruction_cell_train_top10_5.png}
}
\subfigure[C39P4]{
  \includegraphics[width=27mm]{reconstruction_cell_train_top10_6.png}
}
\subfigure[C59P20]{
  \includegraphics[width=27mm]{reconstruction_cell_train_top10_7.png}
}
\subfigure[C68P29]{
  \includegraphics[width=27mm]{reconstruction_cell_train_top10_8.png}
}
\subfigure[C99P60]{
  \includegraphics[width=27mm]{reconstruction_cell_train_top10_9.png}
}
\end{figure}

\subsubsection{Training procedure: DIVA}
All DIVA models are trained for 500 epochs. The training is terminated if the validation accuracy for $y$ has not improved for 100 epochs. As proposed in \citet{burgess_understanding_2018}, we linearly increase $\beta$ from 0.0 to 1.0 during the first 100 epochs of training. We set $\alpha_d$ = 100000 and  $\alpha_y$ = 75000.  Last, $\rvz_d$, $\rvz_x$ and $\rvz_y$ each have 64 latent dimensions. We searched over the following parameters: $\alpha_d$, $\alpha_d \in \{25000, 50000, 75000, 100000\}$; $\mathrm{dim}(\rvz_d)$ = $\mathrm{dim}(\rvz_x)$ = $\mathrm{dim}(\rvz_y)$, $\mathrm{dim}(\rvz_{x}) \in \{32, 64\}$; $\beta_{max} \in \{1, 5, 10\}$. All hyperparameters were determined using a validation set that consists of 20 $\%$ of the training set. All models were trained using ADAM \citep{kingma_adam:_2014} (with default settings), a mixture of discretized logistics \citep{salimans_pixelcnn++:_2017} loss and a batch size of 100. In case of the semi-supervised experiment in Section \ref{sec:malaria_semi} we adapt $\alpha_d$ and $\alpha_y$ according to Equation \ref{eq:semi_diva}.

\subsubsection{Training procedure: Baseline and DA}
In case of the supervised experiments in Section \ref{sec:supervised_malaria} all models are trained for 500 epochs. The training is terminated if the validation accuracy for $y$ has not improved for 100 epochs. In case of the semi-supervised experiments in Section \ref{sec:semisuper_diva} the amount of epochs is adjusted to match the number of parameter updates of DIVA.
For DA we follow the same training procedure as described in \citet{ganin_domain-adversarial_2015}. In the supervised case, first, a labeled batch randomly sampled from the training distributions is used to update the class classifier, domain classifier and the feature extractor in an adversarial fashion. Second, a second batch randomly sampled from the training distributions is used to update only the domain classifier and the feature extractor in an adversarial fashion. In the semi-supervised case samples from the unsupervised domains form the second batch together with samples from the supervised domains. We use the same domain adaptation parameter $\lambda$ schedule as described in \citet{ganin_domain-adversarial_2015}. Determined by hyperparameter search we find that DA performs better when $\lambda \cdot \epsilon$ is used. Here, $\epsilon = 0.001$. We searched over the following values of $\epsilon \in \{0.1, 0.05, 0.01, 0.005, 0.001, 0.0005, 0.0001\}$. In case of the semi-supervised experiment in Section \ref{sec:malaria_semi} $\epsilon = 0.01$ was determined by hyperparameter search.

\subsubsection{Architecture}
In the following we will describe the architecture of DIVA in detail. Note that the architecture for the baseline model is the same as $q_{\phi_y}(\rvz_y|\rvx)$ (we only use the mean of $\rvz_y$) followed by $q_{\omega_y}(y|\rvz_y)$ where $\rvz_y$ has 1024 dimensions. DA is using $q_{\phi_y}(\rvz_y|\rvx)$ without the linear layer as a feature extractor. The class classifier and the domain classifier consist of two linear layers. The feature extractor for all models consist of seven ResNet blocks \citep{he_deep_2015}. During training batch norm \cite{ioffe_batch_2015} is used for all layers.

\begin{table}[H]
\caption{Architecture for  $p_\theta(\rvx|\rvz_d, \rvz_x, \rvz_y)$. The parameter for Linear is output features. The parameters for ResidualConvTranspose2d are output channels and kernel size. The parameters for Conv2d are output channels and kernel size.}
\begin{center}
\begin{tabular}{c|c}
block & details\\
\hline
1 & Linear(1024), BatchNorm1d, LeakyReLU\\
2 & ResidualConvTranspose2d(64, 3), LeakyReLU\\
3 & ResidualConvTranspose2d(64, 3), LeakyReLU\\
4 & ResidualConvTranspose2d(64, 3), LeakyReLU\\
5 & ResidualConvTranspose2d(32, 3), LeakyReLU\\
6 & ResidualConvTranspose2d(32, 3), LeakyReLU\\
7 & ResidualConvTranspose2d(32, 3), LeakyReLU\\
8 & ResidualConvTranspose2d(32, 3), LeakyReLU\\
9 & ResidualConvTranspose2d(32, 3), LeakyReLU\\
10 & Conv2d(100, 3)\\
11 & Conv2d(100, 1)
\end{tabular}
\end{center}
\end{table}

\begin{table}[H]
\caption{Architecture for $p_{\theta_d}(\rvz_d|d)$ and $p_{\theta_y}(\rvz_y|y)$. Each network has two heads one for the mean and one for the scale. The parameter for Linear is output features.}
\begin{center}
\begin{tabular}{c|c}
block & details\\
\hline
1 & Linear(64), BatchNorm1d, LeakyReLU\\
2.1 & Linear(64)\\
2.2 & Linear(64), Softplus\\
\end{tabular}
\end{center}
\end{table}

\begin{table}[H]
\caption{Architecture for $q_{\phi_d}(\rvz_d|\rvx)$, $q_{\phi_x}(\rvz_x|\rvx)$ and $q_{\phi_y}(\rvz_y|\rvx)$. Each network has two heads one for the mean one and for the scale. The parameters for Conv2d are output channels and kernel size. The parameters for ResidualConv2d are output channels and kernel size. The parameter for Linear is output features.}
\begin{center}
\begin{tabular}{c|c}
block & details\\
\hline
1 & Conv2d(32, 3), BatchNorm2d, LeakyReLU\\
2 & ResidualConv2d(32), LeakyReLU\\
3 & ResidualConv2d(32), LeakyReLU\\
4 & ResidualConv2d(64, 3), LeakyReLU\\
5 & ResidualConv2d(64, 3), LeakyReLU\\
6 & ResidualConv2d(64, 3), LeakyReLU\\
7 & ResidualConv2d(64, 3), LeakyReLU\\
8 & ResidualConv2d(64, 3), LeakyReLU\\
9.1 & Linear(64)\\
9.2 & Linear(64), Softplus\\
\end{tabular}
\end{center}
\end{table}

\begin{table}[H]
\caption{Architecture for $q_{\omega_d}(d|\rvz_d)$ and $q_{\omega_y}(y|\rvz_y)$. The parameter for Linear is output features.}
\begin{center}
\begin{tabular}{c|c}
block & details\\
\hline
1 & LeakyReLU, Linear(9 for $q_{\omega_d}(d|\rvz_d)$/2 for $q_{\omega_y}(y|\rvz_y)$), Softmax\\
\end{tabular}
\end{center}
\end{table}

\subsubsection{Supervised case}
\label{sec:supervised_malaria}
To further evaluate domain generalization abilities, models are trained on nine domains (patient IDs) and tested on the remaining 10th domain. We choose ROC AUC on the hold out test domain as the evaluation metric, since the two classes are highly imbalanced. All experiments are repeated five times. 

\begin{table}[H]
\caption{Results of the supervised experiments for the first part of domains. We report the average
and standard error of ROC AUC. }
\centering
\small
\begin{tabular}{ccccccc}
Model &C116P77 &C132P93 &C137P98 &C180P141 &C182P143 &C184P145\\
\hline
Baseline & 90.6 $\pm$ 0.7 & 97.8 $\pm$ 0.5 & 98.9 $\pm$ 0.2 & 98.5 $\pm$ 0.2 & 96.7 $\pm$ 0.4 & 98.1 $\pm$ 0.2 \\
DA & 90.6 $\pm$ 1.7 & \bf{98.3} $\pm$ 0.4 & 99.0 $\pm$ 0.1 & 98.8 $\pm$ 0.1 & 96.9 $\pm$ 0.4 & 97.1 $\pm$ 0.8 \\
DIVA & \bf{93.3} $\pm$ 0.4 & \bf{98.4} $\pm$ 0.3 & 99.0 $\pm$ 0.1 & \bf{99.0} $\pm$ 0.1 & 96.5 $\pm$ 0.3 & \bf{98.5} $\pm$ 0.3 \\
\end{tabular}
\label{tab:malaria_super_1}
\end{table}

\begin{table}[H]
\centering
\small
\caption{Results of the supervised experiments for the second part of domains. As well as the average across all domains. We report the average and standard error of ROC AUC.}
\begin{tabular}{ccccc|c}
Model &C39P4 &C59P20 &C68P29 &C99P60 &Average\\
\hline
Baseline & 97.1 $\pm$ 0.4 & 82.8 $\pm$ 2.8 & 95.3 $\pm$ 0.6 & 96.2 $\pm$ 0.1 & 95.2 $\pm$ 1.6 \\
DA & 97.4 $\pm$ 0.3 & 83.2 $\pm$ 3.3 & \bf{96.3} $\pm$ 0.1 & 96.1 $\pm$ 0.3 & 95.4 $\pm$ 1.6\\
DIVA & \bf{97.8} $\pm$ 0.2 & 82.1 $\pm$ 3.0 & \bf{96.3} $\pm$ 0.2 & \bf{96.6} $\pm$ 0.3 & 95.8 $\pm$ 1.6 \\
\end{tabular}
\label{tab:malaria_super_2}
\end{table}

We compare DIVA with a ResNet-like \citep{he_deep_2015} baseline and DA. During inference all three models have the same architecture, seven ResNet blocks followed by two linear layers. Detailed information about hyperparameters, architecture and training schedule can be found in the Appendix. In Table \ref{tab:malaria_super_1} and \ref{tab:malaria_super_2} we find that the results are not equally distributed across all test domains. In five cases DIVA is able to significantly improve upon the baseline model and DA. Upon visual inspection we find that cells from domain C116P77 and domain C59P20 are stained pink, similar to the stain of the parasite, see Appendix. In case of C116P77 DIVA achieves the highest ROC AUC of all three models.  In case of domain C59P20, all three methods have difficulties to detect the parasite which leads to the lowest ROC AUC among all domains. We believe that the difficulties arise the poor contrast between cell and parasite. Last, DIVA is able to improve on average when compared to the baseline model and DA, although the improvements are within the standard error.

\subsubsection{Predicting $y$ using either $\rvz_d$, $\rvz_x$ or $\rvz_y$}
We test how predictive $\rvz_d$, $\rvz_x$ and $\rvz_y$ are for the class $y$ on the Malaria Cell Images dataset. First, we use the trained DIVA models from \ref{sec:supervised_malaria} to create embeddings $\rvz_d$, $\rvz_x$ and $\rvz_y$ for every $\rvx$ in the training domain and hold out test domain. Second, we train a 2-layer MLP on the embeddings $\rvz_d$, $\rvz_x$ and $\rvz_y$ from the training domains. We train the MLP for 100 epochs using ADAM \cite{kingma_adam:_2014}. After training we test the MLP embeddings $\rvz_d$, $\rvz_x$ and $\rvz_y$ from the test domain. In Table \ref{tab:predict_y} we clearly see that $\rvz_y$ captures all relevant information in order to predict $y$, while the MLPs trained using $\rvz_d$ and $\rvz_x$ perform worse than a classifier that would always pick the majority class.
\begin{table}[H]
\centering
\caption{Prediction of $y$ using a 2 layer MLP trained using $\rvz_d$, $\rvz_x$ and $\rvz_y$. We report the mean and standard error of the classification accuracy on the hold out test domain.}
\begin{tabular}{ccccc}
test domain &$\rvz_d$ &$\rvz_x$ &$\rvz_y$ & majority class\\
\hline
0 & 84.6 $\pm$ 1.0 & 85.0 $\pm$ 0.2 & \bf{87.9} $\pm$ 0.9 & 0.86 \\
1 & 89.5 $\pm$ 0.4 & 88.2 $\pm$ 0.5 & \bf{96.8} $\pm$ 0.1 & 0.9 \\
2 & 68.2 $\pm$ 3.5 & 80.0 $\pm$ 1.6 & \bf{96.9} $\pm$ 0.5 & 0.81 \\
3 & 87.0 $\pm$ 0.3 & 75.2 $\pm$ 2.9 & \bf{95.5} $\pm$ 0.2 & 0.88 \\
4 & 89.1 $\pm$ 0.3 & 82.7 $\pm$ 2.4 & \bf{92.5} $\pm$ 0.4 & 0.90 \\
5 & 88.3 $\pm$ 0.2 & 87.7 $\pm$ 0.2 & \bf{90.6} $\pm$ 0.5 & 0.88 \\
6 & 82.6 $\pm$ 3.7 & 56.3 $\pm$ 5.1 & \bf{91.1} $\pm$ 0.1 & 0.90 \\
7 & 88.3 $\pm$ 0.1 & 88.3 $\pm$ 0.1 & \bf{90.8} $\pm$ 0.8 & 0.88 \\
8 & 89.5 $\pm$ 0.3 & 85.3 $\pm$ 1.7 & \bf{93.5} $\pm$ 0.4 & 0.90 \\
9 & 89.1 $\pm$ 0.2 & 86.9 $\pm$ 1.5 & \bf{94.0} $\pm$ 0.3 & 0.89
\end{tabular}
\label{tab:predict_y}
\end{table}
\end{document}
