import argparse
import glob
import os
import sys
from pathlib import Path

import pandas as pd
import torch
from scipy import io
from tqdm import tqdm

sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from extract_feature.WavLM.WavLM import WavLM, WavLMConfig


def extract_wavlm(model, wavfile, savefile, layer=24):
    """
    Args:
        layer (int): varies from 1 to 24.
    """

    wav_input_16khz = torch.load(wavfile).cuda()

    ############################################
    # extract the representation of each layer #
    ############################################
    with torch.no_grad():
        if model.cfg.normalize:
            wav_input_16khz = torch.nn.functional.layer_norm(wav_input_16khz, wav_input_16khz.shape)
        model.cuda()
        try:
            rep, layer_results = model.extract_features(
                wav_input_16khz,
                output_layer=model.cfg.encoder_layers,
                ret_layer_results=True,
            )[0]
            layer_reps = [
                x.transpose(0, 1) for x, _ in layer_results
            ]  # layer_results: [(x, z), (x, z), ...] z is attn_weight
        except RuntimeError as e:
            print(f"Kernel size can't be greater than actual input size! \n {wavfile}")
            raise e
    dict = {}
    for l in range(layer):
        dict[f"WavLM{l + 1}"] = layer_reps[l].squeeze(dim=0).cpu().detach().numpy()  # (t, 768)  / (t, 1024)
    io.savemat(savefile, dict)


@torch.no_grad()
def main(args):
    wavdir = args.wavdir
    savedir = args.savedir
    ckpt = args.WavLM
    layer = args.layer
    csvfile = args.csvfile
    gpu = args.gpu

    os.environ["CUDA_VISIBLE_DEVICES"] = gpu

    checkpoint = torch.load(ckpt)
    cfg = WavLMConfig(checkpoint["cfg"])
    model = WavLM(cfg)
    model.load_state_dict(checkpoint["model"])
    model.eval().cuda()

    if not os.path.exists(savedir):
        os.makedirs(savedir)

    if csvfile is not None:
        df = pd.read_csv(csvfile)
        # file_names = [Path(i).name for i in df["file_path"].tolist()]
        file_names = [i for i in df["trimmed_tensor_path"].tolist() if isinstance(i, str)]

    else:
        file_names = os.listdir(wavdir)

    already_augmented_data = set([Path(i).name for i in glob.glob(f"{savedir}/*.pt")])
    # file_names = ["/app/data/preprocessed/trim_pt/MSP-PODCAST_2942_0214.pt",
    #               "/app/data/preprocessed/trim_pt/MSP-PODCAST_0052_0041.pt"]

    for i, name in tqdm(
        enumerate(file_names),
        desc="Extracting features...",
        leave=False,
        total=len(file_names),
    ):
        wavfile = f"{wavdir}{name}"  # Path(wavdir) / name
        # wavfile = name
        savefile = f"{savedir}/{Path(name).name}"  # Path(savedir) / name
        if Path(name).name not in already_augmented_data:
            extract_wavlm(model, wavfile, savefile, layer=layer)
            torch.cuda.empty_cache()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--wavdir", type=str, help="wav directory")
    parser.add_argument("--savedir", type=str, help="save directory")
    parser.add_argument("--WavLM", type=str, default=None, help="ckpt of WavLM model")
    parser.add_argument("--layer", type=int, default=24, help="layer index, varies from 1 to 24")
    parser.add_argument("--csvfile", type=str, default=None, help="csv file with name column")
    parser.add_argument("--gpu", type=str, default="0", help="gpu id")
    args = parser.parse_args()

    main(args)
