from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from transformers import LlamaForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast

from src.model.modeling_utils import reduce_mean, reduce_sum


class GoldLlamaForCausalLM(LlamaForCausalLM):
    def __init__(self, config, *model_args, **model_kargs):
        super().__init__(config)
        self.model_args = model_kargs["model_args"]
        self.gold_alpha = self.model_args.gold_alpha
        self.gold_beta = self.model_args.gold_beta
        self.score_mode = "softmax"

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        outputs = super().forward(
            input_ids,
            attention_mask,
            position_ids,
            past_key_values,
            inputs_embeds,
            # labels,
            None,
            use_cache,
            output_attentions,
            output_hidden_states,
            return_dict,
        )

        logits = outputs.logits
        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            output_ids = input_ids[..., 1:].contiguous()
            # Flatten the             # Flatten the tokens
            loss_fct = CrossEntropyLoss(reduction="none")
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            unreduced_loss = loss_fct(
                shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)
            )
            loss_mask = shift_labels.ne(-100)
            unreduced_loss = unreduced_loss.view(shift_labels.shape) * loss_mask
            if self.score_mode == "log":
                probs = F.log_softmax(logits, dim=-1)
            else:
                probs = F.softmax(logits, dim=-1)
            probs = (
                torch.gather(probs, 2, output_ids[:, :, None]).squeeze(2) * loss_mask
            )  # TODO check +/-
            unreduced_loss = (
                unreduced_loss
                * torch.clamp(probs**self.gold_alpha, max=self.gold_beta).detach()
            )

            # loss = reduce_mean(unreduced_loss, loss_mask, axis=1)
            loss = reduce_sum(unreduced_loss, loss_mask, axis=1)
            # loss = unreduced_loss.sum() / loss_mask.long().sum()
            loss = loss.mean()

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
