from torch import Tensor, nn
import torch
from torch.nn import functional as F


class LinearLayer(nn.Module):
    def __init__(self, dim_in, dim_out, k, model):
        super(LinearLayer, self).__init__()
        if 'ViT' in model:
            self.fc = nn.ModuleList([nn.Linear(dim_in, dim_out) for i in range(k)])
        else:
            self.fc = nn.ModuleList([nn.Linear(dim_in * 2 ** (i + 2), dim_out) for i in range(k)])

    def forward(self, tokens):
        for i in range(len(tokens)):
            if len(tokens[i].shape) == 3:
                # 打印 t 的形状
                t = tokens[i][:, 1:, :]
                #print(f"Shape of t (tokens[{i}][:, 1:, :]): {t.shape}")
                # # 打印线性层的形状
                # print(f"Shape of self.fc[{i}].weight: {self.fc[i].weight.shape}")
                tokens[i] = self.fc[i](t)
            else:
                B, C, H, W = tokens[i].shape
                tokens[i] = self.fc[i](tokens[i].view(B, C, -1).permute(0, 2, 1).contiguous())
        return tokens
class LinearDiscriminator(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LinearDiscriminator, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.fc(x)
