import random
from typing import List, Any

from datasets import concatenate_datasets
from torch.utils.data import Dataset
from transformers.utils import logging

from src.utils.data_utils import get_audio_dataset
from src.utils.utils import QWEN2_START_TOKEN, QWEN2_END_TOKEN, preprocess_text_for_training, print_rank_0

logger = logging.get_logger(__name__)


class SpeechRecognitionDataset(Dataset):

    def __init__(
        self,
        dataset_paths: List[str] = None,
        ratios: List[float] = None,
        start_token: str = QWEN2_START_TOKEN,
        end_token: str = QWEN2_END_TOKEN,
        extractor: Any = None,
        tokenizer: Any = None,
        split: str = 'train'
    ):
        super().__init__()

        self.dataset_paths = dataset_paths
        assert len(dataset_paths) == len(ratios), f"dataset_paths and ratios must have the same length, got {len(dataset_paths)} and {len(ratios)}"
        self.start_token = start_token
        self.end_token = end_token

        all_datasets = []
        for dataset_path, ratio in zip(dataset_paths, ratios):
            
            dataset = get_audio_dataset(dataset_path=dataset_path, split=split)

            print_rank_0(f"@@@ Loaded dataset from '{dataset_path}' with {len(dataset)} samples")

            if ratio > 1.0:
                print_rank_0(f"*** Repeat all {len(dataset)} samples from dataset '{dataset_path}' for {ratio} times")
                
                while ratio > 1.0:
                    all_datasets.append(dataset)
                    ratio -= 1.0

            if ratio == 1.0:
                all_datasets.append(dataset)
                print_rank_0(f"+++ Using all {len(dataset)} samples from dataset '{dataset_path}'")
                continue

            indices = random.sample(range(len(dataset)), int(len(dataset) * ratio))
            indices = sorted(indices)
            dataset = dataset.select(indices)
            print_rank_0(f"^^^ Sampled {len(dataset)} samples from {dataset_path} with ratio {ratio}")

            all_datasets.append(dataset)

        self.dataset = concatenate_datasets(all_datasets, axis=0)

        self.extractor = extractor
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):

        try:
            item = self.dataset[index]
        except:
            item = self.dataset[index - 1]  # in rare cases where the reading failed
        
        sample = item['audio']

        text = preprocess_text_for_training(item['text'])
        text = self.start_token + text + self.end_token # add start and end tokens

        decoder_input_ids = self.tokenizer(text, return_tensors="pt")['input_ids'][0]
        labels = decoder_input_ids.clone()
        
        inputs = self.extractor(sample['array'], sampling_rate=sample['sampling_rate'], return_tensors="pt", return_attention_mask=True)
        input_features = inputs.input_features[0]
        attention_mask = inputs.attention_mask[0]
        
        return {
            'input_features': input_features,
            'attention_mask': attention_mask,
            'decoder_input_ids': decoder_input_ids,
            'labels': labels,
        }
