# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Dataset of music tracks with rich metadata.
"""
import argparse
import copy
import gzip
import json
import logging
import random
import sys
import typing as tp
from dataclasses import dataclass, field, fields
from functools import partial
from pathlib import Path

import h5py
import torch
import torch.nn.functional as F

from ..modules.conditioners import (AlphabetTokenizer, BPETokenizer,
                                    ConditioningAttributes,
                                    JointEmbedCondition, Tokenizer,
                                    WavCondition)
from .audio import audio_info, audio_read
from .audio_dataset import (DEFAULT_EXTS, AudioMeta, SegmentInfo,
                            _resolve_audio_meta, find_audio_files,
                            save_audio_meta)
from .audio_utils import convert_audio
from .info_audio_dataset import (AudioInfo, InfoAudioDataset,
                                 clusterify_all_meta)
from .zip import PathInZip

logger = logging.getLogger(__name__)


@dataclass(order=True)
class SpeechMeta(AudioMeta):
    path: str
    duration: float
    sample_rate: int
    amplitude: tp.Optional[float] = None
    weight: tp.Optional[float] = None
    use_idxes: tp.Optional[tp.List[int]] = None
    # info_path is used to load additional information about the audio file that is stored in zip files.
    info_path: tp.Optional[PathInZip] = None


@dataclass
class SpeechInfo(AudioInfo):
    """Segment info augmented with music metadata."""

    text: tp.Optional[str] = None
    description: tp.Optional[str] = None
    text_history: tp.Optional[str] = None
    # original wav accompanying the metadata
    self_wav: tp.Optional[WavCondition] = None
    # dict mapping attributes names to tuple of wav, text and metadata
    joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)

    @property
    def has_speech_meta(self) -> bool:
        return self.transcript is not None

    def to_condition_attributes(self) -> ConditioningAttributes:
        out = ConditioningAttributes()
        for _field in fields(self):
            key, value = _field.name, getattr(self, _field.name)
            if key == "self_wav":
                out.wav[key] = value
            elif key == "joint_embed":
                for embed_attribute, embed_cond in value.items():
                    out.joint_embed[embed_attribute] = embed_cond
            else:
                if isinstance(value, list):
                    value = " ".join(value)
                out.text[key] = value
        return out

    @staticmethod
    def attribute_getter(attribute):
        # NOTE: If you need to preprocess the attribute, add a function here
        preprocess_func = None
        return preprocess_func

    @classmethod
    def from_dict(cls, dictionary: dict, fields_required: bool = False):
        _dictionary: tp.Dict[str, tp.Any] = {}
        post_init_attributes = ["audio_tokens", "self_wav", "joint_embed"]
        optional_fields = ["description"]

        for _field in fields(cls):
            if _field.name in post_init_attributes:
                continue
            elif _field.name not in dictionary:
                if fields_required and _field.name not in optional_fields:
                    raise KeyError(f"Unexpected missing key: {_field.name}")
            else:
                preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(
                    _field.name
                )
                value = dictionary[_field.name]
                if preprocess_func:
                    value = preprocess_func(value)
                _dictionary[_field.name] = value
        return cls(**_dictionary)


def load_speech_meta(
    path: tp.Union[str, Path], resolve: bool = True, fast: bool = True
) -> tp.List[SpeechMeta]:
    if str(path).lower().endswith(".h5"):
        """
        通常の h5 ファイルへの path に対して，
        /path/to/h5/{split}/{detailed_info_name}/hoge.h5
        というパスにして渡されることを想定
        """
        path = Path(path)
        detailed_info_name = path.parent.stem
        logger.debug(f"Loading detailed info from {detailed_info_name}")
        path = path.parent.parent / path.name
        split = path.parent.stem
        assert split in ["train", "valid", "test"]
        path = path.parent.parent / path.name
        assert path.exists(), f"Path {path} does not exist."
        meta = []
        skipped_files_wo_detailed_info = 0
        skipped_files_wo_used_idxes = 0
        field_names = [f.name for f in fields(AudioMeta)]
        with h5py.File(path, 'r') as h5:
            phase_group = h5[split]
            for wav_id in phase_group.keys():
                if detailed_info_name not in phase_group[wav_id]:
                    skipped_files_wo_detailed_info += 1
                    continue
                data = {
                    k: phase_group[wav_id][k][()] for k in field_names if k in phase_group[wav_id]
                }
                for k, v in data.items():
                    if isinstance(v, bytes):
                        data[k] = v.decode()
                data["path"] = f"{path}:{split}/{wav_id}/wav"
                data["info_path"] = f"{path}:{split}/{wav_id}/{detailed_info_name}"
                data["use_idxes"] = phase_group[wav_id][f"{detailed_info_name}:use_idxes"][()]
                if len(data["use_idxes"]) == 0:
                    skipped_files_wo_used_idxes += 1
                    continue
                m = SpeechMeta.from_dict(data)
                meta.append(m)
        logger.debug(f"Skipped {skipped_files_wo_detailed_info} files due to missing detailed info: {detailed_info_name}.")
        logger.debug(f"Skipped {skipped_files_wo_used_idxes} files due to zero use_idxes.")
    else:
        open_fn = gzip.open if str(path).lower().endswith(".gz") else open
        with open_fn(path, "rb") as fp:  # type: ignore
            lines = fp.readlines()
        meta = []
        for line in lines:
            d = json.loads(line)
            m = SpeechMeta.from_dict(d)
            if resolve:
                m = _resolve_audio_meta(m, fast=fast)
            meta.append(m)
    return meta


class SpeechDataset(InfoAudioDataset):
    def __init__(
        self,
        meta: tp.List[SpeechMeta],
        segment_duration: tp.Optional[float] = None,
        shuffle: bool = True,
        num_samples: tp.Optional[int] = None,
        sample_rate: int = 48_000,
        channels: int = 2,
        pad: bool = True,
        sample_on_duration: bool = True,
        sample_on_weight: bool = True,
        sample_on_weight_for_utter: bool = False,
        min_segment_ratio: tp.Optional[float] = None,
        max_read_retry: int = 10,
        return_info: bool = False,
        min_audio_duration: tp.Optional[float] = None,
        max_audio_duration: tp.Optional[float] = None,
        shuffle_seed: int = 0,
        load_wav: bool = True,
        permutation_on_files: bool = False,
        # original
        text_history_length: int = 0,
        use_current_text_for_history: bool = False,
        info_fields_required: bool = True,
        buffer_length: float = 0.0,
        max_audio_duration_for_short: tp.Optional[float] = None,
    ):
        # InfoAudioDataset class init
        meta = clusterify_all_meta(meta)

        assert (
            len(meta) > 0
        ), "No audio meta provided to SpeechDataset. Please check loading of audio meta."
        assert segment_duration is None or segment_duration > 0
        self.segment_duration = segment_duration
        assert self.segment_duration is None, "we don't support segment_duration"
        self.min_segment_ratio = None  # we don't support min_segment_ratio
        self.max_audio_duration = max_audio_duration
        self.min_audio_duration = min_audio_duration
        if self.min_audio_duration is not None and self.max_audio_duration is not None:
            assert self.min_audio_duration <= self.max_audio_duration
        # We skip the filtering of duration for now
        self.meta = meta
        # self.meta: tp.List[SpeechMeta] = self._filter_duration(meta)
        assert len(self.meta)  # Fail fast if all data has been filtered.
        self.total_duration = sum(d.duration for d in self.meta)

        self.num_samples = num_samples if num_samples is not None else len(self.meta)
        self.shuffle = shuffle
        self.sample_rate = sample_rate
        self.channels = channels
        self.pad = True  # we always pad
        self.sample_on_weight = sample_on_weight
        self.sample_on_duration = sample_on_duration
        self.sample_on_weight_for_utter = sample_on_weight_for_utter
        self.sampling_probabilities = self._get_sampling_probabilities()
        self.max_read_retry = max_read_retry
        self.return_info = True  # we always return info
        self.shuffle_seed = shuffle_seed
        self.current_epoch: tp.Optional[int] = None
        self.load_wav = load_wav
        if not load_wav:
            assert segment_duration is not None
        self.permutation_on_files = permutation_on_files
        if permutation_on_files:
            assert not self.sample_on_duration
            assert not self.sample_on_weight
            assert self.shuffle

        self.text_history_length = text_history_length
        self.use_current_text_for_history = use_current_text_for_history
        self.info_fields_required = info_fields_required
        self.buffer_length = buffer_length
        self.max_audio_duration_for_short = max_audio_duration_for_short

    @classmethod
    def from_meta(cls, root: tp.Union[str, Path], **kwargs):
        return super().from_meta(root, load_speech_meta, **kwargs)

    def __getitem__(self, index):
        rng = torch.Generator()
        if self.shuffle:
            # We use index, plus extra randomness, either totally random if we don't know the epoch.
            # otherwise we make use of the epoch number and optional shuffle_seed.
            if self.current_epoch is None:
                rng.manual_seed(index + self.num_samples * random.randint(0, 2**24))
            else:
                rng.manual_seed(
                    index + self.num_samples * (self.current_epoch + self.shuffle_seed)
                )
        else:
            # We only use index
            rng.manual_seed(index)
        for retry in range(self.max_read_retry):
            file_meta = self.sample_file(index, rng)
            try:
                if file_meta.info_path is not None:
                    # use h5py
                    if self.sample_on_weight_for_utter is False:
                        sample_data_idx = file_meta.use_idxes[
                            int(
                                torch.randint(
                                    len(file_meta.use_idxes), (1,), generator=rng
                                ).item()
                            )
                        ]
                    else:
                        # get durations
                        probabilities = []
                        with h5py.File(file_meta.info_path.zip_path, 'r') as h5:
                            for idx in file_meta.use_idxes:
                                if "weight" in h5[file_meta.info_path.file_path][str(idx)]:
                                    weight = h5[file_meta.info_path.file_path][str(idx)]["weight"][()]
                                    probabilities.append(weight)
                                else:
                                    raise ValueError(
                                        f"weight is None: {file_meta.info_path.zip_path}/{idx}"
                                    )
                        probabilities = torch.tensor(probabilities)
                        probabilities /= probabilities.sum()
                        file_index = int(
                            torch.multinomial(probabilities, 1, generator=rng).item()
                        )
                        sample_data_idx = int(file_meta.use_idxes[file_index])
                    speech_data = {"text": None, "text_history": None}
                    with h5py.File(file_meta.info_path.zip_path, 'r') as h5:
                        for idx in h5[file_meta.info_path.file_path].keys():
                            idx = int(idx)
                            if idx in list(
                                range(
                                    sample_data_idx - self.text_history_length,
                                    sample_data_idx,
                                )
                            ):
                                if speech_data["text_history"] is None:
                                    speech_data["text_history"] = ""
                                speech_data["text_history"] += (
                                    h5[file_meta.info_path.file_path][str(idx)]["transcript"][()].decode()
                                    + " "
                                )
                            if idx == sample_data_idx:
                                if "phoneme" in h5[file_meta.info_path.file_path][str(idx)]:
                                    speech_data["text"] = h5[file_meta.info_path.file_path][str(idx)][
                                        "phoneme"
                                    ][()].decode()
                                else:
                                    speech_data["text"] = h5[file_meta.info_path.file_path][str(idx)][
                                        "transcript"
                                    ][()].decode()
                                start_time = h5[file_meta.info_path.file_path][str(idx)]["start"][()]
                                end_time = h5[file_meta.info_path.file_path][str(idx)]["end"][()]
                                if self.use_current_text_for_history is True:
                                    if speech_data["text_history"] is None:
                                        speech_data["text_history"] = ""
                                    speech_data["text_history"] += h5[file_meta.info_path.file_path][str(idx)][
                                        "transcript"
                                    ][()].decode()
                else:
                    speech_info_path = Path(file_meta.path).with_suffix(".json")
                    if self.sample_on_weight_for_utter is False:
                        sample_data_idx = file_meta.use_idxes[
                            int(
                                torch.randint(
                                    len(file_meta.use_idxes), (1,), generator=rng
                                ).item()
                            )
                        ]
                    else:
                        # get durations
                        probabilities = []
                        with open(speech_info_path, "r") as f:
                            for idx, d in enumerate(f.readlines()):
                                if idx not in file_meta.use_idxes:
                                    continue
                                _data = json.loads(d)
                                if "weight" in _data:
                                    weight = _data["weight"]
                                    if weight is not None:
                                        probabilities.append(float(weight))
                                    else:
                                        raise ValueError(
                                            f"weight is None: {file_meta.path}"
                                        )
                                else:
                                    raise ValueError(
                                        f"weight is not found: {file_meta.path}"
                                    )
                        probabilities = torch.tensor(probabilities)
                        probabilities /= probabilities.sum()
                        file_index = int(
                            torch.multinomial(probabilities, 1, generator=rng).item()
                        )
                        sample_data_idx = file_meta.use_idxes[file_index]
                    speech_data = {"text": None, "text_history": None}
                    with open(speech_info_path, "r") as f:
                        for idx, d in enumerate(f.readlines()):
                            if idx in list(
                                range(
                                    sample_data_idx - self.text_history_length,
                                    sample_data_idx,
                                )
                            ):
                                if speech_data["text_history"] is None:
                                    speech_data["text_history"] = ""
                                speech_data["text_history"] += (
                                    json.loads(d)["transcript"] + " "
                                )
                            if idx == sample_data_idx:
                                _data = json.loads(d)
                                if "phoneme" in _data:
                                    speech_data["text"] = _data["phoneme"]
                                else:
                                    speech_data["text"] = _data["transcript"]
                                start_time = _data["start"]
                                end_time = _data["end"]
                                if self.use_current_text_for_history is True:
                                    if speech_data["text_history"] is None:
                                        speech_data["text_history"] = ""
                                    speech_data["text_history"] += _data["transcript"]
                # audio read
                if self.max_audio_duration_for_short is not None:
                    duration = end_time - start_time + self.buffer_length
                    max_seek = max(0, duration - self.max_audio_duration_for_short)
                    seek_time = torch.rand(1, generator=rng).item() * max_seek
                    wav, sr = audio_read(
                        file_meta.path,
                        start_time + seek_time,
                        min(self.max_audio_duration_for_short, duration),
                        pad=False,
                    )
                else:
                    wav, sr = audio_read(
                        file_meta.path,
                        start_time,
                        end_time - start_time + self.buffer_length,
                        pad=False,
                    )
                wav = convert_audio(wav, sr, self.sample_rate, self.channels)
                n_frames = wav.shape[-1]
                info = SegmentInfo(
                    file_meta,
                    start_time,
                    n_frames=n_frames,
                    total_frames=n_frames,
                    sample_rate=self.sample_rate,
                    channels=wav.shape[0],
                )

                info_data = info.to_dict()
                speech_data.update(info_data)
                speech_info = SpeechInfo.from_dict(
                    speech_data, fields_required=self.info_fields_required
                )

                speech_info.self_wav = WavCondition(
                    wav=wav[None],
                    length=torch.tensor([info.n_frames]),
                    sample_rate=[info.sample_rate],
                    path=[info.meta.path],
                    seek_time=[info.seek_time],
                )
            except Exception as exc:
                logger.warning("Error opening file %s: %r", file_meta.path, exc)
                if retry == self.max_read_retry - 1:
                    raise
            else:
                break
        return wav, speech_info

    def collater(self, samples):
        """The collater function has to be provided to the dataloader
        if AudioDataset has return_info=True in order to properly collate
        the samples of a batch.
        """
        assert self.return_info is True
        max_len = max([wav.shape[-1] for wav, _ in samples])

        def _pad_wav(wav):
            return F.pad(wav, (0, max_len - wav.shape[-1]))

        if len(samples) > 0:
            assert len(samples[0]) == 2
            assert isinstance(samples[0][0], torch.Tensor)
            assert isinstance(samples[0][1], SegmentInfo)

        wavs = [wav for wav, _ in samples]
        segment_infos = [copy.deepcopy(info) for _, info in samples]

        # Each wav could be of a different duration as they are not segmented.
        for i in range(len(samples)):
            # Determines the total length of the signal with padding, so we update here as we pad.
            segment_infos[i].total_frames = max_len
            wavs[i] = _pad_wav(wavs[i])
            segment_infos[i].self_wav = WavCondition(
                wav=wavs[i][None],
                length=torch.tensor([segment_infos[i].n_frames]),
                sample_rate=[segment_infos[i].sample_rate],
                path=[segment_infos[i].meta.path],
                seek_time=[segment_infos[i].seek_time],
            )
        wav = torch.stack(wavs)
        return wav, segment_infos


def _get_speech_meta(
    file_path: str,
    minimal: bool = True,
    min_utt_sec: int = 2.0,
    max_utt_sec: int = 20.0,
    min_text_len: int = 30,
    max_text_len: int = 9999,
    text_history_length: int = 0,
    g2p: Tokenizer = None,
    save_g2p: bool = False,
    pyannote_task: str = None,
    pyannote_path: str = None,
    g2p_override: bool = False,
) -> SpeechMeta:
    """SpeechMeta from a path to an audio file.

    Args:
        file_path (str): Resolved path of valid audio file.
        minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
    Returns:
        SpeechMeta: Audio file path and its metadata.
    """
    info = audio_info(file_path)
    amplitude: tp.Optional[float] = None
    if not minimal:
        wav, sr = audio_read(file_path)
        amplitude = wav.abs().max().item()

    speech_info_path = Path(file_path).with_suffix(".json")
    if not speech_info_path.exists():
        raise FileNotFoundError(f"Speech info file not found: {speech_info_path}")

    use_idxes = []
    text_lenghes_in_text_history = []
    new_output = []

    pyannote_data = []
    if pyannote_task == "overlap":
        _pyannote_path = Path(pyannote_path) / (Path(file_path).stem + ".jsonl")
        if _pyannote_path.exists():
            with open(_pyannote_path, "r") as f:
                lines = f.readlines()
            for idx, line in enumerate(lines):
                line = line.replace("'", '"')
                data = json.loads(line)
                start = float(data["start"])
                end = float(data["end"])
                pyannote_data.append((start, end))
        else:
            logger.error(
                f"Pyannote file not found: {_pyannote_path}. Skip checking of overlap."
            )

    with open(speech_info_path, "r") as f:
        _seg_id = 0
        for idx, d in enumerate(f.readlines()):
            data = json.loads(d)
            flg = False

            # Check by pyannote
            if pyannote_task == "overlap" and _seg_id < len(pyannote_data):
                for start, end in pyannote_data[_seg_id:]:
                    if (data["start"] <= start and start <= data["end"]) or (
                        data["start"] <= end and end <= data["end"]
                    ):
                        flg = True
                        break
                    elif data["end"] < start:
                        break
                    _seg_id += 1
            # Check text history length
            if idx < text_history_length:
                flg = True

            # Check duration
            duration = data["duration"]
            if duration < min_utt_sec or duration > max_utt_sec:
                flg = True

            # Check text length
            if "phoneme" in data and not g2p_override:
                text = data["phoneme"]
            else:
                text = data["transcript"]
                if g2p:
                    try:
                        text, _ = g2p([text])
                        assert len(text) == 1
                        text = ",".join([str(p) for p in text[0].numpy()])
                    except Exception as e:
                        logger.error(f"Error in g2p: {e}")
                        text = "0,0"  # dummy like: [sos, eos]
                        flg = True
            text_length = len(text.split(","))
            if text_length < min_text_len or text_length > max_text_len:
                flg = True
            text_lenghes_in_text_history.append(text_length)
            text_lenghes_in_text_history = text_lenghes_in_text_history[
                -text_history_length:
            ]
            if save_g2p:
                data["phoneme"] = text
                new_output.append(data)

            if flg:
                continue
            use_idxes.append(idx)

    if save_g2p:
        with open(speech_info_path, "w") as f:
            for data in new_output:
                f.write(json.dumps(data) + "\n")

    if len(use_idxes) == 0:
        raise ValueError(f"No valid data found in {speech_info_path}")

    return SpeechMeta(
        file_path, info.duration, info.sample_rate, amplitude, use_idxes=use_idxes
    )


def main():
    logging.basicConfig(stream=sys.stderr, level=logging.INFO)
    parser = argparse.ArgumentParser(
        prog="audio_dataset", description="Generate .jsonl files by scanning a folder."
    )
    parser.add_argument("root", help="Root folder with all the audio files")
    parser.add_argument("output_meta_file", help="Output file to store the metadata, ")
    parser.add_argument(
        "--complete",
        action="store_false",
        dest="minimal",
        default=True,
        help="Retrieve all metadata, even the one that are expansive "
        "to compute (e.g. normalization).",
    )
    parser.add_argument(
        "--resolve",
        action="store_true",
        default=False,
        help="Resolve the paths to be absolute and with no symlinks.",
    )
    parser.add_argument("--workers", default=10, type=int, help="Number of workers.")
    parser.add_argument(
        "--min_utt_sec", default=5.0, type=float, help="Minimum duration of utterance."
    )
    parser.add_argument(
        "--max_utt_sec", default=20.0, type=float, help="Maximum duration of utterance."
    )
    parser.add_argument(
        "--g2p_tokenizer", default="alphabet", type=str, help="G2P tokenizer."
    )
    parser.add_argument(
        "--min_text_len", default=50, type=int, help="Minimum length of text."
    )
    parser.add_argument(
        "--max_text_len", default=200, type=int, help="Maximum length of text."
    )
    parser.add_argument(
        "--text_history_length", default=10, type=int, help="Length of text history."
    )
    parser.add_argument(
        "--save_g2p", action="store_true", default=False, help="Save g2p tokens."
    )
    parser.add_argument(
        "--pyannote_task", default=None, help="overlap or powerset", choices=["overlap"]
    )
    parser.add_argument(
        "--pyannote_path", default=None, help="Path to pyannote model", type=str
    )
    parser.add_argument(
        "--tokenizer_file",
        default=None,
        help="Path to tokenizer file of BPETokenizer.",
        type=str,
    )
    parser.add_argument(
        "--g2p_override", default=False, action="store_true", help="Override phoneme"
    )
    args = parser.parse_args()

    if args.g2p_tokenizer == "alphabet":
        g2p = AlphabetTokenizer(use_g2p=True)
    elif args.g2p_tokenizer == "bpe":
        g2p = (
            BPETokenizer(use_g2p=True)
            if args.tokenizer_file is None
            else BPETokenizer(use_g2p=True, tokenizer_file=args.tokenizer_file)
        )
    else:
        raise ValueError(f"Unknown g2p tokenizer: {args.g2p_tokenizer}")
    # print args
    logger.info("Args: %s", args)
    # Pyannote pipeline
    meta = find_audio_files(
        args.root,
        DEFAULT_EXTS,
        progress=True,
        resolve=args.resolve,
        minimal=args.minimal,
        workers=args.workers,
        process_fn=partial(
            _get_speech_meta,
            min_utt_sec=args.min_utt_sec,
            max_utt_sec=args.max_utt_sec,
            min_text_len=args.min_text_len,
            max_text_len=args.max_text_len,
            text_history_length=args.text_history_length,
            g2p=g2p,
            save_g2p=args.save_g2p,
            pyannote_task=args.pyannote_task,
            pyannote_path=args.pyannote_path,
            g2p_override=args.g2p_override,
        ),
    )
    save_audio_meta(args.output_meta_file, meta)


if __name__ == "__main__":
    main()
