import torch
import torch.nn as nn

class RepUnit(nn.Module):
    def __init__(self, dim_in: int, dim_out: int):
        super(RepUnit, self).__init__()
        self.linear = nn.Linear(dim_in, dim_out)
        self.bn = nn.BatchNorm1d(dim_out)
        self.deplay = False
        
    def forward(self, x):
        if self.deplay:
            return nn.functional.linear(x, self.weight, self.bias)
        else:
            x = self.linear(x)
            x = x.permute(0, 2, 1) # [B, C, L]
            x = self.bn(x)
            x = x.permute(0, 2, 1) # [B, L, C]
            return x
        
    def merge(self):
        if not self.deplay:
            weight = self.linear.weight
            bias = self.linear.bias
            running_mean = self.bn.running_mean
            running_var = self.bn.running_var
            gamma = self.bn.weight
            beta = self.bn.bias
            eps = self.bn.eps

            std = (running_var + eps).sqrt()
            t = (gamma / std)
            self.weight = weight * t.reshape(-1, 1)
            self.bias = bias * t + beta - running_mean * gamma / std

            self.__delattr__('linear')
            self.__delattr__('bn')
            self.deplay = True

class RepLinears(nn.Module):
    """

    Args:
        dim_in: the input channels.
        dim_out: the output channels.
        depth (int): the number of groups which consists of 
                    a linear layer and a batch normalization.
    """
    def __init__(self, dim_in: int, dim_out: int, depth: int):
        super(RepLinears, self).__init__()
        assert depth >= 1, "The value of depth should larger than 1."
        self.deplay = False
        linear_list = [RepUnit(dim_in, dim_out)]
        for i in range(depth-1):
            linear_list.append(RepUnit(dim_out, dim_out))
        self.linear_list = nn.ModuleList(linear_list)
        self.linear = nn.Linear(dim_out, dim_out)
        
    def forward(self, x):
        if self.deplay:
            return nn.functional.linear(x, self.weight, self.bias)
        else:
            for layer in self.linear_list:
                x = layer(x)
            return self.linear(x)
        
    def merge(self):
        if not self.deplay:
            # 1. Merge each RepUnit.
            for layer in self.linear_list:
                layer.merge()
            # 2. Combine all the RepUnits.
            weight = self.linear_list[0].weight
            bias = self.linear_list[0].bias
            for i in range(1, len(self.linear_list)):
                weight = torch.einsum('ij, jk->ik', self.linear_list[i].weight, weight)
                bias = torch.einsum('i, ji->j', bias, self.linear_list[i].weight) + self.linear_list[i].bias
            # 3. Combine the final linear layer.
            self.weight = torch.einsum('ij, jk->ik', self.linear.weight, weight)
            self.bias = torch.einsum('i, ji->j', bias, self.linear.weight) + self.linear.bias

            self.__delattr__('linear_list')
            self.__delattr__('linear')
            self.deplay = True

class TDRL(nn.Module):
    """_summary_

    Args:
        dim_in: the input channels.
        dim_out: the output channels.
        width (int): the number of re-branches in TDRL.
        depth (int): the number of groups which consists of 
                    a linear layer and a batch normalization.
        type (str): the type of TDRL, including regular and pyramid.
        rectify (str): the type of distribution rectification.
        deplay (bool): whether fuse re-parameterized architecture.
    """
    def __init__(self, dim_in: int, dim_out: int, width: int, 
                 depth: int, type: str = 'pyramid', rectify: str='scale', 
                 deplay: bool=False):
        super(TDRL, self).__init__()
        self.rectify = rectify
        self.deplay = deplay
        self.dim_in = dim_in
        self.dim_out = dim_out
        
        if deplay:
            self.fused_linear = nn.Linear(dim_in, dim_out)
        else:
            if type == 'pyramid':
                depth_list = [i+1 for i in range(width)]
            elif type == 'regular':
                depth_list = [depth for i in range(width)]
            else:
                raise ValueError("The type of TDRL {} is not recognized.".format(type))
            
            self.linear_list = nn.ModuleList([RepLinears(dim_in, dim_out, depth_list[i]) for i in range(width)])
            self.linear = nn.Linear(dim_in, dim_out)
            
            if rectify == "scale":
                self.scale_value = (width + 1) ** -0.5
            elif rectify == "none":
                self.scale_value = 1
            elif rectify == "norm":
                self.norm = nn.BatchNorm1d(dim_out)
            else:
                raise ValueError("The type of rectification {} is not recognized.".format(rectify))
    
    def forward(self, x):
        if self.deplay:
            return self.fused_linear(x)
        else:
            y = self.linear(x)
            for layer in self.linear_list:
                y += layer(x)
            if self.rectify == "norm":
                y = y.permute(0, 2, 1) # [B, C, L]
                y = self.norm(y)
                y = y.permute(0, 2, 1) # [B, L, C]
            else:
                y = y * self.scale_value
            return y
        
    def merge(self):
        if not self.deplay:
            for layer in self.linear_list:
                layer.merge()
            weight = self.linear.weight.data
            bias = self.linear.bias.data
            for layer in self.linear_list:
                weight += layer.weight.data
                bias += layer.bias.data
            if self.rectify == "scale":
                weight *= self.scale_value
                bias *= self.scale_value
            elif self.rectify == "norm":
                running_mean = self.norm.running_mean
                running_var = self.norm.running_var
                gamma = self.norm.weight
                beta = self.norm.bias
                eps = self.norm.eps

                std = (running_var + eps).sqrt()
                t = (gamma / std)
                weight = weight * t.reshape(-1, 1)
                bias = bias * t + beta - running_mean * gamma / std
                self.__delattr__('norm')
            
            self.fused_linear = nn.Linear(self.dim_in, self.dim_out)
            self.fused_linear.weight.data = weight
            self.fused_linear.bias.data = bias
            self.__delattr__('linear_list')
            self.__delattr__('linear')
            self.deplay = True
        
        
if __name__ == '__main__':
    print('Test the re-parameterized architecture.')

    model = TDRL(192, 384, 3, 3).train()
    input = torch.randn(4, 196, 192)
    
    # update the mean and std of batch normalizations
    _ = model(input)
    _ = model(input)
    _ = model(input)

    # test the model before merging and after merging
    model.eval()
    output1 = model(input).clone()
    model.merge()
    output2 = model(input)
    print('The difference: ', (output1-output2).abs().sum())