from __future__ import annotations
import random
import numpy as np
from algorithms.abstract.game_history import GameHistory
from typing import TYPE_CHECKING, List

from algorithms.mu_zero.types import MuZeroHistoryItem

from algorithms.utils.types import TrainingTarget
if TYPE_CHECKING:
    from algorithms.utils.types import StateFeature, GameSampleData
    from algorithms.utils.types import ActionImage


class MuZeroGameHistory(GameHistory):
    def __init__(self, final_target: TrainingTarget, final_action_image: ActionImage):
        self.final_target = final_target
        self.final_action_image = final_action_image
        self.items = []  # type: List[MuZeroHistoryItem]

    def store(self, history_item: MuZeroHistoryItem) -> None:
        self.items.append(history_item)

    def update_for_board_games(self, terminal_returns: List[float]) -> None:
        new_items = []
        for item in self.items:
            if item.active_player >= 0:
                terminal_return = terminal_returns[item.active_player]
                value = np.array([terminal_return], dtype=np.float32)
            else:
                value = np.array([0.0], dtype=np.float32)
            new_item = MuZeroHistoryItem(state_feature=item.state_feature,
                                         value=value,
                                         chance_target=item.chance_target,
                                         choice_target=item.choice_target,
                                         tau_target=item.tau_target,
                                         action=item.action,
                                         action_image=item.action_image,
                                         state_string=item.state_string,
                                         active_player=item.active_player,
                                         node=item.node)
            new_items.append(new_item)
        self.items = new_items

    def get_action_images(self, index: int, k: int) -> List[ActionImage]:
        action_images = []
        for item in self.items[index : index + k]:
            action_image = item.action_image
            action_images.append(action_image)
        return action_images

    def get_root_state_feature(self, index: int) -> StateFeature:
        history_item = self.items[index]
        return history_item.state_feature

    def get_training_target(self, index: int) -> TrainingTarget:
        history_item = self.items[index]
        target = TrainingTarget(tau_policy=history_item.tau_target,
                                value=history_item.value,
                                chance_policy=history_item.chance_target,
                                choice_policy=history_item.choice_target)
        return target

    def sample_data(self, k: int) -> GameSampleData:
        index = random.randint(0, len(self.items) - 1)
        root_state_feature = self.get_root_state_feature(index)
        action_images = self.get_action_images(index, k)
        targets = []
        state_strings = []
        for i in range(index, index + k + 1):
            if i < len(self.items):
                history_item = self.items[i]
                state_strings.append(history_item.state_string)
                target = self.get_training_target(i)
                targets.append(target)  # choice policy target
            else:
                state_strings.append('Game ended')
                targets.append(self.final_target)
        while len(action_images) < len(targets) - 1:
            action_images.append(self.final_action_image)
        return root_state_feature, action_images, targets, state_strings

    def __len__(self):
        return len(self.items)

