import torch
import torch.nn as nn
import torch.nn.functional as F

debug=False

class ConvLayer(nn.Module):
    def __init__(self, c_in):
        super(ConvLayer, self).__init__()
        self.downConv = nn.Conv1d(in_channels=c_in,
                                  out_channels=c_in,
                                  kernel_size=3,
                                  padding=2,
                                  padding_mode='circular')
        self.norm = nn.BatchNorm1d(c_in)
        self.activation = nn.ELU()
        self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)

    def forward(self, x):

        x = self.downConv(x.permute(0, 2, 1))
        x = self.norm(x)
        x = self.activation(x)
        x = self.maxPool(x)
        x = x.transpose(1,2)

        return x

class EncoderLayer(nn.Module):
    def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
        super(EncoderLayer, self).__init__()
        d_ff = d_ff or 4*d_model
        self.attention = attention
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu

    def forward(self, x, attn_mask=None):
        new_x, attn = self.attention(
            x, x, x,
            attn_mask = attn_mask
        )
        x = x + self.dropout(new_x)

        y = x = self.norm1(x)
        y = self.dropout(self.activation(self.conv1(y.transpose(-1,1))))
        y = self.dropout(self.conv2(y).transpose(-1,1))

        return self.norm2(x+y), attn

class Encoder(nn.Module):
    def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
        super(Encoder, self).__init__()
        self.attn_layers = nn.ModuleList(attn_layers)
        self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
        self.norm = norm_layer

    def forward(self, x, attn_mask=None):
        # x [B, L, D]
        attns = []
        x_list = []
        if self.conv_layers is not None:
            for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers):
                x, attn = attn_layer(x, attn_mask=attn_mask)
                x = conv_layer(x)
                x_list.append(x)
                attns.append(attn)
            x, attn = self.attn_layers[-1](x)
            x_list.append(x)
            attns.append(attn)
        else:
            for attn_layer in self.attn_layers:
                x, attn = attn_layer(x, attn_mask=attn_mask)
                x_list.append(x)
                attns.append(attn)

        if self.norm is not None: 
            x = self.norm(x)
            for i in range(len(x_list)): # add this norm to every output of x_list
                x_list[i] = self.norm(x_list[i])
        return x, attns, x_list



class YformerEncoder(nn.Module):
    def __init__(self, attn_layers=None, conv_layers=None, norm_layer=None):
        super(YformerEncoder, self).__init__()
        self.attn_layers = nn.ModuleList(attn_layers) if attn_layers is not None else None
        self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
        self.norm = norm_layer

    def forward(self, x, attn_mask=None):
        # x [B, L, D]
        attns = []
        x_list = []
        x_list.append(x)
        if self.conv_layers is not None:
            # print("Conv layers not none")
            if self.attn_layers is not None:
                for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers):
                    x, attn = attn_layer(x, attn_mask=attn_mask)
                    x = conv_layer(x)
                    x_list.append(x)
                    attns.append(attn)
                # x, attn = self.attn_layers[-1](x)
                # x_list.append(x)
                # attns.append(attn)
            else:
                # pipeline for only convolution layers
                for conv_layer in self.conv_layers:
                    x = conv_layer(x)
                    x_list.append(x)
                    attns.append(None)
        else:
            for attn_layer in self.attn_layers:
                x, attn = attn_layer(x, attn_mask=attn_mask)
                x_list.append(x)
                attns.append(attn)

        if self.norm is not None: 
            x = self.norm(x)
            for i in range(len(x_list)): # add this norm to every output of x_list
                x_list[i] = self.norm(x_list[i])
        return x, attns, x_list




class EncoderStack(nn.Module):
    def __init__(self, encoders):
        super(EncoderStack, self).__init__()
        self.encoders = nn.ModuleList(encoders)

    def forward(self, x, attn_mask=None):
        # x [B, L, D]
        inp_len = x.shape[1]
        x_stack = []
        attns = []
        for encoder in self.encoders:
            if encoder is None:
                inp_len = inp_len//2
                continue
            x, attn = encoder(x[:, -inp_len:, :])
            x_stack.append(x); attns.append(attn)
            inp_len = inp_len//2
        x_stack = torch.cat(x_stack, -2)
        
        return x_stack, attns