import math
from typing import NamedTuple
import numpy as np
import torch
from torch import Tensor

from core.model import ActorOutput
from torch.distributions import Normal


class ActionCapsule(NamedTuple):
    action: Tensor
    repeat: int
    repeat_one_hot: Tensor

    def __eq__(self, other):
        if (self.action == other.action).all().item() and self.repeat == other.repeat:
            return True
        else:
            return False

    def __gt__(self, other):
        if (self.action < other.action).all().item():
            return True
        else:
            return False
    def __ge__(self, other):
        if (self.action < other.action).all().item():
            return True
        else:
            return False


class MinMaxStats(object):
    """A class that holds the min-max values of the tree."""

    def __init__(self, min_value_bound=None, max_value_bound=None):
        self.maximum = min_value_bound if min_value_bound else -float('inf')
        self.minimum = max_value_bound if max_value_bound else float('inf')

    def update(self, value: float):
        self.maximum = max(self.maximum, value)
        self.minimum = min(self.minimum, value)

    def normalize(self, value: float) -> float:
        if self.maximum > self.minimum:
            # We normalize only when we have set the maximum and minimum values.
            return (value - self.minimum) / (self.maximum - self.minimum)
        return value


class Node(object):

    def __init__(self, action_log_prob: float, root=False):
        self.visit_count = 0
        self.root = root
        self.prior = None
        self.value_sum = 0
        self.children = {}
        self.hidden_state = None
        self.reward = 0
        self.actor_output = None
        self.action_log_prob = action_log_prob

    def expanded(self) -> bool:
        return len(self.children) > 0

    def value(self) -> float:
        if self.visit_count == 0:
            return 0
        return self.value_sum / self.visit_count

    def expand(self, model, actor_output, action_log_prob, action, action_repeat, action_repeat_one_hot, state, reward,
               progressive=False, proposal_action_sample_n=20):
        self.hidden_state = state
        self.reward = reward
        self.actor_output = actor_output

        if progressive:
            self.children[ActionCapsule(action, action_repeat, action_repeat_one_hot)] = Node(action_log_prob)
        else:
            _mu = actor_output.mu.repeat(1, proposal_action_sample_n). \
                reshape(proposal_action_sample_n, actor_output.mu.shape[1])
            _std = actor_output.std_dev.repeat(1, proposal_action_sample_n). \
                reshape(proposal_action_sample_n, actor_output.std_dev.shape[1])
            _actor_output = ActorOutput(_mu, _std)
            actor_dist = model.actor.action_dist(_actor_output)
            _actions = model.actor.action_sample(_actor_output)
            noise_dist = Normal(torch.zeros(_actions.shape).to(_actions.device),
                                torch.ones(_actions.shape).to(_actions.device) * 0.3)
            _actions = _actions + noise_dist.sample()
            child_action_log_prob = - actor_dist.entropy(_actions)

            for child_i, child_action in enumerate(_actions):
                self.children[ActionCapsule(child_action.unsqueeze(0).unsqueeze(0), action_repeat, action_repeat_one_hot)] = Node(0)
        self.update_prior()

    def update_prior(self):
        exp_sum = sum([(math.e ** 0.25) ** (child.action_log_prob) for action_cap, child in self.children.items()])
        for child in self.children.values():
            child.prior = ((math.e ** 0.25) ** (child.action_log_prob) / exp_sum)

    def add_exploration_noise(self, dirichlet_alpha, exploration_fraction):
        actions = list(self.children.keys())
        noise = np.random.dirichlet([dirichlet_alpha] * len(actions))
        frac = exploration_fraction
        for a, n in zip(actions, noise):
            self.children[a].prior = self.children[a].prior * (1 - frac) + n * frac

    def show(self):
        from igraph import Graph, EdgeSeq

        #############
        # Create Tree
        #############
        g = Graph()
        g.add_vertices(1)  # add root

        parent_idx = 0
        total_nodes = 0
        unprocessed_nodes = [(parent_idx, action_cap, child) for action_cap, child in self.children.items()]
        hover_labels = ['Id: {} Value: {} Visit Count: {} prior: {}'.format(0,
                                                                            round(self.value(), 2),
                                                                            round(self.visit_count, 2),
                                                                            (None if self.prior is None
                                                                             else round(self.prior, 2)))]
        while len(unprocessed_nodes) > 0:
            parent_idx, action_cap, node = unprocessed_nodes.pop(0)
            child_idx = total_nodes + 1
            g.add_vertices(1)
            hover_labels.append('Id: {} Value: {} reward:{} Visit Count: {} Prior: {} Action:{}'.
                                format(total_nodes + 1,
                                       round(node.value(), 2),
                                       round(node.reward, 2),
                                       round(node.visit_count, 2),
                                       None if node.prior is None else round(node.prior, 2),
                                       action_cap.action.numpy()[0, 0]))
            g.add_edges([(parent_idx, child_idx)])
            g.add_vertices(len(node.children.keys()))
            unprocessed_nodes += [(child_idx, _action_cap, _child) for _action_cap, _child in node.children.items()]
            # for child_i, child in enumerate(node.children.values()):
            #     unprocessed_nodes += [(child_idx, _action_cap, _child) for _action_cap, _child in
            #                           child.children.items()]
            total_nodes += 1

        total_nodes += 1  # incr for root index
        lay = g.layout('tree', root=0)

        ##########
        # Plot
        ##########

        position = {k: lay[k] for k in range(total_nodes)}
        Y = [lay[k][1] for k in range(total_nodes)]
        M = max(Y)

        es = EdgeSeq(g)  # sequence of edges
        E = [e.tuple for e in g.es]  # list of edges

        L = len(position)
        Xn = [position[k][0] for k in range(L)]
        Yn = [2 * M - position[k][1] for k in range(L)]
        Xe = []
        Ye = []
        for edge in E:
            Xe += [position[edge[0]][0], position[edge[1]][0], None]
            Ye += [2 * M - position[edge[0]][1], 2 * M - position[edge[1]][1], None]

        v_label = list(map(str, range(total_nodes)))
        labels = v_label

        import plotly.graph_objects as go
        fig = go.Figure()

        def make_annotations(pos, text, font_size=10, font_color='rgb(250,250,250)'):
            L = len(pos)
            if len(text) != L:
                raise ValueError('The lists pos and text must have the same len')
            annotations = []
            for k in range(L):
                annotations.append(
                    dict(
                        text=labels[k],  # or replace labels with a different list for the text within the circle
                        x=pos[k][0], y=2 * M - position[k][1],
                        xref='x1', yref='y1',
                        font=dict(color=font_color, size=font_size),
                        showarrow=False)
                )
            return annotations

        fig.add_trace(go.Scatter(x=Xe,
                                 y=Ye,
                                 mode='lines',
                                 name='action',
                                 line=dict(color='rgb(210,210,210)', width=1),
                                 hoverinfo='none'
                                 ))
        fig.add_trace(go.Scatter(x=Xn,
                                 y=Yn,
                                 mode='markers',
                                 name='state',
                                 marker=dict(symbol='circle-dot',
                                             size=18,
                                             color='#6175c1',  # '#DB4551',
                                             line=dict(color='rgb(50,50,50)', width=1)
                                             ),
                                 text=hover_labels,
                                 hoverinfo='text',
                                 opacity=0.8
                                 ))

        axis = dict(showline=False,  # hide axis line, grid, ticklabels and  title
                    zeroline=False,
                    showgrid=True,
                    showticklabels=True,
                    )

        fig.update_layout(annotations=make_annotations(position, v_label),
                          font_size=12,
                          showlegend=True,
                          xaxis=axis,
                          yaxis=axis,
                          margin=dict(l=40, r=40, b=85, t=100),
                          hovermode='closest',
                          plot_bgcolor='rgb(248,248,248)'
                          )
        fig.show()


class MCTS(object):
    def __init__(self, config, exploration=False, progressive=False):
        self.config = config
        self.exploration = exploration
        self.progressive = progressive

    def run(self, root, model, num_simulations=50):
        min_max_stats = MinMaxStats()

        for _ in range(num_simulations):
            node = root
            search_path = [node]

            while node.expanded():
                action_cap, node = self.select_child(model, node, min_max_stats)
                search_path.append(node)

            # Inside the search tree we use the dynamics function to obtain the next
            # hidden state given an action and the previous hidden state.
            parent = search_path[-2]
            belief, state = parent.hidden_state
            transition_output = model.transition(belief, state, action_cap.action, action_cap.repeat_one_hot)

            next_belief, next_state = transition_output.belief.squeeze(0), transition_output.prior_state.squeeze(0)
            actor_output = model.actor(next_belief, next_state)
            value = model.value(next_belief, next_state)
            action = model.actor.action_sample(actor_output, deterministic=False)
            actor_repeat_output = model.actor_repeat(next_belief, next_state, action)
            action_repeat_one_hot, action_repeat = model.actor_repeat.sample(actor_repeat_output, deterministic=False)
            reward = model.reward(next_belief, next_state)

            actor_dist = model.actor.action_dist(actor_output)
            action_log_prob = - actor_dist.entropy(action)

            node.expand(model, actor_output, action_log_prob.item(), action.unsqueeze(0),
                        action_repeat.int().item(),
                        action_repeat_one_hot.unsqueeze(0),
                        (next_belief, next_state), reward.item())

            self.backpropagate(search_path, value.item(), min_max_stats)

    def select_child(self, model, node, min_max_stats):
        p = self.config.mcts_cpw * (node.visit_count) ** self.config.mcts_alpha
        if (self.progressive and p <= len(node.children.keys())) or (not self.progressive):
            _, action_cap, child = max((self.ucb_score(node, child, min_max_stats), action_cap, child)
                                       for action_cap, child in node.children.items())
        else:
            belief, state = node.hidden_state
            actor_output = node.actor_output

            # add an action to the node
            action = model.actor.action_sample(actor_output, deterministic=False)
            action_sample_attempt = 0

            rounded = lambda arr: np.around(arr.cpu().numpy(), 2).tolist()

            while action_sample_attempt < 100 and \
                    (rounded(action) in [rounded(action_cap.action.squeeze(0)) for action_cap in node.children.keys()]):
                action = model.actor.action_sample(actor_output, deterministic=False)
                action_sample_attempt += 1

            if action_sample_attempt == 99:
                print('Something is wrong in sampling')

            actor_repeat_output = model.actor_repeat(belief, state, action)
            action_repeat_one_hot, action_repeat = model.actor_repeat.sample(actor_repeat_output, deterministic=False)
            action_cap = ActionCapsule(action.unsqueeze(0), action_repeat.int().item(),
                                       action_repeat_one_hot.unsqueeze(0))

            actor_dist = model.actor.action_dist(actor_output)
            action_log_prob = - actor_dist.entropy(action)

            child = Node(action_log_prob.item())
            node.children[action_cap] = child
            node.update_prior()

            if self.exploration and node.root:
                node.add_exploration_noise(self.config.root_dirichlet_alpha, self.config.root_exploration_fraction)

        return action_cap, child

    def ucb_score(self, parent, child, min_max_stats) -> float:
        pb_c = math.log((parent.visit_count + self.config.pb_c_base + 1) / self.config.pb_c_base)
        pb_c += self.config.pb_c_init
        pb_c *= math.sqrt(parent.visit_count) / (child.visit_count + 1)

        prior_score = pb_c * child.prior
        value_score = min_max_stats.normalize(child.value())

        return prior_score + value_score

    def backpropagate(self, search_path, value, min_max_stats):
        for node in search_path:
            node.value_sum += value
            node.visit_count += 1
            min_max_stats.update(node.value())

            value = node.reward + self.config.gamma * value
