from .comlib import *

def flatten(tensors,keep_dim=0):
    vec = []
    for t in tensors:
        t:torch.tensor
        new_shape = t.shape[:keep_dim] + (-1,)
        vec.append(t.view(new_shape)) 
    return torch.cat(vec,-1)


class SimpleNN(nn.Module):
    @staticmethod
    def getBlock(in_feat, out_feat):
            layers = [nn.Linear(in_feat, out_feat)]
            layers.append(nn.LeakyReLU(0.2))
            return layers
    def __init__(self, layer_num, input_len, hidden_num, output_len,batch_dim_num=1):
        '''
        input_len: mnist 28 * 28
        '''
        super(SimpleNN, self).__init__()
        # self.input_len=input_len

        blocks=[]
        blocks.extend(self.getBlock(input_len, hidden_num)) # 输入层到隐藏层
        for i in range(layer_num):
            blocks.extend(self.getBlock(hidden_num, hidden_num))
        blocks.append(nn.Linear(hidden_num, output_len)) # 隐藏层到输出层
        self.model = nn.Sequential(*blocks)

        self.batch_dim_num=batch_dim_num
    

    def forward(self, *inputs):
        x = flatten(inputs,self.batch_dim_num) # 将输入的图像展平
        return self.model(x)
    
class SimpleNNEmbed(SimpleNN):
    def __init__(self, layer_num, input_len, hidden_num, output_len,n_classes=10,batch_dim_num=1):
        '''
        input_len: mnist 28 * 28
        '''
        super(SimpleNNEmbed, self).__init__(layer_num, input_len, hidden_num, output_len,batch_dim_num)
        self.label_embedding = nn.Embedding(n_classes, n_classes)
    

    def forward(self, input,target):
        x  = torch.cat((input.view(input.size(0), -1), self.label_embedding(target)), dim=1)# 将输入的图像展平
        return self.model(x)
    
class SimpleNNLayerNorm(SimpleNN):
    @staticmethod
    def getBlock(in_feat, out_feat):
            layers = [nn.Linear(in_feat, out_feat),
                       nn.LayerNorm(out_feat),
                       nn.LeakyReLU(0.2)]
            return layers
    
def create_model(name:Literal["default","embed","layerNorm"],task,**kwargs):
    '''
    kwargs:layer_num, hidden_num
    '''
    input_len=task.input_elenum
    if name=="default":
        return SimpleNN(input_len=input_len,output_len=task.class_num,batch_dim_num=1,**kwargs)
    if name=="embed":
        return SimpleNNEmbed(input_len=input_len+task.class_num,output_len=1,n_classes=task.class_num,batch_dim_num=1,**kwargs)
    if name=="layerNorm":
        return SimpleNNLayerNorm(input_len=task.input_elenum,output_len=task.class_num,batch_dim_num=1,**kwargs)
    
    print("invalid name")

class Initialize():
    def __init__(self,random_seed):
        self.random_seed=self.get_int_random_seed(random_seed)
        
    @staticmethod
    def get_int_random_seed(random_seed):
        unicode_list = [ord(c) for c in random_seed]
        return np.sum(unicode_list)
    def initialize_weights(self,m):
        torch.manual_seed(self.random_seed)
        if isinstance(m, (nn.Linear, nn.Conv2d)):
            # 使用 He 初始化
            nn.init.kaiming_normal_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)