from typing import Callable, Dict, List, Union

import torch
from torch.utils.data.dataloader import DataLoader, default_collate

from lhotse import validate
from lhotse.cut import CutSet
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
from lhotse.utils import compute_num_frames, ifnone
from lhotse.workarounds import Hdf5MemoryIssueFix
from re import sub


class AudioCaptionDataset(torch.utils.data.Dataset):
    """
    .. code-block::

        {
            'inputs': float tensor with shape determined by :attr:`input_strategy`:
                      - single-channel:
                        - features: (B, T, F)
                        - audio: (B, T)
                      - multi-channel: currently not supported
            'supervisions': [
                {
                    # For audio event, which can be mapped to a multi-hot tensor
                    'audio_event': string separated by semicolon

                    # For feature input strategies
                    'start_frame': Tensor[int] of shape (S,)
                    'num_frames': Tensor[int] of shape (S,)

                    # For audio input strategies
                    'start_sample': Tensor[int] of shape (S,)
                    'num_samples': Tensor[int] of shape (S,)

                    # Optionally, when return_cuts=True
                    'cut': List[AnyCut] of len S
                }
            ]
        }
    """

    global_idx = 0
    def __init__(
        self,
        return_cuts: bool = False,
        cut_transforms: List[Callable[[CutSet], CutSet]] = None,
        input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None,
        input_strategy: BatchIO = PrecomputedFeatures(),
    ):
        """
        Audio tagging IterableDataset constructor.

        :param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut
            objects used to create that batch.
        :param cut_transforms: A list of transforms to be applied on each sampled batch,
            before converting cuts to an input representation (audio/features).
            Examples: cut concatenation, noise cuts mixing, etc.
        :param input_transforms: A list of transforms to be applied on each sampled batch,
            after the cuts are converted to audio/features.
            Examples: normalization, SpecAugment, etc.
        :param input_strategy: Converts cuts into a collated batch of audio/features.
            By default, reads pre-computed features from disk.
        """
        super().__init__()
        # Initialize the fields
        self.return_cuts = return_cuts
        self.cut_transforms = ifnone(cut_transforms, [])
        self.input_transforms = ifnone(input_transforms, [])
        self.input_strategy = input_strategy

        # This attribute is a workaround to constantly growing HDF5 memory
        # throughout the epoch. It regularly closes open file handles to
        # reset the internal HDF5 caches.
        self.hdf5_fix = Hdf5MemoryIssueFix(reset_interval=100)


    def _text_preprocess(self, sentence):

        # transform to lower case
        sentence = sentence.lower()

        # remove any forgotten space before punctuation and double space
        sentence = sub(r'\s([,.!?;:"](?:\s|$))', r'\1', sentence).replace('  ', ' ')

        # remove punctuations
        # sentence = sub('[,.!?;:\"]', ' ', sentence).replace('  ', ' ')
        sentence = sub('[(,.!?;:|*\")]', ' ', sentence).replace('  ', ' ')
        return sentence

    # def _get_next_idx(self):
    #     idx = AudioCaptionDataset.global_idx
    #     AudioCaptionDataset.global_idx += 1
    #     return idx

    def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]:
        """
        Return a new batch, with the batch size automatically determined using the constraints
        of max_duration and max_cuts.

           {'id': '1-100032-A-0', 
            'start': 0.0, 
            'duration': 5.0, 
            'channel': 0, 
            'supervisions': [{'id': '1-100032-A-0', 'recording_id': '1-100032-A-0', 'start': 0.0, 'duration': 5.0, 'channel': 0, 'custom': {'fold': '1', 'audio_event': '0', 'category': 'dog', 'esc10': 'True', 'take': 'A'}}], 
            'recording': {'id': '1-100032-A-0', 'sources': [{'type': 'file', 'channels': [0], 'source': '/apdcephfs_cq12/share_302080740/data/audio_train_data/raw/esc-50/ESC-50-master/audio/1-100032-A-0.wav'}], 'sampling_rate': 44100, 'num_samples': 220500, 'duration': 5.0, 'channel_ids': [0]}, 
            'type': 'MonoCut'}

        """
        self.hdf5_fix.update()

        # Sort the cuts by duration so that the first one determines the batch time dimensions.
        cuts = cuts.sort_by_duration(ascending=False)

        # for cut in cuts:
        #     print(cut.recording.sampling_rate)

        # Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts
        # the supervision boundaries.
        for tnfm in self.cut_transforms:
            cuts = tnfm(cuts)

        # Sort the cuts again after transforms
        cuts = cuts.sort_by_duration(ascending=False)

        # Get a tensor with batched feature matrices, shape (B, T, F)
        # Collation performs auto-padding, if necessary.
        input_tpl = self.input_strategy(cuts)
        if len(input_tpl) == 3:
            # An input strategy with fault tolerant audio reading mode.
            # "cuts" may be a subset of the original "cuts" variable,
            # that only has cuts for which we succesfully read the audio.
            inputs, _, cuts = input_tpl
        else:
            inputs, _ = input_tpl

        # Get a dict of tensors that encode the positional information about supervisions
        # in the batch of feature matrices. The tensors are named "sequence_idx",
        # "start_frame/sample" and "num_frames/samples".
        supervision_intervals = self.input_strategy.supervision_intervals(cuts)

        # Apply all available transforms on the inputs, i.e. either audio or features.
        # This could be feature extraction, global MVN, SpecAugment, etc.
        segments = torch.stack(list(supervision_intervals.values()), dim=1)
        for tnfm in self.input_transforms:
            inputs = tnfm(inputs, supervision_segments=segments)

        batch = {
            "inputs": inputs,
            "supervisions": {
                "audio_caption": [
                    self._text_preprocess(supervision.caption) if isinstance(supervision.caption, str) else [self._text_preprocess(caption) for caption in supervision.caption]
                    for sequence_idx, cut in enumerate(cuts)
                    for supervision in cut.supervisions
                ],   
            }    
        }
        # Update the 'supervisions' field with sequence_idx and start/num frames/samples
        batch["supervisions"].update(supervision_intervals)
        if self.return_cuts:
            batch["supervisions"]["cut"] = [
                cut for cut in cuts for sup in cut.supervisions
            ]

        return batch
