import argparse
import os

import torch
import yaml
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from utils.model import get_model, get_vocoder, get_param_num
from utils.tools import to_device, log, synth_one_sample
from model import DPPLoss
from dataset import Dataset

from evaluate import evaluate
import numpy as np  

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def main(args, configs):
    print("Prepare transforming ...")

    preprocess_config, model_config, train_config = configs


    # Get dataset
    dataset = Dataset(
     "val3.txt", preprocess_config, train_config, sort=False, drop_last=False
     )
    batch_size = 1
    group_size = 1  # Set this larger than 1 to enable sorting in Dataset
    assert batch_size * group_size < len(dataset)
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=dataset.collate_fn,
    )

    # Prepare model
    model = get_model(args, configs, device, train=False)
  
    for batchs in tqdm(loader):
        for batch in batchs:
            batch = to_device(batch, device)
            # Forward 
            with torch.no_grad():
                output = model(*(batch[2:]))
            basename = str(batch[0][0])
            filename = "{}.npy".format(basename)
            mel_path = "hifi-gan-master" + "/ft_dataset" 
            full_path = os.path.join(mel_path,filename)  
            np.save(full_path,output[0].detach().cpu().numpy())


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--restore_step", type=int, default=0)
    parser.add_argument(
        "-p",
        "--preprocess_config",
        type=str,
        required=True,
        help="path to preprocess.yaml",
    )
    parser.add_argument(
        "-m", "--model_config", type=str, required=True, help="path to model.yaml"
    )
    parser.add_argument(
        "-t", "--train_config", type=str, required=True, help="path to train.yaml"
    )
    args = parser.parse_args()

    # Read Config
    preprocess_config = yaml.load(
        open(args.preprocess_config, "r"), Loader=yaml.FullLoader
    )
    model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader)
    train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.FullLoader)
    configs = (preprocess_config, model_config, train_config)

    main(args, configs)