import torch.nn as nn
from einops import rearrange
from models.custom_layers import act_registry

class MultiBatchConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, n_dim):
        super().__init__()
        self.n_dim = n_dim
        if self.n_dim == 1:
            self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding)
        elif self.n_dim == 2:
            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        elif self.n_dim == 3:
            self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding)
        else:
            raise ValueError('n_dim must be 1, 2, or 3')

    def forward(self, x):
        ''' 
        input: (B, [T], Sx, [Sy], [Sz], out_channels)
        output: (B, [T], Sx', [Sy'], [Sz'], in_channels)'''
        includes_time = len(x.shape) == self.n_dim + 3
        if includes_time:
            B, T = x.shape[:2]
            z = x.view(B*T, *x.shape[2:])
        else:
            z = x
        # (B', Sx, [Sy], [Sz], in_channels) -> (B', in_channels, Sx, [Sy], [Sz])
        z = rearrange(z, 'B ... C -> B C ...')
        z = self.conv(z)
        # (B', out_channels, Sx', [Sy'], [Sz'], ) -> (B', Sx', [Sy'], [Sz'], out_channels)
        z = rearrange(z, 'B C ... -> B ... C')
        if includes_time:
            z = z.view(B, T, *z.shape[1:])
        return z

class MultiBatchConvTranspose(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, n_dim, transposed=True):
        super().__init__()
        self.n_dim = n_dim
        self.transposed = transposed
        if self.n_dim == 1: 
            self.deconv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride, padding)
        elif self.n_dim == 2:
            self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding)
        elif self.n_dim == 3:
            self.deconv = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride, padding)
        else:
            raise ValueError('n_dim must be 1, 2, or 3')
        
    def forward(self, x):
        ''' 
        input: (B, [T], Sx', [Sy'], [Sz'], in_channels) if self.transposed else (B, [T], in_channels, Sx', [Sy'], [Sz'])
        output: (B, [T], Sx, [Sy], [Sz], out_channels)'''
        includes_time = len(x.shape) == self.n_dim + 3
        if includes_time:
            B, T = x.shape[:2]
            z = x.view(B*T, *x.shape[2:])
        else: 
            z = x
        # (B', Sx, [Sy], [Sz], in_channels) -> (B', in_channels, Sx, [Sy], [Sz])
        if self.transposed:
            z = rearrange(z, 'B ... C -> B C ...')
        z = self.deconv(z)
        # (B', out_channels, Sx', [Sy'], [Sz'], ) -> (B', Sx', [Sy'], [Sz'], out_channels)
        if self.transposed:
            z = rearrange(z, 'B C ... -> B ... C')
        if includes_time:
            z = z.view(B, T, *z.shape[1:])
        return z

def fast_input_layer(kernel_size, stride, in_channels, out_channels, n_dim):
    ''' Convolution
    x: (Sx, [Sy], [Sz], in_channels) -> (Sx', [Sy'], [Sz'], out_channels)'''
    print(f"Using fast input layer with kernel size {kernel_size} and stride {stride}")
    return MultiBatchConv(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size//2, stride=stride, n_dim=n_dim)

def fast_output_layer(kernel_size, stride, in_channels, out_channels, n_dim, transposed = True, final_mlp_hidden_expansion=None, final_mlp_act = None):
    ''' Deconvolution
    x: (Sx', [Sy'], [Sz'], out_channels) -> (Sx, [Sy], [Sz], in_channels)'''
    if final_mlp_hidden_expansion is None:
        return MultiBatchConvTranspose(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size//2, stride=stride, n_dim=n_dim, transposed=transposed)
    else: 
        mid_channels = final_mlp_hidden_expansion * in_channels
        return nn.Sequential(
            # linear
            nn.Linear(in_channels, mid_channels),
            act_registry[final_mlp_act],
            MultiBatchConvTranspose(mid_channels, out_channels, kernel_size=kernel_size, padding=kernel_size//2, stride=stride, n_dim=n_dim, transposed=transposed),
        )

