import os
import functools

from evaluation.experiments import separate_slakh_msdm
from main.data import ChunkedSupervisedDataset, assert_is_audio
from main.module_base import Model
from audio_diffusion_pytorch import KarrasSchedule
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch
import torchaudio


@torch.no_grad()
def generate_track(
    denoise_fn,
    sigmas,
    noises,
    source=None,
    mask=None,
    num_resamples=1,
    s_churn=0.0,
) -> torch.Tensor:

    x = sigmas[0] * noises
    _, num_sources, _  = x.shape    

    # Initialize default values
    source = torch.zeros_like(x) if source is None else source
    mask = torch.zeros_like(x) if mask is None else mask
    
    sigmas = sigmas.to(x.device)
    gamma = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1)
    
    # Iterate over all timesteps
    for i in tqdm(range(len(sigmas) - 1)):
        sigma, sigma_next = sigmas[i], sigmas[i+1]

        # Noise source to current noise level
        noisy_source = source + sigma*torch.randn_like(source)
        
        for r in range(num_resamples):
            # Merge noisy source and current x
            x = mask*noisy_source + (1.0 - mask)*x 

            # Inject randomness
            sigma_hat = sigma * (gamma + 1)            
            x_hat = x + torch.randn_like(x) * (sigma_hat**2 - sigma**2)**0.5

            # Compute conditioned derivative
            d = (x_hat - denoise_fn(x_hat, sigma=sigma_hat)) / sigma_hat

            # Update integral
            x = x_hat + d*(sigma_next - sigma_hat)
                
            # Renoise if not last resample step
            if r < num_resamples - 1:
                x = x + torch.randn_like(x) * (sigma**2 - sigma_next**2)**0.5

    return mask*source + (1.0 - mask)*x

@torch.no_grad()
def generate_inpaint_mask(sources, stem_to_inpaint):
    mask = torch.ones_like(sources)
    for stem_idx in stem_to_inpaint:
        mask[:,stem_idx,:] = 0.0
    return mask

def main():
    dataset_path = '/nas/datasets/SLAKH/slakh2100/test'
    model_path = 'ckpts/glorious-star-335/epoch=729-valid_loss=0.014.ckpt'
    output_dir = 'output/partial_generating/B/try_reproduce'
    s_churn = 10.0
    num_resamples = 1
    sigma_min = 1e-4
    sigma_max = 20.0
    num_steps = 256
    batch_size = 32
    sample_rate=22050
    stems = ["bass", "drums", "guitar", "piano"]
    stems_to_inpaint = {"bass"}
    resume = False
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    dataset = ChunkedSupervisedDataset(
        audio_dir=dataset_path,
        stems=["bass", "drums", "guitar", "piano"],
        sample_rate=sample_rate,
        max_chunk_size=262144,
        min_chunk_size=262144,
    )
    loader = DataLoader(dataset, batch_size=batch_size, num_workers=8)
    
    
    # Load model
    model = Model.load_from_checkpoint("ckpts/glorious-star-335/epoch=729-valid_loss=0.014.ckpt").cuda()
    schedule = KarrasSchedule(sigma_min=1e-4, sigma_max=20.0, rho=7)(num_steps, model.device)
    denoise_fn = model.model.diffusion.denoise_fn

    stemidx_to_inpaint = [i for i,s in enumerate(stems) if s in stems_to_inpaint]
    inpaint_mask = None
    chunk_id = 0
    for idx, batch_data in tqdm(enumerate(loader)):
        # batch_data: List, 4 * bs * 1 * lens
        # print(batch_data)
        data = torch.cat([batch_data[0], batch_data[1], batch_data[2], batch_data[3]], dim=1).cuda()
        # torchaudio.save('test.wav', data[0, [0], :].cpu(), sample_rate=sample_rate)
        # print(data.shape)
        if inpaint_mask is None or inpaint_mask.shape[0] != data.shape[0]:
            # 生成mask
            inpaint_mask = generate_inpaint_mask(data, stem_to_inpaint=stemidx_to_inpaint)
        # print(inpaint_mask[0])
        inpainted_tracks = generate_track(
            source=data,
            mask=inpaint_mask,
            denoise_fn=denoise_fn,
            sigmas=schedule,
            noises=torch.randn_like(data),
            s_churn=s_churn,
            num_resamples=num_resamples,
        )
        # inpainted_tracks = {"bass": inpainted_tracks[:, 0, :], "drums", "guitar", "piano"}
        num_samples = inpainted_tracks.shape[0]
        for i in range(num_samples):
            chunk_path_separate = os.sep.join([output_dir, 'separate', str(chunk_id)])
            chunk_path_sum = os.sep.join([output_dir, 'sum', str(chunk_id)])
            os.makedirs(chunk_path_separate, exist_ok=True)
            os.makedirs(chunk_path_sum, exist_ok=True)
            one_track = {'bass': inpainted_tracks[i, [0], :], 'drums': inpainted_tracks[i, [1], :], 'guitar': inpainted_tracks[i, [2], :],\
                     'piano': inpainted_tracks[i, [3], :], 'mixture': torch.sum(inpainted_tracks[i, :, :], dim=0, keepdim=True), \
                    'gt_mixture': torch.sum(data[i, :, :], dim=0, keepdim=True)}
            for stem, separated_track in one_track.items():
                assert_is_audio(separated_track)
                torchaudio.save(os.sep.join([chunk_path_separate, '{}.wav'.format(stem)]), separated_track.cpu(), sample_rate=sample_rate)
            assert_is_audio(one_track['mixture'])
            torchaudio.save(os.sep.join([chunk_path_sum, 'mixture.wav']), one_track['mixture'].cpu(), sample_rate=sample_rate)
            torchaudio.save(os.sep.join([chunk_path_sum, 'gt_mixture.wav']), one_track['gt_mixture'].cpu(), sample_rate=sample_rate)

            chunk_id += 1

if __name__ == '__main__':
    main()