
"""MSA model."""
import torch
import torch.nn as nn

from .mlp import Net

# from .positional_encoding import RFF, PosEncCat
from .transformer import TFSEncoder


class MSA(nn.Module):
    """Default Transformer for the experiments."""

    def __init__(
        self,
        dim_x,
        dim_cy,
        dim_ty,
        dim_h,
        nhead,
        nlayers,
        pos_enc_freq=100.0,
        use_rff=False,
        use_same_pos_enc=False,
        use_mlps=False,
        dropout=0.0,
    ):
        """Initialize the model.

        Args:
            dim_x (int): Dimension of the input.
            dim_cy (int): Dimension of the context.
            dim_ty (int): Dimension of the target.
            dim_h (int): Dimension of the hidden layers.
            nhead (int): Number of heads in the multi-head attention.
            nlayers_encoder (int): Number of layers in the encoder.
            nlayers_decoder (int): Number of layers in the decoder.
            dropout (float): Dropout probability.
        """
        super().__init__()
        self.use_same_pos_enc = use_same_pos_enc
        if use_same_pos_enc:
            self.pos_encoder = Net(dims=[dim_x, dim_h, dim_h, dim_h])
            self.val_encoder = Net(dims=[dim_cy, dim_h, dim_h, dim_h])
            self.decoder = Net(dims=[dim_h, dim_h, dim_ty])

            # if use_mlps:
            #     self.pos_encoder = Net([dim_x * dim_h * 2, dim_h, dim_h, dim_h])
            #     self.val_encoder = Net([dim_cy * dim_h * 2, dim_h, dim_h, dim_h])
            #     self.decoder = Net([dim_h, dim_h, dim_ty])

            # else:
            #     self.pos_encoder = nn.Linear(dim_x * dim_h * 2, dim_h)
            #     self.val_encoder = nn.Linear(dim_cy * dim_h * 2, dim_h)
            #     self.decoder = nn.Linear(dim_h, dim_ty)

            # if use_rff:
            #     self.pe_x = RFF(dim_x, dim_x * dim_h * 2)
            #     self.pe_v = RFF(dim_cy, dim_cy * dim_h * 2)
            # else:
            #     self.pe_x = PosEncCat(dim_h, pos_enc_freq)
            #     self.pe_v = PosEncCat(dim_h, pos_enc_freq)
        else:
            # if use_mlps:
            #     self.encoder = Net([(dim_x + dim_cy) * dim_h * 2, dim_h, dim_h, dim_h])
            #     self.decoder = Net([dim_h, dim_h, dim_ty])
            #     self.q_encoder = Net([dim_x * dim_h * 2, dim_h, dim_h, dim_h])
            # else:
            #     self.encoder = nn.Linear((dim_x + dim_cy) * dim_h * 2, dim_h)
            #     self.decoder = nn.Linear(dim_h, dim_ty)
            #     self.q_encoder = nn.Linear(dim_x * dim_h * 2, dim_h)

            # if use_rff:
            #     self.pe_x = RFF(dim_x + dim_cy, (dim_x + dim_cy) * dim_h * 2)
            #     self.pe_q = RFF(dim_x, dim_x * dim_h * 2)
            # else:
            #     self.pe_x = PosEncCat(dim_h, pos_enc_freq)
            #     self.pe_q = PosEncCat(dim_h, pos_enc_freq)
            self.encoder = Net(dims=[dim_x + dim_cy, dim_h, dim_h, dim_h])
            self.decoder = Net(dims=[dim_h, dim_h, dim_h, dim_ty])
            self.q_encoder = Net(dims=[dim_x, dim_h, dim_h, dim_h])

        encoder_layer = nn.TransformerEncoderLayer(
            dim_h, nhead, 2 * dim_h, dropout=dropout, batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, nlayers)

    def forward(self, cx, cy, tx):
        """Forward pass.

        Args:
            cx (torch.Tensor): Context input.
            cy (torch.Tensor): Context value.
            tx (torch.Tensor): Target input.

        Returns:
            torch.Tensor: Target value.
        """
        if self.use_same_pos_enc:
            ce_x = self.pos_encoder(cx)
            ce_y = self.val_encoder(cy)
            ce = ce_x + ce_y
            te = self.pos_encoder(tx)

        else:
            ce = self.encoder(torch.cat((cx, cy), dim=-1))
            te = self.q_encoder(tx)
        enc = torch.cat((ce, te), dim=1)
        z = self.transformer_encoder(enc)
        L = tx.shape[1]
        z = z[:, -L:, :]
        return self.decoder(z)


class MSAEncoderOnly(nn.Module):
    """MSA for baselines."""

    def __init__(
        self,
        dim_x,
        dim_cy,
        dim_ty,
        dim_h,
        nhead,
        nlayers,
        share_blocks=True,
    ):
        """Initialize the model.

        Args:
            dim_x (int): Dimension of the input.
            dim_cy (int): Dimension of the context.
            dim_ty (int): Dimension of the target.
            dim_h (int): Dimension of the hidden layers.
            nhead (int): Number of heads in the multi-head attention.
            nlayers (int): Number of layers.
            share_blocks (bool): Whether to share the blocks.
        """
        super().__init__()
        self.context_encoder = Net(dims=[dim_x + dim_cy, dim_h, dim_h, dim_h])
        self.target_encoder = Net(dims=[dim_x, dim_h, dim_h, dim_h])
        self.decoder = Net(dims=[dim_h, dim_h, dim_h, dim_ty])

        if share_blocks:
            self.blocks = nn.ModuleList([TFSEncoder(dim_h, dim_h, nhead)] * nlayers)
        else:
            self.blocks = nn.ModuleList(
                [TFSEncoder(dim_h, dim_h, nhead) for _ in range(nlayers)]
            )

    def forward(self, cx, cy, tx):
        """Forward pass.

        Args:
            cx (torch.Tensor): Context input.
            cy (torch.Tensor): Context output.
            tx (torch.Tensor): Target input.

        Returns:
            torch.Tensor: Target output.
        """
        latents_context = self.context_encoder(torch.cat((cx, cy), dim=-1))
        latents_target = self.target_encoder(tx)
        latents = torch.cat((latents_context, latents_target), dim=1)
        L = tx.shape[1]

        for block in self.blocks:
            latents = block(latents)

        return self.decoder(latents[:, -L:, :])
