import collections
import dataclasses
import functools
import gc
import typing as tp

import numpy as np
import numpy.typing as npt
import torch
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
import tree


class ModelForwardOutputs(tp.TypedDict):
  logits: torch.Tensor
  attentions: tuple[torch.Tensor, ...] | None
  hidden_states: tuple[torch.Tensor, ...] | None


def compute_token_logprobs(
  input_ids: torch.Tensor,
  logits: torch.Tensor,
  temperature: float = 1,
):
  scaled_logits = logits / temperature
  logprobs = torch.nn.functional.log_softmax(scaled_logits, dim=-1)
  return torch.gather(logprobs, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)


def score_forward_pass(
  input_ids: torch.Tensor,
  attention_mask: torch.Tensor,
  model: PreTrainedModel,
  pad_token_id: int,
  output_attentions: bool = False,
  output_hidden_states: bool = False,
) -> ModelForwardOutputs:
  pad_tokens = (
    torch.ones((input_ids.shape[0], 1), dtype=input_ids.dtype, device=input_ids.device)
    * pad_token_id
  )
  padded_input_ids = torch.cat([pad_tokens, input_ids], dim=1)
  # [batch_size, input_len + 1]
  padded_attention_mask = torch.cat(
    [torch.zeros_like(pad_tokens).bool(), attention_mask], dim=1
  )
  # print(f"{padded_attention_mask.sum()=}")
  # print(f"{padded_attention_mask.shape=}, {padded_input_ids.shape=}")
  with torch.inference_mode():
    raw_outputs = model(
      input_ids=padded_input_ids,
      attention_mask=padded_attention_mask,
      output_attentions=output_attentions,
      output_hidden_states=output_hidden_states,
      return_dict=True,
    )
    # [batch_size, input_len + 1, vocab_size]

    outputs: ModelForwardOutputs = {
      "logits": raw_outputs.logits[
        :, :-1, :
      ].cpu(),  # at position n, we have logits for position n + 1
      "attentions": None,
      "hidden_states": None,
    }
    # [batch_size, input_len, vocab_size]
    if "attentions" in raw_outputs:
      attentions = raw_outputs["attentions"]
      # (layer_idx, batch_size, num_heads, input_len + 1, input_len + 1)
      # At position n in the input, we have attention used to predict position n + 1
      attentions: tuple[torch.Tensor] = tuple(
        layer_attn[:, :, :-1, 1:].cpu() for layer_attn in attentions
      )  # type: ignore
      # (layer_idx, batch_size, num_heads, input_len, input_len)
      outputs["attentions"] = attentions

    if "hidden_states" in raw_outputs:
      hidden_states = raw_outputs["hidden_states"]
      # (layer_idx, batch_size, input_len + 1, hidden_size)
      # At position n in the input, we have hidden_states used to predict position n + 1
      hidden_states = tuple(
        layer_hidden_states[:, :-1, :].cpu() for layer_hidden_states in hidden_states
      )
      # (layer_idx, batch_size, input_len, hidden_size)
      outputs["hidden_states"] = hidden_states

  return outputs


@dataclasses.dataclass
class ProcessedModelOutputs:
  token_ids: npt.NDArray[np.int32]
  token_logprobs: npt.NDArray[np.float32]
  attentions: tp.Sequence[npt.NDArray[np.float32]] | None = None
  hidden_states: tp.Sequence[npt.NDArray[np.float32]] | None = None
  logits: npt.NDArray[np.float32] | None = None


def postprocess_score_outputs(
  batch_input_ids: torch.Tensor,
  batch_logits: torch.Tensor,
  pad_token_id: int,
  padding_side: tp.Literal["left", "right"],
  batch_attention: tuple[torch.Tensor, ...] | None = None,
  batch_hidden_states: tuple[torch.Tensor, ...] | None = None,
) -> tp.Sequence[ProcessedModelOutputs]:
  """
  Postprocesses the score outputs by computing token probabilities, truncating and adding first token dummy probabilities,
  removing padding tokens, and splitting the batch into an array of examples.

  Args:
    batch_input_ids (torch.Tensor): The input IDs of the batch. Shape: (batch_size, input_len)
    batch_logits (torch.Tensor): The logits of the batch. Shape: (batch_size, input_len, num_classes)
    pad_token_id (int): The ID of the padding token.
    batch_attention (tuple[torch.Tensor, ...] | None, optional): The attention tensors of the batch.
      Shape: (num_layers, batch_size, num_heads, input_len, input_len). Defaults to None.
    batch_hidden_states (tuple[torch.Tensor, ...] | None, optional): The hidden states of the batch.
      Shape: (num_layers, batch_size, input_len, hidden_size). Defaults to None.

  Returns:
    tp.Sequence[Scores]: The postprocessed output entries, each containing token IDs, token log probabilities,
    attentions, hidden states, and logits.
  """
  # ---------------------------------------------------------------------------- #
  #     compute token probabilities, truncate and add first token dummy probs    #
  # ---------------------------------------------------------------------------- #
  batch_input_ids = batch_input_ids.cpu().to(dtype=torch.long)
  batch_logits = batch_logits.cpu().to(dtype=torch.float)
  batch_logprobs = (
    torch.nn.functional.log_softmax(batch_logits, dim=-1).cpu().to(dtype=torch.float)
  )
  # Get logits for the given token_ids

  batch_token_logprobs = (
    torch.gather(batch_logprobs, -1, batch_input_ids.unsqueeze(-1))
    .squeeze(-1)
    .cpu()
    .to(dtype=torch.float)
    .numpy()
  )  # [actual_batch_size, input_len]

  if batch_attention:
    batch_attention = tuple(
      layer_attn.cpu().to(dtype=torch.float32) for layer_attn in batch_attention
    )
  if batch_hidden_states:
    batch_hidden_states = tuple(
      layer_hidden_state.cpu().to(dtype=torch.float32)
      for layer_hidden_state in batch_hidden_states
    )

  # ---------------------------------------------------------------------------- #
  #                      Split batch into array of examples                      #
  # ---------------------------------------------------------------------------- #
  output_entries = []
  for i in range(batch_input_ids.shape[0]):
    # ---------------------------------------------------------------------------- #
    #                                    inputs                                    #
    # ---------------------------------------------------------------------------- #
    input_ids = batch_input_ids[i].numpy()
    if padding_side == "left":
      input_pad_end_idx = find_padding_end(input_ids, pad_token_id, padding_side="left")
      input_ids = input_ids[input_pad_end_idx:]
      token_logprobs = batch_token_logprobs[i, input_pad_end_idx:]
      logits = (
        batch_logits[i, input_pad_end_idx:, :].cpu().to(dtype=torch.float).numpy()
      )
    else:
      input_pad_end_idx = find_padding_end(
        input_ids, pad_token_id, padding_side="right"
      )
      input_ids = input_ids[: input_pad_end_idx + 1]
      token_logprobs = batch_token_logprobs[i, : input_pad_end_idx + 1]
      logits = (
        batch_logits[i, :input_pad_end_idx, :].cpu().to(dtype=torch.float).numpy()
      )
    # ---------------------------------------------------------------------------- #
    #                           logits and token logprobs                          #
    # ---------------------------------------------------------------------------- #
    attentions = None
    if batch_attention is not None:
      attentions = tuple(
        (
          layer_attn[
            i,
            :,
            input_pad_end_idx:,
            input_pad_end_idx:,
          ].numpy()
          if padding_side == "left"
          else layer_attn[
            i, :, : input_pad_end_idx + 1, : input_pad_end_idx + 1
          ].numpy()
        )
        for layer_attn in batch_attention
      )
    hidden_states = None
    if batch_hidden_states is not None:
      # (layer, batch_size, seq_len, hidden_size)
      hidden_states = tuple(
        (
          layer_hidden_state[i, input_pad_end_idx:, :].numpy()
          if padding_side == "left"
          else layer_hidden_state[i, : input_pad_end_idx + 1, :].numpy()
        )
        for layer_hidden_state in batch_hidden_states
      )
    output_entries.append(
      ProcessedModelOutputs(
        token_ids=input_ids,
        token_logprobs=token_logprobs,
        attentions=attentions,
        hidden_states=hidden_states,
        logits=logits,
      )
    )

  return output_entries


def find_padding_end(
  token_ids: tp.Sequence[int],
  pad_token_id: int,
  padding_side: tp.Literal["left", "right"] = "left",
):
  """
  Finds the index of the first non-padding token in the given sequence of token IDs.

  Parameters:
    token_ids (Sequence[int]): The sequence of token IDs.
    pad_token_id (int): The ID of the padding token.
    padding_side (Literal["left", "right"], optional): The side from which to start searching for the first non-padding token.
      Defaults to "left".

  Returns:
    int: The index of the first non-padding token.

  Examples:
    >>> token_ids = [0, 0, 0, 1, 2, 3, 0, 0]
    >>> pad_token_id = 0
    >>> left_pad = find_padding_end(token_ids, pad_token_id, padding_side="left")
    >>> left_pad
    3
    >>> right_pad = find_padding_end(token_ids, pad_token_id, padding_side="right")
    >>> right_pad
    5
    >>> token_ids[left_pad:]
    [1, 2, 3, 0, 0]
    >>> token_ids[:right_pad + 1]
    [0, 0, 0, 1, 2, 3]
    >>> token_ids[left_pad:right_pad + 1]
    [1, 2, 3]
  """
  enumerated_token_ids = list(enumerate(token_ids))
  if padding_side == "right":
    enumerated_token_ids = list(reversed(enumerated_token_ids))
  for i, token_id in enumerated_token_ids:
    if token_id != pad_token_id:
      return i
  return len(token_ids)


@dataclasses.dataclass
class ScoreOutputs:
  token_ids: npt.NDArray[np.int32]
  token_logprobs: npt.NDArray[np.float32]
  attentions: tp.Sequence[npt.NDArray[np.float32]] | None = None
  hidden_states: dict[int, npt.NDArray[np.float32]] | None = None
  logits: npt.NDArray[np.float32] | None = None


def extract_hidden_states(
  all_hidden_states: tp.Sequence[np.ndarray],
  hidden_states_to_return: tp.Sequence[int] | int | None = None,
):
  if isinstance(hidden_states_to_return, int):
    hidden_states_to_return = [hidden_states_to_return]
  if hidden_states_to_return is None:
    hidden_states_to_return = list(range(len(all_hidden_states)))
  hidden_states_to_return = [
    i if i >= 0 else len(all_hidden_states) + i for i in hidden_states_to_return
  ]

  filtered_hidden_states = []
  num_rows = all_hidden_states[0].shape[0]
  for row_idx in range(num_rows):
    hidden_states_for_row = {}
    for layer_idx in hidden_states_to_return:
      hidden_states_for_row[layer_idx] = all_hidden_states[layer_idx][row_idx]
    filtered_hidden_states.append(hidden_states_for_row)
  return filtered_hidden_states


def postprocess_batch_outputs2(
  batch_outputs, pad_token_id: int, padding_side: str = "left", to_numpy: bool = True
):
  """To be used with score_batch2"""
  input_lengths = (batch_outputs["input_ids"] != pad_token_id).sum(axis=-1)
  outputs = []
  _special_cols = ["hidden_states", "attentions"]
  for i in range(len(batch_outputs["input_ids"])):
    input_length = input_lengths[i]
    output_row = {
      k: v[i, -input_length:] if padding_side == "left" else v[i, :input_length]
      for k, v in batch_outputs.items()
      if k not in _special_cols
    }
    # output_row = {'token_ids': batch_outputs['input_ids'][i], 'hidden_states': hidden_states}
    if "hidden_states" in batch_outputs:
      hidden_states = batch_outputs["hidden_states"][i]
      if padding_side == "left":
        output_row["hidden_states"] = {
          k: v[-input_length:] for k, v in hidden_states.items()
        }
      else:
        output_row["hidden_states"] = {
          k: v[:input_length] for k, v in hidden_states.items()
        }
    if "attentions" in batch_outputs:
      raise NotImplementedError("attentions not supported currently")

    outputs.append(output_row)
  if to_numpy:
    outputs = tree.map_structure(
      lambda v: v.cpu().numpy() if isinstance(v, torch.Tensor) else v, outputs
    )
  return outputs


def score_batch2(
  model: PreTrainedModel,
  tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
  texts: list[str] | None = None,
  input_ids: list[np.ndarray] | torch.Tensor | None = None,
  attention_mask: torch.Tensor | None = None,
  output_attentions: bool = False,
  output_hidden_states: bool = False,
  hidden_states_to_return: int | list[int] | None = None,
):
  if texts is None and input_ids is None:
    raise ValueError("Either texts or input_ids must be provided.")
  if tokenizer.pad_token is None or tokenizer.pad_token_id is None:
    raise ValueError(
      "The given tokenizer does not have a padding token. Please set the padding token before scoring."
    )
  from transformers.tokenization_utils import BatchEncoding

  if isinstance(input_ids, torch.Tensor):
    if input_ids.device != model.device:
      input_ids = input_ids.to(model.device)

    if attention_mask is None:
      attention_mask = torch.ones_like(input_ids, device=input_ids.device)
    batch_encoding = BatchEncoding(
      {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
      }
    )
  elif texts is not None:
    batch_encoding = tokenizer(
      texts, return_tensors="pt", padding=True, truncation=True
    ).to(model.device)
  else:
    # input_ids is not None
    assert input_ids is not None
    batch_encoding = tokenizer.pad(
      [{"input_ids": x} for x in input_ids],  # type: ignore
      return_tensors="pt",
      padding="longest",
      return_attention_mask=True,
    ).to(model.device)

  score_outputs = score_forward_pass(
    input_ids=batch_encoding.input_ids,
    attention_mask=batch_encoding.attention_mask,
    model=model,
    pad_token_id=tokenizer.pad_token_id,
    output_attentions=output_attentions,
    output_hidden_states=output_hidden_states,
  )
  score_outputs = tree.map_structure(
    lambda v: v.cpu() if isinstance(v, torch.Tensor) else v, score_outputs
  )
  score_outputs["input_ids"] = batch_encoding.input_ids.cpu()
  if not output_attentions and "attentions" in score_outputs:
    score_outputs.pop("attentions")
  if not output_hidden_states and "hidden_states" in score_outputs:
    score_outputs.pop("hidden_states")

  if "hidden_states" in score_outputs:
    score_outputs["hidden_states"] = extract_hidden_states(
      score_outputs["hidden_states"], hidden_states_to_return
    )

  return tp.cast(dict, score_outputs)


def process_batch_outputs(
  batch: dict,
  padding_side="left",
  pad_token_id: int = 0,
):
  batch_input_ids = torch.tensor(batch["input_ids"])
  batch_logits = torch.tensor(batch["logits"])
  batch_attention = batch.get("attentions", None)
  batch_hidden_states = batch.get("hidden_states", None)

  batch_input_ids = batch_input_ids.to(dtype=torch.long)
  batch_logits = batch_logits.to(dtype=torch.float32)
  batch_logprobs = (
    torch.nn.functional.log_softmax(batch_logits, dim=-1).cpu().to(dtype=torch.float)
  )
  # Get logits for the given token_ids

  batch_token_logprobs = (
    torch.gather(batch_logprobs, -1, batch_input_ids.unsqueeze(-1))
    .squeeze(-1)
    .cpu()
    .to(dtype=torch.float)
    .numpy()
  )  # [actual_batch_size, input_len]

  if batch_attention:
    batch_attention = tuple(
      torch.tensor(layer_attn).to(dtype=torch.float32) for layer_attn in batch_attention
    )

  # ---------------------------------------------------------------------------- #
  #                      Split batch into array of examples                      #
  # ---------------------------------------------------------------------------- #
  output_entries = collections.defaultdict(list)
  for i in range(batch_input_ids.shape[0]):
    # ---------------------------------------------------------------------------- #
    #                                    inputs                                    #
    # ---------------------------------------------------------------------------- #
    input_ids = batch_input_ids[i].numpy()
    if padding_side == "left":
      input_pad_end_idx = find_padding_end(input_ids, pad_token_id, padding_side="left")
      input_ids = input_ids[input_pad_end_idx:]
      token_logprobs = batch_token_logprobs[i, input_pad_end_idx:]
      logits = batch_logits[i, input_pad_end_idx:, :].numpy()
    else:
      input_pad_end_idx = find_padding_end(
        input_ids, pad_token_id, padding_side="right"
      )
      input_ids = input_ids[: input_pad_end_idx + 1]
      token_logprobs = batch_token_logprobs[i, : input_pad_end_idx + 1]
      logits = batch_logits[i, :input_pad_end_idx, :].numpy()
    # ---------------------------------------------------------------------------- #
    #                           logits and token logprobs                          #
    # ---------------------------------------------------------------------------- #
    attentions = None
    if batch_attention is not None:
      attentions = tuple(
        (
          layer_attn[
            i,
            :,
            input_pad_end_idx:,
            input_pad_end_idx:,
          ].numpy()
          if padding_side == "left"
          else layer_attn[
            i, :, : input_pad_end_idx + 1, : input_pad_end_idx + 1
          ].numpy()
        )
        for layer_attn in batch_attention
      )

    # ---------------------------------------------------------------------------- #
    #                              hidden state stuff                              #
    # ---------------------------------------------------------------------------- #
    if batch_hidden_states is not None:
      # (layer, batch_size, seq_len, hidden_size)
      hidden_states = {
        layer_idx: layer_hidden_state[input_pad_end_idx:, :]
        if padding_side == "left"
        else layer_hidden_state[: input_pad_end_idx + 1, :]
        for layer_idx, layer_hidden_state in batch_hidden_states[i].items()
      }

      output_entries["hidden_states"].append(hidden_states)

    output_entries["token_ids"].append(input_ids)
    output_entries["token_logprobs"].append(token_logprobs)
    output_entries["attentions"].append(attentions)
    output_entries["logits"].append(logits)

  return output_entries


def score_batch(
  model: PreTrainedModel,
  tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
  texts: list[str] | None = None,
  input_ids: list[np.ndarray] | None = None,
  output_attentions: bool = False,
  output_hidden_states: bool = False,
  hidden_states_to_return: int | list[int] | None = None,
) -> list[ScoreOutputs]:
  if texts is None and input_ids is None:
    raise ValueError("Either texts or input_ids must be provided.")
  if tokenizer.pad_token is None or tokenizer.pad_token_id is None:
    raise ValueError(
      "The given tokenizer does not have a padding token. Please set the padding token before scoring."
    )
  if texts is not None:
    batch_encoding = tokenizer(
      texts, return_tensors="pt", padding=True, truncation=True
    ).to(model.device)
  else:
    # input_ids is not None
    assert input_ids is not None
    batch_encoding = tokenizer.pad(
      [{"input_ids": x} for x in input_ids],  # type: ignore
      return_tensors="pt",
      padding="longest",
      return_attention_mask=True,
    ).to(model.device)

  score_outputs = score_forward_pass(
    input_ids=batch_encoding.input_ids,
    attention_mask=batch_encoding.attention_mask,
    model=model,
    pad_token_id=tokenizer.pad_token_id,
    output_attentions=output_attentions,
    output_hidden_states=output_hidden_states,
  )

  processed_score_outputs = postprocess_score_outputs(
    batch_input_ids=batch_encoding.input_ids,
    batch_logits=score_outputs["logits"],
    pad_token_id=tokenizer.pad_token_id,
    padding_side="left",
    batch_attention=score_outputs.get("attentions", None),
    batch_hidden_states=score_outputs.get("hidden_states", None),
  )

  output_entries = []
  if isinstance(hidden_states_to_return, int):
    hidden_states_to_return = [hidden_states_to_return]

  for processed_score_output in processed_score_outputs:
    hidden_states = None
    if processed_score_output.hidden_states is not None:
      hidden_states = {
        layer_idx: hidden_state
        for layer_idx, hidden_state in enumerate(processed_score_output.hidden_states)
      }
      if hidden_states_to_return is not None:
        hidden_states_to_return = [
          layer_idx if layer_idx >= 0 else (len(hidden_states) + layer_idx)
          for layer_idx in hidden_states_to_return
        ]
        hidden_states = {
          layer_idx: hidden_state
          for layer_idx, hidden_state in hidden_states.items()
          if layer_idx in hidden_states_to_return
        }
    output_entries.append(
      ScoreOutputs(
        token_ids=processed_score_output.token_ids,
        token_logprobs=processed_score_output.token_logprobs,
        attentions=processed_score_output.attentions,
        hidden_states=hidden_states,
        logits=processed_score_output.logits,
      )
    )
  return output_entries


# ---------------------------------------------------------------------------- #
#                utilities for loading tensors from numpy files                #
# ---------------------------------------------------------------------------- #
def load_hidden_states(filename):
  x = np.load(filename)
  return {int(f): x[f] for f in x.files}


def load_hidden_states_single(filename):
  hidden_states = load_hidden_states(filename)

  if len(hidden_states) > 1:
    raise ValueError("More than one hidden state found in file")

  return hidden_states[next(iter(hidden_states))]
