import torch
import torch.nn.functional as F
from torch import nn
import torch.nn.init as init

from models_for_multibranch.convnet import convnet
from models_for_multibranch.resnet import resnet
from models_for_multibranch.mlp import mlp

class MultiBranchLayer(nn.Module):
    def __init__(self, num_branch, num_class, dim_in, bias=False) -> None:
        super(MultiBranchLayer, self).__init__()
        self.num_branch = num_branch
        self.num_class = num_class
        self.weight = torch.nn.Parameter(torch.randn((num_branch, num_class, dim_in), requires_grad=True))
        if bias:
            self.bias = torch.nn.Parameter(torch.randn((num_branch, 1, num_class), requires_grad=True))
        else:
            self.bias = None
    
    def forward(self, x):
        # permute from (k, c, n) to (k, n, c)
        x = (self.weight @ x.T).permute(0, 2, 1)
        if self.bias is not None:
            x = x + self.bias
        return x

# 带特征自适应层的多分支结构网络
class MultiBranchWithFeatureAdaptationLayer(nn.Module):
    def __init__(self, num_branch, num_class, dim_in, low_dim=None, bias=True) -> None:
        super(MultiBranchWithFeatureAdaptationLayer, self).__init__()
        self.num_branch = num_branch
        self.num_class = num_class
        self.dim_in = dim_in
        self.low_dim = low_dim
        self.falayer = FeatureAdaptationLayer(num_branch, dim_in, low_dim)
        self.weight = torch.nn.Parameter(torch.randn((num_branch, dim_in, num_class), requires_grad=True))
        if bias:
            self.bias = torch.nn.Parameter(torch.randn((num_branch, 1, num_class), requires_grad=True))
        else:
            self.bias = None
    
    def forward(self, x):
        # x:(n, t1) -> (k, n, t2)
        x = self.falayer(x)
        x = nn.ReLU(inplace=True)(x)
        # (k, n, t2) to (k, n, c)
        x = torch.bmm(x, self.weight)
        if self.bias is not None:
            x = x + self.bias
        return x
    

class FeatureAdaptationLayer(nn.Module):
    def __init__(self, num_branch, dim_in, low_dim=None, bias=True) -> None:
        super(FeatureAdaptationLayer, self).__init__()
        self.num_branch = num_branch
        self.dim_in = dim_in
        if low_dim == None:
            low_dim = dim_in

        self.weight =  torch.nn.Parameter(torch.randn((num_branch, low_dim, dim_in), requires_grad=True))
        if bias:
            self.bias = torch.nn.Parameter(torch.randn((num_branch, 1, low_dim), requires_grad=True))
        else:
            self.bias = None
        
    def forward(self, x):
        # x:(n, t1)
        x = self.weight @ x.T
        # permute from (k, t2, n) to (k, n, t2)
        x = x.permute(0, 2, 1)
        if self.bias is not None:
            x = x + self.bias
        return x
