import yaml
from functools import partial
from .icefall_data_module import *
from .icefall_data_module import _SeedWorkers
from lhotse import WhisperFbank, WhisperFbankConfig

def _filter_cut(cut, language=None, clip=False):
    if len(cut.supervisions) > 0:
        cut.supervisions = [cut.supervisions[0]]
    if language is not None:
        cut.supervisions[0].language = language
    if clip:
        return 1.0 <= cut.duration <= 30.0
    return True

def _text_normalization_on_cut(c):
    text = c.supervisions[0].text
    text = text_normalization(text, case='lower', space_between_cjk=False, remove_diacritics=False)
    c.supervisions[0].text = text
    return c


class WhisperAsrDatamodule(IcefallAsrDatamodule):

    def _get_train_cuts(self) -> CutSet:
        with open(self.cfg.data.train_data_config, 'r') as file:
            train_data_config = yaml.load(file, Loader=yaml.FullLoader)

        cutset_list = []
        cutset_hours = []
        for train_set in train_data_config:
            logging.info(f"Getting {train_set['manifest']} cuts")
            cutset = CutSet.from_file(train_set['manifest'])
            hours = train_set['hours']
            weight = train_set['weights']
            cutset = cutset.repeat(weight) # repeat k times
            cutset[0].load_audio() # just to make sure we can get access to this cutset audio
            cutset_hours.append(weight * hours)

            cutset = cutset.filter(partial(_filter_cut, language=train_set['language'], clip=True))
            cutset_list.append(cutset)

        logging.info(f"Getting totally {sum(cutset_hours)} hours of training data from {len(cutset_hours)} manifests")

        if len(cutset_list) > 1: # more than 1 dataset
            logging.info("Mixing cuts")
            cutset_train =  CutSet.mux(
                    *cutset_list,
                    weights=cutset_hours,
                    stop_early=True, # the epoch will stop when one of the dataset is exhausted. If using default stop_early=False, the last batches will get extremly unbalanced in the end
                )
        else:
            cutset_train = cutset_list[0]

        if self.cfg.data.text_normalization:
            cutset_train = cutset_train.map(_text_normalization_on_cut)

        return cutset_train

    def _get_valid_cuts(self):
        with open(self.cfg.data.valid_data_config, 'r') as file:
            valid_data_config = yaml.load(file, Loader=yaml.FullLoader)

        cutset_list = []
        for valid_set in valid_data_config:
            logging.info(f"Getting {valid_set['manifest']} cuts")
            cutset = CutSet.from_file(valid_set['manifest'])
            cutset[0].load_audio() # just to make sure we can get access to this cutset audio

            if self.cfg.data.text_normalization:
                cutset = cutset.map(_text_normalization_on_cut)

            cutset = cutset.filter(partial(_filter_cut, language=valid_set['language'], clip=False))
            cutset_list.append(cutset)

        length_filter = lambda c: not (c.duration < 1.0 or c.duration > 30.0)
        return [cutset.filter(length_filter) for cutset in cutset_list]

    def _get_train_dl(
        self,
        cuts_train: CutSet,
        sampler_state_dict: Optional[Dict[str, Any]] = None,
    ) -> DataLoader:
        """
        Args:
          cuts_train:
            CutSet for training.
          sampler_state_dict:
            The state dict for the training sampler.
        """
        transforms = []
        if getattr(self.cfg.data, "enable_musan", False):
            logging.info("Enable MUSAN")
            logging.info("About to get Musan cuts")
            cuts_musan = load_manifest(self.cfg.data.musan)
            transforms.append(
                CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
            )
        else:
            logging.info("Disable MUSAN")

        # if self.cfg.data.concatenate_cuts:
        #     logging.info(
        #         f"Using cut concatenation with duration factor "
        #         f"{1.0} and gap {1.0}."
        #     )
        #     # Cut concatenation should be the first transform in the list,
        #     # so that if we e.g. mix noise in, it will fill the gaps between
        #     # different utterances.
        #     transforms = [
        #         CutConcatenate(
        #             duration_factor=1.0, gap=1.0
        #         )
        #     ] + transforms

        input_transforms = []
        if self.cfg.data.enable_spec_aug:
            logging.info("Enable SpecAugment")
            logging.info(f"Time warp factor: 80")
            # Set the value of num_frame_masks according to Lhotse's version.
            # In different Lhotse's versions, the default of num_frame_masks is
            # different.
            num_frame_masks = 10
            num_frame_masks_parameter = inspect.signature(
                SpecAugment.__init__
            ).parameters["num_frame_masks"]
            if num_frame_masks_parameter.default == 1:
                num_frame_masks = 2
            logging.info(f"Num frame mask: {num_frame_masks}")
            input_transforms.append(
                SpecAugment(
                    time_warp_factor=80,
                    num_frame_masks=num_frame_masks,
                    features_mask_size=27,
                    num_feature_masks=2,
                    frames_mask_size=100,
                )
            )
        else:
            logging.info("Disable SpecAugment")

        logging.info("About to create train dataset")
        if not self.cfg.data.on_the_fly_feats:
            train = K2SpeechRecognitionDataset(
                input_strategy=eval(self.cfg.data.input_strategy)(),
                cut_transforms=transforms,
                input_transforms=input_transforms,
                return_cuts=True,
            )

        else:
            # NOTE: the PerturbSpeed transform should be added only if we
            # remove it from data prep stage.
            # Add on-the-fly speed perturbation; since originally it would
            # have increased epoch size by 3, we will apply prob 2/3 and use
            # 3x more epochs.
            # Speed perturbation probably should come first before
            # concatenation, but in principle the transforms order doesn't have
            # to be strict (e.g. could be randomized)
            if getattr(self.cfg.data, "enable_speed_perturb", False):
                transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms   # noqa
            # Drop feats to be on the safe side.
            train = K2SpeechRecognitionDataset(
                cut_transforms=transforms,
                input_strategy=OnTheFlyFeatures(WhisperFbank(
                    WhisperFbankConfig(num_filters=self.cfg.data.num_filters, device="cpu"))),
                input_transforms=input_transforms,
                return_cuts=True,
            )

        if self.cfg.data.bucketing_sampler:
            logging.info("Using DynamicBucketingSampler.")
            train_sampler = DynamicBucketingSampler(
                cuts_train,
                max_duration=self.cfg.data.max_duration,
                shuffle=self.cfg.data.shuffle,
                num_buckets=self.cfg.data.num_buckets,
                buffer_size=self.cfg.data.num_buckets * 2000,
                shuffle_buffer_size=self.cfg.data.num_buckets * 5000,
                drop_last=self.cfg.data.drop_last,
            )
        else:
            logging.info("Using SimpleCutSampler.")
            train_sampler = SimpleCutSampler(
                cuts_train,
                max_duration=self.cfg.data.max_duration,
                shuffle=self.cfg.data.shuffle,
            )
        logging.info("About to create train dataloader")

        if sampler_state_dict is not None:
            logging.info("Loading sampler state dict")
            train_sampler.load_state_dict(sampler_state_dict)

        # 'seed' is derived from the current random state, which will have
        # previously been set in the main process.
        seed = torch.randint(0, 100000, ()).item()
        worker_init_fn = _SeedWorkers(seed)

        train_dl = DataLoader(
            train,
            sampler=train_sampler,
            batch_size=None,
            num_workers=self.cfg.data.num_workers,
            persistent_workers=True,
            worker_init_fn=worker_init_fn,
        )

        return train_dl


    def _get_valid_dl(self, valid_dataset):
        transforms = []

        logging.info("About to create dev dataset")
        if self.cfg.data.on_the_fly_feats:
            validate = K2SpeechRecognitionDataset(
                cut_transforms=transforms,
                input_strategy=OnTheFlyFeatures(WhisperFbank(
                    WhisperFbankConfig(num_filters=self.cfg.data.num_filters, device="cpu"))),
                return_cuts=True,
            )
        else:
            validate = K2SpeechRecognitionDataset(
                cut_transforms=transforms,
                return_cuts=True,
            )

        valid_sampler = DynamicBucketingSampler(
            valid_dataset,
            max_duration=self.cfg.data.max_duration,
            shuffle=False,
        )
        logging.info("About to create dev dataloader")
        valid_dl = DataLoader(
            validate,
            sampler=valid_sampler,
            batch_size=None,
            num_workers=8,
            persistent_workers=False,
        )

        return valid_dl
