from __future__ import annotations

from collections.abc import MutableMapping
from typing import List, Mapping, Sequence, TypeVar, Union

import torch

from numpy import shape

from tensordict import NestedKey, tensorclass, TensorDict, TensorDictBase


class MCTSNode(TensorDict):
    def __init__(
        self, action: int | torch.Tensor, parent: MCTSNode | None, device=None
    ):
        super().__init__(
            {
                "children_values": torch.tensor([]),
                "children_priors": torch.tensor([]),
                "children_visits": torch.tensor([]),
                "children_rewards": torch.tensor([]),
                "score": torch.tensor([]),
                "children": TensorDict({}, batch_size=[], device=device),
                "state": TensorDict({}, batch_size=[], device=device),
                "terminated": torch.tensor([False]),
            },  # type: ignore
            batch_size=[],
            device=device,
        )
        self.prior_action: int | torch.Tensor = action
        self.parent: MCTSNode | None = parent

    @property
    def visits(self) -> torch.Tensor:
        assert self.parent != None
        return self.parent["children_visits"][self.prior_action]

    @visits.setter
    def visits(self, x) -> None:
        assert self.parent != None
        self.parent["children_visits"][self.prior_action] = x

    @property
    def value(self) -> torch.Tensor:
        assert self.parent != None
        return self.parent["children_values"][self.prior_action]

    @value.setter
    def value(self, x) -> None:
        assert self.parent != None
        self.parent["children_values"][self.prior_action] = x

    @property
    def reward(self) -> torch.Tensor:
        assert self.parent != None
        return self.parent["children_rewards"][self.prior_action]

    @value.setter
    def reward(self, x) -> None:
        assert self.parent != None
        self.parent["children_rewards"][self.prior_action] = x

    @property
    def expanded(self) -> bool:
        return self["children_priors"].numel() > 0

    def get_child(self, action: int | torch.Tensor) -> MCTSNode:
        action_str = str(torch.sym_int(action))
        if action_str not in self["children"].keys(leaves_only=True):
            self["children"][action_str] = MCTSNode(action, self, self.device)
        return self["children"][action_str]

    @classmethod
    def root(cls, device=None) -> MCTSNode:
        return cls(-1, None, device)