\documentclass[margin=2pt]{standalone}

\usepackage{amsmath,amsfonts}
\usepackage{booktabs}
\usepackage[dvipsnames]{xcolor}
\usepackage{tikz}
\usetikzlibrary{positioning}
\usetikzlibrary{arrows}
\usetikzlibrary{calc,fit}
\usetikzlibrary{shapes.geometric}
\usetikzlibrary{shapes.misc}
\usetikzlibrary{decorations.pathmorphing}
\usetikzlibrary{decorations.pathreplacing}
\usetikzlibrary{snakes}
\input{uai2022/math_commands.tex}

\begin{document}
\begin{tikzpicture}[every node/.style={inner sep=0pt}, var/.style={draw=black, minimum size=15pt, circle, font=\small, thick, anchor=center, fill=white}, factor/.style={fill=black, minimum size=5pt, rectangle, anchor=center}, edge/.style={thick, -latex}, modulation/.style={thick, -|}, zlabel/.style={font=\scriptsize, yshift=-2pt}, observed/.style={fill=gray!30}, flowlabel/.style={midway, font=\scriptsize}, rewardlabel/.style={pos=1, font=\scriptsize, right}, gfn_edge/.style={ultra thick, -latex}, gfn_node/.style={ultra thick, circle, draw, text height=3.5em, fill=gray!10}, gfn_other/.style={pos=1, anchor=west, inner sep=2pt}, adjacency/.style={minimum width=10pt, minimum height=10pt, rectangle, rounded corners=2pt, font=\bfseries\scriptsize}, block/.style={rounded corners=2pt, draw=black, thick, minimum height=1.5em, minimum width=6em, align=center, inner xsep=5pt, inner ysep=5pt}, transformer/.style={block, inner ysep=0pt}, plus/.style={draw=black, thick, fill=white, circle, minimum size=1em}, vector/.style={inner sep=0pt, minimum height=5pt, minimum width=2em, draw=black, thick, rounded corners=1pt}, mini vector/.style={vector, minimum width=1em}, y=12pt, x=50pt]

% \colorlet{colorA}{RoyalBlue!30}
% \colorlet{colorB}{ForestGreen!30}
% \colorlet{colorC}{YellowOrange!50}
% \colorlet{colorA}{LimeGreen!30}
% \colorlet{colorB}{Dandelion!30}
% \colorlet{colorC}{RoyalBlue!30}
\definecolor{colorA}{HTML}{a5d6a7}  % material green@200
\definecolor{colorB}{HTML}{ffcc80}  % material orange@200
\definecolor{colorC}{HTML}{90caf9}  % material blue@200

\node[transformer, fill=gray!10, rotate=270] (transformer) at (0, 0) {\begin{tikzpicture}
    \node[block, fill=Dandelion!30] (norm1) {Norm};
    \node[block, fill=LimeGreen!30, above=1em of norm1] (mha) {Linearized\\[-0.2em]Attention};
    
    \node[plus, above=3pt of mha] (plus1) {};
    \draw[thick, shorten <=3pt, shorten >=3pt] (plus1.north) -- (plus1.south);
    \draw[thick, shorten <=3pt, shorten >=3pt] (plus1.west) -- (plus1.east);
    
    \node[block, fill=Dandelion!30, above=0.8em of plus1] (norm2) {Norm};
    \node[block, fill=RoyalBlue!30, above=1em of norm2] (mlp) {MLP};
    
    \node[plus, above=3pt of mlp] (plus2) {};
    \draw[thick, shorten <=3pt, shorten >=3pt] (plus2.north) -- (plus2.south);
    \draw[thick, shorten <=3pt, shorten >=3pt] (plus2.west) -- (plus2.east);
    
    \coordinate (middle) at ($(norm1.north)!0.2!(mha.south)$);
    \draw[edge] (norm1) -- (mha);
    \draw[edge, rounded corners=2pt] (middle) -| (mha.220);
    \draw[edge, rounded corners=2pt] (middle) -| (mha.320);
    \draw[thick] (mha) -- (plus1);
    \draw[thick] (norm1.south) -- ++(270:0.8em) coordinate[midway] (bottom_norm1);
    \coordinate[right=5pt of mha] (right_mha);
    \draw[edge, rounded corners=2pt] (bottom_norm1) -| (right_mha) |- (plus1);
    \draw[thick] (plus1) -- (norm2) coordinate[midway] (bottom_norm2);
    \draw[edge] (norm2) -- (mlp);
    \draw[thick] (mlp) -- (plus2);
    \draw[thick] (plus2.north) -- ++(90:3pt);
    \coordinate[right=5pt of mlp] (right_mlp);
    \draw[edge, rounded corners=2pt] (bottom_norm2) -| (right_mlp) |- (plus2);
\end{tikzpicture}};

\node[block, fill=red!10, left=1em of transformer.south] (embedding) {\begin{tikzpicture}
    \node[vector, fill=colorA] (v11) {};
    \node[vector, fill=colorA, right=1pt of v11] (v12) {};
    \node[mini vector, fill=white, right=1pt of v12] (v13) {};
    
    \node[vector, fill=colorA, below=3pt of v11] (v21) {};
    \node[vector, fill=colorB, right=1pt of v21] (v22) {};
    \node[mini vector, fill=black!90, right=1pt of v22] (v23) {};
    
    \node[vector, fill=colorA, below=3pt of v21] (v31) {};
    \node[vector, fill=colorC, right=1pt of v31] (v32) {};
    \node[mini vector, fill=black!90, right=1pt of v32] (v33) {};
    
    \node[vector, fill=colorB, below=3pt of v31] (v41) {};
    \node[vector, fill=colorA, right=1pt of v41] (v42) {};
    \node[mini vector, fill=white, right=1pt of v42] (v43) {};
    
    \node[vector, fill=colorB, below=3pt of v41] (v51) {};
    \node[vector, fill=colorB, right=1pt of v51] (v52) {};
    \node[mini vector, fill=white, right=1pt of v52] (v53) {};
    
    \node[vector, fill=colorB, below=3pt of v51] (v61) {};
    \node[vector, fill=colorC, right=1pt of v61] (v62) {};
    \node[mini vector, fill=white, right=1pt of v62] (v63) {};
    
    \node[vector, fill=colorC, below=3pt of v61] (v71) {};
    \node[vector, fill=colorA, right=1pt of v71] (v72) {};
    \node[mini vector, fill=white, right=1pt of v72] (v73) {};
    
    \node[vector, fill=colorC, below=3pt of v71] (v81) {};
    \node[vector, fill=colorB, right=1pt of v81] (v82) {};
    \node[mini vector, fill=white, right=1pt of v82] (v83) {};
    
    \node[vector, fill=colorC, below=3pt of v81] (v91) {};
    \node[vector, fill=colorC, right=1pt of v91] (v92) {};
    \node[mini vector, fill=white, right=1pt of v92] (v93) {};
\end{tikzpicture}};
\draw[thick] ([yshift=2.6pt]embedding.east) -- ++(0:1em);

\node[left=2em of embedding.west] (graph_mask) {\begin{tikzpicture}
\node[] (graph) at (0, 0) {\begin{tikzpicture}
    \node[var, fill=colorA] (A) {$A$};
    \node[var, fill=colorB, above right=1em and 0.5em of A] (B) {$B$};
    \node[var, fill=colorC, below right=1em and 0.5em of B] (C) {$C$};
    
    \draw[edge] (A) -- (B);
    \draw[edge] (A) -- (C);
\end{tikzpicture}};
\node[below=0.5em of graph] (graph_label) {$G$};

\node[below=2em of graph_label] (mask) {\begin{tikzpicture}
    \node[adjacency, fill=black!85] (AA) {};
    \node[adjacency, fill=black!85, right=1pt of AA] (AB) {};
    \node[adjacency, fill=black!85, right=1pt of AB] (AC) {};
    
    \node[adjacency, fill=black!85, below=1pt of AA] (BA) {};
    \node[adjacency, fill=black!85, right=1pt of BA] (BB) {};
    \node[adjacency, fill=gray!20, right=1pt of BB] (BC) {};
    
    \node[adjacency, fill=black!85, below=1pt of BA] (CA) {};
    \node[adjacency, fill=gray!20, right=1pt of CA] (CB) {};
    \node[adjacency, fill=black!85, right=1pt of CB] (CC) {};
\end{tikzpicture}};
\node[below=0.5em of mask] (mask_label) {$\vm$};

\end{tikzpicture}};

\node[font=\scriptsize, anchor={north east}, xshift=-3pt, yshift=-3pt] (timesL) at (transformer.{north west}) {$\times L$};

\node[below=0.5em of embedding, font=\small, align=center] (embedding_label) {Embeddings};
\node[font=\small] (transformer_label) at (embedding_label -| transformer.east) {Linear Transformer};

\node[block, fill=gray!10, rotate=270, above right=3.6em and 2em of {transformer.north}, anchor=south, font=\small, align=center] (logits_head) {Linear\\[-0.2em]Transformer};
\node[block, fill=RoyalBlue!30, rotate=270, right=1em of logits_head.north, anchor=south] (logits_head_1) {MLP};
\node[block, fill=Plum!30, rotate=270, right=1em of logits_head_1.north, anchor=south] (logits_head_2) {Mask};
\coordinate (mid_logits) at ($(transformer.north)!0.3!(logits_head)$);
\draw[edge, rounded corners=2pt] ([yshift=2.6pt]transformer.north) -| (mid_logits) |- (logits_head.south);
\draw[edge] (logits_head) -- (logits_head_1);
\draw[thick] (logits_head_1) -- (logits_head_2);

\node[block, fill=gray!10, rotate=270, below right=3em and 2em of {transformer.north}, anchor=south, font=\small, align=center] (stop_head) {Linear\\[-0.2em]Transformer};
\node[block, fill=Rhodamine!30, rotate=270, right=1em of stop_head.north , anchor=south] (stop_head_1) {Pool};
\node[block, fill=RoyalBlue!30, rotate=270, right=1em of stop_head_1.north, anchor=south] (stop_head_2) {MLP};
\coordinate (mid_stop) at ($(transformer.north)!0.3!(stop_head)$);
\draw[edge, rounded corners=2pt] ([yshift=2.6pt]transformer.north) -| (mid_stop) |- (stop_head.south);
\draw[thick] (stop_head) -- (stop_head_1);
\draw[edge] (stop_head_1) -- (stop_head_2);

\node[block, fill=Peach!30, rotate=270, right=1em of logits_head_2.north, anchor=south] (softmax) {Softmax};
\draw[edge] (logits_head_2) -- (softmax);

\node[right=1.5em of softmax.north, inner sep=5pt, rounded corners=2pt, fill=gray!10, draw=black, thick] (adjacency_out) {\begin{tikzpicture}
    \node[adjacency, fill=gray!30] (AA) {};
    \node[adjacency, fill=gray!30, right=1pt of AA] (AB) {};
    \node[adjacency, fill=gray!30, right=1pt of AB] (AC) {};
    
    \node[adjacency, fill=gray!30, below=1pt of AA] (BA) {};
    \node[adjacency, fill=gray!30, right=1pt of BA] (BB) {};
    \node[adjacency, fill=white!70!RoyalBlue, right=1pt of BB] (BC) {};
    
    \node[adjacency, fill=gray!30, below=1pt of BA] (CA) {};
    \node[adjacency, fill=white!00!RoyalBlue, right=1pt of CA] (CB) {};
    \node[adjacency, fill=gray!30, right=1pt of CB] (CC) {};
\end{tikzpicture}};
\draw[edge] (softmax.north) -- (adjacency_out);

\node[block, fill=Peach!30, rotate=270, right=1em of stop_head_2.north, anchor=south] (sigmoid) {\vphantom{Softmax}};
\node[font=\large] (sigmoid_label) at (sigmoid.center) {$\sigma$};
\draw[edge] (stop_head_2) -- (sigmoid);

\node[right=1.5em of sigmoid.north, inner sep=5pt, rounded corners=2pt, fill=gray!10, draw=black, thick] (stop_out) {\begin{tikzpicture}
    \node[adjacency] (AA) {};
    \node[adjacency, right=1pt of AA] (AB) {};
    \node[adjacency, right=1pt of AB] (AC) {};
    
    \node[adjacency, below=1pt of AA] (BA) {};
    \node[adjacency, fill=white!5!ForestGreen, right=1pt of BA] (BB) {};
    \node[adjacency, right=1pt of BB] (BC) {};
    
    \node[adjacency, below=1pt of BA] (CA) {};
    \node[adjacency, right=1pt of CA] (CB) {};
    \node[adjacency, right=1pt of CB] (CC) {};
\end{tikzpicture}};
\draw[edge] (sigmoid.north) -- (stop_out);

\node[below=0em of adjacency_out, anchor=north, font=\small, inner sep=4pt] (logits_label) {$P_{\theta}(G'\mid G, \neg s_{f})$};
\node[below=0em of stop_out, anchor=north, font=\small, inner sep=4pt] (stop_label) {$P_{\theta}(s_{f}\mid G)$};

\end{tikzpicture}
\end{document}