import types
import torch

from typing import Dict
from transformers import (
    AutoModelForCausalLM,
    GPT2LMHeadModel,
    LlamaForCausalLM,
)


def register_model_forward_function(
    model: AutoModelForCausalLM,
):
    if isinstance(model, GPT2LMHeadModel):
        model.compute = types.MethodType(compute_gpt2, model)
    elif isinstance(model, LlamaForCausalLM):
        model.compute = types.MethodType(compute_llama2, model)


def compute_gpt2(self: GPT2LMHeadModel, batch: Dict[str, torch.Tensor]):
    transformer_outputs = self.transformer(
        batch["input_ids"]
    )
    hidden_states = transformer_outputs[0]
    lm_logits = self.lm_head(hidden_states)
    return lm_logits


def compute_gpt2_with_positions(self: GPT2LMHeadModel, batch: Dict[str, torch.Tensor]):
    transformer_outputs = self.transformer(
        batch["input_ids"],
        position_ids=batch["position_ids"]
    )
    hidden_states = transformer_outputs[0]
    seq_len = batch['labels'].size(1)
    hidden_states = hidden_states[:, -seq_len:]
    lm_logits = self.lm_head(hidden_states)
    return lm_logits


def compute_llama2(self: LlamaForCausalLM, batch: Dict[str, torch.Tensor]):
    outputs = self.model(
        input_ids=batch["input_ids"],
    )
    hidden_states = outputs[0]
    lm_logits = self.lm_head(hidden_states)
    return lm_logits

def compute_llama2_test(self: LlamaForCausalLM, batch: Dict[str, torch.Tensor]):
    outputs = self.model(
        input_ids=batch["input_ids"],
    )
    hidden_states = outputs[0]
    lm_logits = self.lm_head(hidden_states)
    return lm_logits, lm_logits


# def compute_llama2_with_positions(self: LlamaForCausalLM, batch: Dict[str, torch.Tensor]):
#     outputs = self.model(
#         input_ids=batch["input_ids"],
#         position_ids=batch["position_ids"],
#     )
#     hidden_states = outputs[0]
#     seq_len = batch['labels'].size(1)
#     hidden_states = hidden_states[:, -seq_len:]
#     lm_logits = self.lm_head(hidden_states)
#     return lm_logits

# Adapted for double loss in mixed_training
def compute_llama2_with_positions(self: LlamaForCausalLM, batch: Dict[str, torch.Tensor]):
    outputs = self.model(
        input_ids=batch["input_ids"],
        position_ids=batch["position_ids"],
    )
    outputs_std = self.model(     
        input_ids=batch["standard"]["input_ids"],
    )

    hidden_states = outputs[0]
    seq_len = batch['labels'].size(1)
    hidden_states = hidden_states[:, -seq_len:]
    lm_logits = self.lm_head(hidden_states)

    hidden_states_std = outputs_std[0]
    lm_logits_std = self.lm_head(hidden_states_std)

    return lm_logits, lm_logits_std


