import numpy as np
import ray
import torch
import random

from models import muzero_models
from self_play import MCTS, GameHistory, game_observation_to_inputs


@ray.remote
class Evaluate:
    """
    Class which run in a dedicated thread to play games and save them to the replay-buffer.
    """

    def __init__(self, Game, config):
        self.config = config

        self.evaluation_datasets = {}
        for name, items in self.config.evaluation_datasets.items():
            # override for evaluation environments
            game_config = config.game_config.copy()
            for param, value in items.items():
                game_config["dataset_config"][param] = value

            self.evaluation_datasets[name] = Game(game_config)

        # Initialize the network
        self.model = muzero_models.MuZeroLinesModel(
            self.config.network, **self.config.model_config
        )
        self.model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
        self.model.eval()

    def continuous_self_evaluation(self, shared_storage):
        while ray.get(
            shared_storage.get_info.remote("training_step")
        ) < self.config.training_steps and not ray.get(
            shared_storage.get_info.remote("terminate")
        ):
            self.evaluate(shared_storage)

    def evaluate(self, shared_storage):
        step = ray.get(shared_storage.get_info.remote("training_step"))
        self.model.set_weights(ray.get(shared_storage.get_info.remote("weights")))

        info = {"evaluation_last_step": step}
        for name, game in self.evaluation_datasets.items():
            total_steps = []
            total_correct_edges = []
            total_overlapping_edges = []
            total_wrong_edges = []
            total_solved = []
            total_solved_no_mistages = []
            total_rewards = []
            total_mean_values = []
            total_actual_steps_taken = []

            extra_metrics = {}

            rendered_game_index = random.randint(0, len(game.env.dataset) - 1)

            for idx in range(len(game.env.dataset)):
                game_history, rendered_game_ = self.play_game(
                    game,
                    0,
                    self.config.temperature_threshold,
                    idx,
                    render=(idx == rendered_game_index),
                )

                (
                    current_step,
                    num_correct_edges,
                    num_overlapping_edges,
                    num_wrong_edges,
                    actual_steps_taken,
                    solved,
                ) = game.env.statistics()

                metrics = game.env.metrics()
                for k, v in metrics.items():
                    if k not in extra_metrics:
                        extra_metrics[k] = []
                    extra_metrics[k].append(float(v))

                total_steps.append(current_step)
                total_correct_edges.append(num_correct_edges)
                total_overlapping_edges.append(num_overlapping_edges)
                total_wrong_edges.append(num_wrong_edges)
                total_solved.append(solved)
                total_actual_steps_taken.append(actual_steps_taken)
                total_solved_no_mistages.append(
                    solved and num_overlapping_edges + num_wrong_edges == 0
                )
                total_rewards.append(game.env.total_reward)
                total_mean_values.append(
                    np.mean([value for value in game_history.root_values if value])
                )

                if rendered_game_ is not None:
                    rendered_game = rendered_game_

            for k, v in extra_metrics.items():
                info["evaluation_dataset_" + name + "/" + k] = np.mean(v)
                info["evaluation_dataset_" + name + "/" + k + "_hg"] = torch.tensor(v)

            # At last position we should have a rendered game
            info["evaluation_dataset_" + name + "/rendered_game"] = torch.tensor(
                rendered_game
            )

            info["evaluation_dataset_" + name + "/total_steps"] = np.mean(total_steps)
            info["evaluation_dataset_" + name + "/total_steps_hg"] = torch.tensor(
                total_steps
            )

            info["evaluation_dataset_" + name + "/total_actual_steps"] = np.mean(
                total_actual_steps_taken
            )
            info[
                "evaluation_dataset_" + name + "/total_actual_steps_hg"
            ] = torch.tensor(total_actual_steps_taken)

            info["evaluation_dataset_" + name + "/total_correct_edges"] = np.mean(
                total_correct_edges
            )
            info[
                "evaluation_dataset_" + name + "/total_correct_edges_hg"
            ] = torch.tensor(total_correct_edges)

            info["evaluation_dataset_" + name + "/total_overlapping_edges"] = np.mean(
                total_overlapping_edges
            )
            info[
                "evaluation_dataset_" + name + "/total_overlapping_edges_hg"
            ] = torch.tensor(total_overlapping_edges)

            info["evaluation_dataset_" + name + "/total_solved"] = np.mean(total_solved)
            info["evaluation_dataset_" + name + "/total_solved_hg"] = torch.tensor(
                total_solved
            )

            info["evaluation_dataset_" + name + "/total_solved_no_mistages"] = np.mean(
                total_solved_no_mistages
            )
            info[
                "evaluation_dataset_" + name + "/total_solved_no_mistages_hg"
            ] = torch.tensor(total_solved_no_mistages)

            info["evaluation_dataset_" + name + "/total_rewards"] = np.mean(
                total_rewards
            )
            info["evaluation_dataset_" + name + "/total_rewards_hg"] = torch.tensor(
                total_rewards
            )

            info["evaluation_dataset_" + name + "/total_mean_values"] = np.mean(
                total_mean_values
            )
            info["evaluation_dataset_" + name + "/total_mean_values_hg"] = torch.tensor(
                total_mean_values
            )

            info["evaluation_dataset_" + name + "/total_wrong_edges"] = np.mean(
                total_wrong_edges
            )
            info["evaluation_dataset_" + name + "/total_wrong_edges_hg"] = torch.tensor(
                total_wrong_edges
            )

        shared_storage.set_info.remote(info)

    def play_game(
        self,
        game,
        temperature,
        temperature_threshold,
        idx,
        render=False,
    ):
        """
        Play one game with actions based on the Monte Carlo tree search at each moves.
        """
        game_history = GameHistory()
        observation = game.reset(idx=idx)

        global_features = self.model.get_constant_features(
            game_observation_to_inputs(
                observation, next(self.model.parameters()).device
            )
        )

        game_history.action_history.append(0)
        game_history.observation_history.append(observation)
        game_history.reward_history.append(0)

        done = False

        if render:
            rendered_game = [game.env.render()]

        with torch.no_grad():
            while (
                not done and len(game_history.action_history) <= self.config.max_moves
            ):
                curr_observation = game_history.get_observation(-1)
                possible_actions = game.legal_actions()

                # Choose the action
                root, mcts_info = MCTS(self.config).run(
                    self.model,
                    curr_observation,
                    True,
                    global_features=global_features,
                )
                action = self.select_action(
                    root,
                    temperature
                    if not temperature_threshold
                    or len(game_history.action_history) < temperature_threshold
                    else 0,
                )

                observation, reward, done = game.step(action)

                game_history.store_search_statistics(root, possible_actions)

                # Next batch
                game_history.action_history.append(action)
                game_history.observation_history.append(observation)
                game_history.reward_history.append(reward)

                if render:
                    rendered_game.append(game.env.render())

        if render:
            return game_history, np.stack(rendered_game)
        else:
            return game_history, None

    @staticmethod
    def select_action(node, temperature):
        """
        Select action according to the visit count distribution and the temperature.
        The temperature is changed dynamically with the visit_softmax_temperature function
        in the config.
        """
        visit_counts = np.array(
            [child.visit_count for child in node.children.values()], dtype="int32"
        )
        actions = [action for action in node.children.keys()]
        if temperature == 0:
            # to resolve equalities
            priors = np.array([child.prior for child in node.children.values()])
            action = actions[np.argmax(visit_counts + priors)]
        elif temperature == float("inf"):
            action = np.random.choice(actions)
        else:
            # See paper appendix Data Generation
            visit_count_distribution = visit_counts ** (1 / temperature)
            visit_count_distribution = visit_count_distribution / sum(
                visit_count_distribution
            )
            action = np.random.choice(actions, p=visit_count_distribution)

        return action
