import torch
import torch.nn.functional as F

from torch import nn
from .CREPE import CREPE
from .HARMOF0 import HARMOF0
from .ByteSep import ByteSep
from .Unet import Unet
from .modules import Wav2Spec, Spec2Wav
from .config import *


class CM(nn.Module):
    def __init__(self, in_channel, in_size, hop_length, svs='unet', pe='harmof0'):
        super(CM, self).__init__()
        self.hop_length = hop_length
        self.svs_type = svs.lower()
        self.pe_type = pe.lower()
        self.to_spec = Wav2Spec(hop_length, WINDOW_SIZE)
        self.to_wav = Spec2Wav(hop_length, WINDOW_SIZE)
        if svs.lower() == 'unet':
            self.svs = Unet(in_channel, in_size, hop_length)
        elif svs.lower() == 'bytesep':
            self.svs = ByteSep(in_channel, in_size, hop_length)
        else:
            raise Exception('svs error.')

        if pe.lower() == 'harmof0':
            self.pe = HARMOF0()
        elif pe.lower() == 'crepe':
            self.pe = CREPE('full')
        else:
            raise Exception('pe error.')

    def forward(self, audio_m, audio_v=None):
        spec_m, cos_m, sin_m = self.to_spec(audio_m)
        out_spec, mask_cos, mask_sin = self.svs(spec_m)
        out_cos = cos_m * mask_cos - sin_m * mask_sin
        out_sin = sin_m * mask_cos + cos_m * mask_sin
        out_real, out_imag = out_spec * out_cos, out_spec * out_sin
        out_audio = self.to_wav(out_real, out_imag, audio_m.shape[-1])
        out_spec = out_spec.squeeze(1)

        if self.pe_type == 'crepe':
            n_steps = out_spec.shape[-2]
            t_audio_v_pred = F.pad(out_audio, (512, 512), 'constant', 0)
            audio_frames = torch.zeros((out_audio.shape[0], 1, n_steps, 1024), device=out_audio.device)
            for i in range(n_steps):
                audio_frames[:, 0, i, :] = t_audio_v_pred[:, i * self.hop_length:i * self.hop_length + 1024]
            pitch_pred = self.pe(audio_frames)
        else:
            pitch_pred = self.pe(out_spec).squeeze(1)

        if audio_v is None:
            return out_audio, pitch_pred
        else:
            spec_v, _, _ = self.to_spec(audio_v)
            spec_v = spec_v.squeeze(1)
            if self.pe_type == 'crepe':
                n_steps = spec_v.shape[-2]
                t_audio_v_pred = F.pad(audio_v, (512, 512), 'constant', 0)
                audio_frames = torch.zeros((audio_v.shape[0], 1, n_steps, 1024), device=out_audio.device)
                for i in range(n_steps):
                    audio_frames[:, 0, i, :] = t_audio_v_pred[:, i * self.hop_length:i * self.hop_length + 1024]
                v_pitch_pred = self.pe(audio_frames)
            else:
                v_pitch_pred = self.pe(spec_v).squeeze(1)
            return out_audio, pitch_pred, v_pitch_pred, out_spec, spec_v
