import jax
from typing import *
import jax.numpy as jnp
from .modelbasedagent import ModelBasedAgent
import numpy as np
import time


class VBRBAgent(ModelBasedAgent):
    def __init__(self, dirichlet_param, reward_param, tau, precision, beta, use_jax: bool, rng_key=None, transition_var_scale=1.0, reward_var_scale=1.0, use_normal_gamma_prior=False, instant_reward=True, env_reward=None, **kwargs):
        super(VBRBAgent, self).__init__(**kwargs)
        self.dirichlet_param = dirichlet_param
        self.reward_param = reward_param
        self.rng_key = rng_key
        if use_jax:
            assert rng_key is not None, f"rng key is None when using jax."
        self.use_jax = use_jax

        self.use_normal_gamma_prior = use_normal_gamma_prior
        self.transition_var_scale = transition_var_scale
        self.reward_var_scale = reward_var_scale

        if env_reward is not None:
            self.known_reward = True
            self.reward = env_reward
        else:
            self.known_reward = False

        self.instant_reward = instant_reward

        if not self.known_reward:
            if self.reward_param is not None: # assign the reward when it is first known
                self.reward = np.full((self.num_states, self.num_actions, self.num_states), self.reward_param, dtype=np.float32)
            else:               # use conjugate priors: [Normal with known precision, Normal-Gamma]
                if self.instant_reward:
                    self.reward = np.zeros((self.num_states, self.num_actions, self.num_states), dtype=np.float32)
                    self.reward_observations = np.zeros((self.num_states, self.num_actions, self.num_states), dtype=np.float32)
                else:
                    self.reward = np.zeros((self.num_states, self.num_actions), dtype=np.float32)
                    self.reward_observations = np.zeros((self.num_states, self.num_actions), dtype=np.float32)

                if self.use_normal_gamma_prior:
                    if self.instant_reward:
                        self.reward_squared_observations = np.zeros((self.num_states, self.num_actions, self.num_states), dtype=np.float32)
                    else:
                        self.reward_squared_observations = np.zeros((self.num_states, self.num_actions), dtype=np.float32)

                if self.use_normal_gamma_prior:
                    self.mu = 0.0
                    self.lam = beta
                    self.alpha = 1.0
                    self.beta = beta
                else:
                    self.mu = 0
                    self.tau = tau
                    self.precision = precision

        self.reward_bonus = None

    def reset(self):
        super(VBRBAgent, self).reset()
        self.transition_observations = np.zeros((self.num_states, self.num_actions, self.num_states), dtype=np.float32)
        self.value_table = np.zeros((self.num_states, self.num_actions), dtype=np.float32)

    def update_model(self, reward, next_state):
        if reward is not None:
            # Update the reward associated with (s,a,s') if first time.
            if not self.known_reward:
                if self.reward_param is not None:
                    if self.reward[self.last_state, self.last_action, next_state] == self.reward_param:
                        self.reward[self.last_state, self.last_action, next_state] = reward
                else:
                    if self.instant_reward:
                        self.reward_observations[self.last_state, self.last_action, next_state] += reward
                        if self.use_normal_gamma_prior:
                            self.reward_squared_observations[self.last_state, self.last_action, next_state] += reward**2
                    else:
                        self.reward_observations[self.last_state, self.last_action] += reward
                        if self.use_normal_gamma_prior:
                            self.reward_squared_observations[self.last_state, self.last_action] += reward**2

            # Update set of states reached by playing a.
            self.transition_observations[self.last_state, self.last_action, next_state] += 1

    def interact(self, reward, next_state):
        # update model
        self.update_model(reward, next_state)

        # Update transition probabilities after every T steps
        if self.policy_step == self.T:
            self._compute_policy()

        # Choose next action according to policy.
        if self.use_jax:
            next_action = self.jax_argmax_breaking_ties_randomly(self.value_table[next_state])
        else:
            next_action = self._argmax_breaking_ties_randomly(self.value_table[next_state])

        self.policy_step += 1
        self.last_state = next_state
        self.last_action = next_action

        return self.last_action

    def _compute_policy(self):
        """Compute an optimal T-step policy for the current state."""
        self.policy_step = 0

        param = self.transition_observations + self.dirichlet_param
        transition_probs = self.dirichlet_mean(param)[0]
        var = self.dirichlet_var(param)

        if not self.known_reward and self.reward_param is None:
            transition_var = np.sum(var, axis=-1)

            if self.instant_reward:
                count = self.transition_observations
            else:
                count = np.sum(self.transition_observations, axis=-1)

            count_safe = np.where(count == 0, 1, count)

            if self.use_normal_gamma_prior:
                lam_new = self.lam + count
                alpha_new = self.alpha + count / 2
                # sample variance
                s = (self.reward_squared_observations - self.reward_observations**2 / count_safe) / count_safe
                beta_new = self.beta + (count * s + ((self.lam * (self.reward_observations - count * self.mu)**2) / (count_safe * (self.lam + count)))) / 2

                reward_var = beta_new / (lam_new * alpha_new)
                self.reward = (self.lam * self.mu + self.reward_observations) / (self.lam + count)
            else:
                reward_var = 1 / (self.tau + self.precision * count) # reward variance
                self.reward = (self.tau * self.mu + self.precision * self.reward_observations) / (self.tau + self.precision * count)

            if self.instant_reward or self.reward_param is not None:
                reward_var = np.sum(reward_var * transition_probs, axis=-1) # E_s'[E_r(s, a, s')] = E_r(s, a)
            self.reward_bonus = np.sqrt(transition_var) * self.transition_var_scale + reward_var * self.reward_var_scale
        else:
            if self.instant_reward:
                self.reward_bonus = np.sqrt(np.sum(var, axis=-1, keepdims=True)) * self.transition_var_scale
            else:
                self.reward_bonus = np.sqrt(np.sum(var, axis=-1)) * self.transition_var_scale # S x A

        if self.instant_reward or self.reward_param is not None:
            self.reward_bonus = self.reward_bonus[..., np.newaxis] # S x A x 1, since the instant_reward is of S x A x S

        # terminal states self looping
        if self.terminal_indexes is not None:
            transition_probs[self.terminal_indexes] = 0
            transition_probs[self.terminal_indexes, :, self.terminal_indexes] = 1

            # no reward at terminal states
            self.reward[self.terminal_indexes] = 0
            self.reward_bonus[self.terminal_indexes] = 0

        rewards = self.reward + self.reward_bonus

        if self.use_jax:
            # calculate rewards
            self.jax_value_iteration(jnp.array(rewards), transition_probs)
        else:                   # fall back to numpy
            self._value_iteration(rewards, transition_probs)
