import numpy as np
import scipy.stats
from scipy.special import logsumexp
import math
import time


'''
Inference of the log marginal likelihood using the following state-space model:

                  f(s_t, a_t) := s_t + a_t
                        p(s_0) = Nornal(0, 1)
                  p(a_t | s_t) = Normal(s_t * 0.5, 1)
         p(s_{t+1} | s_t, a_t) = δ_{f(s_{t},a_t)}(s_{t+1})
log p(O_t | s_t, a_t, s_{t+1}) = 0 if -1e-2 <= s_{t+1} <= 1e-2 else -10000
                               ≈ \tilde{Q}(s_t, a_t)
                               = -1000*|s_t + a_t| +  ε
'''


p_s0 = scipy.stats.norm(loc=0.0, scale=1.0)  # Initial state distribution
p_at_given_st = lambda st: scipy.stats.norm(loc=st*0.5, scale=1)  # Prior policy
c_st = lambda st: np.where(np.logical_and(st >= -1e-2, st <= 1e-2), 0, -10000)  # Optimality condition
critic_func = lambda st, at, eps=1e-8: -1000*np.abs(st + at) + eps  # Heuristic function


# Slow to execute deterministic state transition function
def p_next_s_given_s_and_a(st, at):
    original_shape = st.shape
    st = st.reshape(-1)
    at = at.reshape(-1)
    next_st = []
    for i in range(len(st)):
        time.sleep(0.0005)
        next_st.append(st[i] + at[i])
    return np.array(next_st, dtype=st.dtype).reshape(original_shape)


def multinomial_resampling(log_w):
    log_w = np.transpose(log_w, (1,0))
    global_samples, local_samples = log_w.shape[0], log_w.shape[1]
    log_w = log_w.reshape(global_samples*local_samples)
    log_z = logsumexp(log_w, keepdims=True)
    log_w = log_w - log_z
    ancestral_indices = np.random.choice(np.arange(local_samples*global_samples), global_samples, p=np.exp(log_w))
    log_w = np.log(np.ones_like(ancestral_indices, dtype=log_w.dtype) / global_samples)
    log_w = log_w + log_z
    ancestral_indices = np.stack([np.remainder(ancestral_indices, local_samples, dtype=int),
                                  np.floor_divide(ancestral_indices, local_samples, dtype=int)], axis=-1)
    return ancestral_indices, log_w, log_z


def critic_smc(s0, critic=True, max_timesteps=10, num_particles=1, num_putative_action_particles=128):
    st = np.repeat(s0, num_particles)
    logw = np.zeros(num_particles, dtype=np.float32)
    logz = 0
    for t in range(max_timesteps):
        at = p_at_given_st(st).rvs((num_putative_action_particles, num_particles))
        st = np.repeat(st[None, :], num_putative_action_particles, axis=0)
        logw = np.repeat(logw[None, :], num_putative_action_particles, axis=0)
        logw = logw - math.log(num_putative_action_particles)
        if critic:  # Using heuristic factor
            critic_value = critic_func(st, at)  # Evaluate actions with critic
            logw = logw + critic_value  # Update logw with heuristic factor (critic value)
            ancestral_indices, logw, logz_t = multinomial_resampling(logw)  # Resampling
            logz += logz_t
            at = at[ancestral_indices[:, 0], ancestral_indices[:, 1]]
            st = st[ancestral_indices[:, 0], ancestral_indices[:, 1]]
            st = p_next_s_given_s_and_a(st, at)  # Step environment after resampling
            condition_prob = c_st(st)  # Evaluate optimality condition
            logw = logw + condition_prob
            selected_critic_value = critic_value[ancestral_indices[:, 0], ancestral_indices[:, 1]]
            logw = logw - selected_critic_value  # Decouple resampling probabilities from particle weights
        else:  # Standard SMC using bootstrap proposals
            st = p_next_s_given_s_and_a(st, at)  # Step environment
            condition_prob = c_st(st)  # Evaluate optimality condition
            logw = logw + condition_prob
            ancestral_indices, logw = multinomial_resampling(logw)  # Resampling
            logz += logz_t
            st = st[ancestral_indices[:, 0], ancestral_indices[:, 1]]
    return logw, logz


s0 = p_s0.rvs(1)  # Sample an initial state
print("Initial state:", s0)

num_particles = 10
num_putative_action_particles = 1024
use_critic = True  # If False then standard SMC is executed

print("Num particles:", num_particles)
print("Num putative action particles:", num_putative_action_particles)
print("Using critic:", use_critic)

begin = time.time()
logw, logz = critic_smc(s0, num_particles=num_particles, num_putative_action_particles=num_putative_action_particles,
                        critic=use_critic)
end = time.time()
print("Logw:", logw)
print("Logz:", logz)
print("Execution time in seconds (s):", end - begin)
