import torch
import torchaudio
import gc
import argparse
import os
from tqdm import tqdm
import wandb
from audio_diffusion_pytorch import DiffusionModel, UNetV0, VDiffusion, VSampler, StyleVDiffusion, StyleVSampler
from audioldm.pipeline import build_model
from audio_data_pytorch import AllTransform
# from wav_dataset import WAVDataset
from wav_dataset_adobe import WAVDataset
import soundfile as sf
import random
import laion_clap

SAMPLE_RATE = 16000
BATCH_SIZE = 32
NUM_SAMPLES = int(2.56 * SAMPLE_RATE)
# NUM_SAMPLES = 2 ** 15


def create_model():
    return DiffusionModel(
        net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case)
        dim=2, # for spectrogram we use 2D-CNN
        in_channels=16, # U-Net: number of input (audio) channels
        out_channels=8, # U-Net: number of output (audio) channels
        channels=[256, 512, 768, 1280, 1280], # U-Net: channels at each layer
        factors=[2, 2, 2, 2, 1], # U-Net: downsampling and upsampling factors at each layer
        items=[2, 2, 2, 2, 2], # U-Net: number of repeating items at each layer
        attentions=[0, 0, 1, 1, 1], # U-Net: attention enabled/disabled at each layer
        attention_heads=8, # U-Net: number of attention heads per attention item
        attention_features=64, # U-Net: number of attention features per attention item
        diffusion_t=StyleVDiffusion, # The diffusion method used
        sampler_t=StyleVSampler, # The diffusion sampler used
        # use_embedding_cfg=True, # Use classifier free guidance
        # embedding_max_length=1, # Maximum length of the embeddings
        embedding_features=512, # U-Net: embedding features
        cross_attentions=[0, 0, 1, 1, 1], # U-Net: cross-attention enabled/disabled at each layer 
    )

def main():
    args = parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    train_dataset = WAVDataset(
        path="/mnt/localssd/actionkid/train_speech_non_speech_actionkid_subset_timesteps.json",
        random_crop_size=NUM_SAMPLES,
        sample_rate=SAMPLE_RATE,
        transforms=AllTransform(
            mono=True,
        ),
        is_training=True
    )

    test_dataset = WAVDataset(
        path="/mnt/localssd/actionkid/test_speech_non_speech_actionkid_subset_timesteps.json",
        random_crop_size=NUM_SAMPLES,
        sample_rate=SAMPLE_RATE,
        transforms=AllTransform(
            mono=True,
        ),
        is_training=False
    )

    print(f"Dataset length: {len(train_dataset)}")

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=32,
        pin_memory=True,
    )

    test_dataloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=32,
        pin_memory=True,
    )

    model = create_model().to(device)
    audio_codec = build_model().to(device)
    clap_model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base').to(device)
    clap_model.load_ckpt("/mnt/localssd/texture/checkpoints/music_speech_audioset_epoch_15_esc_89.98.pt")
    audio_codec.latent_t_size = 256

    # optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    optimizer = torch.optim.AdamW(params=list(model.parameters()), lr=1e-4, betas= (0.95, 0.999), eps=1e-6, weight_decay=1e-3)

    # print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

    run_id = wandb.util.generate_id()
    if args.run_id is not None:
        run_id = args.run_id
    print(f"Run ID: {run_id}")

    wandb.init(project="ldm", resume=args.resume, id=run_id)

    epoch = 0
    step = 0

    os.makedirs(os.path.join(args.checkpoint, run_id, 'wavs'), exist_ok=True)
    checkpoint_path = os.path.join(args.checkpoint, run_id)

    if wandb.run.resumed:
        if os.path.exists(checkpoint_path):
            checkpoint = torch.load(checkpoint_path)
        else:
            checkpoint = torch.load(wandb.restore(checkpoint_path))
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        step = epoch * len(train_dataloader)
    
    # scaler = torch.cuda.amp.GradScaler()

    # Initialize the best test loss with a high value to make sure the first test_avg_loss will be considered better.
    best_test_loss = 100

    model.train()
    for epoch in range(epoch, 2500):
        avg_loss = 0
        avg_loss_step = 0
        progress = tqdm(train_dataloader, ncols=80)
        for i, (audio, target, cond) in enumerate(progress):
            optimizer.zero_grad()
            audio = audio.unsqueeze(1).to(device)
            target = target.unsqueeze(1).to(device)
            with torch.no_grad():
                audio = audio_codec.vae_encode(audio)
                target = audio_codec.vae_encode(target)
                cond_embed = torch.from_numpy(clap_model.get_audio_embedding_from_data(x=cond.squeeze(1).cpu().numpy())).unsqueeze(1).to(device)
            loss = model(audio, target, embedding=cond_embed)
            avg_loss += loss.item()
            avg_loss_step += 1
            loss.backward()
            optimizer.step()
            progress.set_postfix(
                # loss=loss.item(),
                loss=avg_loss / avg_loss_step,
                epoch=epoch + i / len(train_dataloader),
            )


            if step % 500 == 0:
                model.eval()
                inp, tgt, cue = random.choice(test_dataset)
                inp, tgt = inp.to(device)[None, None, ...], tgt.to(device)[None, None, ...]
                with torch.no_grad():
                    inp = audio_codec.vae_encode(inp)
                    tgt = audio_codec.vae_encode(tgt)
                    cue_embed = torch.from_numpy(clap_model.get_audio_embedding_from_data(x=cue.squeeze(1).cpu().numpy())).unsqueeze(1).to(device)

                cue = torchaudio.transforms.Resample(48000, SAMPLE_RATE)(cue)

                # Turn noise into new audio sample with diffusion
                noise = torch.randn(1, 8, 64, 16, device=device)
                # with torch.cuda.amp.autocast():
                sample = model.sample(inp[0], noise, embedding=cue_embed, num_steps=200)

                sf.write(os.path.join(checkpoint_path, 'wavs', f'test_input_sound_{step}.wav'), audio_codec.mel_spectrogram_to_waveform(audio_codec.decode_first_stage(inp))[0][0], SAMPLE_RATE)
                sf.write(os.path.join(checkpoint_path, 'wavs', f'test_condiontional_sound_{step}.wav'), cue[0].cpu(), SAMPLE_RATE)
                sf.write(os.path.join(checkpoint_path, 'wavs', f'test_generated_sound_{step}.wav'), audio_codec.mel_spectrogram_to_waveform(audio_codec.decode_first_stage(sample))[0][0], SAMPLE_RATE)
                sf.write(os.path.join(checkpoint_path, 'wavs', f'test_target_sound_{step}.wav'), audio_codec.mel_spectrogram_to_waveform(audio_codec.decode_first_stage(tgt))[0][0], SAMPLE_RATE)

                wandb.log({
                    "step": step,
                    "epoch": epoch + i / len(train_dataloader),
                    "loss": avg_loss / avg_loss_step,
                    "input_audio": wandb.Audio(os.path.join(checkpoint_path, 'wavs', f'test_input_sound_{step}.wav'), caption="Input audio", sample_rate=SAMPLE_RATE),
                    "conditioned_audio": wandb.Audio(os.path.join(checkpoint_path, 'wavs', f'test_condiontional_sound_{step}.wav'), caption="Conditioned audio", sample_rate=SAMPLE_RATE),
                    "generated_audio": wandb.Audio(os.path.join(checkpoint_path, 'wavs', f'test_generated_sound_{step}.wav'), caption="Generated audio", sample_rate=SAMPLE_RATE),
                    "target_audio": wandb.Audio(os.path.join(checkpoint_path, 'wavs', f'test_target_sound_{step}.wav'), caption="Target audio", sample_rate=SAMPLE_RATE)
                })
                model.train()
            
            if step % 100 == 0:
                wandb.log({
                    "step": step,
                    "epoch": epoch + i / len(train_dataloader),
                    "loss": avg_loss / avg_loss_step,
                })
                avg_loss = 0
                avg_loss_step = 0
            
            step += 1

        # Evaluate on test set
        model.eval()
        test_loss = 0
        test_loss_step = 0
        test_progress = tqdm(test_dataloader, ncols=80)
        for i, (test_audio, test_target, test_cond) in enumerate(test_progress):
            test_audio = test_audio.unsqueeze(1).to(device)
            test_target = test_target.unsqueeze(1).to(device)
            with torch.no_grad():
                test_audio = audio_codec.vae_encode(test_audio)
                test_target = audio_codec.vae_encode(test_target)
                test_cond_embed = torch.from_numpy(clap_model.get_audio_embedding_from_data(x=test_cond.squeeze(1).cpu().numpy())).unsqueeze(1).to(device)
                test_loss += model(test_audio, test_target, embedding=test_cond_embed).item()
                test_loss_step += 1
                test_avg_loss = test_loss / test_loss_step
            test_progress.set_postfix(
                test_loss=test_avg_loss,
                epoch=epoch + i / len(test_dataloader),
            )
        # if current test loss is better than previous best, save model
        epoch += 1
        if test_avg_loss < best_test_loss:
            best_test_loss = test_avg_loss
            # Save the model checkpoint with the current best test loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, os.path.join(checkpoint_path, "best_test_loss.pt"))
            # Log the best_test_loss to WandB using wandb.log
            wandb.log({"best_test_loss": best_test_loss})
            wandb.save(checkpoint_path, base_path=args.checkpoint)

        # Save the latest model checkpoint
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, os.path.join(checkpoint_path, "latest.pt"))
        model.train()


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoint", type=str, default='/mnt/localssd/texture/checkpoints/')
    parser.add_argument("--resume", action="store_true")
    parser.add_argument("--run_id", type=str) # , default='ldm_audio_condition_no_overlap'
    return parser.parse_args()


if __name__ == "__main__":
    main()