"""
Diagonal Transformer Neural Process (TNPD) with MixtureGaussian head.
"""

from src.models.benchmarks.tnpd_flex_head import TNPDFlexHead
from src.models.modules import MixtureGaussian, MultiChannelMixtureGaussian


class TNPDMG(TNPDFlexHead):
    def __init__(
        self,
        dim_x: int,
        dim_y: int,
        d_model: int,
        emb_depth: int,
        dim_feedforward: int,
        nhead: int,
        dropout: float,
        num_layers: int,
        head_num_components: int,
        correlate_outputs: bool=False,
        pos_emb_init: bool = False,
    ):
        if correlate_outputs:
            head = MultiChannelMixtureGaussian(
                dim_y=dim_y,
                dim_model=d_model,
                dim_feedforward=dim_feedforward,
                num_components=head_num_components
            )
        else:
            head = MixtureGaussian(
                dim_y=dim_y,
                dim_model=d_model,
                dim_feedforward=dim_feedforward,
                num_components=head_num_components
            )

        super().__init__(
            dim_x=dim_x,
            dim_y=dim_y,
            d_model=d_model,
            emb_depth=emb_depth,
            dim_feedforward=dim_feedforward,
            nhead=nhead,
            dropout=dropout,
            num_layers=num_layers,
            head=head,
            pos_emb_init=pos_emb_init,
        )
