import dataclasses
from typing import Dict

import torch


def cat_if_not_none(a, b):
    if a is None or b is None:
        return None
    else:
        return torch.cat((a, b), dim=0)


@dataclasses.dataclass
class State:
    ids: torch.Tensor
    target: torch.Tensor
    xentropy: torch.Tensor
    final_token: torch.Tensor
    token_grads: torch.Tensor
    extra: Dict[str, torch.Tensor]

    def cat(self, state2):
        return State(
            ids=torch.cat((self.ids, state2.ids), dim=0),
            target=torch.cat((self.target, state2.target), dim=0),
            xentropy=torch.cat((self.xentropy, state2.xentropy), dim=0),
            final_token=torch.cat((self.final_token, state2.final_token), dim=0),
            token_grads=cat_if_not_none(self.token_grads, state2.token_grads),
            extra={
                k: cat_if_not_none(self.extra[k], state2.extra[k]) for k in self.extra
            },
        )

    def subset(self, keep):
        return State(
            ids=self.ids[keep],
            target=self.target[keep],
            xentropy=self.xentropy[keep],
            final_token=self.final_token[keep],
            token_grads=(
                self.token_grads[keep] if self.token_grads is not None else None
            ),
            extra={k: self.extra[k][keep] for k in self.extra},
        )

    def update_from(self, other_state):
        """
        Update this state in-place with values from another state.
        Returns self for method chaining.
        """
        self.ids = other_state.ids
        self.target = other_state.target
        self.xentropy = other_state.xentropy
        self.final_token = other_state.final_token
        self.token_grads = other_state.token_grads
        self.extra = other_state.extra
        return self
