import random
import torch
import copy
import re

from torch import nn
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F

class CNNAdapter(torch.nn.Module):
    def __init__(
        self,
        enc_out_dim: int = 512,
        llm_embed_dim: int = 4096,
        kernel_size: int = 5,
    ):
        super().__init__()
        
        self.left_padding1 = nn.ConstantPad1d((kernel_size - 1, 0), 0.0)
        self.left_padding2 = nn.ConstantPad1d((kernel_size - 1, 0), 0.0)
        
        self.conv1d1 = nn.Conv1d(enc_out_dim, 2 * enc_out_dim, kernel_size, 1, 0)
        self.conv1d2 = nn.Conv1d(2 * enc_out_dim, 4 * enc_out_dim, kernel_size, 1, 0)
        
        self.bn1 = nn.BatchNorm1d(2 * enc_out_dim, eps=1e-3, momentum=0.99)
        self.bn2 = nn.BatchNorm1d(4 * enc_out_dim, eps=1e-3, momentum=0.99)
        
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        
        self.project = nn.Linear(4 * enc_out_dim, llm_embed_dim)
    
    def forward(self, x, mask_pad):
        """
            x: B, T, enc_out_dim
            mask: (B, T) or (B, 1, T)
        """
        x = x.transpose(1, 2)  # B, channels, T

        # mask batch padding
        if mask_pad.size(2) > 0:  # time > 0
            x.masked_fill_(~mask_pad, 0.0)

        x = self.left_padding1(x)
        x = self.conv1d1(x)
        x = self.bn1(x)
        x = self.relu1(x)

        x = self.left_padding2(x)
        x = self.conv1d2(x)
        x = self.bn2(x)
        x = self.relu2(x)

        x = x.transpose(1, 2)
        x = self.project(x)

        return x, mask_pad

class LinearAdapter(torch.nn.Module):
    def __init__(
        self,
        enc_out_dim: int = 512,
        llm_embed_dim: int = 4096,
    ):
        super().__init__()

        self.adpter = torch.nn.Linear(enc_out_dim, llm_embed_dim)

    def forward(self, x, mask_pad):
        return self.adpter(x), mask_pad

class CNNSubsampling(torch.nn.Module):
    def __init__(
        self,
        enc_out_dim: int = 512,
        llm_embed_dim: int = 4096,
        kernel_size: int = 5,
        activation_func: str = 'relu',
        norm: str = 'batch',
    ):
        super().__init__()
        
        self.kernel_size = kernel_size
        if enc_out_dim * 4 < llm_embed_dim:
            self.left_padding1 = nn.ConstantPad1d((kernel_size - 1, 0), 0.0)
            self.conv1d1 = nn.Conv1d(enc_out_dim, 2 * enc_out_dim, kernel_size, 1, 0)
            self.bn1 = nn.BatchNorm1d(2 * enc_out_dim, eps=1e-3, momentum=0.99)
            self.relu1 = nn.ReLU()

            self.left_padding2 = nn.ConstantPad1d((kernel_size - 1, 0), 0.0)
            self.conv1d2 = nn.Conv1d(2 * enc_out_dim, 4 * enc_out_dim, kernel_size, 2, 0)
            self.bn2 = nn.BatchNorm1d(4 * enc_out_dim, eps=1e-3, momentum=0.99)
            self.relu2 = nn.ReLU()
            
            self.project = nn.Linear(4 * enc_out_dim, llm_embed_dim)
            self.cnn_num = 2
        else:
            self.left_padding2 = nn.ConstantPad1d((kernel_size - 1, 0), 0.0)
            self.conv1d2 = nn.Conv1d(enc_out_dim, 2 * enc_out_dim, kernel_size, 2, 0)
            if norm == 'batch':
                self.bn2 = nn.BatchNorm1d(2 * enc_out_dim, eps=1e-3, momentum=0.99)
            elif norm == 'layer':
                self.bn2 = nn.LayerNorm(2 * enc_out_dim, eps=1e-3)
            if activation_func == 'gelu':
                self.relu2 = nn.GELU()
            else:
                self.relu2 = nn.ReLU()
            
            self.project = nn.Linear(2 * enc_out_dim, llm_embed_dim)
            self.cnn_num = 1
    
    def forward(self, x, mask_pad, cache=None, return_cache=False):
        """
            x: B, T, enc_out_dim
            mask: (B, T) or (B, 1, T)
        """
        x = x.transpose(1, 2)  # B, channels, T

        # mask batch padding
        if mask_pad.size(2) > 0:  # time > 0
            x.masked_fill_(~mask_pad, 0.0)

        if self.cnn_num == 2:
            if cache is None:
                x = self.left_padding1(x)
            else:
                x = torch.cat((cache[1], x), dim=2)
            if cache is not None:
                cache[1] = x[:, :, 1-self.kernel_size:]
            else:
                cache = [None, x[:, :, 1-self.kernel_size:]]
            x = self.conv1d1(x)
            x = self.bn1(x)
            x = self.relu1(x)

        if cache is None or cache[0] is None:
            x = self.left_padding2(x)
        else:
            x = torch.cat((cache[0], x), dim=2)
        if cache is not None:
            cache[0] = x[:, :, 1-self.kernel_size:]
        else:
            cache = [x[:, :, 1-self.kernel_size:]]
        x = self.conv1d2(x)
        if isinstance(self.bn2, nn.LayerNorm):
            x = x.transpose(1, 2)
        x = self.bn2(x)
        if isinstance(self.bn2, nn.LayerNorm):
            x = x.transpose(1, 2)
        x = self.relu2(x)

        x = x.transpose(1, 2)
        x = self.project(x)

        if return_cache:
            return x, mask_pad[:, :, 0::2], cache
        return x, mask_pad[:, :, 0::2]
