import torch
import torch.nn as nn
import torch.nn.functional as F


class LeftPad1d(nn.Module):
    def __init__(self, left_pad):
        super(LeftPad1d, self).__init__()
        self.left_pad = left_pad

    def forward(self, x):
        return F.pad(x, (self.left_pad, 0))


class Conv_Lob(nn.Module):
    def __init__(self, conv_type='exp', in_c=41, out_c=14, kernel=2, dilation=2, num_layers=5, groups=1):
        super(Conv_Lob, self).__init__()
        conv_layers = []
        
        if conv_type == 'exp':
            conv_layers.extend([
                LeftPad1d(1*(kernel-1)),
                nn.Conv1d(in_channels=in_c, out_channels=out_c, kernel_size=kernel, dilation=1, groups=groups),
                nn.ReLU()
            ])

            for i in range(1, num_layers):
                conv_layers.extend([
                    LeftPad1d((dilation**i)*(kernel-1)),
                    nn.Conv1d(in_channels=out_c, out_channels=out_c, kernel_size=kernel, dilation=dilation**i, groups=groups),
                    nn.ReLU(),
                ])

        elif conv_type == 'constant':
            conv_layers.extend([
                LeftPad1d(dilation*(kernel-1)),
                nn.Conv1d(in_channels=in_c, out_channels=out_c, kernel_size=kernel, dilation=dilation, groups=groups),
                nn.ReLU()
            ])

            for i in range(1, num_layers):
                conv_layers.extend([
                    LeftPad1d(dilation*(kernel-1)),
                    nn.Conv1d(in_channels=out_c, out_channels=out_c, kernel_size=kernel, dilation=dilation, groups=groups),
                    nn.ReLU(),
                ])

        elif conv_type == 'linear':
            conv_layers.extend([
                LeftPad1d(1*(kernel-1)),
                nn.Conv1d(in_channels=in_c, out_channels=out_c, kernel_size=kernel, dilation=1),
                nn.ReLU()
            ])

            for i in range(2, dilation+1):
                conv_layers.extend([
                    LeftPad1d(i*(kernel-1)),
                    nn.Conv1d(in_channels=out_c, out_channels=out_c, kernel_size=kernel, dilation=i),
                    nn.ReLU(),
                ])

        else:
            print('Missing valid conv type')
            exit()

        self.conv_layers = nn.Sequential(*conv_layers)

        print(self.conv_layers)

    def forward(self, x):
        # B, L, D

        x = torch.permute(x, (0, 2, 1))     # B, L, D --> B, D, L
        x = self.conv_layers(x)
        x = torch.permute(x, (0, 2, 1))     # B, D, L --> B, L, D

        return x

