from typing import Any, Dict, Tuple

import torch
from torch import Tensor
from transformers import PreTrainedModel


class IRunner:
    def setup(self, input_ids: Tensor) -> None:
        """
        Optional setup method called at the beginning of EPO.

        Args:
            input_ids: Initial token IDs used for optimization

        Returns:
            None
        """
        pass

    def run_with_embeddings(
        self,
        input_embeddings: Tensor,  # shape: [batch_size, seq_length, embedding_dim]
    ) -> Tuple[
        Tensor,  # shape: [batch_size]
        Tensor,  # shape: [batch_size, seq_length, vocab_size]
        Dict[str, Any],
    ]:
        """
        Run the model using the provided embedded input.

        Args:
            input_embeddings: Tensor of embedded input

        Returns:
            A tuple containing:
            - target: Tensor representing the target values (batch_size,)
            - logits: Tensor representing the model logits
            - optional_outputs: Dictionary containing any additional optional outputs from the model
        """
        raise NotImplementedError

    def one_hot_to_embed(
        self,
        one_hot: Tensor,  # shape: [batch_size, seq_length, vocab_size]
    ) -> Tensor:  # shape: [batch_size, seq_length, embedding_dim]
        """
        Convert one-hot encoded input to embedded input.
        """
        raise NotImplementedError

    def int_ids_to_embed(
        self,
        int_ids: Tensor,  # shape: [batch_size, seq_length]
    ) -> Tensor:  # shape: [batch_size, seq_length, embedding_dim]
        """
        Convert integer token IDs to embedded input.
        """
        raise NotImplementedError

    def calc_xentropy(
        self, logits, target_ids, nonfixed_positions: torch.Tensor = None
    ):
        logits_offset = logits[:, :-1]
        if nonfixed_positions is None:
            # Original behavior - use all positions
            return (
                torch.nn.CrossEntropyLoss(reduction="none")(
                    logits_offset.reshape(-1, logits_offset.shape[-1]),
                    target_ids[:, 1:].reshape(-1),
                )
                .view(*logits_offset.shape[:2])
                .mean(dim=-1)
            )
        else:
            # Only use specified valid positions
            nonfixed_positions_offset = nonfixed_positions[nonfixed_positions != 0] - 1

            return (
                torch.nn.CrossEntropyLoss(reduction="none")(
                    logits_offset[:, nonfixed_positions_offset].reshape(
                        -1, logits_offset.shape[-1]
                    ),
                    target_ids[:, 1:][:, nonfixed_positions_offset].reshape(-1),
                )
                .view(logits_offset.shape[0], -1)
                .mean(dim=-1)
            )


def transfomer_embed_one_hot_input(
    one_hot: Tensor,  # shape: [batch_size, seq_length, vocab_size]
    model: PreTrainedModel,
) -> Tensor:  # shape: [batch_size, seq_length, embedding_dim]
    embed_weight = model.get_input_embeddings().weight
    return torch.matmul(one_hot, embed_weight)
