import torch
import torch.nn as nn

from models.utils.linear_norm_3 import LinearNorm
from models.utils.nothing import Nothing
from models.utils.sr import SR


class SRCMD(nn.Module):
    def __init__(
            self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None,
            pre_layer=True, r=20, d=5
    ):
        super(SRCMD, self).__init__()
        if pre_layer:
            self.fc_hr = nn.Linear(in_features=in_features, out_features=in_features)
        else:
            self.fc_hr = Nothing()
        self.sr = SR(r=r, d=d)
        self.fc = LinearNorm(in_features, out_features, bias, device, dtype)

    def forward(self, x):
        x = self.sr(self.fc_hr(x))
        return self.fc(x), x
