# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from types import SimpleNamespace
from typing import NamedTuple, Optional, Sequence, Tuple

import chex
import jax
import jax.numpy as jnp
from jax import lax

from jumanji import specs
from jumanji.env import Environment
from jumanji.types import TimeStep, restart, termination, transition
from jumanji.viewer import Viewer
from jumanji.wrappers import AutoResetWrapper
from enzyme.src.bayes.actor import GenericActor, RandomActor, make_actor
from enzyme.src.bayes.config import THETAS

from enzyme.src.mcmc import find_end_N_cons_GOs, generate_trajs_nb, get_poisson_trials


dt = 1.

ITI = 0.
reg = 1e-10

theta_safe = .0005
N_trials = 10000

@chex.dataclass
class State:
    """
    s:
    x:
    theta:
    key: random key used to generate random numbers at each step and for auto-reset.
    """

    s: chex.Scalar
    x: chex.Scalar
    r: chex.Scalar
    theta: chex.Scalar
    key: chex.PRNGKey  # (2,)
    step_count: chex.Numeric  # ()


class Observation(NamedTuple):
    x: chex.Scalar 

class TransitionAtoms(NamedTuple):
    T_S: chex.Array
    T_theta: chex.Array
    T_0: chex.Array
    T_1: chex.Array
    llh: chex.Array


def construct_transition_atoms(lmbd_s, lmbd_theta, theta_safe):
    llh_0 = jnp.array([1 - THETAS*dt, jnp.full_like(THETAS, theta_safe*dt)])
    llh_1 = jnp.array([THETAS*dt, jnp.full_like(THETAS, 1.-theta_safe*dt)])
    llh = jnp.stack([llh_0, llh_1], axis=0)

    T_theta = jnp.eye(THETAS.size) * (1 - lmbd_theta*dt) + jnp.ones((THETAS.size, THETAS.size)) * lmbd_theta*dt / THETAS.size
    T_S = jnp.array(
                    [[1 - lmbd_s*dt, 0], 
                    [lmbd_s*dt,      1]]
                    )
    
    llh = llh
    T_theta = T_theta
    T_S = T_S
    T_1 = jnp.einsum('SG,Ss,Gg->SGsg', llh[1], T_S, T_theta)
    T_0 = jnp.einsum('SG,Ss,Gg->SGsg', llh[0], T_S, T_theta)

    return TransitionAtoms(T_S=T_S, T_theta=T_theta, T_0=T_0, T_1=T_1, llh=llh)


class MouseWorld(Environment[State]):
    def __init__(
        self, lmbd_s=.1, lmbd_theta=.01, ITI=ITI, r_rat=.1, gamma=1., episode_length=100, THETAS=THETAS, backend="thy"
    ) -> None:

        self.lmbd_s = lmbd_s
        self.lmbd_act = lmbd_s
        self.lmbd_theta = lmbd_theta
        self.theta_safe = theta_safe
        self.ITI = ITI

        self.tau_safe = self.lmbd_s**-1
        self.tau_act = (self.lmbd_act + 1e-10)**-1
        self.tau_theta = (self.lmbd_theta + 1e-10)**-1

        self.r_rat = r_rat
        R_SUCCESS = 1.
        R_FAIL = - R_SUCCESS * r_rat
        self.R_SUCCESS = R_SUCCESS
        self.R_FAIL = R_FAIL

        self.gamma = gamma

        self.THETAS = THETAS

        self.episode_length = episode_length
        self.backend = backend

        # transition matrices and observation generation
        # The likelihood can be viewed as a diagonal matrix with a multiindex (s,theta; s',theta')
        self.T_S, self.T_theta, self.T_0, self.T_1, self.llh = construct_transition_atoms(self.lmbd_s, self.lmbd_theta, self.theta_safe)
        

        # some analytical probs
        @jnp.vectorize
        def p_s0__x1_(n, theta):
            a, b, c =  lmbd_s, (1-lmbd_s)*theta, (1-lmbd_s)*(1-theta)
            sum_n_ = lax.fori_loop(0, n, lambda n_, val: val + c*b**n_, 0.0)
            return (b**n) / (1 - sum_n_)
        
        @jnp.vectorize
        def p_s1__x1_(t, theta):
            return 1. - p_s0__x1_(t, theta)

        # ensure that the probabilities sum to 1
        def p_s0__x1_thy(t, theta):
            return p_s0__x1_(t, theta) / (p_s0__x1_(t, theta) + p_s1__x1_(t, theta))

        def p_s1__x1_thy(t, theta):
            return p_s1__x1_(t, theta) / (p_s0__x1_(t, theta) + p_s1__x1_(t, theta))
        
        
        @jnp.vectorize
        def p_s1_thy(t):
            return jnp.where(t >= ITI, 1. - jnp.exp(-lmbd_s*(t - ITI)), 0.)
        
        def p_s1_mc(t):
            x, s = get_poisson_trials(N_trials, int(self.lmbd_s**-1*10), theta=0, lmbd_s=self.lmbd_s)
            return jnp.mean(s, axis=0)[t]
        
        @jnp.vectorize
        def t_N_cons_mc(t, theta):
            x, s = get_poisson_trials(N_trials, int(self.lmbd_s**-1*10), theta, lmbd_s)
            return find_end_N_cons_GOs(x, t).mean(0)
        
        @jnp.vectorize
        def p_s0__x1_mc(t, theta):
            x, s = get_poisson_trials(N_trials, int(self.lmbd_s**-1*10), theta, lmbd_s)
            # get the index where x is 1 for t steps
            idx = jnp.where(jnp.sum(x[:, :t], axis=1) == t)[0]
            # get the corresponding s
            ps = s[idx, t].mean()
            return 1. - ps
        
        def p_s1__x1_mc(t, theta):
            return 1 - p_s0__x1_mc(t, theta)
        

        @jnp.vectorize
        # @lru_cache
        def t_N_cons_thy(n, theta):
            a, b, c =  lmbd_s, (1-lmbd_s)*theta, (1-lmbd_s)*(1-theta)
            sum_a = lax.fori_loop(0, n, 
                                  lambda i, val: val + b**i*(n*a + c*(i + 1)), 
                                  0.0)
            sum_b = lax.fori_loop(0, n, 
                                    lambda i, val: val + b**i,
                                    0.0)

            return (n*b**n + sum_a) / (1 - c*sum_b)

            # ts = jnp.linspace(0, n, 1000)
            # dt = ts[1] - ts[0]
            # return (n*b**n + jnp.sum(b**ts*(n*a + c*(ts + 1)))*dt) / (1 - c*jnp.sum(b**ts)*dt)

        @jnp.vectorize
        # @lru_cache
        def t_N_cons_inv_thy(n, theta):
            a, b, c = lmbd_s, (1-lmbd_s)*theta, (1-lmbd_s)*(1-theta)
            i = jnp.arange(n)
            from scipy.optimize import root_scalar
            def zero(A_inv):
                lhs = A_inv
                rhs = ((1/n)*b**n + jnp.sum(b**i*((1/n)*a + c*(1/(i + 1 + A_inv**-1)))))
                return lhs - rhs
            n_max = 1/a
            sol = root_scalar(zero, x0=(n_max + n + 2)**-1, x1=1e2)
            assert sol.converged
            A_inv = sol.root
            return A_inv

        

        # make the functions available to the class
        if self.backend == "mc":
            p_s1 = p_s1_mc
            p_s1__x1 = p_s1__x1_mc
            p_s0__x1 = p_s0__x1_mc
            t_N_cons = t_N_cons_mc
            t_N_cons_inv = t_N_cons_inv_thy
        elif self.backend == "thy":
            p_s1 = p_s1_thy
            p_s1__x1 = p_s1__x1_thy
            p_s0__x1 = p_s0__x1_thy
            t_N_cons = t_N_cons_thy
            t_N_cons_inv = t_N_cons_inv_thy
        else:
            raise ValueError("backend not recognized")

        self.p_s1 = p_s1
        self.p_s1__x1 = p_s1__x1
        self.p_s0__x1 = p_s0__x1
        self.t_N_cons = t_N_cons
        self.t_N_cons_inv = t_N_cons_inv

        @jnp.vectorize
        def r_t(t):
            return (1/(t+reg)) * (gamma**t * R_SUCCESS * p_s1(t) + (1. - p_s1(t))*(R_FAIL))

        # https://www.wolframalpha.com/input?i=maximize+%282*%281-e%5E%28-x%29%29+%2B+-5*e%5E%28-x%29%29%2F%28x%29
        # W = lambertw
        # e = jnp.exp(1)
        # def t_star():
        #     return -lmbd_s**-1 * W(z=-1/e * (1 - jnp.abs(r_rat)), k=-1) - 1

        # t_act = lambda t: t - ITI

        def R_t__x1(t, theta):
            return gamma**t * R_SUCCESS * p_s1__x1(t, theta) + p_s0__x1(t, theta)*R_FAIL

        def r_t__X1(t, theta): 
            t_act_ = t_N_cons(t, theta)
            r_t_ = (1/(t_act_ + reg + ITI)) * R_t__x1(t, theta)
            # jax.debug.print("t_act: {}", (t, t_act_, r_t_))
            return r_t_
        

        # R, r = rm, r = rp/rm
        # r_t__X1 = lambda t, theta: (1/(t_act(t) + reg + ITI)) * R*(1 - p_s1__x1_theta(t, theta)*(1 - r))

        def t_star__X1(theta): 
            t_star = jnp.argmax(r_t__X1(jnp.arange(100)[:, None], theta), axis=0)
            # jax.debug.print("t_star: {}", t_star)
            return t_star
        
        self.r_t = r_t

        self.R_t__x1 = R_t__x1
        self.r_t__X1 = r_t__X1
        self.t_star__X1 = t_star__X1


    def observation_spec(self) -> specs.Spec[Observation]:
        return specs.DiscreteArray(2, name="x")

    def action_spec(self) -> specs.DiscreteArray:
        return specs.DiscreteArray(2, name="action")

    def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
        key, _ = jax.random.split(key)
        x = False
        theta = jax.random.choice(key, THETAS)
        s = False
        obs = Observation(x=x)

        state = State(
            s=s,
            x=x,
            r=0., 
            theta=theta,
            key=key,
            step_count=0,
        )

        timestep = restart(observation=obs, extras={})

        return state, timestep

    def step(
        self, state: State, action: chex.Array
    ) -> Tuple[State, TimeStep[Observation]]:
        """Updates the environment state after the agent takes an action.

        Args:
            state: the current state of the environment.
            action: the action taken by the agent.

        Returns:
            state: the new state of the environment.
            timestep: the next timestep.
        """
        key, k_x, k_th, k_lmbd_s, k_lmbd_th = jax.random.split(state.key, 5)
        a = action

        # s update
        s = jax.lax.cond(
            jax.random.uniform(k_lmbd_s) < self.lmbd_s,
            lambda s: True,
            lambda s: s,
            operand=state.s,
        )

        # get reward and update state depending on action
        
        s, r = lax.cond(
            a, 
            lambda sr: (
                lax.cond(sr[0], 
                    lambda _: (False, self.R_SUCCESS), 
                    lambda _: (False, self.R_FAIL), 
                    None
                )
            ), 
            lambda sr: (sr[0], 0.0), 
            (s, 0.0)
        )

        # theta update
        theta = jax.lax.cond(
            jax.lax.lt(jax.random.uniform(k_lmbd_th), self.lmbd_theta),
            lambda x: jax.random.choice(k_th, THETAS),
            lambda x: state.theta,
            operand=None,
        )

        x = jax.lax.cond(
            jax.lax.bitwise_or(s, jax.lax.lt(jax.random.uniform(k_x), theta)),
            lambda x: True,
            lambda x: False,
            operand=None,
        )

        # Build the state.
        state = State(
            s=s,
            x=x,
            r=r,
            step_count=state.step_count + 1,
            theta=theta,
            key=key,
        )

        # Generate the observation from the environment state.
        observation = Observation(
            x=x
        )

        done = False
        extras = dict()
            
        timestep = jax.lax.cond(
            done,
            lambda: termination(
                reward=r,
                observation=observation,
                extras=extras,
            ),
            lambda: transition(
                reward=r,
                observation=observation,
                extras=extras,
            ),
        )

        return state, timestep
    
       


def environment_loop(env, actor, n_steps, n_batch=1):
    key = jax.random.key(0)
    env = AutoResetWrapper(env)     # Automatically reset the environment when an episode terminates

    random_key = jax.random.PRNGKey(0)
    key1, key2 = jax.random.split(random_key)

    state, timestep = env.reset(key=jax.random.PRNGKey(0))
    actor.observe_first(timestep) 

    class RolloutSlice(NamedTuple):
        x: chex.Array
        r: chex.Array
        s: chex.Array
        theta: chex.Array
        a: chex.Array
        p_S_TH: chex.Array

    def step_fn(state, key):
        world_state, actor_state = state
        x, r, s, theta = world_state.x, world_state.r, world_state.s, world_state.theta
        action,  actor_state_ = actor._policy([], observation=(x, r, s, theta), state=actor_state)
        world_state_, timestep = env.step(world_state, action)
        rollout_slice = RolloutSlice(x=x, r=r, s=s, theta=theta, a=action, p_S_TH=actor_state_.p_S_TH, )
        return (world_state_, actor_state_), rollout_slice

    def run_n_steps(state_0, key, n):
        random_keys = jax.random.split(key, n)
        (world_state_f, actor_state_f), rollout = jax.lax.scan(step_fn, state_0, random_keys)
        return rollout

    # Instantiate a batch of environment states
    keys_env = jax.random.split(key1, n_batch)
    env_state_0, timestep = jax.vmap(env.reset)(keys_env)

    # Collect a batch of rollouts
    keys_actor = jax.random.split(key2, n_batch)
    actor_state_0 = jax.vmap(actor._init)(keys_actor)
    rollout = jax.vmap(run_n_steps, in_axes=(0, 0, None))((env_state_0, actor_state_0), keys_env, n_steps)

    # Shape and type of given rollout:
    # TimeStep(step_type=(7, 5), reward=(7, 5), discount=(7, 5), observation=(7, 5, 6, 6, 5), extras=None)
    rollout = jax.tree_map(lambda x: x.squeeze(), rollout)
    return rollout



from jax_tqdm import scan_tqdm
if __name__ == "__main__":
    env = MouseWorld()
    actor = RandomActor()     
    actor_core = make_actor(env)
    actor = GenericActor(actor_core, jit=True, random_key=jax.random.PRNGKey(1))
    state, timestep = env.reset(key=jax.random.PRNGKey(0))
    actor.observe_first(timestep)   

    tape_shape = (int(1e6), int(1e1))
    rollout = environment_loop(env, actor, int(1e6))
    rollout