import argparse
import inspect
import logging
from pathlib import Path
from typing import Any, Dict, Optional
from functools import lru_cache
import os

import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import (  # noqa F401 for PrecomputedFeatures
    CutConcatenate,
    CutMix,
    DynamicBucketingSampler,
    WeightedSimpleCutSampler,
    PrecomputedFeatures,
    SimpleCutSampler,
    SpecAugment,
    PerturbSpeed,
)
from lhotse.dataset.input_strategies import (  # noqa F401 For AudioSamples
    OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader


from .dataset.audio_tag_dataset import AudioTaggingDataset
from .audio_tag_data_module import AudioTagDatamodule



class _SeedWorkers:
    def __init__(self, seed: int):
        self.seed = seed

    def __call__(self, worker_id: int):
        fix_random_seed(self.seed + worker_id)



from torch.utils.data import Sampler
from collections import defaultdict
import random
import math

class GE2ECutSampler(Sampler):
    def __init__(
        self,
        cuts: CutSet,
        speakers_per_batch: int,
        utts_per_speaker: int,
        drop_last: bool = True,
        shuffle: bool = True,
    ):
        self.speakers_per_batch = speakers_per_batch 
        self.utts_per_speaker = utts_per_speaker
        self.drop_last = drop_last
        self.shuffle = shuffle
        self.epoch = 0  

        self.speaker_to_cuts = defaultdict(list)
        for cut in cuts:
            spk = cut.supervisions[0].speaker
            self.speaker_to_cuts[spk].append(cut)

        self.speakers = list(self.speaker_to_cuts.keys())
    
    def set_epoch(self, epoch: int):
        self.epoch = epoch

    def __iter__(self):
        speakers = self.speakers.copy()
        if self.shuffle:
            random.shuffle(speakers)

        for i in range(0, len(speakers), self.speakers_per_batch):
            batch_spks = speakers[i:i + self.speakers_per_batch]
            if len(batch_spks) < self.speakers_per_batch and self.drop_last:
                continue
            if any(len(self.speaker_to_cuts[spk]) < self.utts_per_speaker for spk in batch_spks):
                continue  # skip if any speaker lacks enough utts
            batch = []
            for spk in batch_spks:
                utts = random.sample(self.speaker_to_cuts[spk], self.utts_per_speaker)
                batch.extend([(spk, cut) for cut in utts])
            # 重新排序：按 speaker 排序（保留每个 speaker 的连续块）
            batch.sort(key=lambda x: x[0])
            cuts_sorted = [cut for _, cut in batch]

            yield CutSet.from_cuts(cuts_sorted)

    def __len__(self):
        return len(self.speakers) // self.speakers_per_batch if self.drop_last else math.ceil(len(self.speakers) / self.speakers_per_batch)

    def state_dict(self):
        """Return sampler state for checkpoint saving."""
        return {
            "epoch": self.epoch,
            "shuffle": self.shuffle,
        }

    def load_state_dict(self, state):
        """Restore sampler state from checkpoint."""
        self.epoch = state.get("epoch", 0)
        self.shuffle = state.get("shuffle", True)


# import math
# import random
# import logging
# from collections import defaultdict
# from typing import List

# from torch.utils.data import Sampler
# from lhotse import CutSet
# from lhotse.cut import Cut

# class GE2ECutSampler(Sampler):
#     """
#     A PyTorch Sampler specifically designed for GE2E loss training in speaker verification.
#     It yields batches containing a fixed number of speakers, each with a fixed number of utterances.
#     This version optimizes memory by only storing Cut IDs in __init__, and loading Cut objects
#     lazily during __iter__.
#     """
#     def __init__(
#         self,
#         cuts: CutSet, # This CutSet can be lazy (e.g., from .jsonl.gz)
#         speakers_per_batch: int,
#         utts_per_speaker: int,
#         drop_last: bool = True,
#         shuffle: bool = True,
#     ):
#         if not isinstance(cuts, CutSet):
#             raise TypeError(f"Expected cuts to be a Lhotse CutSet, but got {type(cuts)}")
#         if speakers_per_batch <= 0 or utts_per_speaker <= 0:
#             raise ValueError("speakers_per_batch and utts_per_speaker must be positive.")

#         self.speakers_per_batch = speakers_per_batch
#         self.utts_per_speaker = utts_per_speaker
#         self.drop_last = drop_last
#         self.shuffle = shuffle
#         self.epoch = 0

#         # Store the original CutSet for lazy lookup during iteration
#         self.original_cuts = cuts

#         # --- OPTIMIZED: Build speaker_to_cut_ids map instead of storing full Cut objects ---
#         self.speaker_to_cut_ids = defaultdict(list)
        
#         logging.info(f"GE2ECutSampler: Building speaker-to-cut-ID map ")
#         processed_count = 0
        
#         # Iterate over the CutSet once to build the map of speaker_id to cut_ids
#         # This will trigger lazy loading of Cut metadata as needed.
#         # It's important that speaker ID is directly accessible from Cut metadata without full audio load.
#         for cut in cuts:
#             spk = cut.supervisions[0].speaker
#             self.speaker_to_cut_ids[spk].append(cut.id) # Store cut.id instead of cut object
#             processed_count += 1
#             if processed_count % 100000 == 0: # Log progress for large datasets
#                 logging.info(f"GE2ECutSampler: Mapped {processed_count}")
        
#         logging.info(f"GE2ECutSampler: Finished mapping {processed_count} cuts. Found {len(self.speaker_to_cut_ids)} unique speakers.")

#         # Filter out speakers who don't have enough utterances
#         # This step helps prevent 'continue' in __iter__ for valid speaker batches
#         self.eligible_speakers = [
#             spk for spk in self.speaker_to_cut_ids
#             if len(self.speaker_to_cut_ids[spk]) >= self.utts_per_speaker
#         ]
#         logging.info(f"GE2ECutSampler: {len(self.eligible_speakers)} speakers are eligible (>= {self.utts_per_speaker} utts).")
        
#         # This will be the list of speakers from which we form batches
#         self.speakers_for_epoch = self.eligible_speakers.copy()


#     def set_epoch(self, epoch: int):
#         """Sets the epoch for the sampler, used for reproducible shuffling."""
#         self.epoch = epoch

#     def __iter__(self):
#         # Create a copy for shuffling for the current epoch
#         speakers_for_this_epoch = self.eligible_speakers.copy()
#         if self.shuffle:
#             # Use epoch to seed random for reproducibility
#             random.Random(self.epoch).shuffle(speakers_for_this_epoch)

#         num_batches = 0
#         for i in range(0, len(speakers_for_this_epoch), self.speakers_per_batch):
#             batch_spks = speakers_for_this_epoch[i:i + self.speakers_per_batch]

#             # Handle drop_last for the last incomplete batch of speakers
#             if len(batch_spks) < self.speakers_per_batch and self.drop_last:
#                 continue
            
#             # Since we pre-filtered eligible_speakers, we don't need to check here
#             # `if any(len(self.speaker_to_cuts[spk]) < self.utts_per_speaker for spk in batch_spks): continue`
#             # However, if self.utts_per_speaker changes dynamically, you might reconsider.

#             current_batch_cuts: List[Cut] = []
            
#             # --- Lazily load Cuts here for the current batch ---
#             for spk in batch_spks:
#                 # Select cut_ids first
#                 sampled_cut_ids = random.sample(self.speaker_to_cut_ids[spk], self.utts_per_speaker)
                
#                 # Retrieve Cut objects using find() from the original CutSet
#                 # This will trigger lazy loading of actual Cut data from manifest/disk
#                 retrieved_cuts = [self.original_cuts[cut_id] for cut_id in sampled_cut_ids]
#                 current_batch_cuts.extend(retrieved_cuts)
            
#             # Sort cuts within the batch to group by speaker ID
#             # This is typically required by GE2E Loss for efficient computation
#             current_batch_cuts.sort(key=lambda cut: cut.supervisions[0].speaker)

#             # Yield a CutSet for the current batch
#             yield CutSet.from_cuts(current_batch_cuts)
#             num_batches += 1

#         logging.info(f"GE2ECutSampler: Iteration completed for epoch {self.epoch}. Yielded {num_batches} batches.")


#     def __len__(self):
#         """Returns the total number of batches that will be yielded."""
#         num_speaker_groups = len(self.eligible_speakers) // self.speakers_per_batch
#         if not self.drop_last and len(self.eligible_speakers) % self.speakers_per_batch != 0:
#             num_speaker_groups += 1 # Add one for the last incomplete group

#         return num_speaker_groups

#     def state_dict(self):
#         """Return sampler state for checkpoint saving."""
#         return {
#             "epoch": self.epoch,
#             "shuffle": self.shuffle,
#             # For exact reproducibility of speaker order within an epoch,
#             # you might need to save the shuffled 'self.speakers_for_epoch' state.
#             # For now, just epoch and shuffle mode is enough.
#         }

#     def load_state_dict(self, state):
#         """Restore sampler state from checkpoint."""
#         self.epoch = state.get("epoch", 0)
#         self.shuffle = state.get("shuffle", True)
#         # Re-shuffle 'self.speakers_for_epoch' based on restored epoch if shuffle is True
#         if self.shuffle:
#             random.Random(self.epoch).shuffle(self.speakers_for_epoch)





class SpeakerVerificationDatamodule(AudioTagDatamodule):
    def __init__(self, cfg):
        super().__init__(cfg)
    
    # _get_train_cuts and _get_valid_cuts are the same as parent class

    def _get_train_dl(
        self,
        cuts_train: CutSet,
        sampler_state_dict: Optional[Dict[str, Any]] = None,
    ) -> DataLoader:
        """
        Args:
          cuts_train:
            CutSet for training.
          sampler_state_dict:
            The state dict for the training sampler.
        """
        transforms = []
        if self.cfg.data.enable_musan:
            logging.info("Enable MUSAN")
            logging.info("About to get Musan cuts")
            cuts_musan = load_manifest(self.cfg.data.musan)
            transforms.append(
                CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
            )
        else:
            logging.info("Disable MUSAN")

        input_transforms = []
        if self.cfg.data.enable_spec_aug:
            logging.info("Enable SpecAugment")
            logging.info(f"Time warp factor: 80")
            # Set the value of num_frame_masks according to Lhotse's version.
            # In different Lhotse's versions, the default of num_frame_masks is
            # different.
            num_frame_masks = 10
            num_frame_masks_parameter = inspect.signature(
                SpecAugment.__init__
            ).parameters["num_frame_masks"]
            if num_frame_masks_parameter.default == 1:
                num_frame_masks = 2
            logging.info(f"Num frame mask: {num_frame_masks}")
            input_transforms.append(
                SpecAugment(
                    time_warp_factor=80,
                    num_frame_masks=num_frame_masks,
                    features_mask_size=27,
                    num_feature_masks=2,
                    frames_mask_size=100,
                )
            )
        else:
            logging.info("Disable SpecAugment")

        logging.info("About to create train dataset")
        if not self.cfg.data.on_the_fly_feats:
            train = AudioTaggingDataset(
                input_strategy=eval(self.cfg.data.input_strategy)(),
                cut_transforms=transforms,
                input_transforms=input_transforms,
                return_cuts=True,
                label_field=getattr(self.cfg.data, "label_field", "audio_tag"),
            )

        else:
            # NOTE: the PerturbSpeed transform should be added only if we
            # remove it from data prep stage.
            # Add on-the-fly speed perturbation; since originally it would
            # have increased epoch size by 3, we will apply prob 2/3 and use
            # 3x more epochs.
            # Speed perturbation probably should come first before
            # concatenation, but in principle the transforms order doesn't have
            # to be strict (e.g. could be randomized)
            if self.cfg.data.enable_speed_perturb:
                transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms   # noqa
            # Drop feats to be on the safe side.
            train = AudioTaggingDataset(
                cut_transforms=transforms,
                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                input_transforms=input_transforms,
                return_cuts=True,
                label_field=getattr(self.cfg.data, "label_field", "audio_tag"),
            )

        if self.cfg.data.ge2e_sampler:
            logging.info("Using GE2ESampler.")
            train_sampler = GE2ECutSampler(
                cuts_train,
                speakers_per_batch=self.cfg.data.speakers_per_batch,
                utts_per_speaker=self.cfg.data.utts_per_speaker,
                shuffle=self.cfg.data.shuffle,
                drop_last=self.cfg.data.drop_last,
            )
            
        else:
            raise ValueError(
                "GE2E Loss requires GE2ESampler (N speakers * M utterances per batch). "
                "Please set 'cfg.data.ge2e_sampler' to True."
            )

        logging.info("About to create train dataloader")

        if sampler_state_dict is not None:
            logging.info("Loading sampler state dict")
            train_sampler.load_state_dict(sampler_state_dict)

        # 'seed' is derived from the current random state, which will have
        # previously been set in the main process.
        seed = torch.randint(0, 100000, ()).item()
        worker_init_fn = _SeedWorkers(seed)

        train_dl = DataLoader(
            train,
            sampler=train_sampler,
            batch_size=None,
            num_workers=self.cfg.data.num_workers,
            persistent_workers=True,
            worker_init_fn=worker_init_fn,
        )

        return train_dl


    def _get_valid_dl(self, valid_dataset):
        transforms = []

        logging.info("About to create dev dataset")
        if self.cfg.data.on_the_fly_feats:
            validate = AudioTaggingDataset(
                cut_transforms=transforms,
                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                return_cuts=True,
                label_field=getattr(self.cfg.data, "label_field", "audio_tag"),
            )
        else:
            validate = AudioTaggingDataset(
                cut_transforms=transforms,
                return_cuts=True,
                label_field=getattr(self.cfg.data, "label_field", "audio_tag"),
            )
            
        valid_sampler = GE2ECutSampler(
            valid_dataset,
            speakers_per_batch=self.cfg.data.speakers_per_batch,
            utts_per_speaker=self.cfg.data.utts_per_speaker,
            shuffle=False,
            )

        logging.info("About to create dev dataloader")
        valid_dl = DataLoader(
            validate,
            sampler=valid_sampler,
            batch_size=None,
            num_workers=8,
            persistent_workers=False,
        )

        return valid_dl
    






# class SpkVeriDatamodule:
#     def __init__(self, cfg):
#         self.cfg = cfg
#         self.train_datasets = self._get_train_cuts()
#         self.valid_datasets = self._get_valid_cuts()
        
#         self.train_dl = self._get_train_dl(self.train_datasets)
#         self.valid_dl = [self._get_valid_dl(valid_set) for valid_set in self.valid_datasets] # could be several valid sets
       
#     def _get_train_cuts(self) -> CutSet:
#         import yaml
#         with open(self.cfg.data.train_data_config, 'r') as file:
#             train_data_config = yaml.load(file, Loader=yaml.FullLoader)
            
#         cutset_list = []
#         cutset_hours = []
#         for train_set in train_data_config:
#             logging.info(f"Getting {train_set['manifest']} cuts")
#             cutset = CutSet.from_file(train_set['manifest']).resample(16000)
#             hours = train_set['hours']
#             weight = train_set['weights']
#             if self.cfg.data.use_infinite_dataset:
#                 cutset = cutset.repeat() # this will result in infinite iterator that will never end
#             else:
#                 cutset = cutset.repeat(weight) # each of the cutset will be repeated infinitely to make sure the iterator will not stop by them
#             cutset[0].load_audio() # just to make sure we can get access to this cutset audio
#             cutset_hours.append(weight * hours)
#             cutset_list.append(cutset)
        
#         logging.info(f"Getting totally {sum(cutset_hours)} hours of training data from {len(cutset_hours)} manifests")
                
#         if len(cutset_list) > 1: # more than 1 dataset
#             logging.info("Mixing cuts")
#             cutset_train =  CutSet.mux(
#                     *cutset_list,
#                     weights=cutset_hours,
#                     stop_early=True, # the epoch will stop when one of the dataset is exhausted. If using default stop_early=False, the last batches will get extremly unbalanced in the end
#                 )
#         else:
#             cutset_train = cutset_list[0]
        
#         def remove_short_utt(c):
#             # remove buggy short utterance
#             if c.duration < 1.0:
#                 return False
#             return True
        
#         cutset_train = cutset_train.filter(remove_short_utt)
        
#         return cutset_train
    
#     def _get_valid_cuts(self):
#         cutset_list = []
#         for dataset in self.cfg.data.valid_sets:
#             logging.info(f"Getting {dataset} cuts for validation")
#             cuts = load_manifest_lazy(dataset)
#             cuts = cuts.resample(16000)
#             if self.cfg.data.valid_data_path_prefix:
#                 data_dir = os.path.dirname(dataset)
#                 cuts = cuts.with_features_path_prefix(data_dir)
                
#             def remove_short_utt(c):
#                 # remove buggy short utterance
#                 if c.duration < 1.0:
#                     return False
#                 return True
#             cuts = cuts.filter(remove_short_utt)
            
#             cutset_list.append(cuts)
#         return cutset_list
    
#     def _get_train_dl(
#         self,
#         cuts_train: CutSet,
#         sampler_state_dict: Optional[Dict[str, Any]] = None,
#     ) -> DataLoader:
#         """
#         Args:
#           cuts_train:
#             CutSet for training.
#           sampler_state_dict:
#             The state dict for the training sampler.
#         """
#         transforms = []
#         if self.cfg.data.enable_musan:
#             logging.info("Enable MUSAN")
#             logging.info("About to get Musan cuts")
#             cuts_musan = load_manifest(self.cfg.data.musan)
#             transforms.append(
#                 CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
#             )
#         else:
#             logging.info("Disable MUSAN")

#         input_transforms = []
#         if self.cfg.data.enable_spec_aug:
#             logging.info("Enable SpecAugment")
#             logging.info(f"Time warp factor: 80")
#             # Set the value of num_frame_masks according to Lhotse's version.
#             # In different Lhotse's versions, the default of num_frame_masks is
#             # different.
#             num_frame_masks = 10
#             num_frame_masks_parameter = inspect.signature(
#                 SpecAugment.__init__
#             ).parameters["num_frame_masks"]
#             if num_frame_masks_parameter.default == 1:
#                 num_frame_masks = 2
#             logging.info(f"Num frame mask: {num_frame_masks}")
#             input_transforms.append(
#                 SpecAugment(
#                     time_warp_factor=80,
#                     num_frame_masks=num_frame_masks,
#                     features_mask_size=27,
#                     num_feature_masks=2,
#                     frames_mask_size=100,
#                 )
#             )
#         else:
#             logging.info("Disable SpecAugment")

#         logging.info("About to create train dataset")
#         if not self.cfg.data.on_the_fly_feats:
#             train = AudioTaggingDataset(
#                 input_strategy=eval(self.cfg.data.input_strategy)(),
#                 cut_transforms=transforms,
#                 input_transforms=input_transforms,
#                 return_cuts=True,
#                 label_field=getattr(self.cfg.data, "label_field", "audio_tag"),
#             )

#         else:
#             # NOTE: the PerturbSpeed transform should be added only if we
#             # remove it from data prep stage.
#             # Add on-the-fly speed perturbation; since originally it would
#             # have increased epoch size by 3, we will apply prob 2/3 and use
#             # 3x more epochs.
#             # Speed perturbation probably should come first before
#             # concatenation, but in principle the transforms order doesn't have
#             # to be strict (e.g. could be randomized)
#             if self.cfg.data.enable_speed_perturb:
#                 transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms   # noqa
#             # Drop feats to be on the safe side.
#             train = AudioTaggingDataset(
#                 cut_transforms=transforms,
#                 input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
#                 input_transforms=input_transforms,
#                 return_cuts=True,
#                 label_field=getattr(self.cfg.data, "label_field", "audio_tag"),
#             )

#         if self.cfg.data.ge2e_sampler:
#             logging.info("Using GE2ESampler.")
#             train_sampler = GE2ECutSampler(
#                 cuts_train,
#                 speakers_per_batch=self.cfg.data.speakers_per_batch,
#                 utts_per_speaker=self.cfg.data.utts_per_speaker,
#                 shuffle=self.cfg.data.shuffle,
#                 drop_last=self.cfg.data.drop_last,
#             )
            
#         else:
#             raise ValueError(
#                 "GE2E Loss requires GE2ESampler (N speakers * M utterances per batch). "
#                 "Please set 'cfg.data.ge2e_sampler' to True."
#             )

#         logging.info("About to create train dataloader")

#         if sampler_state_dict is not None:
#             logging.info("Loading sampler state dict")
#             train_sampler.load_state_dict(sampler_state_dict)

#         # 'seed' is derived from the current random state, which will have
#         # previously been set in the main process.
#         seed = torch.randint(0, 100000, ()).item()
#         worker_init_fn = _SeedWorkers(seed)

#         train_dl = DataLoader(
#             train,
#             sampler=train_sampler,
#             batch_size=None,
#             num_workers=self.cfg.data.num_workers,
#             persistent_workers=True,
#             worker_init_fn=worker_init_fn,
#         )

#         return train_dl


#     def _get_valid_dl(self, valid_dataset):
#         transforms = []

#         logging.info("About to create dev dataset")
#         if self.cfg.data.on_the_fly_feats:
#             validate = AudioTaggingDataset(
#                 cut_transforms=transforms,
#                 input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
#                 return_cuts=True,
#                 label_field=getattr(self.cfg.data, "label_field", "audio_tag"),
#             )
#         else:
#             validate = AudioTaggingDataset(
#                 cut_transforms=transforms,
#                 return_cuts=True,
#                 label_field=getattr(self.cfg.data, "label_field", "audio_tag"),
#             )
            
#         valid_sampler = GE2ECutSampler(
#             valid_dataset,
#             speakers_per_batch=self.cfg.data.speakers_per_batch,
#             utts_per_speaker=self.cfg.data.utts_per_speaker,
#             shuffle=False,
#             )

#         logging.info("About to create dev dataloader")
#         valid_dl = DataLoader(
#             validate,
#             sampler=valid_sampler,
#             batch_size=None,
#             num_workers=8,
#             persistent_workers=False,
#         )

#         return valid_dl
    
    
#     # def audioset_sampling_weights(self):
#     #     logging.info(
#     #         f"About to get the sampling weight for {audioset_subset} in AudioSet"
#     #     )
#     #     weights = []
#     #     with open(
#     #         self.cfg.data.class_weights_file,
#     #         "r",
#     #     ) as f:
#     #         while True:
#     #             line = f.readline()
#     #             if not line:
#     #                 break
#     #             weight = float(line.split()[1])
#     #             weights.append(weight)
#     #     logging.info(f"Get the sampling weight for {len(weights)} cuts")
#     #     return weights