import argparse
import json
from pathlib import Path
from typing import List

import torch
import torch.nn as nn
import torchaudio
import torchaudio.functional as F
from torchaudio.pipelines import SQUIM_OBJECTIVE
from tqdm import tqdm
import numpy as np


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("gen_wav_path", type=Path)
    parser.add_argument("output_path", type=Path)
    parser.add_argument("--min_duration", type=float, default=4.0)
    parser.add_argument("--max_duration", type=float, default=10.0)
    return parser.parse_args()


class SquimObjective(nn.Module):
    """Speech Quality and Intelligibility Measures (SQUIM) model that predicts **objective** metric scores
    for speech enhancement (e.g., STOI, PESQ, and SI-SDR).

    Args:
        encoder (torch.nn.Module): Encoder module to transform 1D waveform to 2D feature representation.
        dprnn (torch.nn.Module): DPRNN module to model sequential feature.
        branches (torch.nn.ModuleList): Transformer branches in which each branch estimate one objective metirc score.
    """

    def __init__(
        self,
        encoder: nn.Module,
        dprnn: nn.Module,
        branch: nn.Module,
    ):
        super(SquimObjective, self).__init__()
        self.encoder = encoder
        self.dprnn = dprnn
        self.branch = branch

    def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
        """
        Args:
            x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`.

        Returns:
            List(torch.Tensor): List of score Tenosrs. Each Tensor is with dimension `(batch,)`.
        """
        if x.ndim != 2:
            raise ValueError(
                f"The input must be a 2D Tensor. Found dimension {x.ndim}."
            )
        x = x / (torch.mean(x**2, dim=1, keepdim=True) ** 0.5 * 20)
        out = self.encoder(x)
        out = self.dprnn(out)
        score = self.branch(out).squeeze(dim=1)
        return score


def get_new_model(model):
    new_model = SquimObjective(
        model.encoder,
        model.dprnn,
        model.branches[1],
    ).eval()
    # check pretrained weight
    for (key1, tensor1), (key2, tensor2) in zip(
        new_model.branch.state_dict().items(), model.branches[1].state_dict().items()
    ):
        assert key1 == key2 and torch.equal(tensor1, tensor2), "Tensors are not equal"
    return new_model


def main(
    gen_wav_path,
    output_path,
    min_duration=4.0,
    max_duration=10.0,
):
    if not output_path.exists():
        device = "cuda" if torch.cuda.is_available() else "cpu"
        objective_model = get_new_model(SQUIM_OBJECTIVE.get_model()).to(device)

        audio_lst = list(gen_wav_path.glob("*.wav"))

        print("Number of audio files: ", len(audio_lst))

        output = []
        for audio_path in tqdm(audio_lst):
            audio, sr = torchaudio.load(
                audio_path, format="wav", backend="soundfile"
            )

            if sr != 16000:
                audio = F.resample(audio, sr, 16000)

            # Cut audio into time_batch seconds
            with torch.no_grad():
                pesq_hyp = objective_model(audio.to(device))
            output.append({
                "original_path": str(audio_path),
                "duration": audio.size(1) / 16000,
                "pesq": float(pesq_hyp.cpu().numpy()[0]),
            })

        with open(output_path, "w") as f:
            for line in output:
                f.write(json.dumps(line) + "\n")
    else:
        with open(output_path, "r") as f:
            output = [json.loads(line) for line in f]

    # calc mean & std
    pesq_lst = []
    for line in output:
        if min_duration < line["duration"] < max_duration:
            pesq_lst.append(line["pesq"])
    print("Min & Max duration: ", min_duration, max_duration)
    print("Mean PESQ: ", np.mean(pesq_lst))
    print("Std PESQ: ", np.std(pesq_lst))
    return


if __name__ == "__main__":
    args = get_args()
    main(
        args.gen_wav_path,
        args.output_path,
        args.min_duration,
        args.max_duration,
    )
