import argparse
import inspect
import logging
from pathlib import Path
from typing import Any, Dict, Optional
from functools import lru_cache
import os

import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import (  # noqa F401 for PrecomputedFeatures
    CutConcatenate,
    CutMix,
    DynamicBucketingSampler,
    WeightedSimpleCutSampler,
    PrecomputedFeatures,
    SimpleCutSampler,
    SpecAugment,
    PerturbSpeed,
)
from .dataset.sound_event_detection_dataset import SoundEventDetectionDataset
from lhotse.dataset.input_strategies import (  # noqa F401 For AudioSamples
    OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
import torch.distributed as dist

class _SeedWorkers:
    def __init__(self, seed: int):
        self.seed = seed

    def __call__(self, worker_id: int):
        fix_random_seed(self.seed + worker_id)


class SoundEventDetectionDatamodule:
    def __init__(self, cfg):
        self.cfg = cfg
        self.train_datasets = self._get_train_cuts()
        self.valid_datasets = self._get_valid_cuts()
        
        self.train_dl = self._get_train_dl(self.train_datasets)
        self.valid_dl = [self._get_valid_dl(valid_set) for valid_set in self.valid_datasets] # could be several valid sets
    
    def _get_train_cuts(self) -> CutSet:
        import yaml
        with open(self.cfg.data.train_data_config, 'r') as file:
            train_data_config = yaml.load(file, Loader=yaml.FullLoader)
            
        cutset_list = []
        cutset_hours = []
        for train_set in train_data_config:
            logging.info(f"Getting {train_set['manifest']} cuts")
            cutset = CutSet.from_file(train_set['manifest']).resample(16000)
            hours = train_set['hours']
            weight = train_set['weights']
            if self.cfg.data.use_infinite_dataset:
                cutset = cutset.repeat() # this will result in infinite iterator that will never end
            else:
                cutset = cutset.repeat(weight) # each of the cutset will be repeated infinitely to make sure the iterator will not stop by them
            cutset[0].load_audio() # just to make sure we can get access to this cutset audio
            cutset_hours.append(weight * hours)
            cutset_list.append(cutset)
        
        logging.info(f"Getting totally {sum(cutset_hours)} hours of training data from {len(cutset_hours)} manifests")

        if len(cutset_list) > 1: # more than 1 dataset
            if getattr(self.cfg.data, "mix_dataset", True): # mix all datasets
                logging.info("Mixing cuts")
                cutset_train =  CutSet.mux(
                    *cutset_list,
                    weights=cutset_hours,
                    stop_early=True, # the epoch will stop when one of the dataset is exhausted. If using default stop_early=False, the last batches will get extremly unbalanced in the end
                )
            else: # treat each dataset separately in lhotse dataloader
                logging.info("Using multiple datasets separately")
                cutset_train = cutset_list
        else:
            cutset_train = cutset_list[0]

        def remove_short_utt(c):
            # remove buggy short utterance
            if c.duration < 1.0:
                return False
            return True
        
        if isinstance(cutset_train, CutSet):
            cutset_train = cutset_train.filter(remove_short_utt)
        elif isinstance(cutset_train, tuple) or isinstance(cutset_train, list):
            cutset_train = [cutset.filter(remove_short_utt) for cutset in cutset_train]
        else:
            raise ValueError(f"Unknown type of cutset_train: {type(cutset_train)}")
        
        return cutset_train
    
    def _get_valid_cuts(self):
        cutset_list = []
        for dataset in self.cfg.data.valid_sets:
            logging.info(f"Getting {dataset} cuts for validation")
            cuts = load_manifest_lazy(dataset)
            cuts = cuts.resample(16000)
            if self.cfg.data.valid_data_path_prefix:
                data_dir = os.path.dirname(dataset)
                cuts = cuts.with_features_path_prefix(data_dir)
                
            def remove_short_utt(c):
                # remove buggy short utterance
                if c.duration < 1.0:
                    return False
                return True
            cuts = cuts.filter(remove_short_utt)
            
            cutset_list.append(cuts)
        return cutset_list

    def _get_train_dl(
        self,
        cuts_train: CutSet,
        sampler_state_dict: Optional[Dict[str, Any]] = None,
    ) -> DataLoader:
        """
        Args:
          cuts_train:
            CutSet for training.
          sampler_state_dict:
            The state dict for the training sampler.
        """
        transforms = []
        if self.cfg.data.enable_musan:
            logging.info("Enable MUSAN")
            logging.info("About to get Musan cuts")
            cuts_musan = load_manifest(self.cfg.data.musan)
            transforms.append(
                CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
            )
        else:
            logging.info("Disable MUSAN")

        input_transforms = []
        if self.cfg.data.enable_spec_aug:
            logging.info("Enable SpecAugment")
            logging.info(f"Time warp factor: 80")
            # Set the value of num_frame_masks according to Lhotse's version.
            # In different Lhotse's versions, the default of num_frame_masks is
            # different.
            num_frame_masks = 10
            num_frame_masks_parameter = inspect.signature(
                SpecAugment.__init__
            ).parameters["num_frame_masks"]
            if num_frame_masks_parameter.default == 1:
                num_frame_masks = 2
            logging.info(f"Num frame mask: {num_frame_masks}")
            input_transforms.append(
                SpecAugment(
                    time_warp_factor=80,
                    num_frame_masks=num_frame_masks,
                    features_mask_size=27,
                    num_feature_masks=2,
                    frames_mask_size=100,
                )
            )
        else:
            logging.info("Disable SpecAugment")

        logging.info("About to create train dataset")
        if not self.cfg.data.on_the_fly_feats:
            train = SoundEventDetectionDataset(
                input_strategy=eval(self.cfg.data.input_strategy)(),
                cut_transforms=transforms,
                input_transforms=input_transforms,
                return_cuts=True,
            )

        else:
            # NOTE: the PerturbSpeed transform should be added only if we
            # remove it from data prep stage.
            # Add on-the-fly speed perturbation; since originally it would
            # have increased epoch size by 3, we will apply prob 2/3 and use
            # 3x more epochs.
            # Speed perturbation probably should come first before
            # concatenation, but in principle the transforms order doesn't have
            # to be strict (e.g. could be randomized)
            if self.cfg.data.enable_speed_perturb:
                transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms   # noqa
            # Drop feats to be on the safe side.
            train = SoundEventDetectionDataset(
                cut_transforms=transforms,
                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                input_transforms=input_transforms,
                return_cuts=True,
            )

        # Handle the case when we have multiple datasets
        if isinstance(cuts_train, (tuple, list)):
            cuts = *cuts_train,
        else:
            cuts = (cuts_train,)

        if self.cfg.data.bucketing_sampler:
            logging.info("Using DynamicBucketingSampler.")

            train_sampler = DynamicBucketingSampler(
                *cuts,
                max_duration=self.cfg.data.max_duration,
                shuffle=self.cfg.data.shuffle,
                num_buckets=self.cfg.data.num_buckets,
                buffer_size=self.cfg.data.num_buckets * 2000,
                shuffle_buffer_size=self.cfg.data.num_buckets * 5000,
                drop_last=self.cfg.data.drop_last,
                consistent_ids=False,  # do not use consistent ids for bucketing sampler
            )
        elif self.cfg.data.weighted_sampler:
            # assert self.args.audioset_subset == "full", "Only use weighted sampling for full audioset"
            logging.info("Using weighted SimpleCutSampler")
            weights = self.audioset_sampling_weights()
            train_sampler = WeightedSimpleCutSampler(
                *cuts,
                weights,
                num_samples=self.args.num_samples,
                max_duration=self.args.max_duration,
                shuffle=False,  # do not support shuffle
                drop_last=self.args.drop_last,
            )
        else:
            logging.info("Using SimpleCutSampler.")
            train_sampler = SimpleCutSampler(
                *cuts,
                max_duration=self.cfg.data.max_duration,
                shuffle=self.cfg.data.shuffle,
            )
        logging.info("About to create train dataloader")
        
        if sampler_state_dict is not None:
            logging.info("Loading sampler state dict")
            train_sampler.load_state_dict(sampler_state_dict)

        # 'seed' is derived from the current random state, which will have
        # previously been set in the main process.
        seed = torch.randint(0, 100000, ()).item()
        worker_init_fn = _SeedWorkers(seed)

        train_dl = DataLoader(
            train,
            sampler=train_sampler,
            batch_size=None,
            num_workers=self.cfg.data.num_workers,
            persistent_workers=True,
            worker_init_fn=worker_init_fn,
        )

        return train_dl

    def _get_valid_dl(self, valid_dataset):
        transforms = []

        logging.info("About to create dev dataset")
        if self.cfg.data.on_the_fly_feats:
            validate = SoundEventDetectionDataset(
                cut_transforms=transforms,
                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                return_cuts=True,
            )
        else:
            validate = SoundEventDetectionDataset(
                cut_transforms=transforms,
                return_cuts=True,
            )

        valid_sampler = DynamicBucketingSampler(
                valid_dataset,
                max_duration=self.cfg.data.max_duration,
                shuffle=False,
            )

        logging.info("About to create dev dataloader")
        valid_dl = DataLoader(
            validate,
            sampler=valid_sampler,
            batch_size=None,
            num_workers=8,
            persistent_workers=False,
        )

        return valid_dl
    
    def audioset_sampling_weights(self):
        logging.info(
            f"About to get the sampling weight for {audioset_subset} in AudioSet"
        )
        weights = []
        with open(
            self.cfg.data.class_weights_file,
            "r",
        ) as f:
            while True:
                line = f.readline()
                if not line:
                    break
                weight = float(line.split()[1])
                weights.append(weight)
        logging.info(f"Get the sampling weight for {len(weights)} cuts")
        return weights