import copy
from abc import abstractmethod
from dataclasses import dataclass
from typing import Iterable, List, Optional

import torch
from tensordict import NestedKey, TensorDict, TensorDictBase
from tensordict.nn import TensorDictModule, TensorDictSequential

from tensordict.nn.common import TensorDictModuleBase
from torch.distributions.dirichlet import _Dirichlet
from torchrl.envs import EnvBase
from torchrl.envs.utils import exploration_type, ExplorationType, set_exploration_type
from torchrl.objectives.value import TDLambdaEstimator, ValueEstimatorBase
from torchrl.objectives.value.functional import reward2go

from .mcts_node import MCTSNode


@dataclass
class AlphaZeroConfig:
    num_simulations: int = 100
    simulation_max_steps: int = 20  # 30 for chess
    max_steps: int = 20
    c_puct: float = 1.0
    b_puct: float = 0.1
    dirichlet_alpha: float | None = 0.03
    inject_noise: bool = True
    use_value_network: bool = False
    reutilize_tree: bool = False


class UcbSelectionPolicy(TensorDictModuleBase):
    """
    A policy to select an action in every node in the tree using UCB estimation.
    See Section 2.6 Upper-Confidence-Bound Action Selection on
    Sutton, Richard S., and Andrew G. Barto. 2018. “Reinforcement Learning: An Introduction (Second Edition).”
    http://incompleteideas.net/book/RLbook2020.pdf

    Args:
        action_value_key: The input key representing the mean of Q(s, a) for every action `a` at state `s`.
            Defaults to ``action_value``
        action_count_key: The input key representing the number of times action `a` is selected at state `s`.
          Defaults to ``action_count``
        action_value_under_uncertainty_key: The output key representing estimated action value.
    """

    def __init__(
        self,
        cucb: float = 2.0,
        action_value_under_uncertainty_key: NestedKey = "action_value_under_uncertainty",
        action_value_key: NestedKey = "action_value",
        action_count_key: NestedKey = "action_count",
    ):
        self.in_keys = [action_value_key, action_count_key]  # type: ignore
        self.out_keys = [action_value_under_uncertainty_key]  # type: ignore
        super().__init__()
        self.cucb = cucb
        self.action_value_key = action_value_key
        self.action_count_key = action_count_key
        self.action_value_under_uncertainty_key = action_value_under_uncertainty_key

    def forward(self, node: TensorDictBase) -> torch.Tensor:
        node = node.clone(False)
        x_hat = node[self.action_value_key]
        n_sa = node[self.action_count_key]
        mask = n_sa != 0
        n = torch.sum(n_sa)
        optimism_estimation = x_hat.clone()
        optimism_estimation[mask] = x_hat[mask] + self.cucb * torch.sqrt(
            torch.log(n) / n_sa[mask]
        )
        node[self.action_value_under_uncertainty_key] = optimism_estimation
        return node  # type: ignore


class ActionExplorationModule(TensorDictModuleBase):
    def __init__(
        self,
        action_key: str = "action",
    ):
        super().__init__()
        self.action_key = action_key
        self.top_k = 16

    def forward(self, node: MCTSNode) -> TensorDictBase:
        tensordict = node["state"].clone(False)

        if exploration_type() == ExplorationType.RANDOM or exploration_type() is None:
            # if len(self.node["children"].keys()) == 0:
            #     tensordict = tensordict.expand(self.tok_k)
            #     tensordict[self.action_key] = self.explore_action(node)
            # else:
            tensordict[self.action_key] = self.explore_action(node)
        elif exploration_type() == ExplorationType.MODE:
            tensordict[self.action_key] = self.get_greedy_action(node)

        return tensordict

    def get_greedy_action(self, node: MCTSNode) -> torch.Tensor:
        # action = torch.argmax(node["children_visits"])
        # action = torch.argmax(node["children_values"] + node["children_visits"])

        # max_value = torch.max(node["children_values"])
        # action = torch.argmax(
        #     node["children_visits"] * (node["children_values"] == max_value)
        # )

        # max_value = torch.max(node["children_visits"])
        # action = torch.argmax(
        #     node["children_values"] * (node["children_visits"] == max_value)
        # )

        max_value = torch.max(node["children_rewards"])
        action = torch.argmax(
            node["children_values"] * (node["children_rewards"] == max_value)
        )

        return action

    def explore_action(self, node: MCTSNode) -> torch.Tensor:
        action_score = node["score"]

        max_value = torch.max(action_score)
        action = torch.argmax(
            torch.rand_like(action_score) * (action_score == max_value)
        )
        # return torch.nn.functional.one_hot(action, action_value.shape[-1])
        return action

    def batch_explore_action(self, node: MCTSNode) -> torch.Tensor:
        action_score = node["score"]

        _, actions = torch.topk(action_score, self.top_k, dim=-1)

        return actions.unsqueeze(-1)

    def set_node(self, node: MCTSNode) -> None:
        self.node = node


class BatchedActionExplorationModule(ActionExplorationModule):
    def __init__(
        self,
        action_key: str = "action",
        batch_size: int = 1,
    ):
        super().__init__()
        self.batch_size = batch_size
        self.action_key = action_key

    def forward(self, node: MCTSNode) -> TensorDictBase:
        tensordict = node["state"].clone(False)

        action_score = node["score"]
        _, actions = torch.topk(action_score, self.batch_size, dim=-1)

        tensordict = tensordict.expand(self.batch_size)
        tensordict[self.action_key] = actions.unsqueeze(-1)

        return tensordict

    def set_node(self, node: MCTSNode) -> None:
        self.node = node


class UpdateTreeStrategy:
    """
    The strategy to update tree after each rollout. This class uses the given value estimator
    to compute a target value after each roll out and compute the mean of target values in the tree.

    It also updates the number of time nodes get visited in tree.

    Args:
        tree: A TensorDictMap that store stats of the tree.
        value_estimator: A ValueEstimatorBase that compute target value.
        action_key: A key in the rollout TensorDict to store the selected action.
        action_value_key: A key in the tree nodes that stores the mean of Q(s, a).
        action_count_key: A key in the tree nodes that stores the number of times nodes get visited.
    """

    # noinspection PyTypeChecker
    def __init__(
        self,
        value_network: TensorDictModuleBase,
        action_key: NestedKey = "action",
        use_value_network: bool = True,
        # value_estimator: Optional[ValueEstimatorBase] = None,
    ):
        self.action_key = action_key
        self.value_network = value_network
        self.root: MCTSNode
        self.use_value_network = use_value_network
        # self.value_estimator = value_estimator or self.get_default_value_network(root)

    def update(self, rollout: TensorDictBase) -> None:
        target_value = torch.zeros(rollout.batch_size[-1] + 1, dtype=torch.float32)
        target_reward = torch.zeros(rollout.batch_size[-1] + 1, dtype=torch.float32)
        
        done = torch.zeros_like(target_value, dtype=torch.bool)
        done[-1] = True

        if rollout[-1][("next", "done")]:
            target_reward[-1] = rollout[-1][("next", "reward")]
            target_value[-1] = rollout[-1][("next", "reward")]

        if self.use_value_network:
            target_value[-1] = self.value_network(rollout[-1]["next"])[
                "state_value"
            ]
        else:
            target_value[-1] = 0

        target_value = reward2go(target_value, done, gamma=0.95, time_dim=-1)
        target_reward = reward2go(target_reward, done, gamma=0.95, time_dim=-1)
        node = self.root
        for idx in range(rollout.batch_size[-1]):
            action = rollout[self.action_key][idx]
            node = node.get_child(action)
            node.value = (node.value * node.visits + target_value[idx]) / (
                node.visits + 1
            )
            node.reward = (node.reward * node.visits + target_reward[idx]) / (
                node.visits + 1
            )
            node.visits += 1

    def start_simulation(self, root_node=None, device=None) -> None:
        if root_node is not None:
            self.root = root_node
        else:
            self.root = MCTSNode.root(device)


class ExpansionStrategy(TensorDictModuleBase):
    """
    The rollout policy in expanding tree.
    This policy will use to initialize a node when it gets expanded at the first time.
    """

    def __init__(self):
        super().__init__()

    def forward(self, node: MCTSNode) -> TensorDictBase:
        """
        The node to be expanded. The output Tensordict will be used in future
        to select action.
        Args:
            tensordict: The state that need to be explored

        Returns:
            A initialized statistics to select actions in the future.
        """

        if not node.expanded:
            self.expand(node)

        return node

    @abstractmethod
    def expand(self, node: MCTSNode) -> None:
        pass

    def set_node(self, node: MCTSNode) -> None:
        self.node = node


class BatchedLeafExpansionStrategy(ExpansionStrategy):
    """
    An implementation of Alpha Zero to initialize a node at its first time.

    Args:
            value_module: a TensorDictModule to initialize a prior for Q(s, a)
            module_action_value_key: a key in the output of value_module that contains Q(s, a) values
    """

    def __init__(
        self,
        policy_module: TensorDictModule,
        action_count_key: str = "children_visits",
        action_value_key: str = "children_values",
        action_reward_key: str = "children_rewards",
        module_action_value_key: str = "action_value",
        prior_action_value_key: str = "children_priors",
    ):
        super().__init__()
        assert module_action_value_key in policy_module.out_keys
        self.policy_module = policy_module
        self.action_value_key = module_action_value_key
        self.q_sa_key = action_value_key
        self.p_sa_key = prior_action_value_key
        self.n_sa_key = action_count_key
        self.r_sa_key = action_reward_key

    def expand(self, node: MCTSNode) -> None:
        policy_netword_td = node["state"].select(*self.policy_module.in_keys)
        policy_netword_td = self.policy_module(policy_netword_td)
        p_sa = policy_netword_td[self.action_value_key]
        node.set(self.p_sa_key, p_sa)  # prior_action_value
        node.set(self.q_sa_key, torch.zeros_like(p_sa))  # action_value
        node.set(self.n_sa_key, torch.zeros_like(p_sa))  # action_count
        node.set(self.r_sa_key, torch.zeros_like(p_sa))  # action_reward
        # node["truncated"] = torch.ones_like(node["state", "truncated"])


class AlphaZeroExpansionStrategy(ExpansionStrategy):
    """
    An implementation of Alpha Zero to initialize a node at its first time.

    Args:
            policy_module: a TensorDictModule to initialize a prior for Q(s, a)
            module_action_value_key: a key in the output of value_module that contains Q(s, a) values
    """

    def __init__(
        self,
        policy_module: TensorDictModule,
        module_action_value_key: str = "action_value",
    ):
        super().__init__()
        assert module_action_value_key in policy_module.out_keys
        self.policy_module = policy_module
        self.module_action_value_key = module_action_value_key

    def expand(self, node: MCTSNode) -> None:
        policy_netword_td = node["state"].select(*self.policy_module.in_keys)
        policy_netword_td = self.policy_module(policy_netword_td)
        p_sa = policy_netword_td[self.module_action_value_key]
        node.set("children_priors", p_sa)  # prior_action_value
        node.set("children_values", torch.zeros_like(p_sa))  # action_value
        node.set("children_visits", torch.zeros_like(p_sa))  # action_count
        node.set("children_rewards", torch.zeros_like(p_sa))  # action_reward
        # node["truncated"][0] = True


class PuctSelectionPolicy(TensorDictModuleBase):
    """
    The optimism under uncertainty estimation computed by the PUCT formula in AlphaZero paper:
    https://discovery.ucl.ac.uk/id/eprint/10069050/1/alphazero_preprint.pdf

    Args:
        cpuct: A constant to control exploration
        action_value_key: an input key, representing the mean of Q(s, a) for every action `a` at state `s`.
        prior_action_value_key: an input key, representing the prior of Q(s, a) for every action `a` at state `s`.
        action_count_key: an input key, representing the number of times action `a` is selected at state `s`.
        action_value_under_uncertainty_key: an output key, representing the output estimate value using PUCT

    """

    def __init__(
        self,
        cpuct: float = 1.0,
        bpuct: float = 0.1,
    ):
        super().__init__()
        self.cpuct = cpuct
        self.bpuct = bpuct
        self.node: MCTSNode

    def forward(self, node: MCTSNode) -> TensorDictBase:
        # we will always add 1, to avoid zero U values in the first visit of the node. See:
        # https://ai.stackexchange.com/questions/25451/how-does-alphazeros-mcts-work-when-starting-from-the-root-node
        # for a discussion on this topic.
        # TODO: investigate MuZero paper, AlphaZero paper and Bandit based monte-carlo planning to understand what
        # is the right implementation. Also check this discussion:
        # https://groups.google.com/g/computer-go-archive/c/K9XHb64JSqU
        n = torch.sum(node["children_visits"], dim=-1) + 1
        u_sa = (
            self.cpuct
            * node["children_priors"]
            * torch.sqrt(n)
            / (1 + node["children_visits"])
        )

        optimism_estimation = node["children_rewards"] + self.bpuct*node["children_values"] + u_sa
        node["score"] = optimism_estimation

        return node


class DirichletNoiseModule(TensorDictModuleBase):
    def __init__(
        self,
        alpha: float = 0.3,
        epsilon: float = 0.25,
        prior_action_value_key: str = "children_priors",
    ):
        self.in_keys = [prior_action_value_key]
        self.out_keys = [prior_action_value_key]
        super().__init__()
        self.alpha = alpha
        self.epsilon = epsilon
        self.prior_action_value_key = prior_action_value_key

    def forward(self, node: MCTSNode) -> TensorDictBase:
        p_sa = node[self.prior_action_value_key]
        if p_sa.device.type == "mps":
            device = p_sa.device
            noise = _Dirichlet.apply(self.alpha * torch.ones_like(p_sa).cpu())
            noise = noise.to(device)  # type: ignore
        else:
            noise = _Dirichlet.apply(self.alpha * torch.ones_like(p_sa))
        # print(p_sa.device, noise.device)
        p_sa = (1 - self.epsilon) * p_sa + self.epsilon * noise  # type: ignore
        node[self.prior_action_value_key] = p_sa
        return node


class MctsPolicy(TensorDictModuleBase):
    """
    An implementation of MCTS algorithm.

    Args:
        expansion_strategy: a policy to initialize stats of a node at its first visit.
        selection_strategy: a policy to select action in each state
        exploration_strategy: a policy to exploration vs exploitation
    """

    def __init__(
        self,
        expansion_strategy: ExpansionStrategy,
        selection_strategy: TensorDictModuleBase = PuctSelectionPolicy(),
        exploration_strategy: ActionExplorationModule = ActionExplorationModule(),
        batch_size: int = 1,
    ):
        super().__init__()
        self.expansion_strategy = expansion_strategy
        self.selection_strategy = selection_strategy
        self.exploration_strategy = exploration_strategy
        self.node: MCTSNode
        self.batch_size = batch_size

    def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
        if not self.node.expanded:
            self.node["state"] = tensordict
        self.expansion_strategy.forward(self.node)
        self.selection_strategy.forward(self.node)
        tensordict = self.exploration_strategy.forward(self.node)

        batched_nodes = []
        if self.batch_size > 1:
            for i in range(self.batch_size):
                node = self.node[i]
                if not tensordict[i]["terminated"]:
                    node = node.get_child(tensordict[i]["action"])
                batched_nodes.append(node)
            self.set_node(torch.stack(batched_nodes))  # type: ignore
        else:
            self.set_node(self.node.get_child(tensordict["action"]))

        return tensordict

    def set_node(self, node: MCTSNode) -> None:
        self.node = node


class SimulatedSearchPolicy(TensorDictModuleBase):
    """
    A simulated search policy. In each step, it simulates `n` rollout of maximum steps of `max_simulation_steps`
    using the given policy and then choose the best action given the simulation results.

    Args:
        policy: a policy to select action in each simulation rollout.
        env: an environment to simulate a rollout
        num_simulation: the number of simulation
        max_simulation_steps: the max steps of each simulated rollout
        max_steps: the max steps performed by SimulatedSearchPolicy
        noise_module: a module to inject noise in the root node for exploration

    """

    def __init__(
        self,
        policy: MctsPolicy,
        tree_updater: UpdateTreeStrategy,
        env: EnvBase,
        num_simulations: int,
        simulation_max_steps: int,
        max_steps: int,
        noise_module: Optional[DirichletNoiseModule] = DirichletNoiseModule(),
        reutilize_tree: bool = True,
    ):
        super().__init__()
        self.policy = policy
        self.tree_updater = tree_updater
        self.env = env
        self.num_simulations = num_simulations
        self.simulation_max_steps = simulation_max_steps
        self.max_steps = max_steps
        self.noise_module = noise_module
        self.reutilize_tree = reutilize_tree
        self.root_priors = torch.tensor([])
        self.root_list = []
        self.init_state: torch.Tensor
        self._steps = 0
        self.new_root = None
        self.top_k = 1

    def forward(self, tensordict: TensorDictBase):
        tensordict = tensordict.clone(False)
        self._steps += 1

        with torch.no_grad():
            self.start_simulation(tensordict)

            with set_exploration_type(ExplorationType.RANDOM):
                for i in range(self.num_simulations):
                    self.simulate()

            with set_exploration_type(ExplorationType.MODE):
                root = self.tree_updater.root
                tensordict = self.policy.exploration_strategy(root)
                self.root_list.append(root)

                if self.reutilize_tree:
                    self.new_root = root.get_child(tensordict["action"])

            # This can be achieved with step counter
            if self._steps == self.max_steps:
                tensordict["truncated"] = torch.ones_like(tensordict["truncated"])

            return tensordict

    def simulate(self) -> None:
        self.reset_simulation()

        rollout = self.env.rollout(
            max_steps=self.simulation_max_steps,
            policy=self.policy,
            return_contiguous=False,
        )
        # the tensordict will be expanded need to make sure it will work
        # Maybe make the batched_step the standard step method
        # Maybe still bad because previous states are not expanded
        # Maybe we expand here and then batch_stepkyutf

        # if not rollout[-1]["next", "terminated"] and self.top_k > 1:
        #     last_state = rollout[-1]["next"]
        #     last_state["terminated"] = torch.zeros_like(last_state["terminated"])
        #     last_state["done"] = torch.zeros_like(last_state["done"])
        #     last_state("truncated") = torch.zeros_like(last_state["truncated"])
        #     last_state = last_state.expand(self.top_k)
        #     next_states = self.env.batch_step(last_state)

        # Resets the environment to the original state # type: ignore
        self.env.set_state(self.init_state.clone(True))  # type: ignore

        # update the nodes visited during the simulation
        self.tree_updater.update(rollout)  # type: ignore

    def start_simulation(self, tensordict) -> None:
        # creates new root node for the MCTS tree
        if self.new_root is not None:
            self.tree_updater.start_simulation(root_node=self.new_root)
        else:
            self.tree_updater.start_simulation(device=tensordict.device)

        # make a copy of the initial state
        self.init_state = self.env.copy_state()

        # initialize and expand the root
        self.tree_updater.root["state"] = tensordict
        self.policy.expansion_strategy(self.tree_updater.root)

        # inject dirichlet noise for exploration
        if self.noise_module is not None:
            self.noise_module(self.tree_updater.root)

        # save the root priors for resetting the in case of injecting noise
        # self.root_priors = self.tree_updater.root["children_priors"]

    def reset_simulation(self) -> None:
        # reset the root's priors
        # self.tree_updater.root["children_priors"] = self.root_priors

        # reset the policy node to the root
        if self.policy.batch_size > 1:
            self.policy.set_node(self.tree_updater.root.expand(self.policy.batch_size))
        else:
            self.policy.set_node(self.tree_updater.root)
