import sys, os

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
import torch.nn as nn
import numpy as np
import argparse

from tqdm import tqdm
from functools import partial
from argparse import RawTextHelpFormatter
from multiprocessing.pool import ThreadPool

from speaker.models.lstm import LSTMSpeakerEncoder
from speaker.config import SpeakerEncoderConfig
from speaker.utils.audio import AudioProcessor
from speaker.infer import read_json


def get_spk_wavs(dataset_path, output_path):
    wav_files = []
    os.makedirs(f"{output_path}", exist_ok=True)
    for spks in os.listdir(dataset_path):
        if os.path.isdir(f"{dataset_path}/{spks}"):
            os.makedirs(f"{output_path}/{spks}", exist_ok=True)
            for file in os.listdir(f"{dataset_path}/{spks}"):
                if file.endswith(".wav"):
                    wav_files.append(f"{dataset_path}/{spks}/{file}")
        elif spks.endswith(".wav"):
            wav_files.append(f"{dataset_path}/{spks}")
    return wav_files


def process_wav(
    wav_file, dataset_path, output_path, args, speaker_encoder_ap, speaker_encoder
):
    waveform = speaker_encoder_ap.load_wav(wav_file, sr=speaker_encoder_ap.sample_rate)
    spec = speaker_encoder_ap.melspectrogram(waveform)
    spec = torch.from_numpy(spec.T)
    if args.use_cuda:
        spec = spec.cuda()
    spec = spec.unsqueeze(0)
    embed = speaker_encoder.compute_embedding(spec).detach().cpu().numpy()
    embed = embed.squeeze()
    embed_path = wav_file.replace(dataset_path, output_path)
    embed_path = embed_path.replace(".wav", ".spk")
    np.save(embed_path, embed, allow_pickle=False)


def extract_speaker_embeddings(
    wav_files,
    dataset_path,
    output_path,
    args,
    speaker_encoder_ap,
    speaker_encoder,
    concurrency,
):
    bound_process_wav = partial(
        process_wav,
        dataset_path=dataset_path,
        output_path=output_path,
        args=args,
        speaker_encoder_ap=speaker_encoder_ap,
        speaker_encoder=speaker_encoder,
    )

    with ThreadPool(concurrency) as pool:
        list(tqdm(pool.imap(bound_process_wav, wav_files), total=len(wav_files)))


class SpkEncoderHelper(nn.Module):
    def __init__(self, root_path=None):
        super(SpkEncoderHelper, self).__init__()
        # python prepare/preprocess_speaker.py data_svc/waves-16k/ data_svc/speaker
        # model
        self.model_path = os.path.join("speaker_pretrain", "best_model.pth.tar")
        self.config_path = os.path.join("speaker_pretrain", "config.json")
        if root_path:
            self.model_path = os.path.join(root_path, self.model_path)
            self.config_path = os.path.join(root_path, self.config_path)
        # config
        self.config_dict = read_json(self.config_path)

        # model
        self.config = SpeakerEncoderConfig(self.config_dict)
        self.config.from_dict(self.config_dict)

        self.speaker_encoder = LSTMSpeakerEncoder(
            self.config.model_params["input_dim"],
            self.config.model_params["proj_dim"],
            self.config.model_params["lstm_dim"],
            self.config.model_params["num_lstm_layers"],
        )
        self.use_cuda = True
        self.speaker_encoder.load_checkpoint(
            self.model_path, eval=True, use_cuda=self.use_cuda
        )
        # preprocess
        self.speaker_encoder_ap = AudioProcessor(**self.config.audio)
        # normalize the input audio level and trim silences
        self.speaker_encoder_ap.do_sound_norm = True
        self.speaker_encoder_ap.do_trim_silence = True

    def forward(self, wav_files, infer=False):
        embeds = torch.zeros(len(wav_files), self.speaker_encoder.proj_dim)
        device = next(self.speaker_encoder.parameters()).device
        for i, wav_file in enumerate(wav_files):
            waveform = self.speaker_encoder_ap.load_wav(
                wav_file, sr=self.speaker_encoder_ap.sample_rate
            )

            spec = self.speaker_encoder_ap.melspectrogram(waveform)
            spec = torch.from_numpy(spec.T)
            if self.use_cuda:
                spec = spec.to(device)
            spec = spec.unsqueeze(0)
            embed = self.speaker_encoder.compute_embedding(spec, infer=infer)
            embeds[i] = embed
        return embeds


if __name__ == "__main__":

    parser = argparse.ArgumentParser(
        description="""Compute embedding vectors for each wav file in a dataset.""",
        formatter_class=RawTextHelpFormatter,
    )
    parser.add_argument("dataset_path", type=str, help="Path to dataset waves.")
    parser.add_argument(
        "output_path", type=str, help="path for output speaker/speaker_wavs.npy."
    )
    parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True)
    parser.add_argument(
        "-t",
        "--thread_count",
        help="thread count to process, set 0 to use all cpu cores",
        dest="thread_count",
        type=int,
        default=1,
    )
    args = parser.parse_args()
    dataset_path = args.dataset_path
    output_path = args.output_path
    thread_count = args.thread_count
    # model
    args.model_path = os.path.join("speaker_pretrain", "best_model.pth.tar")
    args.config_path = os.path.join("speaker_pretrain", "config.json")
    # config
    config_dict = read_json(args.config_path)

    # model
    config = SpeakerEncoderConfig(config_dict)
    config.from_dict(config_dict)

    speaker_encoder = LSTMSpeakerEncoder(
        config.model_params["input_dim"],
        config.model_params["proj_dim"],
        config.model_params["lstm_dim"],
        config.model_params["num_lstm_layers"],
    )

    speaker_encoder.load_checkpoint(args.model_path, eval=True, use_cuda=args.use_cuda)

    # preprocess
    speaker_encoder_ap = AudioProcessor(**config.audio)
    # normalize the input audio level and trim silences
    speaker_encoder_ap.do_sound_norm = True
    speaker_encoder_ap.do_trim_silence = True

    wav_files = get_spk_wavs(dataset_path, output_path)

    if thread_count == 0:
        process_num = os.cpu_count()
    else:
        process_num = thread_count

    extract_speaker_embeddings(
        wav_files,
        dataset_path,
        output_path,
        args,
        speaker_encoder_ap,
        speaker_encoder,
        process_num,
    )
