"""Utils for model executor."""
import random
from typing import Any, Dict, Optional

import numpy as np
import torch

from dataclasses import dataclass
from torch import Tensor
@dataclass
class HiddenStatesWithEmbedding:
    last_hidden_states: Tensor
    embedding: Tensor

@dataclass
class LogitsWithEmbedding:
    logits: Tensor
    embedding: Tensor

def set_random_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def set_weight_attrs(
    weight: torch.Tensor,
    weight_attrs: Optional[Dict[str, Any]],
):
    """Set attributes on a weight tensor.

    This method is used to set attributes on a weight tensor. This method
    will not overwrite existing attributes.

    Args:
        weight: The weight tensor.
        weight_attrs: A dictionary of attributes to set on the weight tensor.
    """
    if weight_attrs is None:
        return
    for key, value in weight_attrs.items():
        assert not hasattr(
            weight, key), (f"Overwriting existing tensor attribute: {key}")
        setattr(weight, key, value)
