

from queue import Queue
import numpy as np
from fsc import FiniteMemoryPolicy
from interval_models import IPOMDP
from models import POMDPWrapper

import utils
import math

class TreeNode:
    """
    Base class for search tree nodes.
    """
    def __init__(self):
        self.children = {}
    def __getitem__(self, key):
        return self.children.get(key, None)
    def __setitem__(self, key, value):
        self.children[key] = value
    def __contains__(self, key):
        return key in self.children

class QNode(TreeNode):
    """
    The action nodes of the tree.
    """
    def __init__(self, num_visits, value, prob=1, parent=None):
        " Action node "
        self.num_visits = num_visits
        self.value = value
        self.parent = parent
        self.prob = prob
        self.children = {}  # o -> VNode
    def __str__(self):
        return ("QNode") + "(%.3f, %.3f | %s)" % (self.num_visits,
                                                  self.value,
                                                  str(self.children.keys()))
    def __repr__(self):
        return self.__str__()

class VNode(TreeNode):
    """
    The observation nodes of the tree.
    """
    def __init__(self, num_visits, parent : QNode = None):
        " Observation / history node "
        self.num_visits = num_visits
        self.parent = parent
        self.children = {} # a -> QNode

    def __str__(self):
        return ("VNode") + "(%.3f, %.3f | %s)" % (self.num_visits,
                                                  self.value,
                                                  str(self.children.keys()))
    def __repr__(self):
        return self.__str__()

    def print_children_value(self):
        for action in self.children:
            print("   action %s: %.3f" % (str(action), self[action].value))
    
    def argmin(self):
        """argmin(VNode self)
        Returns the action of the child with lowest value"""
        best_value = float("inf")
        best_action = None
        for action in self.children:
            if self[action].value < best_value:
                best_action = action
                best_value = self[action].value
        return best_action

    def argmax(self):
        """argmax(VNode self)
        Returns the action of the child with highest value"""
        best_value = float("-inf")
        best_action = None
        for action in self.children:
            if self[action].value > best_value:
                best_action = action
                best_value = self[action].value
        return best_action

    @property
    def value(self):
        best_action = max(self.children, key=lambda action: self.children[action].value)
        return self.children[best_action].value


class RootVNode(VNode):
    """
    The root node of the search tree.
    """
    def __init__(self, num_visits, history):
        VNode.__init__(self, num_visits)
        self.history = history

    @classmethod
    def from_vnode(cls, vnode, history):
        """from_vnode(cls, vnode, history)"""
        rootnode = RootVNode(vnode.num_visits, history)
        rootnode.children = vnode.children
        return rootnode


class PolicyTree:
        
    @staticmethod
    def belief_update(belief, action, observation, pomdp, T):
        # NEW
        next_belief = np.zeros_like(belief)
        for s in np.where(belief > 0)[0]:
            for next_s, prob in T[s, action].items():
            # for next_s in next_possible_states:
                if pomdp.O[next_s] == observation:
                # if next_s in ipomdp.T[s, actions[b, l]]: 
                    next_belief[next_s] += prob * belief[s]

        if not math.isclose(next_belief.sum(), 1):
            next_belief = utils.normalize(next_belief)

        return np.array(next_belief)
    
    def eqCP(n1, n2, phi, psi):
        if phi[n1] != phi[n2]:
            return False
        if n1 not in psi: # added in addition to pseudo-code
            return True # added in addition to pseudo-code
        for o, j in psi[n1].items():
            if o not in psi[n2]:
                return False
            if not PolicyTree.eqCP(psi[n1][o], psi[n2][o], phi, psi):
                return False
        return True

    
    def qmdp_to_fsc_from_pseudo(**kwargs):       
        N, phi, psi = PolicyTree.create_policy_tree(**kwargs)
        print("FSC from QMDP with", len(N), "memory nodes after creating policy tree.")
        removed = set()
        for i in sorted(list(N)):
            for j in sorted(list(N)):
                if not j < i:
                    continue
                if PolicyTree.eqCP(i, j, phi, psi):
                    descendants_to_remove = {i}.union(psi[i].values()) # "descendents(n_i)" what does that mean exactly?
                    N -= descendants_to_remove
                    removed = removed.union(descendants_to_remove)
                    for n, next_dict in psi.items():
                        for o, n_ in next_dict.items():
                            if n_ == i:
                                psi[n][o] = j
        print(N, removed)
        for n in sorted(list(removed)):
            del phi[n]
            del psi[n]
        return N, phi, psi
    
    @staticmethod
    def from_moore_to_mealy(moore_act_fun : dict[int, int], moore_upd_fun : dict[int, dict[int, int]], num_nodes, num_obs):
        mealy_act_fun, mealy_upd_fun = {}, {}
        assert len(set(list(moore_act_fun.keys())) - (set(list(moore_upd_fun.keys())))) == 0, (moore_act_fun.keys(), moore_upd_fun.keys())
        # for n in range(num_nodes):
        for n, moore_act in moore_act_fun.items():
            # print(n, ":", moore_act)
            mealy_act_fun[n] = {}
            mealy_upd_fun[n] = {}
            for o, next_n in moore_upd_fun[n].items():
                # print(n, "x", o, ':', next_n)
                mealy_act_fun[n][o] = moore_act
                mealy_upd_fun[n][o] = next_n

        return mealy_act_fun, mealy_upd_fun

        
    def create_policy_tree(ipomdp : IPOMDP, pomdp : POMDPWrapper, T, mdp_q_values : np.array, tau, nan_fixer, depth = 3, batch_dim = None, deterministic = False, length = None, T_has_memory_dep = True):
        N = set()
        j = 0
        queue = Queue()
        belief = np.zeros((ipomdp.nS))
        belief[pomdp.initial_state] = 1
        queue.put((belief, 0, j, pomdp.O[pomdp.initial_state].item()))
        
        Q = np.nan_to_num(mdp_q_values, nan=nan_fixer) # nS x nA

        phi = {}
        psi = {}

        while not queue.empty():
            belief, d, i, o = queue.get()
            N.add(i)
            psi[i] = {}
            QMDP = np.dot(belief[np.newaxis, :], Q )
            phi[i] = QMDP.argmin()
            assert bool(pomdp.policy_mask[o][phi[i]])
            assert bool(pomdp.policy_mask[o][phi[i]]) == bool(pomdp.policy_mask[o,phi[i]])

            # for s in range(pomdp.nS):
            for s in np.where(belief > 0)[0]:
                # check wether action fits with policy mask
                assert pomdp.policy_mask[pomdp.O[s], phi[i]] == 1, (s, pomdp.O[s], phi[i], pomdp.policy_mask[pomdp.O[s], phi[i]])
            
            if d == depth:
                pass 
            else:
                for s in np.where(belief > 0)[0]:
                    for next_s in T[(s, phi[i])].keys():
                        o = pomdp.O[next_s]
                        j += 1
                        queue.put((PolicyTree.belief_update(belief, phi[i], o, pomdp, T) , d+1, j, o))
                        psi[i][o] = j
        
        return N, phi, psi
    
    def qmdp_to_fsc(ipomdp : IPOMDP, pomdp : POMDPWrapper, **kwargs):
        N, phi, psi = PolicyTree.qmdp_to_fsc_from_pseudo(ipomdp=ipomdp, pomdp=pomdp, **kwargs)
        print("FSC from QMDP with", len(N), "memory nodes before minimizing.")
        action_distributions = np.ones((len(N), pomdp.nO, pomdp.nA)) / pomdp.nA
        next_memories = np.full((len(N), pomdp.nO),0)
        
        mapping = {n : i for i, n in enumerate(sorted(list(N)))} # old N -> new N
        

        for n, next_fun in psi.items():
            if next_fun == {}:
                next_memories[mapping[n], :] = 1 / pomdp.nO
                continue
            for o, next_n in next_fun.items():
                a = phi[next_n]
                if bool(pomdp.policy_mask[o, a]):
                    action_distributions[mapping[n], o, :] = 0
                    action_distributions[mapping[n], o, a] = 1
                    assert np.isclose(action_distributions[mapping[n], o].sum(), 1), action_distributions[mapping[n], o]
                else:
                    assert False
                    action_distributions[mapping[n], o] = utils.normalize(pomdp.policy_mask[o].copy())
                assert action_distributions[mapping[n]].shape == pomdp.policy_mask.shape
                next_memories[mapping[n], o] = mapping[next_n]
        sums = np.sum(action_distributions, axis = -1)
        if not np.all(np.isclose(sums, 1)):
            raise ValueError(f'Distributions do not sum up to (close to) 0, 1, or are NaN. Sums are: \n{sums}')

        fsc = FiniteMemoryPolicy(
            action_distributions, next_memories,
            make_greedy = False, reshape = True,
            initial_observation = pomdp.initial_observation)

        return fsc
    
    def forward_search(ipomdp : IPOMDP, pomdp : POMDPWrapper, T : np.ndarray, mdp_q_values : np.array, tau, nan_fixer, label_to_reach, batch_dim = None, deterministic = False, length = None, T_has_memory_dep = True):
        state = pomdp.initial_state
        observation = pomdp.initial_observation
        
        assert pomdp.O[state] == pomdp.initial_observation, (pomdp.O[state], pomdp.initial_observation)