import os
import json

import torch
import numpy as np

import hifigan
from model.dpptts import DPPTTS  

from model.dpp import DPP_model 
from model.pdpp import PDPP_model    

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_model(args, configs, device, train=False):
    (preprocess_config, model_config, train_config) = configs

    model = DPPTTS(preprocess_config, model_config).to(device)
    if args.restore_step:
        ckpt_path = os.path.join(
            train_config["path"]["ckpt_path"],
            "{}.pth.tar".format(args.restore_step),
        )
        ckpt = torch.load(ckpt_path)
        model.load_state_dict(ckpt["model"])

    if train:
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr = train_config["optimizer"]["init_lr"],
            betas = train_config["optimizer"]["betas"],
            eps = train_config["optimizer"]["eps"]
            )
        if args.restore_step:
            optimizer.load_state_dict(ckpt["optimizer"])
        model.train()
        return model, optimizer

    model.eval()
    model.requires_grad_ = False
    return model


def get_dpp_model(args, configs, device, train=False):
    (preprocess_config, model_config, train_config) = configs

    model = DPP_model(model_config).to(device)
    if args.dpp_step:
        ckpt_path = os.path.join(
            train_config["path"]["dpp_ckpt_path"],
            "d{}.pth.tar".format(args.dpp_step),
        )
        ckpt = torch.load(ckpt_path)
        model.load_state_dict(ckpt["model"])

    if train:
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr = train_config["optimizer"]["init_lr"],
            betas = train_config["optimizer"]["betas"],
            eps = train_config["optimizer"]["eps"]
            )
        if args.dpp_step:
            optimizer.load_state_dict(ckpt["optimizer"])
        model.train()
    
        return model, optimizer 

    model.eval()
    model.requires_grad_ = False
    return model

def get_pdpp_model(args, configs, device, train=False):
    (preprocess_config, model_config, train_config) = configs

    model = PDPP_model(model_config).to(device)
    if args.dpp_step:
        ckpt_path = os.path.join(
            train_config["path"]["dpp_ckpt_path"],
            "{}.pth.tar".format(args.pdpp_step),
        )
        ckpt = torch.load(ckpt_path)
        model.load_state_dict(ckpt["pmodel"])

    if train:
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr = train_config["optimizer"]["init_lr"],
            betas = train_config["optimizer"]["betas"],
            eps = train_config["optimizer"]["eps"]
            )
        if args.dpp_step:
            optimizer.load_state_dict(ckpt["optimizer_p"])
        model.train()
    
        return model, optimizer 

    model.eval()
    model.requires_grad_ = False
    return model

def get_param_num(model):
    num_param = sum(param.numel() for param in model.parameters())
    return num_param


def get_vocoder(config, device):
    name = config["vocoder"]["model"]
    speaker = config["vocoder"]["speaker"]

    if name == "MelGAN":
        if speaker == "LJSpeech":
            vocoder = torch.hub.load(
                "descriptinc/melgan-neurips", "load_melgan", "linda_johnson"
            )
        elif speaker == "universal":
            vocoder = torch.hub.load(
                "descriptinc/melgan-neurips", "load_melgan", "multi_speaker"
            )
        vocoder.mel2wav.eval()
        vocoder.mel2wav.to(device)
    elif name == "HiFi-GAN":
        with open("hifigan/config.json", "r") as f:
            config = json.load(f)
        config = hifigan.AttrDict(config)
        vocoder = hifigan.Generator(config)
        if speaker == "LJSpeech":
            ckpt = torch.load("hifigan/g_02570000")
        elif speaker == "universal":
            ckpt = torch.load("hifigan/generator_universal.pth.tar")
        vocoder.load_state_dict(ckpt["generator"])
        vocoder.eval()
        vocoder.remove_weight_norm()
        vocoder.to(device)

    return vocoder


def vocoder_infer(mels, vocoder, model_config, preprocess_config, lengths=None):
    name = model_config["vocoder"]["model"]
    with torch.no_grad():
        if name == "MelGAN":
            wavs = vocoder.inverse(mels / np.log(10))
        elif name == "HiFi-GAN":
            wavs = vocoder(mels).squeeze(1)

    wavs = (
        wavs.cpu().numpy()
        * preprocess_config["preprocessing"]["audio"]["max_wav_value"]
    ).astype("int16")
    wavs = [wav for wav in wavs]

    for i in range(len(mels)):
        if lengths is not None:
            wavs[i] = wavs[i][: lengths[i]]

    return wavs
