from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
import torch.distributed as dist
from transformers import LlamaForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from dataclasses import dataclass


@dataclass
class ClsCausalLMOutputWithPast(CausalLMOutputWithPast):
    cls_logits: torch.FloatTensor = None


class ClsLlamaForCausalLM(LlamaForCausalLM):
    def __init__(self, config, *model_args, **model_kargs):
        super().__init__(config)
        self.model_args = model_kargs["model_args"]
        self.num_labels = 2
        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
        self.post_init()

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        reward: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        outputs = super().forward(
            input_ids,
            attention_mask,
            position_ids,
            past_key_values,
            inputs_embeds,
            labels,
            use_cache,
            output_attentions,
            True,
            return_dict,
        )

        if self.model_args.cls_weight > 0.0:
            hidden_states = outputs.hidden_states[-1]
            cls_logits = self.score(hidden_states)

            if input_ids is not None:
                batch_size = input_ids.shape[0]
            else:
                batch_size = inputs_embeds.shape[0]

            if self.config.pad_token_id is None and batch_size != 1:
                raise ValueError(
                    "Cannot handle batch sizes > 1 if no padding token is defined."
                )
            if self.config.pad_token_id is None:
                sequence_lengths = -1
            else:
                if input_ids is not None:
                    sequence_lengths = (
                        torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
                    ).to(cls_logits.device)
                else:
                    sequence_lengths = -1

            pooled_logits = cls_logits[
                torch.arange(batch_size, device=cls_logits.device), sequence_lengths
            ]

            loss = None
            if reward is not None:
                reward = reward.to(cls_logits.device)
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(
                    pooled_logits.view(-1, self.num_labels), reward.view(-1)
                )

                loss = outputs.loss + self.model_args.cls_weight * loss
                if dist.get_rank() == 0:
                    print({"lm": outputs.loss.item(), "cls": loss.item()})
        else:
            loss = outputs.loss
            pooled_logits = None

        return ClsCausalLMOutputWithPast(
            loss=loss,
            logits=outputs.logits,
            cls_logits=pooled_logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
