from typing import Optional, Callable, Tuple, Union, List
import os
import json
import torch
import torchaudio
import logging
from omegaconf import DictConfig, OmegaConf

from pado.core.base.dataset import PadoDataset
from pado.data.datasets import register_dataset

from pado.nn.utils import pad_sequence
from pado.tasks.asr.tools.text_preprocess import english_cleaners

__all__ = ["PadoLibriSpeech"]

logger = logging.getLogger("pado")


@register_dataset("PadoLibriSpeech")
class PadoLibriSpeech(PadoDataset):

    def __init__(self,
                 data_dir: str,
                 mode: str = "test-clean",
                 audio_transform: Optional[Callable] = None,
                 text_transform: Optional[Callable] = None,
                 max_seconds: Optional[float] = None,
                 clean_script: bool = True,
                 script_only: bool = False,
                 audio_ext: str = ".flac") -> None:
        super().__init__()

        if not script_only:
            if audio_transform is None:
                raise ValueError("LibriSpeech dataset requires input audio_transform.")
            if text_transform is None:
                raise ValueError("LibriSpeech dataset requires target text_transform.")

        self.audio_transform = audio_transform
        self.text_transform = text_transform

        mode = mode.lower().replace("_", "-")
        if mode not in ("dev-clean", "dev-other", "test-clean", "test-other",
                        "train-clean-100", "train-clean-360", "train-other-500"):
            raise ValueError(f"LibriSpeech mode {mode} is not supported.")
        self.mode = mode

        self.data_root = os.path.join(data_dir, "LibriSpeech", mode)
        if not os.path.isdir(self.data_root):
            raise ValueError(f"LibriSpeech {mode} path {self.data_root} does not exist.")

        if audio_ext[0] != ".":
            audio_ext = "." + audio_ext
        self.audio_ext = audio_ext

        manifest_path = os.path.join(data_dir, "LibriSpeech", f"pado-{mode}.manifest.json")
        if os.path.isfile(manifest_path):
            with open(manifest_path, "r") as f:
                manifest = json.load(f)
        else:
            logger.info(f"Creating LibriSpeech {mode} manifest to {manifest_path}.")
            local_rank = os.environ.get("LOCAL_RANK", 0)
            if local_rank == 0:  # local rank not given or is 0.
                manifest = self._build_manifest(self.data_root, self.mode, self.audio_ext)
                with open(manifest_path, "w") as f:
                    json.dump(manifest, f)
            else:
                raise RuntimeError(f"LibriSpeech manifest for {mode} not yet created.")

        # filter by max_seconds, heuristic
        if max_seconds is None:  # auto selection
            max_seconds = 16.7 if ("train" in mode) else -1

        self.max_seconds = max_seconds
        if (max_seconds > 0) and (not script_only):
            manifest = [m for m in manifest if (m["num_frames"] / m["sample_rate"] < max_seconds)]

        manifest = sorted(manifest, key=lambda x: x["num_frames"])  # ascending

        self.clean_script = clean_script
        if self.clean_script:  # pre-cleaning script
            for m in manifest:
                if "clean_script" in m.keys():
                    m["script"] = m["clean_script"]
                else:  # backward compatibility
                    s = m["script"]
                    s = english_cleaners(s, lowercase=False)
                    m["script"] = s

        self.dataset = manifest
        self.script_only = script_only

    @staticmethod
    def _build_manifest(data_root: str, mode: str, audio_ext: str) -> list:
        m = []
        p0 = data_root

        speaker_ids = os.listdir(p0)
        for a, speaker in enumerate(speaker_ids):
            p1 = os.path.join(p0, speaker)
            chapter_ids = os.listdir(p1)
            for b, chapter in enumerate(chapter_ids):
                p2 = os.path.join(p1, chapter)
                logger.info(f"... {mode}: speaker {a}/{len(speaker_ids)}, chapter {b}/{len(chapter_ids)}.")

                # load files
                # each utterance file has format of:
                # speaker_id/chapter_id/speaker-chapter-utterance.flac
                # ... we only have to save speech name.
                utterance_and_script_files = os.listdir(p2)
                utterance_files = [u for u in utterance_and_script_files if audio_ext in u]
                assert len(utterance_files) > 0
                script_file = [s for s in utterance_and_script_files if ".trans.txt" in s]
                assert len(script_file) == 1
                script_path = os.path.join(p2, script_file[0])

                scripts = dict()
                with open(script_path, "r") as f:
                    for line in f.readlines():
                        s = line.replace("\n", "").strip().split(" ", 1)
                        scripts[s[0]] = s[1]  # 672-11235-8774, I HAVE A CAT

                for u in utterance_files:
                    u_key = u.replace(audio_ext, "")  # 672-11235-8744.flac
                    u_speaker, u_chapter, u_utterance = u_key.split("-")  # 672, 11235, 8744
                    assert (u_speaker == speaker) and (u_chapter == chapter)

                    audio_path = os.path.join(p2, u)
                    audio_info = torchaudio.info(audio_path)
                    # this line may not work with torchaudio <= 0.7.0.

                    s = scripts[u_key]
                    clean_s = english_cleaners(s.upper(), lowercase=False)

                    k = {
                        "file_name": u,
                        "speaker": int(u_speaker),
                        "chapter": int(u_chapter),
                        "utterance": int(u_utterance),
                        "script": s,  # before cleaning, raw text
                        "clean_script": clean_s,
                        "sample_rate": audio_info.sample_rate,
                        "num_frames": audio_info.num_frames,
                        "num_channels": audio_info.num_channels,
                    }
                    m.append(k)
        return m

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index: int) -> Union[Tuple[torch.Tensor, Union[torch.Tensor, str], str, str], str]:
        """
        :return:
            waveform -> audio_transform(waveform)               (1, feature_dim, feature_length)    float
            target -> text_transform(script) -> indices         (token_length)      long
            script
            audio_path (absolute path)
        """
        info = self.dataset[index]
        audio_path = os.path.join(self.data_root, str(info["speaker"]), str(info["chapter"]), info["file_name"])

        script = info["script"]
        # if self.clean_script:  # already cleaned
        #     script = english_cleaners(script, lowercase=False)

        if not self.script_only:
            waveform, sample_rate = torchaudio.load(audio_path)
            assert sample_rate == 16000

            feature = self.audio_transform(waveform)  # (1, wave_length) -> (1, num_filter_banks, num_windows)
            target = self.text_transform(script)  # text -> (num_tokens,)

            return feature, target, script, audio_path
        else:
            return script

    @staticmethod
    def collate_fn(batch) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
                                   List[str], List[str]]:
        """
        Batch include [feature, target, script, audio_path]

        :return:
            features:           (batch_size, max_feature_length, feature_dim)   float
            feature_length:     (batch_size,)                                   long
            targets:            (batch_size, max_target_length)                 long
            target_length:      (batch_size,)                                   long
        """
        features = [b[0].squeeze(0).transpose(0, 1) for b in batch]
        # feature of LibriSpeech MelSpectrogram have channel at very first, which should be 1.
        # (1, 80, len) -> (80, len) ->  (len, 80)

        assert isinstance(features[0], torch.Tensor)
        feature_length = torch.tensor([f.shape[0] for f in features], dtype=torch.long)

        targets = [b[1] for b in batch]
        assert isinstance(targets[0], torch.Tensor)  # should be encoded and wrapped as Tensor.
        target_length = torch.tensor([t.shape[0] for t in targets], dtype=torch.long)

        utterances = [b[2] for b in batch]
        paths = [b[3] for b in batch]
        assert len(features) == len(targets) == len(utterances) == len(paths)  # should be same as batch_size.

        features = pad_sequence(features, output_batch_first=True, padding_value=0.0, pad_to_multiple=1)
        targets = pad_sequence(targets, output_batch_first=True, padding_value=0, pad_to_multiple=1)

        return features, feature_length, targets, target_length, utterances, paths

    @classmethod
    def from_config(cls, cfg: DictConfig, transform=None, target_transform=None) -> "PadoLibriSpeech":
        cfg = OmegaConf.to_container(cfg, resolve=True)
        return cls(audio_transform=transform, text_transform=target_transform, **cfg)
