import torch
import numpy as np
from typing import List, Dict, Union
from lhotse import CutSet
from lhotse.dataset import K2SpeechRecognitionDataset
from torch.utils.data.dataloader import default_collate


class Speech2TextDataset(K2SpeechRecognitionDataset):
    """ Support general speech2text tasks. (ASR and AST considered now)

        Expect manifest with:
        - supervision.text = '你好，世界！" # speech transcription
        - supervision.language = 'zh' # speech language
        - supervision.translation: List of 
            - text = 'Hello, world!' # speech translation
            - language = 'en' # speech translation language

        The resulted batch contains:
        - "task": 'transcribe' or 'translate'
        - "text": transcription text
        - "language": speech language
        - "text_translated": selected translation text
        - "language_translated": selected translation language
    """

    s2t_translate_ratio : float = 0.5

    def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]:
        # validate_for_asr(cuts)

        self.hdf5_fix.update()

        # Sort the cuts by duration so that the first one determines the batch time dimensions.
        cuts = cuts.sort_by_duration(ascending=False)

        # Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts
        # the supervision boundaries.
        for tnfm in self.cut_transforms:
            cuts = tnfm(cuts)

        # Sort the cuts again after transforms
        cuts = cuts.sort_by_duration(ascending=False)

        # Get a tensor with batched feature matrices, shape (B, T, F)
        # Collation performs auto-padding, if necessary.
        input_tpl = self.input_strategy(cuts)
        if len(input_tpl) == 3:
            # An input strategy with fault tolerant audio reading mode.
            # "cuts" may be a subset of the original "cuts" variable,
            # that only has cuts for which we succesfully read the audio.
            inputs, _, cuts = input_tpl
        else:
            inputs, _ = input_tpl

        # Get a dict of tensors that encode the positional information about supervisions
        # in the batch of feature matrices. The tensors are named "sequence_idx",
        # "start_frame/sample" and "num_frames/samples".
        supervision_intervals = self.input_strategy.supervision_intervals(cuts)

        # Apply all available transforms on the inputs, i.e. either audio or features.
        # This could be feature extraction, global MVN, SpecAugment, etc.
        segments = torch.stack(list(supervision_intervals.values()), dim=1)
        for tnfm in self.input_transforms:
            inputs = tnfm(inputs, supervision_segments=segments)

        infos = []
        for sequence_idx, cut in enumerate(cuts):
            for supervision in cut.supervisions:
                # dynamicly determine the task
                task = 'transcribe'
                if len(getattr(supervision, 'translation', [])) > 0:
                    if np.random.uniform(0, 1) < self.s2t_translate_ratio:
                        task = 'translate'

                text = supervision.text
                language = supervision.language
                # prepare batch according to task
                if task == 'transcribe':
                    text_translated = text
                    language_translated = language
                elif task == 'translate':
                    translation = np.random.choice(supervision.translation)
                    text_translated = translation['text']
                    language_translated = translation['language']

                infos.append({
                    "task": task,
                    "text": text,
                    "language": language,
                    "text_translated": text_translated,
                    "language_translated": language_translated,
                })

        batch = {
            "inputs": inputs,
            "supervisions": default_collate(infos),
        }
        # Update the 'supervisions' field with sequence_idx and start/num frames/samples
        batch["supervisions"].update(supervision_intervals)
        if self.return_cuts:
            batch["supervisions"]["cut"] = [
                cut for cut in cuts for sup in cut.supervisions
            ]

        return batch

