import torch
import tqdm
from torchaudio.transforms import MelSpectrogram as MelS
from diffwave.inference import predict as diffwave_predict
from diffwave.params import AttrDict, params as base_params
from combine_f import network_F
from svoice.models.sisnr_loss import cal_loss as sisnr_calc

mel_args = {
      'sample_rate': 8000,
      'win_length': 384,
      'hop_length': 192,
      'n_fft': 384,
      'f_min': 80.0,
      'f_max': 3000.0,
      'n_mels': 64,
      'power': 2.0,
      'normalized': False,
  }
mel_t = MelS(**mel_args).cuda()

separator_model_path = '' # Path to deterministic model, ref as B in paper.
diffwave_model_path = '' # Path to diffusion model, ref as GM in paper.
sep_model = torch.load(separator_model_path).cuda()
FNet = network_F().cuda()
optimizer = torch.optim.Adam(FNet.parameters(), lr=1e-3).cuda()
## Main Function ##
def calc_loss(v,mix,lengths):
    bar_vd = sep_model(mix)
    mels = mel_t(bar_vd)
    mels = 20 * torch.log10(torch.clamp(mels, min=1e-5)) - 20
    est_mels = torch.clamp((mels + 100) / 100, 0.0, 1.0)
    spectrogram_trimmed = est_mels[:,:,:,:-1]
    B,C,nmel,ntime = spectrogram_trimmed.shape
    enlarged_spectrogram = spectrogram_trimmed.contiguous().view(B*C,nmel,ntime).contiguous()
    bar_vg,_ = diffwave_predict(enlarged_spectrogram, diffwave_model_path, base_params, fast_sampling=True)
    bar_Vd = torch.stft(bar_vd.view(B*C,-1).contiguous(),384,onesided=True,hop_length=384//2,return_complex=True)
    bar_Vg = torch.stft(bar_vg.view(B*C,-1).contiguous(),384,onesided=True,hop_length=384//2,return_complex=True)
    bar_V = FNet(bar_Vd,bar_Vg)
    bar_v = torch.istft(bar_V,384,onesided=True,hop_length=384//2).view(B,C,-1).contiguous()
    loss, max_snr, estimate_source, reorder_estimate_source = sisnr_calc(v, bar_v, lengths)
    return loss


## Training ## 
Trainset_path = '' # Change it with your object
progression_bar = tqdm(Trainset_path)
for epoch in range(100):
    for example in progression_bar:
        optimizer.zero_grad()
        v,mix,lengths = example
        v = v.cuda()
        mix = mix.cuda()
        lengths = lengths.cuda()
        loss = calc_loss(v,mix,lengths)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(FNet.parameters(),1.0)
        optimizer.step()
        progression_bar.set_description_str(f"Epoch {epoch}, Loss {loss:.2f}")
        