from typing import Literal, Sequence
import torch
from torch.optim.lr_scheduler import LambdaLR
import numpy as np

AtomicGroup = Literal["O", "B", "S", "I"]
NDArray = torch.Tensor


class InverseStepScheduler(LambdaLR):
    def __init__(
        self,
        optimizer: torch.optim.Optimizer,
        transition_begin: int = 0,
        last_epoch: int = -1,
    ):
        self.transition_begin = transition_begin

        def lr_lambda(step: int) -> float:
            if self.transition_begin > 0:
                if step < self.transition_begin:
                    denominator = 1.0
                else:
                    denominator = float(step - self.transition_begin + 1)
            else:
                denominator = float(step + 1)

            return 1.0 / max(denominator, 1.0)

        super().__init__(optimizer, lr_lambda, last_epoch)


def to_numpy(value):
    if isinstance(value, torch.Tensor):
        if value.shape == ():
            value.detach().cpu().item()
        return value.detach().cpu().numpy()
    return np.array(value)


def flatten(value: list | tuple):
    if not isinstance(value, (list, tuple)):
        yield value
    for v in value:
        yield from flatten(v)


def to_dtype(values: Sequence[NDArray], dtype: torch.dtype | None = None):
    if dtype is None:
        return values
    return tuple([v.to(dtype) for v in values])
