from dataclasses import dataclass
from typing import Any, Dict, List, Union

import torch
from torch.nn.utils.rnn import pad_sequence

from src.utils.utils import AUDIO_MAX_LENGTH, LLM_MAX_LENGTH, IGNORE_INDEX


@dataclass
class DataCollatorForSpeechRecognition:
    extractor: Any
    tokenizer: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:

        # input_features
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        input_features = self.extractor.pad(input_features, return_tensors="pt")['input_features']

        # attention_mask
        attention_mask = [feature["attention_mask"] for feature in features]
        attention_mask = torch.stack(attention_mask)

        # decoder_input_ids
        decoder_input_ids = [feature["decoder_input_ids"] for feature in features]
        decoder_input_ids = pad_sequence(decoder_input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)

        # labels
        labels = [feature['labels'] for feature in features]
        labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)

        # create batch
        batch = dict(
            input_features=input_features,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids[:, :AUDIO_MAX_LENGTH],
            labels=labels[:, :AUDIO_MAX_LENGTH],
        )

        return batch


@dataclass
class DataCollatorForSpeechResponse:
    extractor: Any
    tokenizer: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:

        # input_ids
        input_ids = [feature['input_ids'] for feature in features]
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)

        # labels
        labels = [feature['labels'] for feature in features]
        labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)

        # input_features
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        input_features = self.extractor.pad(input_features, return_tensors="pt")['input_features']

        # encoder_attention_mask
        encoder_attention_mask = [feature["encoder_attention_mask"] for feature in features]
        encoder_attention_mask = torch.stack(encoder_attention_mask)

        # decoder_input_ids
        decoder_input_ids = [feature["decoder_input_ids"] for feature in features]
        decoder_input_ids = pad_sequence(decoder_input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)

        # decoder_labels
        decoder_labels = [feature["decoder_labels"] for feature in features]
        decoder_labels = pad_sequence(decoder_labels, batch_first=True, padding_value=IGNORE_INDEX)

        # audio_start_index
        audio_start_indices = [feature["audio_start_index"] for feature in features]

        # audio_end_index
        audio_end_indices = [min(feature['audio_end_index'], LLM_MAX_LENGTH, feature['audio_start_index'] + AUDIO_MAX_LENGTH) for feature in features]

        # create batch
        batch = dict(
            input_ids=input_ids,
            labels=labels,
            input_features=input_features,
            encoder_attention_mask=encoder_attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_labels=decoder_labels,
            audio_start_indices=audio_start_indices,
            audio_end_indices=audio_end_indices,
        )

        return batch
