import torch
import transformers
import numpy

from transformers.models.auto.modeling_auto import (
    MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
    MODEL_MAPPING_NAMES,
)
from usw_rbf import OTLossKernel

class NewTrainer(transformers.Trainer):
    
    def compute_loss(self, model, inputs, return_outputs=False):
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
        # print("input keys: ", inputs.keys())
        # print("input ids: ", inputs["input_ids"].shape)
        # print("Attn mask: ", inputs["attention_mask"].shape)
        # print("labels: ", inputs["labels"].shape)
        # print("audio input: ", inputs["audio_input"].shape)
        # print("labels: ", inputs["labels"][0])
        # print("input ids: ", inputs["input_ids"][0])
        emb_input = model.get_base_model().get_input_embeddings()(inputs["input_ids"])
        # print("emb input: ", emb_input.shape)
        usw_rbf = OTLossKernel(pos_dim=4096)
        ids = inputs["input_ids"]
        # print(torch.where(ids==29992))
        # print(torch.where(ids==29992).shape)
        input_ids_len= torch.where(ids==29992)[1].item()
        # input_ids_len= torch.where(ids==29992)[1]
        # input_ids_len = 
        # print("boudary index: ", input_ids_len.item())
        # print("audio input: ", inputs["audio_input"].shape)
        # print("labels: ", inputs["labels"].shape)
        # print("*"*90)

        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
        outputs = model(**inputs, output_hidden_states=True)
        last_hidden_state =  outputs["hidden_states"][-1]
        # first_hidden_state = outputs["hidden_states"][0]
        enc_hid = emb_input[:, :input_ids_len, :]
        dec_hid = last_hidden_state[:, input_ids_len:, :]

        sim_loss = usw_rbf(enc_hid, None, dec_hid, None)
        # print("outputs dict: ", outputs.keys())
        # print("enc hid shape: ", first_hidden_state.shape)
        # print("dec hid shape: ", last_hidden_state.shape)
        
        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if labels is not None:
            if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
                loss = self.label_smoother(outputs, labels, shift_labels=True)
            else:
                loss = self.label_smoother(outputs, labels)
        else:
            if isinstance(outputs, dict) and "loss" not in outputs:
                raise ValueError(
                    "The model did not return a loss from the inputs, only the following keys: "
                    f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
                )
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
        # print("ar loss: ", loss)
        # print("sim loss: ", sim_loss)
        # print("*"*90)
        alpha=0.1
        loss = (1-alpha)*loss - alpha*sim_loss
        # loss = loss
        return (loss, outputs) if return_outputs else loss