import os
import random
from typing import List, Any

import torch.distributed as dist
from datasets import load_from_disk, concatenate_datasets
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers.utils import logging

from src.utils.data_utils import get_audio_dataset
from src.utils.utils import QWEN2_START_TOKEN, QWEN2_END_TOKEN, IGNORE_INDEX, preprocess_text_for_training, remove_emoji, print_rank_0

logger = logging.get_logger(__name__)


class SpeechResponseDataset(Dataset):

    def __init__(
        self,
        dataset_paths: List[str] = None,
        ratios: List[float] = None,
        start_token: str = QWEN2_START_TOKEN,
        end_token: str = QWEN2_END_TOKEN,
        extractor: Any = None,
        tokenizer: Any = None,
        split: str = 'train'
    ):
        super().__init__()

        self.dataset_paths = dataset_paths
        assert len(dataset_paths) == len(ratios), f"dataset_paths and ratios must have the same length, got {len(dataset_paths)} and {len(ratios)}"
        self.start_token = start_token
        self.end_token = end_token

        columns_to_keep = ['audio', 'text', 'response', 'prompt']

        all_datasets = []
        for dataset_path, ratio in zip(dataset_paths, ratios):
            
            dataset = load_from_disk(os.path.join(dataset_path, split))
            dataset = dataset.filter(lambda x: x['finish_reason'] != 'length')

            dataset_name = dataset_path.split('/')[-1]
            audio_dataset = get_audio_dataset(dataset_name=dataset_name, split=split)

            # select audio
            audio_dataset = audio_dataset.select(dataset['id'])

            # remove overlapped columns in audio_dataset, we should use the columns in dataset
            intersect_columns = set(dataset.column_names) & set(audio_dataset.column_names)
            audio_dataset = audio_dataset.remove_columns(intersect_columns)

            # concatenate the dataset
            dataset = concatenate_datasets([dataset, audio_dataset], axis=1)
            
            # remove unnecessary columns
            dataset = dataset.remove_columns([col for col in dataset.column_names if col not in columns_to_keep])

            print_rank_0(f"@@@ Loaded dataset from '{dataset_path}' with {len(dataset)} samples")

            if ratio > 1.0:
                print_rank_0(f"*** Repeat all {len(dataset)} samples from dataset '{dataset_path}' for {ratio} times")
                
                while ratio > 1.0:
                    all_datasets.append(dataset)
                    ratio -= 1.0

            if ratio == 1.0:
                all_datasets.append(dataset)
                print_rank_0(f"+++ Using all {len(dataset)} samples from dataset '{dataset_path}'")
                continue

            indices = random.sample(range(len(dataset)), int(len(dataset) * ratio))
            indices = sorted(indices)
            dataset = dataset.select(indices)
            print_rank_0(f"^^^ Sampled {len(dataset)} samples from {dataset_path} with ratio {ratio}")

            all_datasets.append(dataset)

        self.dataset = concatenate_datasets(all_datasets, axis=0)

        self.extractor = extractor
        self.tokenizer = tokenizer

        self._lengths = None
        

    def __len__(self):
        return len(self.dataset)

    @property
    def lengths(self) -> List[int]:
        if self._lengths is None:
            self._lengths = []
            tmp_dataset = self.dataset.remove_columns(["audio", ])
            is_rank_0 = not dist.is_initialized() or dist.get_rank() == 0
            for item in tqdm(tmp_dataset, disable=not is_rank_0):
                _tmp = item["prompt"] + " " + item["text"] + " " + item["response"]
                self._lengths.append(len(_tmp.split()))
        return self._lengths[:len(self)]

    def __getitem__(self, index):
        
        try:
            item = self.dataset[index]
        except:
            item = self.dataset[index - 1]  # in rare cases where the reading failed

        sample = item['audio']

        inputs = self.extractor(sample['array'], sampling_rate=sample['sampling_rate'], return_tensors="pt", return_attention_mask=True)
        input_features = inputs.input_features[0]
        encoder_attention_mask = inputs.attention_mask[0]

        text = preprocess_text_for_training(item['text'])
        response = remove_emoji(item['response'])

        decoder_text = self.start_token + text + self.end_token

        decoder_input_ids = self.tokenizer(decoder_text, return_tensors="pt")['input_ids'][0]
        decoder_labels = decoder_input_ids.clone()

        prompt = item['prompt']
        dialogue_without_audio = [{'role': 'user', 'content': f'{prompt}'}]
        dialogue_with_audio = [{'role': 'user', 'content': f"{prompt}{text}"}]

        postfix_length = 2
        audio_start_index = len(
            self.tokenizer.apply_chat_template(dialogue_without_audio, tokenize=True, add_generation_prompt=False)
        ) - postfix_length
        audio_end_index = len(
            self.tokenizer.apply_chat_template(dialogue_with_audio, tokenize=True, add_generation_prompt=False)
        ) - postfix_length
        response_start_index = len(
            self.tokenizer.apply_chat_template(dialogue_with_audio, tokenize=True, add_generation_prompt=True)
        )

        dialogue = [
            {'role': 'user', 'content': f'{prompt}{text}'},
            {'role': 'assistant', 'content': response},
        ]
        llm_text = self.tokenizer.apply_chat_template(dialogue, tokenize=False, add_generation_prompt=False).strip()

        input_ids = self.tokenizer(llm_text, return_tensors='pt')['input_ids'][0]
        labels = input_ids.clone()
        labels[:response_start_index] = IGNORE_INDEX

        return {
            'input_ids': input_ids,
            'labels': labels,
            'input_features': input_features,
            'encoder_attention_mask': encoder_attention_mask,
            'decoder_input_ids': decoder_input_ids,
            'decoder_labels': decoder_labels,
            'audio_start_index': audio_start_index,
            'audio_end_index': audio_end_index,
        }
