import torch
import torchaudio
import gc
import argparse
import os
from tqdm import tqdm
import wandb
from audio_diffusion_pytorch import DiffusionModel, UNetV0, VDiffusion, VSampler, StyleVDiffusion, StyleVSampler
from audioldm.pipeline import build_model
from audio_data_pytorch import AllTransform
from wav_dataset import WAVDataset
import random
import soundfile as sf
import numpy as np
import laion_clap
seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

SAMPLE_RATE = 16000
BATCH_SIZE = 1
NUM_SAMPLES = int(2.56 * SAMPLE_RATE)
# NUM_SAMPLES = 2 ** 15


def create_model():
    return DiffusionModel(
        net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case)
        dim=2, # for spectrogram we use 2D-CNN
        in_channels=16, # U-Net: number of input (audio) channels
        out_channels=8, # U-Net: number of output (audio) channels
        channels=[256, 512, 768, 1280, 1280], # U-Net: channels at each layer
        factors=[2, 2, 2, 2, 1], # U-Net: downsampling and upsampling factors at each layer
        items=[2, 2, 2, 2, 2], # U-Net: number of repeating items at each layer
        attentions=[0, 0, 1, 1, 1], # U-Net: attention enabled/disabled at each layer
        attention_heads=8, # U-Net: number of attention heads per attention item
        attention_features=64, # U-Net: number of attention features per attention item
        diffusion_t=StyleVDiffusion, # The diffusion method used
        sampler_t=StyleVSampler, # The diffusion sampler used
        # use_embedding_cfg=True, # Use classifier free guidance
        # embedding_max_length=1, # Maximum length of the embeddings
        embedding_features=512, # U-Net: embedding features
        cross_attentions=[0, 0, 1, 1, 1], # U-Net: cross-attention enabled/disabled at each layer 
    )

def main():
    args = parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    test_dataset = WAVDataset(
        path="/mnt/localssd/actionkid/test_speech_non_speech_actionkid_subset_timesteps.json",
        random_crop_size=NUM_SAMPLES,
        sample_rate=SAMPLE_RATE,
        transforms=AllTransform(
            mono=True,
        ),
        is_training=False
    )

    print(f"Dataset length: {len(test_dataset)}")

    test_dataloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=16,
        pin_memory=True,
    )

    model = create_model().to(device)
    audio_codec = build_model().to(device)
    clap_model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base').to(device)
    clap_model.load_ckpt("/mnt/localssd/texture/checkpoints/music_speech_audioset_epoch_15_esc_89.98.pt")
    audio_codec.latent_t_size = 256

    if args.run_id is not None:
        run_id = args.run_id
    print(f"Run ID: {run_id}")


    os.makedirs(os.path.join(args.checkpoint, run_id, 'test_wavs'), exist_ok=True)
    checkpoint_path = os.path.join(args.checkpoint, run_id)
    checkpoint = torch.load(os.path.join(checkpoint_path, 'best_test_loss.pt'))
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    inp, tgt, _ = test_dataset[19]
    inp, tgt = inp.to(device)[None, None, ...], tgt.to(device)[None, None, ...]
    # Turn noise into new audio sample with diffusion
    noise = torch.randn(1, 8, 64, 16, device=device)
    with torch.no_grad():
        inp = audio_codec.vae_encode(inp)
        tgt = audio_codec.vae_encode(tgt)
       
    progress = tqdm(test_dataloader)
    for i, (_, _, cond) in enumerate(progress):

        with torch.no_grad():
            cond_embed = torch.from_numpy(clap_model.get_audio_embedding_from_data(x=cond.squeeze(1).cpu().numpy())).unsqueeze(1).to(device)

            cond = torchaudio.transforms.Resample(48000, SAMPLE_RATE)(cond)[0]

            sample = model.sample(inp[0], noise, embedding=cond_embed, num_steps=200)

            if i < 5:
                sf.write(os.path.join(checkpoint_path, 'test_wavs', f'test_input_sound_{i}.wav'), audio_codec.mel_spectrogram_to_waveform(audio_codec.decode_first_stage(inp))[0][0], SAMPLE_RATE)
                sf.write(os.path.join(checkpoint_path, 'test_wavs', f'test_condiontional_sound_{i}.wav'), cond[0].cpu(), SAMPLE_RATE)
                sf.write(os.path.join(checkpoint_path, 'test_wavs', f'test_generated_sound_{i}.wav'), audio_codec.mel_spectrogram_to_waveform(audio_codec.decode_first_stage(sample))[0][0], SAMPLE_RATE)
                sf.write(os.path.join(checkpoint_path, 'test_wavs', f'test_target_sound_{i}.wav'), audio_codec.mel_spectrogram_to_waveform(audio_codec.decode_first_stage(tgt))[0][0], SAMPLE_RATE)
                # pass
            else:
                break


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoint", type=str, default='/mnt/localssd/texture/checkpoints')
    parser.add_argument("--run_id", type=str) # , default='ldm_audio_condition_no_overlap'
    return parser.parse_args()


if __name__ == "__main__":
    main()