import torch
import torch.nn as nn
import torch.nn.functional as F

class LinearConverter(nn.Module):
    def __init__(self, in_features=256, out_features=256, start=0, end=5, total=5):
        super().__init__()
        convs = []
        for i in range(total):
            if i >= start and i < end:
                if isinstance(in_features, (tuple, list)):
                    in_feature = in_features[i - start]
                else:
                    in_feature = in_features
                if isinstance(out_features, (tuple, list)):
                    out_feature = out_features[i - start]
                else:
                    out_feature = out_features
                conv = nn.Conv2d(in_feature, out_feature, 1)
                convs.append(conv)
            else:
                convs.append(nn.Identity())
        self.convs = nn.ModuleList(convs)
        self.start = start
        self.end = end
        self.total = total

    def forward(self, x):
        y = []
        for i in range(self.total):
            y.append(self.convs[i](x[i]))
        return y


def add_dim(f):
    f_new = []
    for x in f:
        f_new.append(x.unsqueeze(dim=0))
    return f_new


def print_shape(f):
    for x in f:
        print(x.shape)
