from .comlib import *

class VGG(nn.Module):
    def __init__(self,in_channels, layer_num: int, min_hidden_num: int, max_hidden_num: int, output_len):
        super(VGG, self).__init__()
        self.in_channels=in_channels
        cfg=self.create_cfg(layer_num,min_hidden_num,max_hidden_num)
        self.features = self._make_layers(cfg)
        last_layer_size=cfg[-2]
        self.classifier = nn.Linear(last_layer_size, output_len)

    @staticmethod
    def create_cfg(layer_num: int, min_hidden_num: int, max_hidden_num: int):
        # [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
        cfg=[]
        hidden_num=min_hidden_num
        for i in range(5):
            for j in range(layer_num):
                cfg.append(hidden_num)
            cfg.append('M')
            if hidden_num<=max_hidden_num/2:
                hidden_num=hidden_num*2
        return cfg

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def _make_layers(self, cfg):
        layers = []
        in_channels = self.in_channels
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.ReLU(inplace=True)]
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)

class VGGEmbed(VGG):
    def __init__(self,in_channels, layer_num: int, min_hidden_num: int, max_hidden_num: int, num_classes: int = 10):
        '''
        input_len: mnist 28 * 28
        '''
        super().__init__(in_channels+1, layer_num, min_hidden_num, max_hidden_num, 1)
        self.num_classes=num_classes
        self.label_embedding_nn = nn.Sequential(nn.Embedding(self.num_classes,64)
                                 ,nn.Linear(64,32*32)
                                 ,nn.ReLU(True)
                                 )
    
    def get_label_embedding(self,label):
        
        label_embedding=self.label_embedding_nn(label) 
        label_embedding = label_embedding.view(-1,1,32,32)
        return label_embedding

    def forward(self, input,target):
        # print(input.shape)
        x = torch.cat((input,self.get_label_embedding(target))
                           ,dim=1)
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out


def create_model(name:Literal["default","embed"],task,**kwargs):
    '''
    kwargs: layer_num: int, min_hidden_num: int, max_hidden_num: int
    '''
    in_channels=task.input_shape[0]
    if name=="default":
        return VGG(in_channels=in_channels,output_len=task.class_num,**kwargs)
    if name=="embed":
        return VGGEmbed(in_channels=in_channels,num_classes=task.class_num,**kwargs)
    
    print("invalid name")



# class SimpleCNN(nn.Module):
#     """
#     输入: (batch, 1, h, w)   # 单通道 2-D 数据
#     输出: (batch, num_classes)
#     """
#     @staticmethod
#     def getBlock(in_feat, out_feat):
#             layers = [nn.Conv2d(in_feat,  out_feat, kernel_size=3, padding=1)]
#             layers.append(nn.LeakyReLU(0.2))
#             return layers
    
#     def __init__(self, layer_num: int, min_hidden_num: int, max_hidden_num: int, num_classes: int = 10):
#         super().__init__()
#         # 保持尺寸不变 + 通道扩增
#         for 
#         self.classifier = nn.Linear(64, num_classes)

#     def forward(self, x):
#         x = self.features(x)
#         x = x.view(x.size(0), -1)
#         return self.classifier(x)