import os
from typing import Optional, Tuple, Union, List

import torch
import torch.nn as nn
import torch.distributed as dist
from transformers import AutoModelForCausalLM
from transformers.cache_utils import EncoderDecoderCache
from transformers.modeling_outputs import Seq2SeqLMOutput
from transformers.models.whisper.modeling_whisper import shift_tokens_right

from src.model.audio_model.modeling_alignchat import AlignChatForConditionalGeneration, Seq2SeqLMOutputWithLastHiddenState
from src.utils.loss_utils import LM_LOSSES, NORM_LOSSES, SIM_LOSSES
from src.utils.utils import MODEL_DIR, IGNORE_INDEX


class AlignChatForEndToEndResponse(nn.Module):

    def __init__(
        self,
        audio_model_path: str,
        audio_model_type: str,
        text_model_path: str,
        device_map: str = 'auto',
        torch_dtype = None,
        attn_implementation: str = None
    ):

        super().__init__()

        self.audio_model = AlignChatForConditionalGeneration.from_pretrained(audio_model_path, device_map=device_map)
        self.text_model = AutoModelForCausalLM.from_pretrained(text_model_path, device_map=device_map, torch_dtype=torch_dtype, attn_implementation=attn_implementation)

        self.embed_tokens = torch.load(os.path.join(MODEL_DIR, f'alignchat/{audio_model_type}/embed_tokens.pt'), map_location=device_map, weights_only=False)
        self.embed_tokens.requires_grad_(False)

        self.proj_out = torch.load(os.path.join(MODEL_DIR, f'alignchat/{audio_model_type}/proj_out.pt'), map_location=device_map, weights_only=False)
        self.proj_out.requires_grad_(False)

        self._current_step = 0
        self._logging_file = None

    def _set_logging_file(self, filename):
        if not dist.is_initialized() or dist.get_rank() == 0:
            self._logging_file = open(filename, mode='a', encoding='utf-8')

    def forward(
        self,

        # text_model inputs
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,

        # audio_model inputs
        input_features: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.LongTensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        decoder_head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        decoder_past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None,
        decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
        decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
        decoder_labels: Optional[torch.LongTensor] = None,
        # === Added by AlignChat ===
        loss_ratios: List[float] = [1.0, 1.0, 5.0],
        loss_types: List[str] = ["lm", "l1", "cos_l2"],
        audio_start_indices: Optional[List[int]] = None,
        audio_end_indices: Optional[List[int]] = None,
        # ===

        # rest parameters
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        num_logits_to_keep: int = 0,
        **loss_kwargs,
    ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:

        return_dict = return_dict if return_dict is not None else self.audio_model.config.use_return_dict

        if decoder_labels is not None:
            if decoder_labels.shape[1] > self.audio_model.max_target_positions:
                raise ValueError(
                    f"Input length of {decoder_labels.shape[1]} is longer than the maximum length for this model ({self.audio_model.max_target_positions})"
                )
            if decoder_input_ids is None and decoder_inputs_embeds is None:
                decoder_input_ids = shift_tokens_right(
                    decoder_labels, self.audio_model.config.pad_token_id, self.audio_model.config.decoder_start_token_id
                )

        # Set labels to None to avoid loss computation
        audio_outputs = self.audio_model(
            input_features=input_features,
            attention_mask=encoder_attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            head_mask=head_mask,
            decoder_head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            encoder_outputs=encoder_outputs,
            past_key_values=decoder_past_key_values,
            decoder_inputs_embeds=decoder_inputs_embeds,
            decoder_position_ids=decoder_position_ids,
            labels=None,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
            num_logits_to_keep=num_logits_to_keep,
            loss_ratios=None,
            loss_types=None,
            embed_tokens=self.embed_tokens,
            proj_out=self.proj_out,
        )

        # retrieve logits and hidden states
        audio_lm_logits = audio_outputs.logits
        audio_hidden_states = audio_outputs.decoder_last_hidden_state

        loss = None
        if decoder_labels is not None and labels is not None:
            msg = f""
            num_items_in_batch = loss_kwargs.pop("num_items_in_batch", None)
            if num_items_in_batch is not None:
                num_items_in_batch_decoder_labels = num_items_in_batch[0]
                num_items_in_batch_labels = num_items_in_batch[1]
            else:
                num_items_in_batch_decoder_labels = None
                num_items_in_batch_labels = None

            asr_loss = LM_LOSSES[loss_types[0]](audio_lm_logits, decoder_labels, self.audio_model.config.vocab_size, num_items_in_batch=num_items_in_batch_decoder_labels, **loss_kwargs)
            
            embedding_labels = decoder_labels.clone()
            embedding_labels[embedding_labels == self.audio_model.config.vocab_configs['end_token_id']] = IGNORE_INDEX
            
            # === Compute representation alignment loss ===
            norm_loss = NORM_LOSSES[loss_types[1]](audio_hidden_states, embedding_labels, self.embed_tokens, num_items_in_batch=num_items_in_batch_decoder_labels, **loss_kwargs)
            sim_loss  =  SIM_LOSSES[loss_types[2]](audio_hidden_states, embedding_labels, self.embed_tokens, num_items_in_batch=num_items_in_batch_decoder_labels, **loss_kwargs)
            # ===
            
            loss = loss_ratios[0] * asr_loss + loss_ratios[1] * norm_loss + loss_ratios[2] * sim_loss

            # use inputs_embeds
            if inputs_embeds is None:
                inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
            audio_hidden_states.to(inputs_embeds.device, inputs_embeds.dtype)

            # fill the audio placeholder with hidden_states
            for i in range(inputs_embeds.size(0)):
                # audio length exceeds the max length
                if audio_start_indices[i] >= audio_end_indices[i]:
                    continue
                
                try:
                    inputs_embeds[i, audio_start_indices[i]:audio_end_indices[i]] = inputs_embeds[i, audio_start_indices[i]:audio_end_indices[i]] * 0.0 \
                        + audio_hidden_states[i, :audio_end_indices[i] - audio_start_indices[i]] * 1.0
                except Exception as e:
                    print(e)
                    print(f"Error: audio_start_indices[{i}]: {audio_start_indices[i]}, audio_end_indices[{i}]: {audio_end_indices[i]}")
                    print(f"inputs_embeds.size(): {inputs_embeds.size()}")
                    print(f"hidden_states.size(): {audio_hidden_states.size()}")
                    exit(0)

            # use text_model to compute the loss
            text_outputs = self.text_model(
                input_ids=None,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_values=past_key_values,
                inputs_embeds=inputs_embeds,
                labels=labels,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                cache_position=cache_position,
                num_logits_to_keep=num_logits_to_keep,
                num_items_in_batch=num_items_in_batch_labels,
                **loss_kwargs,
            )

            text_loss = text_outputs.loss   # end-to-end loss
            loss += text_loss

            if not dist.is_initialized() or dist.get_rank() == 0:
                if self._logging_file is not None and self._current_step % 50 == 0:
                    loss_log = {
                        'asr_loss' :  asr_loss.item(),
                        'norm_loss': norm_loss.item(),
                        'sim_loss' :  sim_loss.item(),
                    }

                    if num_items_in_batch_decoder_labels is not None:
                        num_items_current = decoder_labels.ne(-100).sum()
                        for k, v in loss_log.items():
                            loss_log[k] = v / num_items_current * num_items_in_batch_decoder_labels

                    msg += f"step: {self._current_step}"
                    for k, v in loss_log.items():
                        msg += f" | {k}: {v:.6f}"

                    if num_items_in_batch_labels is not None:
                        num_items_current = labels.ne(-100).sum()
                        text_loss_item = text_loss.item() / num_items_current * num_items_in_batch_labels

                    msg += f" | text_loss: {text_loss_item:.6f}"

                    self._logging_file.write(msg + '\n')
                    self._logging_file.flush()
            
            self._current_step += 1

        if not return_dict:
            output = (audio_lm_logits,) + audio_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return Seq2SeqLMOutputWithLastHiddenState(
            loss=loss,
            logits=audio_lm_logits,
            past_key_values=audio_outputs.past_key_values,
            decoder_last_hidden_state=audio_hidden_states,
            decoder_hidden_states=audio_outputs.decoder_hidden_states,
            decoder_attentions=audio_outputs.decoder_attentions,
            cross_attentions=audio_outputs.cross_attentions,
            encoder_last_hidden_state=audio_outputs.encoder_last_hidden_state,
            encoder_hidden_states=audio_outputs.encoder_hidden_states,
            encoder_attentions=audio_outputs.encoder_attentions,
        )
