"""
Taken from https://github.com/vincentherrmann/pytorch-wavenet
"""
import os
import os.path
import time
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable, Function
import numpy as np

from src.models.sequence.base import SequenceModule

def mu_law_expansion(data, mu):
    s = np.sign(data) * (np.exp(np.abs(data) * np.log(mu + 1)) - 1) / mu
    return s

# def dilate(x, dilation, init_dilation=1, pad_start=True):
def dilate(x, dilation, init_dilation=1):
    """
    :param x: Tensor of size (N, C, L), where N is the input dilation, C is the number of channels, and L is the input length
    :param dilation: Target dilation. Will be the size of the first dimension of the output tensor.
    :param pad_start: If the input length is not compatible with the specified dilation, zero padding is used. This parameter determines wether the zeros are added at the start or at the end.
    :return: The dilated tensor of size (dilation, C, L*N / dilation). The output might be zero padded at the start
    """

    [n, c, l] = x.size()
    dilation_factor = dilation / init_dilation
    if dilation_factor == 1:
        return x

    # zero padding for reshaping
    new_l = int(np.ceil(l / dilation_factor) * dilation_factor)
    if new_l != l:
        l = new_l
        # x = constant_pad_1d(x, new_l, dimension=2, pad_start=pad_start)
        x = constant_pad_1d(x, new_l)

    l_old = int(round(l / dilation_factor))
    n_old = int(round(n * dilation_factor))
    l = math.ceil(l * init_dilation / dilation)
    n = math.ceil(n * dilation / init_dilation)

    # reshape according to dilation
    x = x.permute(1, 2, 0).contiguous()  # (n, c, l) -> (c, l, n)
    x = x.view(c, l, n)
    x = x.permute(2, 0, 1).contiguous()  # (c, l, n) -> (n, c, l)

    return x


class DilatedQueue:
    def __init__(self, max_length, data=None, dilation=1, num_deq=1, num_channels=1, dtype=torch.FloatTensor):
        self.in_pos = 0
        self.out_pos = 0
        self.num_deq = num_deq
        self.num_channels = num_channels
        self.dilation = dilation
        self.max_length = max_length
        self.data = data
        self.dtype = dtype
        if data == None:
            self.data = Variable(dtype(num_channels, max_length).zero_())

    def enqueue(self, input):
        assert len(input.shape) == 3
        if len(self.data.shape) == 2:
            self.data = self.data.unsqueeze(0).repeat(input.shape[0], 1, 1)
        self.data[:, :, self.in_pos] = input.squeeze(2)
        self.in_pos = (self.in_pos + 1) % self.max_length

    def dequeue(self, num_deq=1, dilation=1):
        #       |
        #  |6|7|8|1|2|3|4|5|
        #         |
        start = self.out_pos - ((num_deq - 1) * dilation)
        if start < 0:
            t1 = self.data[:, :, start::dilation]
            t2 = self.data[:, :, self.out_pos % dilation:self.out_pos + 1:dilation]
            t = torch.cat((t1, t2), 2)
        else:
            t = self.data[:, :, start:self.out_pos + 1:dilation]

        self.out_pos = (self.out_pos + 1) % self.max_length
        return t

    def reset(self, device):
        self.data = Variable(self.dtype(self.num_channels, self.max_length).zero_()).to(device)
        self.in_pos = 0
        self.out_pos = 0

def constant_pad_1d(
    input,
    target_size,
):  
    cp1d = torch.nn.ConstantPad1d((target_size - input.size(-1), 0), 0)
    return cp1d(input)

class WaveNetModel(SequenceModule):
    """
    A Complete Wavenet Model
    Args:
        layers (Int):               Number of layers in each block
        blocks (Int):               Number of wavenet blocks of this model
        dilation_channels (Int):    Number of channels for the dilated convolution
        residual_channels (Int):    Number of channels for the residual connection
        skip_channels (Int):        Number of channels for the skip connections
        classes (Int):              Number of possible values each sample can have
        output_length (Int):        Number of samples that are generated for each input
        kernel_size (Int):          Size of the dilation kernel
        dtype:                      Parameter type of this model
    Shape:
        - Input: :math:`(N, C_{in}, L_{in})`
        - Output: :math:`()`
        L should be the length of the receptive field
    """

    @property
    def d_output(self):
        return self.classes


    def default_state(self, *batch_shape, device=None):
        return None
    
    def __init__(
        self,
        layers=10,
        blocks=4,
        dilation_channels=32,
        residual_channels=32,
        skip_channels=256,
        end_channels=256,
        classes=256,
        kernel_size=2,
        dtype=torch.FloatTensor,
        bias=False,
    ):

        super(WaveNetModel, self).__init__()

        self.layers = layers
        self.blocks = blocks
        self.dilation_channels = dilation_channels
        self.residual_channels = residual_channels
        self.skip_channels = skip_channels
        self.classes = classes
        self.kernel_size = kernel_size
        self.dtype = dtype

        self.d_model = 256

        # build model
        receptive_field = 1
        init_dilation = 1

        self.dilations = []
        self.dilated_queues = []
        self.filter_convs = nn.ModuleList()
        self.gate_convs = nn.ModuleList()
        self.residual_convs = nn.ModuleList()
        self.skip_convs = nn.ModuleList()

        # 1x1 convolution to create channels
        self.start_conv = nn.Conv1d(in_channels=self.classes,
                                    out_channels=residual_channels,
                                    kernel_size=1,
                                    bias=bias)

        for b in range(blocks):
            additional_scope = kernel_size - 1
            new_dilation = 1
            for i in range(layers):
                # dilations of this layer
                self.dilations.append((new_dilation, init_dilation))

                # dilated queues for fast generation
                self.dilated_queues.append(DilatedQueue(max_length=(kernel_size - 1) * new_dilation + 1,
                                                        num_channels=residual_channels,
                                                        dilation=new_dilation,
                                                        dtype=dtype))

                # dilated convolutions
                self.filter_convs.append(nn.Conv1d(in_channels=residual_channels,
                                                   out_channels=dilation_channels,
                                                   kernel_size=kernel_size,
                                                   bias=bias))

                self.gate_convs.append(nn.Conv1d(in_channels=residual_channels,
                                                 out_channels=dilation_channels,
                                                 kernel_size=kernel_size,
                                                 bias=bias))

                # 1x1 convolution for residual connection
                self.residual_convs.append(nn.Conv1d(in_channels=dilation_channels,
                                                     out_channels=residual_channels,
                                                     kernel_size=1,
                                                     bias=bias))

                # 1x1 convolution for skip connection
                self.skip_convs.append(nn.Conv1d(in_channels=dilation_channels,
                                                 out_channels=skip_channels,
                                                 kernel_size=1,
                                                 bias=bias))

                receptive_field += additional_scope
                additional_scope *= 2
                init_dilation = new_dilation
                new_dilation *= 2

        self.end_conv_1 = nn.Conv1d(in_channels=skip_channels,
                                  out_channels=end_channels,
                                  kernel_size=1,
                                  bias=True)

        self.end_conv_2 = nn.Conv1d(in_channels=end_channels,
                                    out_channels=classes,
                                    kernel_size=1,
                                    bias=True)

        self.receptive_field = receptive_field


    def wavenet(self, input, dilation_func):

        x = self.start_conv(input)
        skip = 0

        # WaveNet layers
        for i in range(self.blocks * self.layers):

            #            |----------------------------------------|     *residual*
            #            |                                        |
            #            |    |-- conv -- tanh --|                |
            # -> dilate -|----|                  * ----|-- 1x1 -- + -->	*input*
            #                 |-- conv -- sigm --|     |
            #                                         1x1
            #                                          |
            # ---------------------------------------> + ------------->	*skip*

            (dilation, init_dilation) = self.dilations[i]

            residual = dilation_func(x, dilation, init_dilation, i)

            # dilated convolution
            filter = self.filter_convs[i](residual)
            filter = torch.tanh(filter)
            gate = self.gate_convs[i](residual)
            gate = torch.sigmoid(gate)
            x = filter * gate

            # parametrized skip connection
            s = x
            if x.size(2) != 1:
                 s = dilate(x, 1, init_dilation=dilation)
            s = self.skip_convs[i](s)
            try:
                skip = skip[:, :, -s.size(2):]
            except:
                skip = 0
            skip = s + skip

            x = self.residual_convs[i](x)
            x = x + residual[:, :, (self.kernel_size - 1):]

        x = F.relu(skip)
        x = F.relu(self.end_conv_1(x))
        x = self.end_conv_2(x)

        return x

    def wavenet_dilate(self, input, dilation, init_dilation, i):
        x = dilate(input, dilation, init_dilation)
        return x

    def queue_dilate(self, input, dilation, init_dilation, i):
        queue = self.dilated_queues[i]
        queue.enqueue(input)
        x = queue.dequeue(num_deq=self.kernel_size,
                          dilation=dilation)
        
        return x

    def forward(self, input, state=None):
        # BLD -> BDL
        input = input.transpose(1, 2).contiguous()

        x = self.wavenet(
            input, 
            dilation_func=self.wavenet_dilate,
        )

        # reshape output
        x = x.transpose(1, 2).contiguous()
        x = x[:, -(input.shape[2] - self.receptive_field):]
        return x, None

    def step(self, x, state=None):
        if len(x.shape) == 1:
            x = x.unsqueeze(1).unsqueeze(1)
        elif len(x.shape) == 2:
            x = x.unsqueeze(1)
        
        if state is None:
            # Reset dilated queues
            for queue in self.dilated_queues:
                queue.reset(device=x.device)

        x = x.transpose(1, 2).contiguous()
        x = self.wavenet(x, dilation_func=self.queue_dilate)
        x = x.transpose(1, 2).contiguous()

        return x, self.dilated_queues

def test_wavenet():
    wavenet = WaveNetModel(
        layers=10,
        blocks=4,
        dilation_channels=32,
        residual_channels=32,
        skip_channels=256,
        end_channels=256,
        classes=256,
        # output_length=16000,
        kernel_size=2,
    ).cuda()

    print(wavenet)
    print(wavenet.parameter_count())
    print(wavenet.receptive_field)
    # BLD
    x = torch.randn(7, 4093 + 16, 256).cuda()
    y, _ = wavenet(x)
    print(y.shape)

    with torch.no_grad():
        state = None
        ys = []
        for i in range(x.shape[1]):
            y_i, state = wavenet.step(x[:, i, :], state)
            ys.append(y_i)
        y_ = torch.stack(ys).squeeze().transpose(0, 1)
    breakpoint()
    # assert y.shape == (8, 16000, 256)

if __name__ == "__main__":
    test_wavenet()
