from utils.model import get_pdpp_model, get_model, get_dpp_model, get_vocoder
from model.dpp_helper import DPP_helper
from make_sample_batch import make_sample_batch, sample_batch_preprocess
from scipy.io.wavfile import write 
from utils.dpp_tools import to_device, synthesize, get_random_phrase
from utils.tools import synth_samples
from utils.tools import to_device as to_device_tts
from tqdm import tqdm 
import os 
import torch
import argparse
import yaml
from synthesize import get_text
import numpy as np 
import pyworld as pw 
import torch.nn as nn 
from utils.model import vocoder_infer
from scipy.interpolate import interp1d

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def determinant(set):
    b, t = set.size(0), set.size(1)  
    matrix = torch.zeros(b,b)
    cos = nn.CosineSimilarity(dim=0)
    for i in range(b):
        for j in range(b):
            matrix[i][j] = cos(set[i], set[j])
    return torch.det(matrix)

def phone_pitch(pitch, duration):
    nonzero_ids = np.where(pitch != 0)[0]
    interp_fn = interp1d(
        nonzero_ids,
        pitch[nonzero_ids],
        fill_value=(pitch[nonzero_ids[0]], pitch[nonzero_ids[-1]]),
        bounds_error=False,
    )
    pitch = interp_fn(np.arange(0, len(pitch)))

    # Phoneme-level average
    pos = 0
    for i, d in enumerate(duration):
        if d > 0:
            pitch[i] = np.mean(pitch[pos : pos + d])
        else:
            pitch[i] = 0
        pos += d
    pitch = pitch[: len(duration)]

    return pitch 

def main(args, configs):    
    preprocess_config, model_config, train_config = configs 

    # Load model
    tts = get_model(args, configs, device, train=False)
    pdpp_model = get_pdpp_model(args, configs, device, train=False)
    dpp_model = get_dpp_model(args, configs, device, train=False)
    embed, prenet, encoder, sdp, spp, variance_adaptor, decoder, mel_linear = (tts.embedding,
                                                                        tts.prenet,
                                                                        tts.encoder,
                                                                        tts.variance_adaptor.dp,
                                                                        tts.variance_adaptor.pp,
                                                                        tts.variance_adaptor,
                                                                        tts.decoder,
                                                                        tts.mel_linear)  

    # DPP helper
    dpp_helper = DPP_helper(embed, prenet, encoder, variance_adaptor, decoder, mel_linear)
    # Load vocoder 
    vocoder = get_vocoder(model_config, device)
    
    # Make sample batch 
    sample_batch = make_sample_batch(text3, preprocess_config)
    sample_batch = to_device(sample_batch, device)
    text, text_lens, np_ids, cw_ids, np_nums = sample_batch_preprocess(sample_batch[3],
                sample_batch[4], sample_batch[-2], sample_batch[-1])
    
    random_phrases = get_random_phrase(np_nums, np_ids, cw_ids)  
    mean_std = 0
    features = torch.zeros(10, text_lens[0])
    det = 0 
    for _ in range(1):
        for i in range(10):
            with torch.no_grad():
                x, mask = tts.encode_text(text, text_lens, sample_batch[5])
                
                pkernel, pkernel_mask, pitch_vector, pitch_seq, h_seq, h_mask = pdpp_model(*random_phrases, 
                                    spp, x, mask, noise_scale=1.0)
                # MAP inference
                cw = [1]+ [2] *(len(np_nums)-2) + [1] 
                seq = dpp_helper.inference(pkernel, pitch_vector, pitch_seq, np_ids, cw)
                x = dpp_helper.adapt2(x, mask, pitch=seq[0].expand(3,-1))

                '''
                # Genearte mel-spectrograms with predicted prosodic features 
                predictions = dpp_helper.adapt(h_seq[:2], h_mask[:2], duration=None, pitch=seq[:2], noise_scale=0.8)
                mel_lens, durations, pitches = predictions[-2], predictions[3], predictions[1]
                output = dpp_helper.decode(predictions[0], predictions[-1])
                '''

                '''
                x = dpp_helper.adapt2(x, mask, pitch=seq)
                predictions = dpp_helper.adapt(h_seq[:2], h_mask[:2], duration=None, pitch=seq[:2], noise_scale=0.8)
                mel_lens, durations, pitches = predictions[-2], predictions[3], predictions[1]
                output = dpp_helper.decode(predictions[0], predictions[-1])
                '''

                dkernel, dkernel_mask, duration_vector, duration_seq, h_seq, h_mask = dpp_model(*random_phrases, 
                                    sdp, x, mask, noise_scale=0.8)
                
                # MAP inference 
                cw = [1]+ [2] *(len(np_nums)-2) + [1]
                seq = dpp_helper.inference(dkernel, duration_vector, duration_seq, np_ids, cw, half=True)

                # Expand hidden sequences with predicted durations
                rate = torch.sum(seq[1]) / torch.sum(seq[0])
                seq[0] = seq[0] * rate 
                durations = torch.ceil(seq[:2])
                features[i] = durations[0] 
                feature = durations[0].detach().cpu().numpy()
                new_feature = feature 
                mean_std += new_feature.std()
            
                x, mel_lens, mel_masks = dpp_helper.expand_seq(h_seq[:2], durations, mask[:2])
                output = dpp_helper.decode(x, mel_masks)
                
                wav_vanilla, wav = synthesize(output, mel_lens, vocoder, model_config, preprocess_config)
                write(os.path.join("sens2", f"dsamples{i+1}.wav"), 22050, wav)
                '''
                pitch, t = pw.dio(
                    wav_vanilla.astype(np.float64),
                    22050,
                    frame_period=256/ 22050 * 1000,
                )
                '''
                '''
                norm = np.linalg.norm(pitch[pitch!=0])
                pitch = pitch / norm 
                dr = durations[0].detach().cpu().numpy()
                pitch = phone_pitch(pitch, dr.astype(np.int64))
                mean_std += pitch.std() 
                '''
                #features[i] = torch.from_numpy(pitch)
                
                #path = os.path.join("para", "dune3.wav")
                #write(path, 22050, wav)

        print(determinant(features))
    '''
    with open("filelist/ljs_audio_test_filelist.txt", 'r') as f:
        lines = f.readlines()
        for line in tqdm(lines):
            name, text, phone = line.strip().split("|")
            name = name.split(".")[0]
            sample_batch = make_sample_batch(text, preprocess_config)
            if not sample_batch:
                continue
            sample_batch = to_device(sample_batch, device)
            with torch.no_grad():
                text, text_lens, np_ids, cw_ids, np_nums = sample_batch_preprocess(sample_batch[3],
                                    sample_batch[4], sample_batch[-2], sample_batch[-1])
                if sum(np_nums) < 2:
                    continue
                else:
                    random_phrases = get_random_phrase(np_nums, np_ids, cw_ids)   
                    x, mask = tts.encode_text(text, text_lens, sample_batch[5]) 
                    for i in range(1,5):
                        pkernel, pkernel_mask, pitch_vector, pitch_seq, h_seq, h_mask = dpp_model_p(*random_phrases, 
                                                spp, x, mask, noise_scale=.6)
                        # MAP inference 
                        cw = [1] + [2] * (len(np_nums)-2) + [1] 
                        seq = dpp_helper.inference(pkernel, pitch_vector, pitch_seq, np_ids, cw)
                        # Genearte mel-spectrograms with predicted prosodic features 
                        predictions = dpp_helper.adapt(h_seq[:2], h_mask[:2], duration=None, pitch=seq[:2], noise_scale=.6)
                        mel_lens, durations, pitches = predictions[-2], predictions[3], predictions[1]
                        output = dpp_helper.decode(predictions[0], predictions[-1])

                        _, wav = synthesize(output, mel_lens, vocoder, model_config, preprocess_config)
                        filename = "{}-{}.wav".format(name, i)
                        path = os.path.join("result2/", filename)
                        write(path, 22050, wav)
    
    '''
    '''
    with open("filelist/ljs_audio_test_filelist1.txt", 'r') as f:
        lines = f.readlines()
        for line in tqdm(lines):
            name, text = line.strip().split("|")
            name = name.split(".")[0]
            ids = raw_texts = name 
            speakers = np.array([1])
            texts = np.array([get_text(text, preprocess_config)])
            text_lens = np.array([len(texts[0])])
            batchs = [(ids, raw_texts, speakers, texts, text_lens, max(text_lens))]
            for batch in batchs:
                batch = to_device_tts(batch, device)
                with torch.no_grad():
                    output = tts(
                        *(batch[2:]),
                        p_control=1.0,
                        e_control=1.0,
                        d_control=1.0,
                    )
                mel_len = output[8][0].item()
                mel = output[0].transpose(1,2)
                lengths = output[8] * preprocess_config["preprocessing"]["stft"]["hop_length"]
                wav = vocoder_infer(mel, vocoder, model_config, preprocess_config, lengths=lengths)
                write(os.path.join("result3/", "{}.wav".format(name)), 22050, wav[0])
    '''

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--restore_step", type=int, default=0)
    parser.add_argument("--dpp_step", type=int, default=0)
    parser.add_argument("--pdpp_step",type=int, default=0)
    parser.add_argument(
        "-p",
        "--preprocess_config",
        type=str,
        required=True,
        help="path to preprocess.yaml",
    )
    parser.add_argument(
        "-m", "--model_config", type=str, required=True, help="path to model.yaml"
    )
    parser.add_argument(
        "-t", "--train_config", type=str, required=True, help="path to train.yaml"
    )
    

    
    args = parser.parse_args()

    # Read Config
    preprocess_config = yaml.load(
        open(args.preprocess_config, "r"), Loader=yaml.FullLoader
    )
    model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader)
    train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.FullLoader)
    configs = (preprocess_config, model_config, train_config)

    main(args, configs)

