import transformers
from torch import Tensor, LongTensor, FloatTensor
from torch import nn
from transformers import AutoConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy


class HiddenVariableAndRewardHead(nn.Module):
    def __init__(self, config, **kwargs):
        super().__init__()
        dropout_prob = getattr(config, "summary_dropout_prob", kwargs.get("summary_dropout_prob", 0.1))
        self.dropout1 = nn.Dropout(dropout_prob) if dropout_prob else nn.Identity()
        self.dropout2 = nn.Dropout(dropout_prob) if dropout_prob else nn.Identity()

        # some models such as OPT have a projection layer before the word embeddings - e.g. OPT-350m
        if hasattr(config, "word_embed_proj_dim"):
            hidden_size = config.word_embed_proj_dim
        elif (
            getattr(config, "is_encoder_decoder", False)
            and hasattr(config, "decoder")
            and hasattr(config.decoder, "hidden_size")
        ):
            hidden_size = config.decoder.hidden_size
        else:
            hidden_size = config.hidden_size
        self.n_candidates = getattr(config, "n_candidates", kwargs.get("n_candidates", 1))
        self.hidden_projection = nn.Linear(hidden_size, hidden_size * self.n_candidates)
        self.reward = nn.Linear(hidden_size * self.n_candidates, self.n_candidates)

    def forward(self, hidden_states):
        hidden_variable = self.hidden_projection(self.dropout1(hidden_states))
        rewards = self.reward(self.dropout2(hidden_variable))
        return hidden_variable, rewards


def load_model_with_hidden_variable_and_reward_head(
    pretrained_model_name_or_path: str, n_candidates: int, use_fixed_liger: bool = False, **kwargs
):
    config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
    model_class_name = config.architectures[0]
    if model_class_name == "AutoModelForCausalLMWithHiddenVariableAndRewardHead":
        model_class_name = config.original_model_class
    model_class = getattr(transformers, model_class_name)

    from transformers.loss.loss_utils import nn
    nn.functional.cross_entropy = liger_cross_entropy

    class AutoModelForCausalLMWithHiddenVariableAndRewardHead(model_class):
        def __init__(self, *args, **kwargs):
            self.n_candidates = getattr(config, "n_candidates", kwargs.pop("n_candidates", 1))
            super().__init__(*args, **kwargs)
            self.hidden_variable_and_reward_head = HiddenVariableAndRewardHead(self.config, n_candidates=self.n_candidates)
            self.config.original_model_class = model_class_name

        def forward(
            self,
            prompt_input_ids: Tensor,
            sequence_input_ids: Tensor,
            hidden_variable_positions: Tensor,
            prompt_attention_mask: Tensor | None = None,
            sequence_attention_mask: Tensor | None = None,
            labels: Tensor | None = None,
            **kwargs,
        ):
            """Override to compute hidden variables and rewards.

            Args:
                prompt_input_ids: Must be padded on the left side.
                sequence_input_ids: Must be padded on the left side.
            """
            # First forward pass for the prompt
            input_embeddings = self.get_input_embeddings()
            prompt_embeddings = input_embeddings(prompt_input_ids)
            outputs = self.model(
                inputs_embeds=prompt_embeddings,
                attention_mask=prompt_attention_mask,
                return_dict=True,
            )

            # Get hidden variables
            last_hidden_state = outputs.last_hidden_state[:, -1, :]
            hidden_variable, rewards = self.hidden_variable_and_reward_head(last_hidden_state)
            sequence_embeddings = input_embeddings(sequence_input_ids)
            sequence_embeddings[range(sequence_embeddings.shape[0]), hidden_variable_positions] = hidden_variable.view(
                sequence_embeddings.shape[0], -1
            )  # Insert hidden variables into the sequences

            # Second forward pass for the responses
            kwargs = {
                **kwargs,
                "inputs_embeds": sequence_embeddings,
                "attention_mask": sequence_attention_mask,
                "labels": labels,
                # "output_hidden_states": True
            }
            outputs = lce_forward(self, **kwargs) if use_fixed_liger else super().forward(**kwargs)

            return outputs.loss, rewards

    return AutoModelForCausalLMWithHiddenVariableAndRewardHead.from_pretrained(
        pretrained_model_name_or_path, n_candidates=n_candidates, **kwargs
    )


def lce_forward(
    self,
    input_ids: Tensor | None = None,
    attention_mask: Tensor | None = None,
    position_ids: LongTensor | None = None,
    past_key_values: list[FloatTensor] | None = None,
    inputs_embeds: FloatTensor | None = None,
    labels: LongTensor | None = None,
    use_cache: bool | None = None,
    output_attentions: bool | None = None,
    output_hidden_states: bool | None = None,
    return_dict: bool | None = None,
    cache_position: LongTensor | None = None,
    num_logits_to_keep: int = 0,
    **loss_kwargs,
) -> CausalLMOutputWithPast:
    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
    output_hidden_states = (
        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
    )
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
    outputs = self.model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_values=past_key_values,
        inputs_embeds=inputs_embeds,
        use_cache=use_cache,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
        cache_position=cache_position,
    )

    hidden_states = outputs[0]

    loss = None
    # Don't materialize logits
    if labels is not None:
        # We do the same thing as ForCausalLMLoss but using Liger FLCE

        shift_hidden_states = hidden_states[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()

        # flatten tokens
        shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
        shift_labels = shift_labels.view(-1)
        reduction = "sum" if "num_items_in_batch" in loss_kwargs and loss_kwargs["num_items_in_batch"] else "mean"
        lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
        loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
        if reduction == "sum":
            loss /= loss_kwargs["num_items_in_batch"]

    return CausalLMOutputWithPast(
        loss=loss,
        logits=None,
        past_key_values=outputs.past_key_values,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )
