import torch.nn as nn
from torch.nn.utils import weight_norm

class chop1d(nn.Module):
    def __init__(self, chop_size):
        super(chop1d, self).__init__()
        self.chop_size = chop_size

    def forward(self, x):
        return x[:, :, :-self.chop_size].contiguous()

class TemporalBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, dilation, padding, dropout=0.2):
        super(TemporalBlock, self).__init__()
        self.conv1 = weight_norm(nn.Conv1d(in_channels, out_channels, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.chop1 = chop1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = weight_norm(nn.Conv1d(out_channels, out_channels, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.chop2 = chop1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.net = nn.Sequential(self.conv1, self.chop1, self.relu1, self.dropout1,
                                 self.conv2, self.chop2, self.relu2, self.dropout2)
        self.downsample = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else None
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()
        self.init_weights()

    def init_weights(self):
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        if self.downsample is not None:
            self.downsample.weight.data.normal_(0, 0.01)

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return self.tanh(out + res)

class TemporalConvNet(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=2, dropout=0.2):
        super(TemporalConvNet, self).__init__()
        layers = []
        num_levels = len(out_channels)
        for i in range(num_levels):
            dilation_size = 2 ** i
#             dilation_size = 2
            block_in_channels = in_channels if i == 0 else out_channels[i-1]
            block_out_channels = out_channels[i]
            layers += [TemporalBlock(block_in_channels, block_out_channels, kernel_size, stride=1, 
                                     dilation=dilation_size, 
                                     padding=(kernel_size-1) * dilation_size, dropout=dropout)]

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)