from typing import List, NamedTuple, Tuple

import numpy as np
import torch
from torch.cuda.amp import autocast as autocast

from core.model import BaseNet, NetworkOutput
from core.config import BaseConfig
from core.mcts.ctree.ctree_sampled import cytree

class SearchOutput(NamedTuple):
    value: np.ndarray
    marginal_visit_count: np.ndarray
    marginal_priors: np.ndarray
    sampled_actions: List[np.ndarray]
    sampled_visit_count: List[np.ndarray]
    sampled_pred_probs: List[np.ndarray]
    sampled_beta: List[np.ndarray]
    sampled_beta_hat: List[np.ndarray]
    sampled_priors: List[np.ndarray]
    sampled_imp_ratio: List[np.ndarray]
    sampled_pred_values: List[np.ndarray]
    sampled_mcts_values: List[np.ndarray]
    sampled_rewards: List[np.ndarray]
    sampled_qvalues: List[np.ndarray]

    roots_adaptive_theta: np.ndarray

class SampledMCTS(object):
    def __init__(self, config: BaseConfig, np_random: np.random.RandomState = None):
        self.config = config
        self.np_random = np.random if np_random is None else np_random

    def batch_search(
        self,
        model: BaseNet,
        network_output: NetworkOutput,
        legal_actions_lst: np.ndarray = None,
        device: torch.device = None,
        add_noise: bool = False,
        sampled_tau: float = 1.0,
        sampled_actions_res: Tuple[np.ndarray, np.ndarray] = None,
        use_adaptive: bool = True,
    ) -> SearchOutput:
        """Create a batch of root nodes from network_output and do MCTS in parallel.
        """
        pb_c_base, pb_c_init, discount, rho, lam = self.config.pb_c_base, self.config.pb_c_init, self.config.discount, self.config.mcts_rho, self.config.mcts_lambda
        noise_alpha, noise_epsilon = self.config.root_dirichlet_alpha, self.config.root_exploration_fraction
        num_agents, action_space_size, sampled_times = self.config.num_agents, self.config.action_space_size, self.config.sampled_action_times
        batch_size = network_output.hidden_state.shape[0]

        batch_hidden_states = network_output.hidden_state
        batch_rewards = network_output.reward
        batch_values = network_output.value
        batch_policy_logits = network_output.policy_logits
        assert batch_values.shape == (batch_size, 1) and batch_policy_logits.shape == (batch_size, num_agents, action_space_size)

        batch_policy_probs = np.exp(batch_policy_logits - np.max(batch_policy_logits, axis=-1, keepdims=True))
        batch_policy_probs = batch_policy_probs / np.sum(batch_policy_probs, axis=-1, keepdims=True)

        noises = self.np_random.dirichlet([noise_alpha] * action_space_size, batch_size * num_agents).astype(np.float32).reshape(batch_size, num_agents, action_space_size)
        if not add_noise:
            noise_epsilon = 0.0

        if legal_actions_lst is not None:
            batch_policy_probs *= legal_actions_lst

            batch_policy_probs += legal_actions_lst * 1e-4
            assert ~(np.sum(batch_policy_probs, axis=-1) == 0).sum()
            batch_policy_probs = batch_policy_probs / np.sum(batch_policy_probs, axis=-1, keepdims=True)

            noises *= legal_actions_lst
            noises += legal_actions_lst * 1e-4
            noises = noises / np.sum(noises, axis=-1, keepdims=True)

        hidden_states_pool = [batch_hidden_states]

        trees = cytree.Tree_batch(batch_size, num_agents, action_space_size, sampled_times, self.config.num_simulations, self.config.tree_value_stat_delta_lb, self.np_random.choice(256), rho, lam)

        trees.set_use_adaptive(use_adaptive)

        if sampled_actions_res is None:
            batch_beta = batch_policy_probs * (1 - noise_epsilon) + noises * noise_epsilon
            batch_beta = batch_beta ** (1 / sampled_tau)

            if legal_actions_lst is not None:
                batch_beta *= legal_actions_lst

                assert ~(np.sum(batch_beta, axis=-1) == 0).sum()
            batch_beta = batch_beta / np.sum(batch_beta, axis=-1, keepdims=True)

            batch_rewards = batch_rewards.reshape(batch_size).astype(np.float32)
            batch_values = batch_values.reshape(batch_size).astype(np.float32)
            batch_policy_probs = batch_policy_probs.astype(np.float32)
            batch_beta = batch_beta.astype(np.float32)

            if use_adaptive:
                with torch.no_grad():
                    root_hypernet_params = model.get_hypernet_params(
                        batch_hidden_states if isinstance(batch_hidden_states, torch.Tensor)
                        else torch.from_numpy(batch_hidden_states).to(device)
                    ).cpu().numpy().astype(np.float32)
            else:

                root_hypernet_params = np.zeros((batch_size, num_agents, action_space_size), dtype=np.float32)

            trees.prepare(batch_rewards, batch_values, batch_policy_probs, batch_beta, sampled_times, noise_epsilon, noises, root_hypernet_params)

        else:
            raise NotImplementedError

        with torch.no_grad():
            model.eval()

            for index_simulation in range(self.config.num_simulations):
                batch_hidden_states = []

                hidden_state_index_x_lst, hidden_state_index_y_lst, batch_actions = \
                    trees.batch_selection(pb_c_base, pb_c_init, discount)

                for ix, iy in zip(hidden_state_index_x_lst, hidden_state_index_y_lst):
                    batch_hidden_states.append(hidden_states_pool[ix][iy])

                batch_hidden_states = torch.vstack(batch_hidden_states)
                batch_actions = torch.from_numpy(batch_actions).to(device)

                with autocast():
                    network_output = model.recurrent_inference(batch_hidden_states, batch_actions)

                batch_hidden_states = network_output.hidden_state
                batch_rewards = network_output.reward
                batch_values = network_output.value
                batch_policy_logits = network_output.policy_logits

                batch_policy_probs = np.exp(batch_policy_logits - np.max(batch_policy_logits, axis=-1, keepdims=True))
                batch_policy_probs = batch_policy_probs / np.sum(batch_policy_probs, axis=-1, keepdims=True)
                batch_beta = batch_policy_probs ** (1 / sampled_tau)
                batch_beta = batch_beta / np.sum(batch_beta, axis=-1, keepdims=True)

                hidden_states_pool.append(batch_hidden_states)

                if use_adaptive:
                    leaf_hypernet_params = model.get_hypernet_params(
                        batch_hidden_states
                    ).cpu().numpy().astype(np.float32)
                else:

                    leaf_hypernet_params = np.zeros((batch_size, num_agents, action_space_size), dtype=np.float32)

                batch_rewards = batch_rewards.reshape(batch_size).astype(np.float32)
                batch_values = batch_values.reshape(batch_size).astype(np.float32)
                batch_policy_probs = batch_policy_probs.astype(np.float32)
                batch_beta = batch_beta.astype(np.float32)
                trees.batch_expansion_and_backup(index_simulation + 1, discount, sampled_times,
                                                 batch_rewards, batch_values, batch_policy_probs, batch_beta, leaf_hypernet_params)

                if use_adaptive and index_simulation >= 2:
                    batch_hidden_states = []
                    hidden_state_index_x_lst, hidden_state_index_y_lst, actions, last_actions, last_actions_u, last_actions_v, last_actions_u_v = trees.adaptive_get_batch_inputs()

                    if len(hidden_state_index_x_lst) == 0:
                        continue

                    for ix, iy in zip(hidden_state_index_x_lst, hidden_state_index_y_lst):
                        batch_hidden_states.append(hidden_states_pool[ix][iy])
                    batch_hidden_states = torch.vstack(batch_hidden_states)

                    batch_hidden_states = torch.cat([batch_hidden_states, batch_hidden_states, batch_hidden_states, batch_hidden_states, batch_hidden_states], dim=0)
                    batch_size_adaptive = batch_hidden_states.shape[0]

                    actions = torch.from_numpy(actions).to(device)
                    last_actions = torch.from_numpy(last_actions).to(device)
                    last_actions_u = torch.from_numpy(last_actions_u).to(device)
                    last_actions_v = torch.from_numpy(last_actions_v).to(device)
                    last_actions_u_v = torch.from_numpy(last_actions_u_v).to(device)
                    batch_actions_all = torch.cat([actions, last_actions, last_actions_u, last_actions_v, last_actions_u_v], dim=-1)
                    batch_actions_all = batch_actions_all.reshape(-1, self.config.num_agents)

                    with autocast():
                        network_output = model.reward_model(batch_hidden_states, batch_actions_all)

                    batch_rewards = network_output.reward
                    batch_values = network_output.value

                    if isinstance(batch_rewards, torch.Tensor):
                         batch_rewards = batch_rewards.detach().cpu().numpy()
                    if isinstance(batch_values, torch.Tensor):
                         batch_values = batch_values.detach().cpu().numpy()

                    batch_rewards = batch_rewards.reshape(batch_size_adaptive).astype(np.float32)
                    batch_values = batch_values.reshape(batch_size_adaptive).astype(np.float32)

                    batch_q_targets = batch_rewards + discount * batch_values

                    trees.adaptive_batch_update(batch_q_targets)

            if (index_simulation + 1) % 20 == 0:
                torch.cuda.empty_cache()

        torch.cuda.empty_cache()

        roots_values = trees.get_roots_values()

        roots_marginal_visit_count = trees.get_roots_marginal_visit_count()
        roots_marginal_priors = trees.get_roots_marginal_priors()

        roots_sampled_actions = trees.get_roots_sampled_actions()
        roots_sampled_visit_count = trees.get_roots_sampled_visit_count()
        roots_sampled_pred_probs = trees.get_roots_sampled_pred_probs()
        roots_sampled_beta = trees.get_roots_sampled_beta()
        roots_sampled_beta_hat = trees.get_roots_sampled_beta_hat()
        roots_sampled_priors = trees.get_roots_sampled_priors()
        roots_sampled_imp_ratio = trees.get_roots_sampled_imp_ratio()
        roots_sampled_pred_values = trees.get_roots_sampled_pred_values()
        roots_sampled_mcts_values = trees.get_roots_sampled_mcts_values()
        roots_sampled_rewards = trees.get_roots_sampled_rewards()
        roots_sampled_qvalues = trees.get_roots_sampled_qvalues(discount)

        roots_adaptive_theta = trees.get_roots_adaptive_theta()

        return SearchOutput(roots_values, roots_marginal_visit_count, roots_marginal_priors, roots_sampled_actions, roots_sampled_visit_count,
                            roots_sampled_pred_probs, roots_sampled_beta, roots_sampled_beta_hat, roots_sampled_priors,
                            roots_sampled_imp_ratio,
                            roots_sampled_pred_values, roots_sampled_mcts_values, roots_sampled_rewards, roots_sampled_qvalues,
                            roots_adaptive_theta)

import numpy as np

def squash_reward(r, method="softsign", s=5.0, K=10.0, rmin=None, rmax=None):
    x = np.asarray(r, dtype=np.float32)
    if method == "softsign":
        y = 0.5 * (x / (1.0 + np.abs(x)) + 1.0)
    elif method == "logistic":
        y = 1.0 / (1.0 + np.exp(-x / float(s)))
    elif method == "tanh":
        y = 0.5 * (np.tanh(x / float(s)) + 1.0)
    elif method == "clip-linear":
        K = float(K)
        y = np.clip(x, -K, K)
        y = (y + K) / (2.0 * K)
    elif method == "minmax":
        if rmin is None or rmax is None or rmax <= rmin:
            raise ValueError("minmax  rmin < rmax")
        y = (x - float(rmin)) / (float(rmax) - float(rmin))
        y = np.clip(y, 0.0, 1.0)
    else:
        raise ValueError(f"unknown method: {method}")
    return y.astype(np.float32)
