import sys
import numpy as np

import torch
import torch.nn as torch_nn
import torch.nn.functional as F


class MaxFeatureMap2D(torch_nn.Module):
    """ Max feature map (along 2D) 
    
    MaxFeatureMap2D(max_dim=1)
    
    l_conv2d = MaxFeatureMap2D(1)
    data_in = torch.rand([1, 4, 5, 5])
    data_out = l_conv2d(data_in)

    
    Input:
    ------
    data_in: tensor of shape (batch, channel, ...)
    
    Output:
    -------
    data_out: tensor of shape (batch, channel//2, ...)
    
    Note
    ----
    By default, Max-feature-map is on channel dimension,
    and maxout is used on (channel ...)
    """

    def __init__(self, max_dim=1):
        super().__init__()
        self.max_dim = max_dim

    def forward(self, inputs):
        # suppose inputs (batchsize, channel, length, dim)

        shape = list(inputs.size())

        if self.max_dim >= len(shape):
            print("MaxFeatureMap: maximize on %d dim" % (self.max_dim))
            print("But input has %d dimensions" % (len(shape)))
            sys.exit(1)
        if shape[self.max_dim] // 2 * 2 != shape[self.max_dim]:
            print("MaxFeatureMap: maximize on %d dim" % (self.max_dim))
            print("But this dimension has an odd number of data")
            sys.exit(1)
        shape[self.max_dim] = shape[self.max_dim] // 2
        shape.insert(self.max_dim, 2)

        # view to (batchsize, 2, channel//2, ...)
        # maximize on the 2nd dim
        m, i = inputs.view(*shape).max(self.max_dim)
        return m


# For blstm
class BLSTMLayer(torch_nn.Module):
    """ Wrapper over dilated conv1D
    Input tensor:  (batchsize=1, length, dim_in)
    Output tensor: (batchsize=1, length, dim_out)
    We want to keep the length the same
    """

    def __init__(self, input_dim, output_dim):
        super(BLSTMLayer, self).__init__()
        if output_dim % 2 != 0:
            print("Output_dim of BLSTMLayer is {:d}".format(output_dim))
            print("BLSTMLayer expects a layer size of even number")
            sys.exit(1)
        # bi-directional LSTM
        self.l_blstm = torch_nn.LSTM(input_dim, output_dim // 2, \
                                     bidirectional=True)

    def forward(self, x):
        self.l_blstm.flatten_parameters()
        # permute to (length, batchsize=1, dim)
        blstm_data, _ = self.l_blstm(x.permute(1, 0, 2))
        # permute it backt to (batchsize=1, length, dim)
        return blstm_data.permute(1, 0, 2)


class LcnnASV(torch_nn.Module):
    """ Model definition
    """

    def __init__(self, lfcc_dim, class_num):
        super(LcnnASV, self).__init__()
        self.class_num = class_num
        print("model:lcnn")
        ####
        # create network
        ####
        # 1st part of the classifier
        self.m_transform = torch_nn.Sequential(
            torch_nn.Conv2d(1, 64, [5, 5], 1, padding=[2, 2]),
            MaxFeatureMap2D(),
            torch.nn.MaxPool2d([2, 2], [2, 2]),

            torch_nn.Conv2d(32, 64, [1, 1], 1, padding=[0, 0]),
            MaxFeatureMap2D(),
            torch_nn.BatchNorm2d(32, affine=False),
            torch_nn.Conv2d(32, 96, [3, 3], 1, padding=[1, 1]),
            MaxFeatureMap2D(),

            torch.nn.MaxPool2d([2, 2], [2, 2]),
            torch_nn.BatchNorm2d(48, affine=False),

            torch_nn.Conv2d(48, 96, [1, 1], 1, padding=[0, 0]),
            MaxFeatureMap2D(),
            torch_nn.BatchNorm2d(48, affine=False),
            torch_nn.Conv2d(48, 128, [3, 3], 1, padding=[1, 1]),
            MaxFeatureMap2D(),

            torch.nn.MaxPool2d([2, 2], [2, 2]),

            torch_nn.Conv2d(64, 128, [1, 1], 1, padding=[0, 0]),
            MaxFeatureMap2D(),
            torch_nn.BatchNorm2d(64, affine=False),
            torch_nn.Conv2d(64, 64, [3, 3], 1, padding=[1, 1]),
            MaxFeatureMap2D(),
            torch_nn.BatchNorm2d(32, affine=False),

            torch_nn.Conv2d(32, 64, [1, 1], 1, padding=[0, 0]),
            MaxFeatureMap2D(),
            torch_nn.BatchNorm2d(32, affine=False),
            torch_nn.Conv2d(32, 64, [3, 3], 1, padding=[1, 1]),
            MaxFeatureMap2D(),
            torch_nn.MaxPool2d([2, 2], [2, 2]),

            torch_nn.Dropout(0.7)
        )

        # before_pooling BLSTM
        self.m_before_pooling = torch_nn.Sequential(
            BLSTMLayer((lfcc_dim // 16) * 32, (lfcc_dim // 16) * 32),
            BLSTMLayer((lfcc_dim // 16) * 32, (lfcc_dim // 16) * 32)
        )

        # 2nd part of the classifier
        self.m_output_act = torch_nn.Linear((lfcc_dim // 16) * 32, self.class_num)

        # done
        return

    def forward(self, x, mask=None):
        """ definition of forward method 
        Assume x (batchsize, 1, length, dim)
        Output x (batchsize, output_dim)
        """
        # compute scores
        #  1. unsqueeze to (batch, 1, frame_length, fft_bin)
        #  2. compute hidden features
        # print('------input-------')
        # print(x.shape)

        if (len(x.shape) == 3):
            x = x.unsqueeze(1)
        # print('------!!!!input-------')
        # print(x.shape)
        hidden_features = self.m_transform(x)

        #  3. (batch, channel, frame//N, feat_dim//N) ->
        #     (batch, frame//N, channel * feat_dim//N)
        #     where N is caused by conv with stride
        hidden_features = hidden_features.permute(0, 2, 1, 3).contiguous()
        frame_num = hidden_features.shape[1]
        batch_size = hidden_features.shape[0]
        hidden_features = hidden_features.view(batch_size, frame_num, -1)

        # print('------hidden_features-------')
        # print(hidden_features.shape)
        # ##[B, T, 96]

        #  4. pooling
        #  4. pass through LSTM then summing
        hidden_features_lstm = self.m_before_pooling(hidden_features)

        #  5. pass through the output layer
        output_emb = self.m_output_act((hidden_features_lstm + hidden_features).mean(1))

        return output_emb, (hidden_features_lstm + hidden_features).mean(1)
