import typing as ty
from typing import Annotated

import torch
import torch.nn.functional as F
from torch import Tensor, nn

from ablator import Enum, configclass, ModelConfig, Derived, Optional, Literal, List

from tablator.dataset import DatasetType
from tablator.masks import MaskType, full_mask, make_mask, random_mask
from tablator.modules import Transformer


class Activation(Enum):
    REGLU = "reglu"
    GEGLU = "geglu"
    GELU = "gelu"
    RELU = "relu"
    LRELU = "leaky_relu"
    SIGMOID = "sigmoid"


@configclass
class TablatorConfig(ModelConfig):
    # Configurable attributes
    token_bias: bool = False
    # transformer
    n_layers: int = 1
    d_token: int = 32
    n_heads: int = 2
    d_ffn_factor: float = 1.333
    attention_dropout: float = 0.1
    ffn_dropout: float = 0.1
    residual_dropout: float = 0
    prenormalization: bool = False
    initialization: Literal["xavier", "kaiming"] = "kaiming"
    #
    d_out: Derived[Optional[int]] = None
    data_type: Derived[Optional[DatasetType]] = None
    categories: Derived[Optional[List[int]]] = None
    d_numerical: Derived[Optional[int]] = None
    # linformer
    kv_compression: Optional[float] = None
    kv_compression_sharing: Optional[
        Literal["layerwise", "key-value", "headwise"]
    ] = None
    activation: Activation = Activation("relu")
    residual: bool = True
    mask_type: MaskType = MaskType("random")
    random_mask_alpha: float = 0.80


class Tablator(nn.Module):
    def __init__(
        self,
        config: TablatorConfig,
    ):
        super().__init__()
        self.model_type = "Transformer"

        if config.data_type == "multiclass":
            self.loss_fn = lambda x, pred: F.cross_entropy(x, pred)
        elif config.data_type == "binclass":
            self.loss_fn = lambda x, pred: F.binary_cross_entropy_with_logits(
                x, pred.float()
            )
        elif config.data_type == "regression":
            self.loss_fn = lambda x, pred: F.mse_loss(x, pred.float())

        self.model = Transformer(**config.to_dict())
        self.mask: torch.Tensor
        n_tokens = self.model.tokenizer.n_tokens
        self.register_buffer("mask", full_mask(n_tokens))
        if config.mask_type != MaskType.RANDOM:
            self.mask |= make_mask(config.mask_type, n_tokens)
        self.mask_type = config.mask_type
        self.random_mask_alpha = config.random_mask_alpha

    def forward(
        self,
        y: Tensor,
        x_num: ty.Optional[Tensor] = None,
        x_cat: ty.Optional[Tensor] = None,
    ) -> ty.Tuple[dict[str, torch.Tensor], torch.Tensor]:
        """
        Args:
            src: Tensor, shape [seq_len, batch_size]
            src_mask: Tensor, shape [seq_len, seq_len]

        Returns:
            output Tensor of shape [seq_len, batch_size, ntoken]
        """
        assert x_num is not None

        attn_mask = self.mask.clone()
        if self.mask_type == MaskType.MIX or self.mask_type == MaskType.RANDOM:
            attn_mask |= random_mask(self.mask.shape[0], self.random_mask_alpha).to(
                x_num.device
            )

        x = self.model(x_num, x_cat, attn_mask=attn_mask)
        x = x.squeeze(-1)

        return {"pred": x, "target": y}, self.loss_fn(x, y.flatten())
