import yaml
import argparse
import inspect
import logging
from pathlib import Path
from typing import Any, Dict, Optional
from functools import partial, lru_cache
from collections import defaultdict

import torch
from lhotse import (
    CutSet, 
    Fbank, 
    FbankConfig, 
    load_manifest, 
    load_manifest_lazy, 
    set_audio_duration_mismatch_tolerance,
    WhisperFbank,
    WhisperFbankConfig
)
from lhotse.dataset import (  # noqa F401 for PrecomputedFeatures
    CutConcatenate,
    CutMix,
    DynamicBucketingSampler,
    K2SpeechRecognitionDataset,
    PrecomputedFeatures,
    SimpleCutSampler,
    SpecAugment,
    PerturbSpeed
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from ..utils.text_normalization import text_normalization
from .dataset.speech2text_dataset import Speech2TextDataset


def unified_language_code(c, lang):
    # for untagged cutset keep the builtin language
    if lang is not None:
        c.supervisions[0].language = lang
    return c

def unified_sampling_rate(c):
    if c.sampling_rate != 16000:
        c = c.resample(16000)
    return c

def unified_text_normalize(text):
    tn = partial(text_normalization,
        case='lower',
        space_between_cjk=False,
        remove_diacritics=False,
        remove_symbols=True
    )
    text = tn(text)
    return text

def text_normalization_on_cut(c):
    text = c.supervisions[0].text
    text = unified_text_normalize(text)
    c.supervisions[0].text = text
    # deal with translation too
    if getattr(c, 'translation', None) is not None:
        for translation in c.translation:
            translation['text'] = unified_text_normalize(translation['text'])
    return c

def remove_short_and_long_utt(c):
    # Keep only utterances with duration between 1 second and 30 seconds
    if c.duration < 1.0 or c.duration > 30.0:
        return False
    # some audio has multiple supervisions, just take the first one
    if len(c.supervisions) > 1:
        c.supervisions = [c.supervisions[0]] 
    # remove empty or abnormally long texts, should be done after text normalization
    text = c.supervisions[0].text
    if len(text) == 0 or len(text) > c.duration * 30:
        return False
    # deal with translation too
    if getattr(c, 'translation', None) is not None:
        for idx in range(len(c.translation)):
            text = c.translation[idx]['text']
            if len(text) == 0 or len(text) > c.duration * 30:
                del c.translation[idx]
    return True


class _SeedWorkers:
    def __init__(self, seed: int):
        self.seed = seed

    def __call__(self, worker_id: int):
        fix_random_seed(self.seed + worker_id)


class AsrDatamodule:
    def __init__(self, cfg):
        # NOTE: some data contains minor inconsistency
        set_audio_duration_mismatch_tolerance(0.1)

        self.cfg = cfg
        self.dataset_class = K2SpeechRecognitionDataset
        if self.cfg.data.get("enable_s2t_dataset", False):
            self.dataset_class = Speech2TextDataset
        self.train_datasets = self._get_train_cuts()
        self.valid_datasets = self._get_valid_cuts()
        
        self.train_dl = self._get_train_dl(self.train_datasets)
        self.valid_dl = [self._get_valid_dl(valid_set) for valid_set in self.valid_datasets]

    @staticmethod
    def _get_mux_cuts_from_yaml(yaml_file, use_infinite_dataset=False):
        with open(yaml_file, 'r') as file:
            train_data_config = yaml.load(file, Loader=yaml.FullLoader)

        cutset_list = []
        cutset_hours = []
        langs_hours = defaultdict(int)

        for train_set in train_data_config:
            logging.info(f"Getting {train_set['manifest']} cuts")
            cutset = CutSet.from_file(train_set['manifest'])
            hours = train_set['hours']
            weight = train_set.get('weights', 1)
            lang = train_set.get('lang', 'zh')
            if use_infinite_dataset:
                cutset = cutset.repeat() # this will result in infinite iterator that will never end
            else:
                cutset = cutset.repeat(weight) # each of the cutset will be repeated infinitely to make sure the iterator will not stop by them
            cutset[0].load_audio() # just to make sure we can get access to this cutset audio
            langs_hours[lang] += weight * hours
            cutset_hours.append(weight * hours)
            cutset = cutset.map(partial(unified_language_code, lang=lang))
            cutset = cutset.map(unified_sampling_rate)
            cutset_list.append(cutset)

        for lang in langs_hours:
            logging.info(f"Getting {langs_hours[lang]} hours of training data from {lang} language")
        logging.info(f"Getting totally {sum(cutset_hours)} hours of training data from {len(cutset_hours)} manifests")

        if len(cutset_list) > 1: # more than 1 dataset
            logging.info("Mixing cuts")
            cutset_train = CutSet.mux(
                    *cutset_list,
                    weights=cutset_hours,
                    stop_early=True, # the epoch will stop when one of the dataset iterator is exhausted. 
                )
        else:
            cutset_train = cutset_list[0]

        return cutset_train

    def _get_train_cuts(self) -> CutSet:
        cutset_train = self._get_mux_cuts_from_yaml(self.cfg.data.train_data_config, self.cfg.data.use_infinite_dataset)

        if self.cfg.data.text_normalization:
            cutset_train = cutset_train.map(text_normalization_on_cut)

        cutset_train = cutset_train.filter(remove_short_and_long_utt)

        return cutset_train

    def _get_valid_cuts(self):
        with open(self.cfg.data.valid_data_config, 'r') as file:
            valid_data_config = yaml.load(file, Loader=yaml.FullLoader)

        cutset_list = []
        for valid_set in valid_data_config:
            logging.info(f"Getting {valid_set['manifest']} cuts")
            cutset = CutSet.from_file(valid_set['manifest'])
            lang = valid_set.get('lang', 'zh')
            cutset = cutset.map(partial(unified_language_code, lang=lang))
            cutset = cutset.map(unified_sampling_rate)
            if self.cfg.data.text_normalization:
                cutset = cutset.map(text_normalization_on_cut)
            cutset = cutset.filter(remove_short_and_long_utt)
            cutset_list.append(cutset)

        return cutset_list

    def _get_train_dl(
        self,
        cuts_train: CutSet,
        sampler_state_dict: Optional[Dict[str, Any]] = None,
    ) -> DataLoader:
        """
        Args:
          cuts_train:
            CutSet for training.
          sampler_state_dict:
            The state dict for the training sampler.
        """
        transforms = []
        # if self.cfg.data.concatenate_cuts:
        #     logging.info(
        #         f"Using cut concatenation with duration factor "
        #         f"{1.0} and gap {1.0}."
        #     )
        #     # Cut concatenation should be the first transform in the list,
        #     # so that if we e.g. mix noise in, it will fill the gaps between
        #     # different utterances.
        #     transforms = [
        #         CutConcatenate(
        #             duration_factor=1.0, gap=1.0
        #         )
        #     ] + transforms

        input_transforms = []
        if self.cfg.data.enable_spec_aug:
            logging.info("Enable SpecAugment")
            logging.info(f"Time warp factor: 80")
            # Set the value of num_frame_masks according to Lhotse's version.
            # In different Lhotse's versions, the default of num_frame_masks is
            # different.
            num_frame_masks = 10
            num_frame_masks_parameter = inspect.signature(
                SpecAugment.__init__
            ).parameters["num_frame_masks"]
            if num_frame_masks_parameter.default == 1:
                num_frame_masks = 2
            logging.info(f"Num frame mask: {num_frame_masks}")
            input_transforms.append(
                SpecAugment(
                    time_warp_factor=80,
                    num_frame_masks=num_frame_masks,
                    features_mask_size=27,
                    num_feature_masks=2,
                    frames_mask_size=100,
                )
            )
        else:
            logging.info("Disable SpecAugment")

        logging.info("About to create train dataset")
        if not self.cfg.data.on_the_fly_feats:
            train = self.dataset_class(
                input_strategy=eval(self.cfg.data.input_strategy)(),
                cut_transforms=transforms,
                input_transforms=input_transforms,
                return_cuts=True,
            )

        else:
            # NOTE: the PerturbSpeed transform should be added only if we
            # remove it from data prep stage.
            # Add on-the-fly speed perturbation; since originally it would
            # have increased epoch size by 3, we will apply prob 2/3 and use
            # 3x more epochs.
            # Speed perturbation probably should come first before
            # concatenation, but in principle the transforms order doesn't have
            # to be strict (e.g. could be randomized)
            if self.cfg.data.enable_speed_perturb:
                transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms   # noqa
            # Drop feats to be on the safe side.
            if self.cfg.data.get("whisper_fbank", False):
                input_strategy=OnTheFlyFeatures(WhisperFbank(WhisperFbankConfig(num_filters=80)))
            else:
                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
            train = self.dataset_class(
                cut_transforms=transforms,
                input_strategy=input_strategy,
                input_transforms=input_transforms,
                return_cuts=True,
            )

        if self.dataset_class == Speech2TextDataset:
            train.s2t_translate_ratio = self.cfg.data.s2t_translate_ratio

        if self.cfg.data.bucketing_sampler:
            logging.info("Using DynamicBucketingSampler.")
            train_sampler = DynamicBucketingSampler(
                cuts_train,
                max_duration=self.cfg.data.max_duration,
                shuffle=self.cfg.data.shuffle,
                num_buckets=self.cfg.data.num_buckets,
                buffer_size=self.cfg.data.num_buckets * 2000,
                shuffle_buffer_size=self.cfg.data.num_buckets * 5000,
                drop_last=self.cfg.data.drop_last,
            )
        else:
            logging.info("Using SimpleCutSampler.")
            train_sampler = SimpleCutSampler(
                cuts_train,
                max_duration=self.cfg.data.max_duration,
                shuffle=self.cfg.data.shuffle,
            )
        logging.info("About to create train dataloader")

        if sampler_state_dict is not None:
            logging.info("Loading sampler state dict")
            train_sampler.load_state_dict(sampler_state_dict)

        # 'seed' is derived from the current random state, which will have
        # previously been set in the main process.
        seed = torch.randint(0, 100000, ()).item()
        worker_init_fn = _SeedWorkers(seed)

        train_dl = DataLoader(
            train,
            sampler=train_sampler,
            batch_size=None,
            num_workers=self.cfg.data.num_workers,
            persistent_workers=bool(self.cfg.data.num_workers > 0),
            worker_init_fn=worker_init_fn,
        )

        return train_dl

    def _get_valid_dl(self, valid_dataset):
        transforms = []

        logging.info("About to create dev dataset")
        if self.cfg.data.get("whisper_fbank", False):
            input_strategy=OnTheFlyFeatures(WhisperFbank(WhisperFbankConfig(num_filters=80)))
        else:
            input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
        validate = self.dataset_class(
            cut_transforms=transforms,
            input_strategy=input_strategy,
            return_cuts=True,
        )
        # enforce AST for validation
        if self.dataset_class == Speech2TextDataset:
            validate.s2t_translate_ratio = 1.0

        valid_sampler = DynamicBucketingSampler(
            valid_dataset,
            max_duration=self.cfg.data.max_duration,
            shuffle=False,
        )
        logging.info("About to create dev dataloader")
        valid_dl = DataLoader(
            validate,
            sampler=valid_sampler,
            batch_size=None,
            num_workers=8,
            persistent_workers=False,
        )

        return valid_dl
