import argparse
import os
import torch
import yaml
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from utils.model import get_model, get_vocoder, get_param_num, get_dpp_model, get_pdpp_model
from utils.dpp_tools import to_device, log, dpp_inference, synthesize  
from model.dpp import DPP_model 
from dpp_dataset import Dataset  

from dpp_evaluate import evaluate 
from utils.dpp_tools import get_random_phrase 
from model.diversity_loss import diversity_loss   
from model.diversity_loss import ref_diversity_loss 
from make_sample_batch import make_sample_batch, sample_batch_preprocess
from utils.tools import get_mask_from_lengths  
from model.dpp_helper import DPP_helper

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def main(args, configs):
    print("Prepare training...")

    preprocess_config, model_config, train_config = configs 

    #Get dataset 
    dataset = Dataset(
        args.train_file, preprocess_config, train_config, sort=False, drop_last=True)

    batch_size = train_config['optimizer']['batch_size']
    group_size = 4  
    assert group_size * batch_size < len(dataset)
    loader = DataLoader(
        dataset,
        batch_size = batch_size * group_size,
        shuffle=True,
        collate_fn=dataset.collate_fn)

    # Prepare model 
    tts = get_model(args, configs, device, train=False)
    embed, prenet, encoder, spp, variance_adaptor, decoder, mel_linear = (tts.embedding,
                                                                        tts.prenet,
                                                                        tts.encoder,
                                                                        tts.variance_adaptor.pp,
                                                                        tts.variance_adaptor,
                                                                        tts.decoder,
                                                                        tts.mel_linear)  

    # DPP helper 
    dpp_helper = DPP_helper(embed, prenet, encoder, variance_adaptor, decoder, mel_linear)    
    dpp_helper = nn.DataParallel(dpp_helper)
    # DPP model 
    dpp_model_p, optimizer_p = get_pdpp_model(args, configs, device, train=True)
    # Scheduler 
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer_p,
        gamma=0.999875, last_epoch=-1, verbose=False)
    # Load vocoder 
    vocoder = get_vocoder(model_config, device)

    # Init logger 
    for p in train_config['path'].values():
        os.makedirs(p, exist_ok=True)
    train_log_path = os.path.join(train_config["path"]["log_path"], "train")
    val_log_path = os.path.join(train_config["path"]["log_path"], "val")
    os.makedirs(train_log_path, exist_ok=True)
    os.makedirs(val_log_path, exist_ok=True)
    train_logger = SummaryWriter(train_log_path)
    val_logger = SummaryWriter(val_log_path)

    # Training 
    step = args.dpp_step +  1  
    epoch = 1 
    grad_acc_step = train_config["optimizer"]["grad_acc_step"]
    grad_clip_thresh = train_config["optimizer"]["grad_clip_thresh"]
    total_step = train_config["step"]["total_step"]
    log_step = train_config["step"]["log_step"]
    save_step  = train_config["step"]["save_step"]
    synth_step = train_config["step"]["synth_step"]
    val_step = train_config["step"]["val_step"]
    num_can = model_config["DPP"]["num_can"]
    noise_scale = model_config["variance_predictor"]["noise_scale"]

    # sample batch for synthesize 
    text = "The Fremen are a tribal society of fierce warriors who have adapted to survive their environment, living in cave systems and riding giant sandworms."
    sample_batch = make_sample_batch(text, preprocess_config)
    sample_batch = to_device(sample_batch, device)

    # Loss function 
    #d_loss = diversity_loss() 
    d_loss = ref_diversity_loss()
    outer_bar = tqdm(total=total_step, desc="Training", position=0)
    outer_bar.n = args.dpp_step
    outer_bar.update()
    while True:
        inner_bar = tqdm(total=len(loader), desc="Epoch {}".format(epoch), position=1)
        for batchs in loader:
            for batch in batchs:
                batch = to_device(batch, device)
                
                # Get idxs for random phrases  
                random_phrases = get_random_phrase(batch[-3], batch[-2], batch[-1])  
                
                # Get hidden seqs 
                x, mask = tts.encode_text(batch[3], batch[4], batch[5])
                # Get kernel and kernel mask 
                pkernel, pkernel_mask, p_vector, _, _, _ = dpp_model_p(*random_phrases, spp, x, mask, noise_scale=noise_scale)
                # Calculate diversity loss  
                loss = d_loss(pkernel, pkernel_mask, num_cw=2)
                loss = loss / grad_acc_step 

                loss.backward()
                
                if step % grad_acc_step == 0:
                    # Clipping gradients to avoid gradient explostion
                    nn.utils.clip_grad_norm_(dpp_model_p.parameters(), grad_clip_thresh)
                    # Update weights 
                    optimizer_p.step()
                    optimizer_p.zero_grad()
                    scheduler.step()

                if step % log_step ==0:
                    message1 = "Step {}/{}, ".format(step, total_step)
                    message2 = "Total loss: {:.4f} ".format(loss)

                    outer_bar.write(message1 + message2)
                    log(train_logger, step, loss=loss.item())
                
                if step % synth_step == 0:
                    dpp_model_p.eval()
                    with torch.no_grad():
                        text, text_lens, np_ids, cw_ids, np_nums = sample_batch_preprocess(sample_batch[3],
                                    sample_batch[4], sample_batch[-2], sample_batch[-1])
                        
                        random_phrases = get_random_phrase(np_nums, np_ids, cw_ids)   
                        x, mask = tts.encode_text(text, text_lens, sample_batch[5])                  
                        pkernel, pkernel_mask, pitch_vector, pitch_seq, h_seq, h_mask = dpp_model_p(*random_phrases, 
                                            spp, x, mask, noise_scale=noise_scale)
                        
                        # MAP inference
                        cw = [1] + [2] *(len(np_nums)-2) + [1] 
                        seq = dpp_helper.module.inference(pkernel, pitch_vector, pitch_seq, np_ids, cw)

                        # Genearte mel-spectrograms with predicted prosodic features 
                        predictions = dpp_helper.module.adapt(h_seq[:2], h_mask[:2], duration=None, pitch=seq[:2], noise_scale=noise_scale)
                        mel_lens, durations, pitches = predictions[-2], predictions[3], predictions[1]
                        output = dpp_helper.module.decode(predictions[0], predictions[-1])
        
                        log(
                            train_logger,
                            step,
                            duration=durations
                            )
                        log(
                            train_logger,
                            step,
                            pitch=pitches
                            )

                        wav_vanilla, wav = synthesize(output, mel_lens, vocoder, model_config, preprocess_config)
                        sampling_rate = preprocess_config["preprocessing"]["audio"][
                            "sampling_rate"
                        ]
                        log(
                            train_logger,
                            audio= wav_vanilla,
                            sampling_rate=sampling_rate,
                            tag="Training/step_{}_{}_synthesized".format(step, "wav_vanilla"),
                        )
                        log(
                            train_logger,
                            audio= wav,
                            sampling_rate=sampling_rate,
                            tag="Training/step_{}_{}_synthesized".format(step, "wav"),
                            )
                    
                    dpp_model_p.train()

                if step % save_step == 0:
                    torch.save(
                        {
                            "pmodel": dpp_model_p.state_dict(),
                            "optimizer_p": optimizer_p.state_dict(), 
                        },
                        os.path.join(
                            train_config["path"]["dpp_ckpt_path"],
                            "{}.pth.tar".format(step),
                        ),
                    )
                if step == total_step:
                    quit()
                step +=1  
                outer_bar.update(1)
            inner_bar.update(1)
        epoch +=1  

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--restore_step", type=int, default=0)
    parser.add_argument("--dpp_step", type=int, default=0)
    parser.add_argument(
        "-p",
        "--preprocess_config",
        type=str,
        required=True,
        help="path to preprocess.yaml",
    )
    parser.add_argument(
        "-m", "--model_config", type=str, required=True, help="path to model.yaml"
    )
    parser.add_argument(
        "-t", "--train_config", type=str, required=True, help="path to train.yaml"
    )
    parser.add_argument(
        "-tf", "--train_file", type=str, required=False, default="filelist/ljs_audio_train_filelist1.txt", 
        help="path to train filelist")

    parser.add_argument(
        "-vf", "--val_file", type=str, required=False, default="filelist/ljs_audio_val_filelist.txt",
        help="path to validation filelist")

    args = parser.parse_args()

    # Read Config
    preprocess_config = yaml.load(
        open(args.preprocess_config, "r"), Loader=yaml.FullLoader
    )
    model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader)
    train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.FullLoader)
    configs = (preprocess_config, model_config, train_config)

    main(args, configs)


