import os
import torch
from torch import nn
from transformers.trainer_pt_utils import LabelSmoother
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model
from .utils import replace_whisper_encoder_forward

IGNORE_TOKEN_ID = LabelSmoother.ignore_index
DEFAULT_SPEECH_TOKEN = "<speech>"

class EncoderProjector(nn.Module):
    """
    The encoder projector module. It is used to project the encoder outputs to the same dimension as the language model.
    Modified from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py.
    Args:
        encoder_dim (:obj:`int`): The dimension of the encoder outputs.
        llm_dim (:obj:`int`): The dimension of the language model.
        downsample_rate (:obj:`int`, `optional`, defaults to 5): The downsample rate to use.
    """

    def __init__(self, encoder_dim, llm_dim, downsample_rate=5):
        super().__init__()
        self.downsample_rate = downsample_rate
        self.linear1 = nn.Linear(encoder_dim * self.downsample_rate, llm_dim)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(llm_dim, llm_dim)

    def forward(self, x):

        batch_size, seq_len, feat_dim = x.size()
        num_frames_to_discard = seq_len % self.downsample_rate
        if num_frames_to_discard > 0:
            x = x[:, :-num_frames_to_discard, :]
        seq_len = x.size(1)

        x = x.contiguous()
        x = x.view(
            batch_size, seq_len // self.downsample_rate, feat_dim * self.downsample_rate
        )

        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x


class AsrLLMModel(nn.Module):
    """
    The Speech-to-Text model. It consists of an encoder, a language model and an encoder projector.
    The encoder is used to extract speech features from the input speech signal.
    The encoder projector is used to project the encoder outputs to the same dimension as the language model.
    The language model is used to generate the text from the speech features.
    Args:
        encoder (:obj:`nn.Module`): The encoder module.
        llm (:obj:`nn.Module`): The language model module.
        encoder_projector (:obj:`nn.Module`): The encoder projector module.
    """

    def make_pad_mask(self, lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
        """
        Args:
        lengths:
            A 1-D tensor containing sentence lengths.
        max_len:
            The length of masks.
        Returns:
        Return a 2-D bool tensor, where masked positions
        are filled with `True` and non-masked positions are
        filled with `False`.

        >>> lengths = torch.tensor([1, 3, 2, 5])
        >>> make_pad_mask(lengths)
        tensor([[False,  True,  True,  True,  True],
                [False, False, False,  True,  True],
                [False, False,  True,  True,  True],
                [False, False, False, False, False]])
        """
        assert lengths.ndim == 1, lengths.ndim
        max_len = max(max_len, lengths.max())
        n = lengths.size(0)
        seq_range = torch.arange(0, max_len, device=lengths.device)
        expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)

        return expaned_lengths >= lengths.unsqueeze(-1)
    
    def __init__(
        self,
        config,
    ):
        super().__init__()
        self.config = config
        self.speech_encoder_type = config.speech_encoder_type
        self.speech_encoder_path = config.speech_encoder_path
        self.llm_path = config.llm_path
        self.use_flash_attn = config.use_flash_attn
        self.stage = config.stage
        self.pretrained_stage1_model_path = config.pretrained_stage1_model_path
        self.encoder_projector_ds_rate = config.encoder_projector_ds_rate
        
        if self.speech_encoder_type == 'whisper':
            import whisper
            replace_whisper_encoder_forward()
            whisper_model = whisper.load_model(self.speech_encoder_path, "cpu")
            self.encoder = whisper_model.encoder
            self.encoder_dim = whisper_model.dims.n_audio_state
        elif self.speech_encoder_type == 'zipformer':
            from auden.auto.auto_model import AutoModel
            model_dir, model_filename = os.path.split(self.speech_encoder_path)
            zipformer_model = AutoModel.from_pretrained(
                exp_dir=model_dir, 
                checkpoint_filename=model_filename)
            self.encoder_embed = zipformer_model.encoder_embed
            self.encoder = zipformer_model.encoder
            self.encoder_dim = zipformer_model.encoder_out_dim   
        for _, param in self.encoder.named_parameters():
                param.requires_grad = False
        self.encoder.eval()
        
        self.tokenizer = AutoTokenizer.from_pretrained(self.llm_path)
        if self.use_flash_attn:
            self.attn_implementation = "flash_attention_2"
            self.tokenizer.padding_side = "left"
        else:
            self.attn_implementation = "sdpa"
            self.tokenizer.padding_side = "right"
        
        self.llm = AutoModelForCausalLM.from_pretrained(
            self.llm_path,
            attn_implementation=self.attn_implementation,
            torch_dtype=torch.float16,
        )
        
        if self.stage == 1:
            self.use_lora = False
            self.unfreeze_llm = False
        elif self.stage == 2:
            self.use_lora = True
            self.unfreeze_llm = True
        
        if not self.unfreeze_llm:
            for _, param in self.llm.named_parameters():
                param.requires_grad = False
            self.llm.eval()
        else:
            if self.use_lora:
                lora_config = LoraConfig(
                    r=config.lora_config.lora_rank,
                    lora_alpha=config.lora_config.lora_alpha,
                    target_modules=config.lora_config.target_modules,
                    lora_dropout=config.lora_config.lora_dropout,
                    task_type=config.lora_config.task_type,
                )
                self.llm = get_peft_model(self.llm, lora_config)
                self.llm.print_trainable_parameters()
        
        special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
        self.tokenizer.add_special_tokens(special_tokens_dict)
        self.llm.config.pad_token_id = self.tokenizer.pad_token_id
        self.llm.config.default_speech_token_id = self.tokenizer.convert_tokens_to_ids(
            DEFAULT_SPEECH_TOKEN
        )
        
        self.encoder_projector = EncoderProjector(self.encoder_dim, self.llm.config.hidden_size, self.encoder_projector_ds_rate)
        
        if self.pretrained_stage1_model_path and self.stage == 2:
            checkpoint = torch.load(self.pretrained_stage1_model_path, map_location="cpu")
            self.encoder_projector.load_state_dict(checkpoint, strict=False)

    def _merge_input_ids_with_speech_features(
        self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None
    ):
        """
        Merge the speech features with the input_ids and attention_mask. This is done by replacing the speech tokens
        with the speech features and padding the input_ids to the maximum length of the speech features.
        Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py#L277.
        Args:
            speech_features (:obj:`torch.Tensor`): The speech features to merge with the input_ids.
            inputs_embeds (:obj:`torch.Tensor`): The embeddings of the input_ids.
            input_ids (:obj:`torch.Tensor`): The input ids to merge.
            attention_mask (:obj:`torch.Tensor`): The attention mask to merge.
            labels (:obj:`torch.Tensor`, `optional`): The labels to merge.
        Returns:
            :obj:`Tuple(torch.Tensor)`: The merged embeddings, attention mask, labels and position ids.
        """
        num_speechs, speech_len, embed_dim = speech_features.shape
        batch_size, sequence_length = input_ids.shape
        left_padding = not torch.sum(
            input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id)
        )
        # 1. Create a mask to know where special speech tokens are
        special_speech_token_mask = input_ids == self.llm.config.default_speech_token_id
        num_special_speech_tokens = torch.sum(special_speech_token_mask, dim=-1)
        # Compute the maximum embed dimension
        max_embed_dim = (
            num_special_speech_tokens.max() * (speech_len - 1)
        ) + sequence_length
        batch_indices, non_speech_indices = torch.where(
            input_ids != self.llm.config.default_speech_token_id
        )

        # 2. Compute the positions where text should be written
        # Calculate new positions for text tokens in merged speech-text sequence.
        # `special_speech_token_mask` identifies speech tokens. Each speech token will be replaced by `nb_text_tokens_per_speechs - 1` text tokens.
        # `torch.cumsum` computes how each speech token shifts subsequent text token positions.
        # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
        new_token_positions = (
            torch.cumsum((special_speech_token_mask * (speech_len - 1) + 1), -1) - 1
        )
        nb_speech_pad = max_embed_dim - 1 - new_token_positions[:, -1]
        if left_padding:
            new_token_positions += nb_speech_pad[:, None]  # offset for left padding
        text_to_overwrite = new_token_positions[batch_indices, non_speech_indices]

        # 3. Create the full embedding, already padded to the maximum position
        final_embedding = torch.zeros(
            batch_size,
            max_embed_dim,
            embed_dim,
            dtype=inputs_embeds.dtype,
            device=inputs_embeds.device,
        )
        final_attention_mask = torch.zeros(
            batch_size,
            max_embed_dim,
            dtype=attention_mask.dtype,
            device=inputs_embeds.device,
        )
        if labels is not None:
            final_labels = torch.full(
                (batch_size, max_embed_dim),
                IGNORE_TOKEN_ID,
                dtype=input_ids.dtype,
                device=input_ids.device,
            )
        # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
        # set the corresponding tensors into their correct target device.
        target_device = inputs_embeds.device
        batch_indices, non_speech_indices, text_to_overwrite = (
            batch_indices.to(target_device),
            non_speech_indices.to(target_device),
            text_to_overwrite.to(target_device),
        )
        attention_mask = attention_mask.to(target_device)

        # 4. Fill the embeddings based on the mask. If we have ["hey" "<speech>", "how", "are"]
        # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the speech features
        final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[
            batch_indices, non_speech_indices
        ]
        final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[
            batch_indices, non_speech_indices
        ]
        if labels is not None:
            final_labels[batch_indices, text_to_overwrite] = labels[
                batch_indices, non_speech_indices
            ]

        # 5. Fill the embeddings corresponding to the speechs. Anything that is not `text_positions` needs filling (#29835)
        speech_to_overwrite = torch.full(
            (batch_size, max_embed_dim),
            True,
            dtype=torch.bool,
            device=inputs_embeds.device,
        )
        speech_to_overwrite[batch_indices, text_to_overwrite] = False
        speech_to_overwrite &= speech_to_overwrite.cumsum(-1) - 1 >= nb_speech_pad[
            :, None
        ].to(target_device)

        if speech_to_overwrite.sum() != speech_features.shape[:-1].numel():
            raise ValueError(
                f"The input provided to the model are wrong. The number of speech tokens is {torch.sum(special_speech_token_mask)} while"
                f" the number of speech given to the model is {num_speechs}. This prevents correct indexing and breaks batch generation."
            )

        final_embedding[speech_to_overwrite] = (
            speech_features.contiguous().reshape(-1, embed_dim).to(target_device)
        )
        final_attention_mask |= speech_to_overwrite
        position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_(
            (final_attention_mask == 0), 1
        )

        # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
        batch_indices, pad_indices = torch.where(
            input_ids == self.llm.config.pad_token_id
        )
        indices_to_mask = new_token_positions[batch_indices, pad_indices]

        final_embedding[batch_indices, indices_to_mask] = 0

        if labels is None:
            final_labels = None

        return final_embedding, final_attention_mask, final_labels, position_ids

    def train_preprocess(
        self,
        messages,
        tokenizer,
        max_len: int,
    ):
        """Preprocesses the data for supervised fine-tuning."""
        texts = []
        TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
        for i, msg in enumerate(messages):
            texts.append(
                tokenizer.apply_chat_template(
                    msg,
                    tokenize=True,
                    chat_template=TEMPLATE,
                    add_generation_prompt=False,
                    padding="longest",  # FIX me change padding to longest
                    max_length=max_len,
                    truncation=True,
                )
            )
        # padding texts to the same length, texts is a list of list, padding with tokenzier.pad_token_id
        max_len_texts = max([len(text) for text in texts])
        if tokenizer.padding_side == "right":
            texts = [
                text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
                for text in texts
            ]
        else:
            texts = [
                [tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
                for text in texts
            ]
        input_ids = torch.tensor(texts, dtype=torch.int)
        # response = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]
        target_ids = input_ids.clone()
        target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
        # mask all tokens before token_id 151646 with IGNORE_TOKEN_ID
        # first get the indices of the tokens
        mask_prompt = True
        if mask_prompt:
            mask_indices = torch.where(
                input_ids == tokenizer.convert_tokens_to_ids("assistant")
            )
            for i in range(mask_indices[0].size(0)):
                row = mask_indices[0][i]
                col = mask_indices[1][i]
                # + 2 to  skip: 'assistant', '\n'
                target_ids[row, : col + 2] = IGNORE_TOKEN_ID
        # target_ids：只剩groundtruth没有被替换成ignore id
        attention_mask = input_ids.ne(tokenizer.pad_token_id)

        return input_ids, attention_mask, target_ids
    
    def forward(
        self,
        fbank: torch.Tensor = None,
        feature_lens: torch.Tensor = None,
        messages= list,
    ):
        if self.speech_encoder_type == "whisper" or self.speech_encoder_type == "finetuned_whisper":
            fbank = fbank.transpose(1, 2) # (N, C, T)
            encoder_outs = self.encoder(fbank)
        elif self.speech_encoder_type == "zipformer":
            encoder_outs, encoder_outs_lens = self.encoder_embed(fbank, feature_lens)
            src_key_padding_mask = self.make_pad_mask(encoder_outs_lens)
            encoder_outs = encoder_outs.permute(1, 0, 2)
            encoder_outs, _, _ = self.encoder(encoder_outs, encoder_outs_lens, src_key_padding_mask)
            encoder_outs = encoder_outs.permute(1, 0, 2)
        speech_features = self.encoder_projector(encoder_outs)
        speech_features = speech_features.to(torch.float16)
        
        input_ids, attention_mask, target_ids = self.train_preprocess(messages, self.tokenizer, max_len=128)
        
        attention_mask = attention_mask.to(speech_features.device)
        target_ids = target_ids.type(torch.LongTensor).to(speech_features.device)
        input_ids = input_ids.type(torch.LongTensor).to(speech_features.device)
        labels = target_ids

        inputs_embeds = self.llm.get_input_embeddings()(input_ids)

        (
            inputs_embeds,
            attention_mask,
            labels,
            _,
        ) = self._merge_input_ids_with_speech_features(
            speech_features, inputs_embeds, input_ids, attention_mask, labels
        )
        
        model_outputs = self.llm(
            inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels
        )

        with torch.no_grad():
            preds = torch.argmax(model_outputs.logits, -1)
            acc = compute_accuracy(
                preds.detach()[:, :-1],
                labels.detach()[:, 1:],
                ignore_label=IGNORE_TOKEN_ID,
            )
        return model_outputs, acc

    def decode_preprocess(
        self,
        messages,
        tokenizer,
        max_len: int = 128,
    ):
        """Preprocesses the data for supervised fine-tuning."""
        texts = []
        TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
        for i, msg in enumerate(messages):
            texts.append(
                tokenizer.apply_chat_template(
                    msg,
                    tokenize=True,
                    add_generation_prompt=False,
                    chat_template=TEMPLATE,
                    padding="longest",
                    max_length=max_len,
                    truncation=True,
                )
            )
        max_len_texts = max([len(text) for text in texts])
        if tokenizer.padding_side == "right":
            texts = [
                text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
                for text in texts
            ]
        else:
            texts = [
                [tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
                for text in texts
            ]

        input_ids = torch.tensor(texts, dtype=torch.int)

        attention_mask = input_ids.ne(tokenizer.pad_token_id)

        return input_ids, attention_mask
    
    def decode(
        self,
        fbank: torch.Tensor = None,
        feature_lens: torch.Tensor = None,
        speech_encoder_type: str = "whisper",
        messages= list,
        **kwargs,
    ):

        if speech_encoder_type == "whisper" or speech_encoder_type == "finetuned_whisper":
            fbank = fbank.transpose(1, 2)
            encoder_outs = self.encoder(fbank)
        elif speech_encoder_type == "zipformer":
            encoder_outs, encoder_outs_lens = self.encoder_embed(fbank, feature_lens)
            src_key_padding_mask = self.make_pad_mask(encoder_outs_lens.to(torch.int32))
            encoder_outs = encoder_outs.permute(1, 0, 2)
            encoder_outs, _ = self.encoder(encoder_outs, encoder_outs_lens, src_key_padding_mask)
            encoder_outs = encoder_outs.permute(1, 0, 2)
            
        speech_features = self.encoder_projector(encoder_outs)
        speech_features = speech_features.to(torch.float16)
        input_ids, attention_mask = self.decode_preprocess(messages, self.tokenizer, max_len=128)
        input_ids = input_ids.type(torch.LongTensor).to(speech_features.device)
        attention_mask = attention_mask.to(speech_features.device)
        inputs_embeds = self.llm.get_input_embeddings()(input_ids)
        (
            inputs_embeds,
            attention_mask,
            _,
            position_ids,
        ) = self._merge_input_ids_with_speech_features(
            speech_features, inputs_embeds, input_ids, attention_mask
        )
        generated_ids = self.llm.generate(
            inputs_embeds=inputs_embeds,
            max_new_tokens=kwargs.get("max_new_tokens", 200),
            num_beams=kwargs.get("num_beams", 1),
            do_sample=kwargs.get("do_sample", False),
            min_length=kwargs.get("min_length", 1),
            top_p=kwargs.get("top_p", 1.0),
            repetition_penalty=kwargs.get("repetition_penalty", 1.0),
            length_penalty=kwargs.get("length_penalty", 1.0),
            temperature=kwargs.get("temperature", 1.0),
            bos_token_id=self.llm.config.bos_token_id,
            eos_token_id=self.llm.config.eos_token_id,
            pad_token_id=self.llm.config.pad_token_id,
        )

        return generated_ids


def compute_accuracy(pad_outputs, pad_targets, ignore_label):
    """Calculate accuracy.
    Copied from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/utils/metric.py
    Args:
        pad_outputs (LongTensor): Prediction tensors (B, Lmax).
        pad_targets (LongTensor): Target label tensors (B, Lmax).
        ignore_label (int): Ignore label id.

    Returns:
        float: Accuracy value (0.0 - 1.0).

    """
    mask = pad_targets != ignore_label
    numerator = torch.sum(
        pad_outputs.masked_select(mask) == pad_targets.masked_select(mask)
    )
    denominator = torch.sum(mask)
    return numerator.float() / denominator.float()
