import torch
from dataclasses import dataclass, field


@dataclass
class EarlyExit:
    num_texts: int
    device: torch.device
    vocab_size: int
    num_tokens: int
    patience_threshold: float = field(default=None)
    entropy_threshold: float = field(default=None)
    kl_threshold: float = field(default=None)

    def __post_init__(self,):
        self.patience_history = []
        self.entropy_history = []
        self.kl_history = []
        self.prev_log_probs = None
        self.prev_patience = torch.zeros(self.num_texts, device=self.device, dtype=torch.long)
        self.exit_mask = torch.zeros(self.num_texts, device=self.device, dtype=torch.bool),

    @staticmethod
    def _update_patience(prev_patience, log_probs, prev_log_probs):
        prev_argmax = prev_log_probs.argmax(-1)
        argmax = log_probs.argmax(-1)
        continue_mask = (argmax == prev_argmax).sum(-1) == argmax.shape[1]
        patience = (prev_patience + continue_mask) * continue_mask
        return patience

    @staticmethod
    def _update_entropy(probs, log_probs):
        entropy = -(probs * log_probs).sum(-1).mean(1)
        return entropy

    @staticmethod
    def _update_kl(log_probs, prev_log_probs):
        kl = torch.nn.functional.kl_div(
                input=prev_log_probs, target=log_probs, log_target=True, 
                reduction="none").sum(-1).mean(-1)
        
        return kl

    def update_state(self, logits):
        log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
        probs = torch.nn.functional.softmax(logits, dim=-1)
        
        if self.prev_log_probs is None:
          self.prev_log_probs = log_probs
          return
        
        patience = self._update_patience(self.prev_patience, log_probs, self.prev_log_probs)
        entropy = self._update_entropy(probs, log_probs)
        kl = self._update_kl(log_probs, self.prev_log_probs)

        self.patience_history.append(patience.cpu().numpy())
        self.entropy_history.append(entropy.cpu().numpy())
        self.kl_history.append(kl.cpu().numpy())
        self.prev_patience = patience
        self.prev_log_probs = log_probs

    # ! this function was never tested or used
    def update_exit_mask(self, strategy: str): 
        new_exit_mask = torch.zeros_like(self.exit_mask, dtype=torch.bool)[
            ~self.exit_mask
        ]
        if strategy == "kl":
            new_exit_mask = self.kl_history[-1] < self.kl_threshold
        elif strategy == "entropy":
            new_exit_mask = self.entropy_history[-1] < self.entropy_threshold
        elif strategy == "patience":
            new_exit_mask = self.patience_history[-1] > self.patience_threshold

        return new_exit_mask
    
    def get_exit_steps(strategy, history, criterion, device):
        exit_mask = torch.zeros(len(history[0]), device=device, dtype=torch.bool)

        exit_steps = torch.ones(len(history), device=device, dtype=torch.int)*len(history[0])
        for step, values in enumerate(history):
            new_exit = criterion(values[~exit_mask])
            exit_mask[~exit_mask] = criterion(values[~exit_mask])
            exit_steps[new_exit] = torch.ones(len(exit_steps[new_exit]), device=device, dtype=torch.int)*step
        
        return exit_steps

    def get_history(self):
      return {
          "entropy": self.entropy_history,
          "kl": self.kl_history,
          "patience": self.patience_history,
      }