import torch
import torch.nn as nn

from models.utils.sr import SR


class SRCMA(nn.Module):
    def __init__(
            self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None,
            r=20, d=5
    ):
        super(SRCMA, self).__init__()
        self.fc = nn.Linear(in_features, out_features, bias, device, dtype)
        self.sr = SR(r=r, d=d)

    def forward(self, x):
        return self.sr(self.fc(x))
