from typing import List, Optional, Tuple, Union
import torch
from torch.nn import CrossEntropyLoss
from transformers import LlamaForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast

from src.common.templates import DATA_TYPE_DICT


class MLEAugLlamaForCausalLM(LlamaForCausalLM):
    def __init__(self, config, *model_args, **model_kargs):
        super().__init__(config)
        self.model_args = model_kargs["model_args"]
        self.mle_aug_norm = self.model_args.mle_aug_norm
        self.data_types = DATA_TYPE_DICT[self.model_args.mle_aug]
        self.beta_smoothing = 1.0
        self.mml_lambda = 0.0
        self.mle_lambda = 1.0

    def get_mixed_mle_mml_loss(
        self,
        unreduced_loss: torch.Tensor,
        cached_program_nums: List[int],
        loss_mask=None,
        log_prob_dist: bool = True,
    ) -> torch.Tensor:
        """
        Compute the loss for the MML and MLE.
        """
        # compute the marginal log prob and the sum of the log probs
        grouped_example_log_probs = torch.split(
            -self.beta_smoothing * torch.sum(unreduced_loss, dim=1), cached_program_nums
        )
        if self.mle_aug_norm:
            grouped_loss_mask = torch.split(loss_mask, cached_program_nums)

        marginal_log_probs = torch.stack(
            [
                -1.0 * torch.logsumexp(log_probs, dim=0) / self.beta_smoothing
                for log_probs in grouped_example_log_probs
            ]
        )
        # norm_func = (
        #     (lambda x: 1.0) if not self.mle_aug_norm else (lambda x: 1.0 / len(x))
        # )
        # ((lambda _: 1.0) if not self.mle_aug_norm else (lambda x: 1.0 / x.sum()))

        sum_log_probs = torch.stack(
            [
                -torch.sum(log_probs, dim=0)
                if not self.mle_aug_norm
                else -torch.sum(log_probs, dim=0) / grouped_loss_mask[i].sum()
                for i, log_probs in enumerate(grouped_example_log_probs)
            ]
        )
        loss = torch.mean(
            self.mml_lambda * marginal_log_probs + self.mle_lambda * sum_log_probs
        )

        # if log_prob_dist:
        #     # some additional metrics to evaluate the distribution of the programs
        #     max_prob = [
        #         sorted(torch.exp(log_probs), reverse=True)[0]
        #         for log_probs in grouped_example_log_probs
        #     ]
        #     second_max_prob = [
        #         sorted(torch.exp(log_probs), reverse=True)[1]
        #         if len(log_probs) > 1
        #         else None
        #         for log_probs in grouped_example_log_probs
        #     ]
        #     second_max_prob = list(filter(lambda x: x is not None, second_max_prob))
        #
        #     max_prob_avg = float(
        #         torch.pow(torch.stack(max_prob).mean(), 1.0 / self.beta_smoothing)
        #     )
        #     second_max_prob_avg = (
        #         float(
        #             torch.pow(
        #                 torch.stack(second_max_prob).mean(), 1.0 / self.beta_smoothing
        #             )
        #         )
        #         if len(second_max_prob) > 0
        #         else 0.0
        #     )
        #
        #     self.log("max_prob", max_prob_avg, on_step=False, on_epoch=True)
        #     self.log(
        #         "second_max_prob", second_max_prob_avg, on_step=False, on_epoch=True
        #     )
        #
        return loss

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        outputs = super().forward(
            input_ids,
            attention_mask,
            position_ids,
            past_key_values,
            inputs_embeds,
            # labels,
            None,
            use_cache,
            output_attentions,
            output_hidden_states,
            return_dict,
        )

        logits = outputs.logits
        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the             # Flatten the tokens
            loss_fct = CrossEntropyLoss(reduction="none")
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            unreduced_loss = loss_fct(
                shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)
            )
            loss_mask = shift_labels.ne(-100)
            unreduced_loss = unreduced_loss.view(shift_labels.shape) * loss_mask
            loss = self.get_mixed_mle_mml_loss(
                unreduced_loss, len(self.data_types), loss_mask
            )

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
