#
import os
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import torch
from torch import nn

from transformers.modeling_outputs import ModelOutput
from transformers import DynamicCache


@dataclass
class CausalLMOutputWithCrossAttentionsAndValue(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    values: torch.FloatTensor = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None


@dataclass
class CausalLMOutputWithPastAndValue(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    values: torch.FloatTensor = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None


@dataclass
class GenerationOutput:
    query: str  # prompt input for LM
    response: str  # response from LM
    values: list
    prob: float
    log_probs: list = None


@dataclass
class GenerationOutputWithCache:
    query: str  # prompt input for LM
    response: str  # response from LM
    values: list
    prompt_cache: DynamicCache
