import torch
from fla import (
    GatedDeltaProductForCausalLM,
    GatedLinearAttentionForCausalLM,
    TransformerForCausalLM,
)
from torch import Tensor


class GatedDeltaProduct(GatedDeltaProductForCausalLM):
    @property
    def num_parameters(self) -> int:
        """Return the total number of trainable parameters."""
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def forward(self, x: Tensor) -> Tensor:
        output = super().forward(input_ids=x)
        return output.logits

    @torch.no_grad
    def get_useful_stats(self):
        print("Empty useful stats for this model, to be implemented")
        return dict()


class GLA(GatedLinearAttentionForCausalLM):
    @property
    def num_parameters(self) -> int:
        """Return the total number of trainable parameters."""
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def forward(self, x: Tensor) -> Tensor:
        output = super().forward(input_ids=x)
        return output.logits

    @torch.no_grad
    def get_useful_stats(self):
        print("Empty useful stats for this model, to be implemented")
        return dict()


class Attn(TransformerForCausalLM):
    @property
    def num_parameters(self) -> int:
        """Return the total number of trainable parameters."""
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def forward(self, x: Tensor) -> Tensor:
        output = super().forward(input_ids=x)
        return output.logits
