import numpy as np
from torch import nn
import librosa
import torch
import torch.nn.functional as F
from .config import *


class ToLog(nn.Module):
    def __init__(self, sample_rate=SAMPLE_RATE, n_fft=N_FFT, freq_bins=N_CLASS, fmin=F_MIN, bins_per_octave=Q,
                 ref=1.0, amin=1e-10, top_db=80.0):
        super(ToLog, self).__init__()
        fre_resolution = sample_rate / n_fft
        idxs = torch.arange(0, freq_bins)  # , device=device
        log_idxs = fmin * (2 ** (idxs / bins_per_octave)) / fre_resolution

        self.ref = ref
        self.amin = amin
        self.top_db = top_db

        # Linear interpolation： y_k = y_i * (k-i) + y_{i+1} * ((i+1)-k)
        log_idxs_floor = torch.floor(log_idxs).long()
        log_idxs_floor_w = (log_idxs - log_idxs_floor).reshape([1, 1, freq_bins])
        log_idxs_ceiling = torch.ceil(log_idxs).long()
        log_idxs_ceiling_w = (log_idxs_ceiling - log_idxs).reshape([1, 1, freq_bins])
        self.register_buffer("log_idxs_floor", log_idxs_floor, persistent=False)
        self.register_buffer("log_idxs_floor_w", log_idxs_floor_w, persistent=False)
        self.register_buffer("log_idxs_ceiling", log_idxs_ceiling, persistent=False)
        self.register_buffer("log_idxs_ceiling_w", log_idxs_ceiling_w, persistent=False)

    def forward(self, specgram):
        specgram = specgram[:, :, self.log_idxs_floor] * self.log_idxs_floor_w + \
                   specgram[:, :, self.log_idxs_ceiling] * self.log_idxs_ceiling_w
        specgram_db = self.power_to_db(specgram)

        return specgram_db

    def power_to_db(self, spec):
        ref_value = self.ref
        log_spec = 10.0 * torch.log10(torch.clamp(spec, min=self.amin, max=np.inf))
        log_spec -= 10.0 * np.log10(np.maximum(self.amin, ref_value))

        if self.top_db is not None:
            if self.top_db < 0:
                raise librosa.util.exceptions.ParameterError('top_db must be non-negative')
            log_spec = torch.clamp(log_spec, min=log_spec.max().item() - self.top_db, max=np.inf)

        return log_spec


class MRDCConv(nn.Module):
    def __init__(self, planes, Q=48, N_har=16):
        super(MRDCConv, self).__init__()
        self.dilation_list = (np.log(np.arange(1, N_har + 1))/np.log(2**(1.0/Q))).round().astype(np.int)
        self.conv_list = nn.ModuleList()
        for i in range(len(self.dilation_list)):
            self.conv_list.append(nn.Conv2d(planes, planes, kernel_size=(1, 1)))

    def forward(self, specgram):
        # input [b x C x T x n_freq]
        # output: [b x C x T x n_freq]
        dilation = self.dilation_list[0]
        y = self.conv_list[0](specgram)
        y = F.pad(y, pad=[0, dilation])
        y = y[:, :, :, dilation:]
        for i in range(1, len(self.conv_list)):
            dilation = self.dilation_list[i]
            x = self.conv_list[i](specgram)
            # => [b x T x (n_freq + dilation)]
            # x = F.pad(x, pad=[0, dilation])
            x = x[:, :, :, dilation:]
            n_freq = x.size()[3]
            y[:, :, :, :n_freq] += x

        return y


class SDConv(nn.Module):
    def __init__(self, planes, kernel_size):
        super(SDConv, self).__init__()
        self.conv = nn.Conv2d(planes, planes, kernel_size, padding=(0, (kernel_size[1]-1)*48//2), dilation=(1, 48))

    def forward(self, x):
        out = self.conv(x)
        return out


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, d_kernel_size, padding, type='SD'):
        super(ConvBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding),
            nn.ReLU(),
            MRDCConv(out_channels) if type == 'MRDC' else SDConv(out_channels, d_kernel_size),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels),
        )

    def forward(self, x):
        out = self.conv(x)
        return out


class HARMOF0(nn.Module):
    def __init__(self):
        super(HARMOF0, self).__init__()
        self.log_spec = ToLog()
        self.model = nn.Sequential(
            ConvBlock(1, 32, (3, 3), (0, 0), (1, 1), 'MRDC'),
            ConvBlock(32, 64, (3, 3), (1, 3), (1, 1), 'SD'),
            ConvBlock(64, 128, (3, 3), (1, 3), (1, 1), 'SD'),
            ConvBlock(128, 128, (3, 3), (1, 3), (1, 1), 'SD'),
            nn.Conv2d(128, 64, (1, 1)),
            nn.ReLU(),
        )
        self.pitch = nn.Sequential(
            nn.Conv2d(64, 1, (1, 1)),
            nn.Sigmoid()
        )

    def forward(self, spec):
        log_spec = self.log_spec(spec*spec).unsqueeze(1)
        out = self.model(log_spec)
        y = self.pitch(out)
        return y


# if __name__ == '__main__':
#     x = torch.ones((1, 128, 1025))
#     crepe = HARMOF0()
#     y = crepe(x)
#     print(y.shape)
