% \documentclass[runningheads]{standalone}
%%%%%% compile from scratch!!!!! %%%%%%
\documentclass[tikz]{standalone}
\input{tikzstyles}

\begin{document}
    \begin{tikzpicture}
        % input nodes
        \node[inputbox](input_I) {$I$};
        \node[inputbox, below=0.5cm of input_I](input_A) {$A$};
        \node[inputbox, below=0.5cm of input_A](input_S) {$S$};

        % encoding
        \node[
            encoder,
            right=1cm of input_I,
            shading=axis,            % enable gradient
            top color=continuous!20!white,      % top part
            bottom color=image,% bottom part
            shading angle=180,        % vertical gradient (top→bottom)
            opacity=0.9      % like fill opacity, but for shading
        ] (encoder_I) {VAE Encoder};
        \node[
            encoder,
            right=1cm of input_A,
            shading=axis,            % enable gradient
            top color=continuous!20!white,      % top part
            bottom color=continuous,% bottom part
            shading angle=180,        % vertical gradient (top→bottom)
            opacity=0.9      % like fill opacity, but for shading
        ] (encoder_A) {Normalize};
        \node[
            encoder,
            right=1cm of input_S,
            shading=axis,            % enable gradient
            top color=discrete!20!white,      % top part
            bottom color=discrete,% bottom part
            shading angle=180,        % vertical gradient (top→bottom)
            opacity=0.9      % like fill opacity, but for shading
        ] (encoder_S) {One-hot Encoding};

        % input encoding connects
        \draw[->, thick] (input_I.east) -- (encoder_I.west);
        \draw[->, thick] (input_A.east) -- (encoder_A.west);
        \draw[->, thick] (input_S.east) -- (encoder_S.west);

        % z_0
        \node[inputbox, right=1cm of encoder_I] (z_I) {$z^I_0$};
        \node[inputbox, right=1cm of encoder_A] (z_A) {$z^A_0$};
        \node[inputbox, right=1cm of encoder_S] (z_S) {$\mathbf{z}^S_0$};

        % encoding z_0 connects
        \draw[->, thick] (encoder_I.east) -- (z_I.west);
        \draw[->, thick] (encoder_A.east) -- (z_A.west);
        \draw[->, thick] (encoder_S.east) -- (z_S.west);

        % forward process
        \node[
            draw=none,
            minimum width=10.5cm,
            minimum height=2.75cm,
            fill=continuous,
            fill opacity=0.2,
            text opacity=1,
            anchor=west
        ] (forward_continuous) at ($(z_I.east)+(1.0cm,-0.875cm)$)  {
        \begin{tabular}{c}
            Forward Process (Continuous) \\
            $q(z_t | z_{t-1}) = \mathcal{N}\!\left(z_t \mid \sqrt{1 - \beta_t^{\text{DDPM}}}\, z_{t-1},\, \beta_t^{\text{DDPM}} I \right)$
        \end{tabular}
        };
        \node[
            draw=none,
            minimum width=10.5cm,
            minimum height=1.25cm,
            fill=discrete,
            fill opacity=0.2,
            text opacity=1,
            below=0cm of forward_continuous
        ] (forward_discrete)  {
        \begin{tabular}{c}
            Forward Process (Discrete) \\
            $q(\mathbf{z}_t \mid \mathbf{z}_{t-1}) = \mathrm{Cat}(\mathbf{z}_t; \mathbf{p} = \mathbf{z}_{t-1} \mathbf{Q}_t),$
        \end{tabular}
        };
        \node[
            draw=black,
            dashed,
            fit=(forward_continuous) (forward_discrete),
            inner sep=0pt,
            outer sep=0pt,
        ] (combined_forward) {};


        % z ---> forward
        \draw[->, thick] (z_I.east) -- ++(1.0cm,0);
        \draw[->, thick] (z_A.east) -- ++(1.0cm,0);
        \draw[->, thick] (z_S.east) -- ++(1.0cm,0);

        % The y-coordinate of the original arrow
        \coordinate (z_I_vert) at ($(z_I.east) + (1cm,0)$);
        \coordinate (z_I_T_west) at ($(forward_continuous.east |- z_I_vert)+(1cm,0)$);
        % New arrow starting horizontally from the east side of object2
        \draw[->, thick] 
            ($(forward_continuous.east |- z_I_vert)$) -- ++(1.0cm,0);
        \node[inputbox, anchor=west](z_I_T) at (z_I_T_west) {$z_T^I$};
        
        \coordinate (z_A_vert) at ($(z_A.east) + (1cm,0)$);
        \draw[->, thick] 
            ($(forward_continuous.east |- z_A_vert)$) -- ++(1.0cm,0);
        \coordinate (z_S_vert) at ($(z_S.east) + (1cm,0)$);
        \draw[->, thick] 
            ($(forward_continuous.east |- z_S_vert)$) -- ++(1.0cm,0);
        \node[inputbox, below=0.5cm of z_I_T](z_A_T) {$z_T^A$};
        \node[inputbox, below=0.5cm of z_A_T](z_S_T) {$\mathbf{z}_T^S$};

        % denoising process
        \node[inputbox, below=1.5cm of z_S_T](z_I_T2) {$z_T^I$};
        \node[inputbox, below=0.5cm of z_I_T2](z_A_T2) {$z_T^A$};
        \node[inputbox, below=0.5cm of z_A_T2](z_S_T2) {$\mathbf{z}_T^S$};

        \node[
            draw=black,
            dashed,
            minimum width=7.5cm,
            minimum height=5.0cm,
            fill=green!30,
            fill opacity=0.4,
            text opacity=1,
            left=1.0cm of z_A_T2,
            yshift=0.3cm
        ] (f_theta){};
        

        % network
        \node[unet_encoder, fill=gray!30, left=2.0cm of z_I_T2](unet_encoder) {};
        \node[unet_decoder, fill=gray!30, left=1.0cm of unet_encoder](unet_decoder) {};

        \coordinate (mlp_A_east) at ($(unet_decoder.east |- z_A_T2.east)$);
        \node[mlp, fill=gray!30, anchor=east](mlp_A) at (mlp_A_east) {};
        \coordinate (mlp_S_east) at ($(unet_decoder.east |- z_S_T2.east)$);
        \node[mlp, fill=gray!30, anchor=east](mlp_S) at (mlp_S_east) {};
        
        \draw[rounded corners=5pt, ->, thick] (unet_encoder.west) -- (unet_decoder.east);
        \coordinate (mlp_A_turn) at ($(unet_encoder.west |- z_A_T2.east) + (-0.5cm,0)$);
        \draw[rounded corners=5pt, thick, ->]
            (unet_encoder.west) -- ++(-0.5cm,0) -- (mlp_A_turn) -- (mlp_A.east);
        \coordinate (mlp_S_turn) at ($(unet_encoder.west |- z_S_T2.east) + (-0.5cm,0)$);
        \draw[rounded corners=5pt, thick, ->]
            (unet_encoder.west) -- ++(-0.5cm,0) -- (mlp_S_turn) -- (mlp_S.east);

        \node[attention, left=0.25cm of unet_encoder.east, fill=gray!30, fill opacity=0.9] (CA1) {CA};
        \node[attention, left=1.25cm of unet_encoder.east, fill=gray!30, fill opacity=0.9] (CA2) {CA};
        \node[attention, right=1.25cm of unet_decoder.west, fill=gray!30, fill opacity=0.9] (CA3) {CA};
        \node[attention, right=0.25cm of unet_decoder.west, fill=gray!30, fill opacity=0.9] (CA4) {CA};
        
        % unet skip connection
        \draw[
          ->,
          thick,
          dotted,
          shorten <=-0.7cm,   % extend start by 0.75cm
          shorten >=-0.7cm    % extend end by 0.75cm (total +1.5cm)
        ]
          ($(unet_encoder.west)+(0,0.6cm)$)
          --
          ($(unet_decoder.east)+(0,0.6cm)$);
        \draw[
          ->,
          thick,
          dotted,
          shorten <=-2.2cm,   % extend start by 0.75cm
          shorten >=-2.2cm    % extend end by 0.75cm (total +1.5cm)
        ]
          ($(unet_encoder.west)+(0,1.0cm)$)
          --
          ($(unet_decoder.east)+(0,1.0cm)$);
        
        % ---> unet
        \draw[->, thick] (z_I_T2.west) -- (unet_encoder.east);
        \coordinate (concat) at ($ (z_I_T2.west) + (-1.5cm,0) $);
        \draw[rounded corners=5pt, thick]
            (z_A_T2.west) -- ++(-1.5cm,0) -- (concat) -- ++(-0.25cm,0);
        \draw[rounded corners=5pt, thick]
            (z_S_T2.west) -- ++(-1.5cm,0) -- (concat) -- ++(-0.25cm,0);

        % z_t-1
        \node[inputbox, left=1.0cm of unet_decoder](z_I_T-1) {$\tilde{z}_{T-1}^I$};
        \node[inputbox, left=1.0cm of mlp_A](z_A_T-1) {$\tilde{z}_{T-1}^A$};
        \node[inputbox, left=1.0cm of mlp_S](z_S_T-1) {$\tilde{\mathbf{z}}_{T-1}^S$};

        % network ---> z_t-1
        \draw[->, thick] (unet_decoder.west) -- (z_I_T-1.east);
        \draw[->, thick] (mlp_A.west) -- (z_A_T-1.east);
        \draw[->, thick] (mlp_S.west) -- (z_S_T-1.east);

        % network overall
        \coordinate (label) at ($(z_A_T2.west)!0.5!(z_S_T2.west) + (-3.1cm,0) $);
        \node[inputbox, font=\Large] (labelnode) at (label) {\begin{tabular}{l}
             Denoising\\
             network $f_\theta$ 
        \end{tabular}};

        % z_T-1 to f
        \node[
            draw=black,
            dashed,
            minimum width=1.0cm,
            minimum height=5.0cm,
            fill=green!30,
            fill opacity=0.4,
            text opacity=1,
            left=0.5cm of z_A_T-1,
            yshift=0.3cm,
            font=\LARGE
        ] (f_theta){$f_\theta$};
        \node[inputbox, below=-0.2cm of f_theta] {$\times(T-1)$};
        \draw[thick] (z_I_T-1.west) -- ++(-0.5cm,0);
        \draw[thick] (z_A_T-1.west) -- ++(-0.5cm,0);
        \draw[thick] (z_S_T-1.west) -- ++(-0.5cm,0);

        % f to z_0
        \node[inputbox, left=2.5cm of z_I_T-1](z_I_0) {$\tilde{z}_{0}^I$};
        \node[inputbox, left=2.5cm of z_A_T-1](z_A_0) {$\tilde{z}_{0}^A$};
        \node[inputbox, left=2.5cm of z_S_T-1](z_S_0) {$\tilde{\mathbf{z}}_{0}^S$};

        % f ---> z_0
        \draw[->, thick] ($(z_I_T-1.west) + (-1.5cm,0)$) -- (z_I_0.east);
        \draw[->, thick] ($(z_A_T-1.west) + (-1.5cm,0)$) -- (z_A_0.east);
        \draw[->, thick] ($(z_S_T-1.west) + (-1.5cm,0)$) -- (z_S_0.east);

        %z_0 to decoder
        \node[
            encoder,
            below=1.5cm of encoder_S,
            shading=axis,            % enable gradient
            top color=image,      % top part
            bottom color=continuous!20!white,% bottom part
            shading angle=180,        % vertical gradient (top→bottom)
            opacity=0.9      % like fill opacity, but for shading
        ] (decoder_I) {VAE Decoder};
        \coordinate (decoder_A_east) at ($(decoder_I.east |- z_A_0.west)$);
        \node[
            encoder,
            shading=axis,            % enable gradient
            top color=continuous,      % top part
            bottom color=continuous!20!white,% bottom part
            shading angle=180,        % vertical gradient (top→bottom)
            opacity=0.9,      % like fill opacity, but for shading
            anchor=east
        ] (decoder_A) at (decoder_A_east) {Unnormalize};
        \coordinate (decoder_S_east) at ($(decoder_I.east |- z_S_0.west)$);
        \node[
            encoder,
            shading=axis,            % enable gradient
            top color=discrete,      % top part
            bottom color=discrete!20!white,% bottom part
            shading angle=180,        % vertical gradient (top→bottom)
            opacity=0.9,      % like fill opacity, but for shading
            anchor=east
        ] (decoder_S) at (decoder_S_east) {Argmax};

        \draw[->, thick] (z_I_0.west) -- (decoder_I.east);
        \draw[->, thick] (z_A_0.west) -- (decoder_A.east);
        \draw[->, thick] (z_S_0.west) -- (decoder_S.east);

        % out
        \node[inputbox, left=1.0cm of decoder_I](out_I) {$\tilde{I}$};
        \node[inputbox, left=1.0cm of decoder_A](out_A) {$\tilde{A}$};
        \node[inputbox, left=1.0cm of decoder_S](out_S) {$\tilde{S}$};

        \draw[->, thick] (decoder_I.west) -- (out_I.east);
        \draw[->, thick] (decoder_A.west) -- (out_A.east);
        \draw[->, thick] (decoder_S.west) -- (out_S.east);
    \end{tikzpicture}
\end{document}
