from typing import List, Tuple
from torch import FloatTensor, LongTensor, Tensor
# from torch._C import FloatTensor, LongTensor
from transformers import LlamaForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast

class llama_GA(LlamaForCausalLM):
    def __init__(self, config):
        super().__init__(config)
    
    def forward(self, input_ids: LongTensor = None, attention_mask: Tensor | None = None, position_ids: LongTensor | None = None, past_key_values: List[FloatTensor] | None = None, inputs_embeds: FloatTensor | None = None, labels: LongTensor | None = None, use_cache: bool | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None) -> Tuple | CausalLMOutputWithPast:
        outputs = super().forward(input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
        if outputs.loss is not None:
            outputs.loss *= -1
        return outputs