import argparse
import os
import torch
import yaml
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from utils.model import get_model, get_vocoder, get_param_num, get_dpp_model
from utils.tools import to_device , get_mask_from_lengths
from dataset import Dataset   

from dpp_evaluate import evaluate 
from utils.dpp_tools import get_random_phrase 
from model.diversity_loss import diversity_loss   
from model.diversity_loss import ref_diversity_loss 
import math 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def main(args, configs):
    print("Estimating average density...")

    preprocess_config, model_config, train_config = configs 

    #Get dataset 
    dataset = Dataset(
        "filelist/ljs_audio_train_filelist.txt", preprocess_config, train_config, sort=True, drop_last=False)

    batch_size = train_config['optimizer']['batch_size']
    group_size = 4  
    assert group_size * batch_size < len(dataset)
    loader = DataLoader(
        dataset,
        batch_size = batch_size * group_size,
        shuffle=True,
        collate_fn=dataset.collate_fn)

    # Prepare model 
    tts = get_model(args, configs, device, train=False)
    embed, prenet = tts.embedding, tts.prenet 
    sdp, text_encoder = tts.variance_adaptor.dp, tts.encoder 
    spp = tts.variance_adaptor.pp 

    quality_sum = 0 
    for batchs in tqdm(loader):
        for batch in batchs:
            batch  = to_device(batch, device)
            x = embed(batch[3]) *  math.sqrt(192)
            mask = get_mask_from_lengths(batch[4])
            x = prenet(x, mask)
            x = text_encoder(x, batch[4])               # shape = [B,C,T]
            x_mask = mask.unsqueeze(1)
            assert args.prosody_feature in ["duration", "pitch"] 
            if args.prosody_feature == 'duration':
                quality = sdp.density_estimation(x=x.transpose(-1,-2), x_mask=~x_mask, logw=batch[-1].float())
            else:
                quality = spp.density_estimation(x, x_mask=~x_mask, p=batch[-3])     
            quality_sum += torch.sum(quality).detach()

    mean = quality_sum / len(dataset)

    print("Density estimation result: mean qualtiy: {:.5f}".format(mean))
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--restore_step", type=int, default=0)
    parser.add_argument(
        "-p",
        "--preprocess_config",
        type=str,
        required=True,
        help="path to preprocess.yaml",
    )
    parser.add_argument(
        "-m", "--model_config", type=str, required=True, help="path to model.yaml"
    )
    parser.add_argument(
        "-t", "--train_config", type=str, required=True, help="path to train.yaml"
    )
    parser.add_argument(
        "-f", "--prosody_feature", type=str, required=True, help="which prosody feature?"
    )
    args = parser.parse_args()

    # Read Config
    preprocess_config = yaml.load(
        open(args.preprocess_config, "r"), Loader=yaml.FullLoader
    )
    model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader)
    train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.FullLoader)
    configs = (preprocess_config, model_config, train_config)

    main(args, configs)