from __future__ import annotations
from dataclasses import dataclass
import time

from numpy.core.multiarray import array as array
from tqdm import tqdm
from fsc import FiniteMemoryPolicy
from interval_models import IPOMDP
from models import POMDPWrapper

import numpy as np
import random
import math

from policytree import QNode, RootVNode, VNode
import utils

from queue import Queue

class PolicySimulation:
    
    def __init__(self, ipomdp : IPOMDP,
                        pomdp : POMDPWrapper,
                        T : np.ndarray,
                        mdp_q_values : np.array,
                        tau : float,
                        nan_fixer : float,
                        label_to_reach : str) -> None:
        self.ipomdp = ipomdp
        self.pomdp = pomdp
        self.T = T
        self.mdp_q_values = mdp_q_values
        self.tau = tau
        self.nan_fixer = nan_fixer
        self.label_to_reach = label_to_reach
        self.Q = np.nan_to_num(mdp_q_values, nan=nan_fixer) # nS x nA
    
    def policy(self, belief, batch_done, deterministic, batch_dim, observation):
        raise ValueError("Must be implemented in class implementations!")

    def update(self, actions, observations, batch_done):
        pass
    
    def simulate_on_current_POMDP(self, batch_dim = None, deterministic = False, length = None, T_has_intervals = True, T_has_memory_dep = True, empirical_rewards_only = False, store_beliefs = False):
        rewards = np.zeros((batch_dim, length, self.pomdp.num_reward_models), dtype = 'float32')
        if not empirical_rewards_only:
            policies = np.zeros((batch_dim, length, self.pomdp.nA), dtype = 'float32')
            actions = np.zeros((batch_dim, length), dtype = 'int32')
            if store_beliefs:
                beliefs = np.zeros((batch_dim, length, self.ipomdp.nS))
            else:
                beliefs = None
            states = np.zeros((batch_dim, length), dtype = 'int32')
            observations = np.zeros((batch_dim, length), dtype = 'int32') - 1
        
        dones = np.ones((batch_dim, length), dtype=bool)

        state = np.array([np.squeeze(self.pomdp.initial_state) for b in range(batch_dim)], dtype = 'int32')
        observation = np.array([np.squeeze(self.pomdp.initial_observation) for b in range(batch_dim)], dtype = 'int32')

        belief = np.zeros((batch_dim, self.pomdp.nS))
        belief[:, self.pomdp.initial_state] = 1
        
        batch_done = np.zeros((batch_dim), bool)

        for l in range(length):
        # for l in (pbar := tqdm(range(length))):
            # pbar.set_description(f"Processing stage {l} for {np.count_nonzero(~batch_done)}/{batch_dim} remaining batches.")
            
            dones[:, l] = batch_done
            
            if store_beliefs: beliefs[~batch_done, l] = belief[~batch_done]
            observations[~batch_done, l] = observation[~batch_done]
            states[~batch_done, l] = state[~batch_done]

            policy, action = self.policy(belief, batch_done, deterministic, batch_dim, observation)

            policies[~batch_done, l] = policy[~batch_done]
            actions[~batch_done, l] = action[~batch_done]
            rewards[~batch_done, l, :] = self.ipomdp.R[state[~batch_done], action[~batch_done]][..., np.newaxis] if self.ipomdp.state_action_rewards else self.ipomdp.R[state[~batch_done]][..., np.newaxis]
            for b in (~batch_done).nonzero()[0]:
                try:
                    possible_states = list(self.T[(state[b], action[b])].keys())
                except KeyError:
                    print(state[b], action[b], "is not in T!")
                    print(self.T, state[b], action[b])
                    exit()
                probs = self.T[(state[b], action[b])].values()
                state[b] = random.choices(possible_states, weights=probs, k=1)[0]
            observation = self.ipomdp.pPOMDP.O[state]
            
            self.update(actions=action, observations=observation, batch_done=batch_done)

            next_belief = np.zeros((batch_dim, self.ipomdp.nS))
            for b in (~batch_done).nonzero()[0]:
                if self.label_to_reach in self.ipomdp.pPOMDP.observation_labels[observation[b]]:
                    batch_done[b] = True
                for s in np.where(belief[b] > 0)[0]:
                    for next_s, prob in self.T[s, actions[b, l]].items():
                    # for next_s in next_possible_states:
                        if self.ipomdp.pPOMDP.O[next_s] == observation[b]:
                        # if next_s in ipomdp.T[s, actions[b, l]]: 
                            next_belief[b, next_s] += prob * belief[b, s]

                if not math.isclose(next_belief[b].sum(), 1):
                    next_belief[b] = utils.normalize(next_belief[b])
            belief = np.array(next_belief)
            
            if batch_done.all(): break

        if empirical_rewards_only:
            beliefs, states, observations, policies, actions, dones = [None for _ in range(6)]
        return beliefs, states, observations, policies, actions, rewards, dones

class QMDP(PolicySimulation):
    
    def policy(self, belief, batch_done, deterministic, batch_dim, observation):
        if deterministic:
            Q_MDP = belief @ self.Q # B x nA
            action = np.argmin(Q_MDP, axis=-1)
            policy = np.zeros((batch_dim, self.ipomdp.pPOMDP.nA))
            policy[np.arange(batch_dim), action] = 1
        else:
            if self.tau > 0:
                Q_MDP = belief @ self.Q # B x nA
                policy = np.nan_to_num(utils.nan_soft_max_norm(Q_MDP, minimize=True, axis=-1, tau=self.tau), nan=0.0)
                action = utils.choice_from_md(policy.copy(), batch_dim, mask=self.pomdp.policy_mask[observation])
            else:
                policy = np.zeros((batch_dim, self.pomdp.nA), dtype=float)
                for b in (~batch_done).nonzero()[0]:
                    for s in belief[b].nonzero()[0]:
                        a = self.Q[s].argmin(axis=-1)
                        assert np.isfinite(a).all(), a
                        assert np.isfinite(belief[b][s]).all(), belief[b][s]
                        policy[b][a] += belief[b][s]
                action = utils.choice_from_md(policy.copy(), batch_dim, mask=self.pomdp.policy_mask[observation])
                assert np.isfinite(action).all(), action
                assert np.isfinite(policy).all(), policy

        assert np.allclose(np.sum(policy[~batch_done], axis=-1), 1), policy
        
        return policy, action

class FIB(PolicySimulation):
    
    def __init__(self, ipomdp: IPOMDP, pomdp: POMDPWrapper, T: np.ndarray, mdp_q_values: np.array, tau: float, nan_fixer: float, label_to_reach: str) -> None:
        super().__init__(ipomdp, pomdp, T, mdp_q_values, tau, nan_fixer, label_to_reach)
        
        alphas = self.compute_FIB(ipomdp, pomdp, T)
        for s in range(pomdp.nS):
            for a in range(pomdp.nA):
                if not pomdp.policy_mask[pomdp.O[s], a]:
                    alphas[s,a] = np.nan

        # print("V_FIB:", alphas[pomdp.initial_state])
        
        self.alphas = np.nan_to_num(alphas, nan=nan_fixer) # nS x nA

    def get_value(self, state):
        return np.nanmin(self.alphas[state])

    @staticmethod
    def compute_FIB(ipomdp : IPOMDP, pomdp : POMDPWrapper, T, max_iterations = 10_000, tolerance=1e-5, discount=1, verbose=False):
        nS = pomdp.nS
        nA = pomdp.nA
        alphas = np.zeros((nS, nA), dtype=float)
        old_alphas = np.zeros_like(alphas)
        
        for i in range(max_iterations):

            old_alphas = alphas.copy()

            residual = 0.0

            for (s, a), next_s_dict in T.items():

                r = ipomdp.R[s, a] if ipomdp.state_action_rewards else ipomdp.R[s]
                if "goal" in pomdp.observation_labels[pomdp.O[s]]:
                    assert r == 0
                    
                o_sum = 0.0
                
                for next_s, prob in next_s_dict.items():
                    ap_sum = np.inf
                    o = pomdp.O[next_s]
                    for ap in pomdp.policy_mask[o].nonzero()[0]:
                        temp_ap_sum = prob * old_alphas[next_s,ap]
                        ap_sum = min(temp_ap_sum, ap_sum)
                    o_sum += ap_sum
                
                alphas[s, a] = r + discount * o_sum
                alpha_diff = abs(alphas[s, a] - old_alphas[s, a])
                residual = max(alpha_diff, residual)
                
                assert np.isfinite(alphas).all(), (r, o_sum)
            
            if verbose: print(i, residual)
            if residual < tolerance: break
        return alphas

    def policy(self, belief, batch_done, deterministic, batch_dim, observation):
        if deterministic:
            FIB = belief @ self.alphas # B x nA
            action = np.argmin(FIB, axis=-1)
            policy = np.zeros((batch_dim, self.ipomdp.pPOMDP.nA))
            policy[np.arange(batch_dim), action] = 1
        else:
            if self.tau > 0:
                FIB = belief @ self.alphas # B x nA
                policy = np.nan_to_num(utils.nan_soft_max_norm(FIB, minimize=True, axis=-1, tau=self.tau), nan=0.0)
                action = utils.choice_from_md(policy.copy(), batch_dim, mask=self.pomdp.policy_mask[observation])
            else:
                policy = np.zeros((batch_dim, self.pomdp.nA), dtype=float)
                for b in (~batch_done).nonzero()[0]:
                    for s in belief[b].nonzero()[0]:
                        a = self.alphas[s].argmin(axis=-1)
                        assert np.isfinite(a).all(), a
                        assert np.isfinite(belief[b][s]).all(), belief[b][s]
                        policy[b][a] += belief[b][s]
                action = utils.choice_from_md(policy.copy(), batch_dim, mask=self.pomdp.policy_mask[observation])
                assert np.isfinite(action).all(), action
                assert np.isfinite(policy).all(), policy
        
        assert np.allclose(np.sum(policy[~batch_done], axis=-1), 1), policy
        
        return policy, action
    

class POUCT(PolicySimulation):
    
    def __init__(self, ipomdp: IPOMDP, pomdp: POMDPWrapper, T: np.ndarray, mdp_q_values: np.array, tau: float, nan_fixer: float, label_to_reach: str) -> None:
        super().__init__(ipomdp, pomdp, T, mdp_q_values, tau, nan_fixer, label_to_reach)
        self.c = math.sqrt(2) # hyperparameter, set to the standard value
        self.max_depth = 25 # hyperparameter, set arbitrarily
    
    def policy(self, belief, batch_done, deterministic, batch_dim, observation):
        if not hasattr(self, 'roots'):
            self.roots = [None] * batch_dim
        action = np.zeros(batch_dim, dtype=int)
        policy = np.zeros((batch_dim, self.pomdp.nA), dtype=float)
        for b in (~batch_done).nonzero()[0]:
            available_actions = self.pomdp.policy_mask[observation[b]].nonzero()[0]
            if available_actions.size == 1:
                action[b] = available_actions.item()
            else:
                action[b] = self.search(belief[b], batch_index=b)
            if deterministic:
                policy[b, action[b]] = 1
            else:
                raise NotImplementedError()
        return policy, action

    def update(self, actions, observations, batch_done):
        for b in (~batch_done).nonzero()[0]:
            action = actions[b]
            observation = observations[b]
            if self.roots[b] is not None and action in self.roots[b] and observation in self.roots[b][action]:
                assert self.roots[b][action][observation] is not None
                self.roots[b] = RootVNode.from_vnode(self.roots[b][action][observation], None)
            else:
                self.roots[b] = None

    def search(self, belief : np.ndarray, batch_index : int, max_simulations = 1000):

        sampled_states = random.choices(range(self.pomdp.nS), belief, k=max_simulations)
        
        start = time.time()

        for s in sampled_states:
            self.simulate(s, node=self.roots[batch_index], parent=None, batch_index=batch_index)
            
            if time.time() - start > 1:
                break
        
        # if batch_index == 0:
            # print(self.roots[batch_index], self.roots[batch_index].children)

        return self.roots[batch_index].argmax()
    
    def step(self, s, a):
        try:
            possible_states, probs = zip(*self.T[(s, a)].items())
        except KeyError:
            print(s, a, "is not in T!")
            print(self.T)
            exit()
        s_ = random.choices(possible_states, weights=probs, k=1)[0]
        observation = self.ipomdp.pPOMDP.O[s_]
        reward = self.ipomdp.R[s, a] if self.ipomdp.state_action_rewards else self.ipomdp.R[s]
        terminal_state = self.label_to_reach in self.ipomdp.pPOMDP.observation_labels[observation]
        if np.count_nonzero(self.pomdp.policy_mask[observation]) == 1 and not terminal_state:
            s_, observation, c, terminal_state = self.step(s_, self.pomdp.policy_mask[observation].nonzero()[0].item())
            reward += c
        return s_, observation, -reward, terminal_state

    def simulate(self, state, node : VNode, parent, batch_index : int, depth = 0, gamma=0.99):
        if depth > self.max_depth:
            return 0
        
        observation = self.pomdp.O[state]
        
        if node is None:
            if self.roots[batch_index] is None:
                node = self._construct_VNode(root=True)
                self.roots[batch_index] = node
            else:
                node = self._construct_VNode()
            
            if parent is not None:
                parent[observation] = node

            self._expand_VNode(node, mask=self.pomdp.policy_mask[observation].astype(bool))
            return self.rollout(state, depth)

        # ucb = np.array([self.UCB(node, node[child]) for child in node.children])
        # assert ucb.size == np.count_nonzero(self.pomdp.policy_mask[observation]), (state, ucb)
        action = np.argmax([self.UCB(node, node[child]) for child in node.children])
            
        next_state, next_observation, c, done = self.step(state, action)
        
        if done:
            total_reward = c
        else:
            total_reward = c + gamma * self.simulate(next_state, node[action][next_observation], node[action], batch_index, depth=depth+1)

        node.num_visits += 1
        node[action].num_visits += 1
        node[action].value = node[action].value + ((total_reward - node[action].value) / (node[action].num_visits))

        return total_reward

    """
    Rollout function to determine the value of a new node.
    """
    def rollout(self, state, depth : int, discount = 0.99, use_Q=True):
        if depth == 0: # Rollouts at the first layer are worthless as the result is discarded.
            return 0
        total_discounted_reward = 0
        done = False
        while depth < self.max_depth and not done:
            if use_Q:
                action = np.argmin(self.Q[state])
            else:
                actions = self.pomdp.policy_mask[self.pomdp.O[state]].astype(bool).nonzero()[0]
                action = random.choice(actions)
            next_state, obs, reward, done = self.step(state, action)
            depth += 1
            total_discounted_reward += reward * discount
            discount *= (discount)
            state = next_state
        return total_discounted_reward


    def UCB(self, root : VNode, child : QNode) -> float:
        """
        Upper Confidence Bound for Upper Confidence Trees algorithm.
        """
        return (child.value) + self.c * math.sqrt(math.log(root.num_visits + 1) / (child.num_visits + 1))

    def _construct_VNode(self, root=False, **kwargs):
        """
        Construct a (root) VNode for this tree and initialise its values.
        """
        return RootVNode(0, None) if root else VNode(0)

    def _expand_VNode(self, vnode : VNode, mask : np.ndarray = None, prior=None, value=None):
        for action in mask.nonzero()[0]:
            if vnode[action] is None:
                if prior is not None and prior[action] is not None:
                    qnode = QNode(0, value if value else 0, prob=prior[action], parent=vnode)
                    vnode[action] = qnode
                else:
                    qnode = QNode(0, value if value else 0, parent=vnode)
                    vnode[action] = qnode
