from .jax_helpers import *
from .abstract import PCTL_Model_Checker
from .matrix_sampler import Matrix_Sampler

import jax.numpy as jnp
import jax
from jax import lax

@jit_decorator
def local_next(one_hot_matrix, mapped_next_states, subformula_sat):
    """jitted helper function for computing the satisfaction set for Next"""
    one_hot_encoded_states = one_hot_matrix[mapped_next_states]
    next_state_sat = jnp.matmul(one_hot_encoded_states[:, :, 1, :], subformula_sat)
    return jnp.mean(jnp.array(next_state_sat >= 1.0).astype(jnp.float32), axis=1)

class Next:

    """Next: satisifed if the subformula is satisfied with probability > prob in the next state

    Input attributes:
        prob: the probability of satisfaction must be greater than this
        subformula: the subformula to check in the next state
    """
    
    def __init__(self, prob, subformula):
        self.prob = prob
        self.subformula = subformula

    def _bound(self):
        return 1 + self.subformula._bound()
        
    def _exact_sat(self, **kwargs):
        """using the state transition matrix, compute the satisfaction sat for the Next operator on the entire state space"""
        matrix = kwargs['matrix']
        cond_action_matrix = kwargs['cond_action_matrix']

        if cond_action_matrix is not None:
            # if the cond_action_matrix is given then use it to compute the satisfaction set
            kwargs['cond_action_matrix'] = None
            return (jnp.matmul(self.subformula._exact_sat(**kwargs), cond_action_matrix) >= self.prob).astype(jnp.float32)
        else:
            return (jnp.matmul(self.subformula._exact_sat(**kwargs), matrix) >= self.prob).astype(jnp.float32)

    def _local_sat(self, states, **kwargs):
        """by sampling from the model, estimate the satisfaction set for the local set of states"""

        num_samples = kwargs['num_samples']
        model = kwargs['model']
        cond_actions = kwargs['cond_actions']
        key = kwargs['key']

        path_length = 2
        # sample a batch of paths from each start state
        key, next_states = model.sample_paths(key, states, num_samples, path_length, cond_actions=cond_actions)
                
        # compute the unique set of states
        flattened_next_states = next_states.flatten()
        unique_states, indices = jnp.unique(flattened_next_states,return_inverse=True)
        
        # throw away the conditional actions if any
        kwargs['cond_actions'] = None
        # set the new key
        kwargs['key'] = key

        # compute the satisfaction set of the subformula for the unique set of states
        key, subformula_sat = self.subformula._local_sat(unique_states, **kwargs)

        # map the next states to (0, unique_states.shape[0])
        flattened_mapped_next_states = jnp.arange(0, unique_states.shape[0])[indices]
        mapped_next_states = flattened_mapped_next_states.reshape(next_states.shape)
        one_hot_matrix = jnp.eye(unique_states.shape[0])
        
        # compute the satisfaction probabilities using the jitted helper function
        sat_probs = local_next(one_hot_matrix, mapped_next_states, subformula_sat)
        return key, jnp.array(sat_probs > self.prob).astype(jnp.float32)
        
class Always:

    """Always: satisifed if the subformula holds for the next n time steps with probability > prob

    Input attributes:
        prob: the probability of satisfaction must be greater than this
        bound: number of time steps n
        subformula: the subformula to check in the next state
    """
    
    def __init__(self, prob, bound, subformula):
        self.prob = prob
        self.bound = bound
        self.subformula = subformula

    def _bound(self):
        return self.bound + self.subformula._bound()
        
    def _exact_sat(self, **kwargs):
        """compute the satisfaction set on the entire state space"""
        return Neg(Until(1.0 - self.prob, self.bound, Truth(), Neg(self.subformula)))._exact_sat(**kwargs)

    def _local_sat(self, states, **kwargs):
        """estimate the satisfaction set for the for the local set of states"""
        return Neg(Until(1.0 - self.prob, self.bound, Truth(), Neg(self.subformula)))._local_sat(states, **kwargs)

class Eventually:

    """Eventually: satisifed if the subformula holds at some point in the next n time steps with probability > prob

    Input attributes:
        prob: the probability of satisfaction must be greater than this
        bound: number of time steps n
        subformula: the subformula to check in the next state
    """
    
    def __init__(self, prob, bound, subformula):
        self.prob = prob
        self.bound = bound
        self.subformula = subformula

    def _bound(self):
        return self.bound + self.subformula._bound()
        
    def _exact_sat(self, **kwargs):
        """compute the satisfaction set on the entire state space"""
        return Until(self.prob, self.bound, Truth(), self.subformula)._exact_sat(**kwargs)

    def _local_sat(self, states, **kwargs):
        """estimate the satisfaction set for the for the local set of states"""
        return Until(self.prob, self.bound, Truth(), self.subformula)._local_sat(states, **kwargs)
    
@jit_decorator
def exact_until(matrix, bound, subformula_1_sat, subformula_2_sat, sat_probs, cond_action_matrix=None):
    """jitted helper function for computing the satisfaction set for Until"""
    S_1 = subformula_2_sat
    S_Q = jnp.clip(subformula_1_sat - S_1, 0.0, 1.0)

    next_sat_probs = jnp.matmul(subformula_2_sat, matrix) * S_Q

    def loop_body(args):
        i, sat_probs = args
        return i + 1, jnp.matmul(sat_probs, matrix) * S_Q + next_sat_probs 

    if cond_action_matrix is not None:
        sat_probs = jnp.matmul(sat_probs, cond_action_matrix) * S_Q + next_sat_probs
        _, sat_probs = lax.while_loop(lambda args : args[0] < bound, loop_body, (1, sat_probs))
    else:
        _, sat_probs = lax.while_loop(lambda args : args[0] < bound, loop_body, (0, sat_probs))

    return sat_probs + S_1

@jit_decorator
def local_until(one_hot_matrix, mapped_paths, bound, subformula_1_sat, subformula_2_sat):
    """jitted helper function for computing the satisfaction set for Until"""
    one_hot_encoded_states = one_hot_matrix[mapped_paths]

    paths_sat_1 = jnp.matmul(one_hot_encoded_states[:, :, :, :], subformula_1_sat)
    paths_sat_2 = jnp.matmul(one_hot_encoded_states[:, :, :, :], subformula_2_sat)
    
    paths_sat = paths_sat_2[:, :, 0]

    def loop_body_2(args):
        j, prod = args
        return j+1, prod * paths_sat_1[:, :, j]

    def loop_body_1(args):
        i, paths_sat = args
        #paths_sat_1_prod= jnp.prod(paths_sat_1[:, :, :i], axis=2)
        _, paths_sat_1_prod = lax.while_loop(lambda args: args[0] < i, loop_body_2, (1, paths_sat_1[:, :, 0]))
        return i+1, paths_sat + paths_sat_1_prod * paths_sat_2[:, :, i]

    _, paths_sat = lax.while_loop(lambda args : args[0] < bound, loop_body_1, (1, paths_sat))
            
    return jnp.mean(jnp.array(paths_sat >= 1.0).astype(jnp.float32), axis=1)

class Until:

    """Until: satisifed if subformula_1 holds untile subformula_2 holds at some point in the next n time steps with probability > prob

    Input attributes:
        prob: the probability of satisfaction must be greater than this
        bound: number of time steps n
        subformula: the subformula to check in the next state
    """
    
    def __init__(self, prob, bound, subformula_1, subformula_2):
        self.prob = prob
        self.bound = bound
        self.subformula_1 = subformula_1
        self.subformula_2 = subformula_2

    def _bound(self):
        return self.bound + self.subformula_1.bound() + self.subformula2._bound()
        
    def _exact_sat(self, **kwargs):
        """using the state transition matrix, compute the satisfaction sat on the entire state space"""

        matrix = kwargs['matrix']
        cond_action_matrix = kwargs['cond_action_matrix']
        num_states = matrix.shape[0]
        
        # throw away the conditional actions if any
        kwargs['cond_action_matrix'] = None

        # compute the satisfaction set of the two subformula
        subformula_1_sat = self.subformula_1._exact_sat(**kwargs)
        subformula_2_sat = self.subformula_2._exact_sat(**kwargs)

        # compute the satisfaction probabilities using the jitted helper function
        sat_probs = exact_until(matrix, to_jnp(self.bound), subformula_1_sat, subformula_2_sat, jnp.zeros(num_states), cond_action_matrix=cond_action_matrix)
        
        return (sat_probs >= self.prob).astype(jnp.float32)

    def _local_sat(self, states, **kwargs):
        """by sampling from the model, estimate the satisfaction set for the local set of states"""

        num_samples = kwargs['num_samples']
        model = kwargs['model']
        cond_actions = kwargs['cond_actions']
        key = kwargs['key']

        path_length = self.bound + 1
        # sample a batch of paths from each start state
        key, paths = model.sample_paths(key, states, num_samples, path_length, cond_actions=cond_actions)
                
        # compute the unique set of states
        flattened_paths = paths.flatten()
        unique_states, indices = jnp.unique(flattened_paths,return_inverse=True)
        
        # throw away the conditional actions if any
        kwargs['cond_actions'] = None
        # set the new key
        kwargs['key'] = key

        # compute the satisfaction set of the two subformula
        key, subformula_1_sat = self.subformula_1._local_sat(unique_states, **kwargs)
        # set the next key
        kwargs['key'] = key
        key, subformula_2_sat = self.subformula_2._local_sat(unique_states, **kwargs)

        # map the paths to (0, unique_states.shape[0])
        flattened_mapped_paths = jnp.arange(0, unique_states.shape[0])[indices]
        mapped_paths = flattened_mapped_paths.reshape(paths.shape)
        one_hot_matrix = jnp.eye(unique_states.shape[0])
        
        # compute the satisfaction probabilities using the jitted helper function
        sat_probs = local_until(one_hot_matrix, mapped_paths, to_jnp(self.bound+1), subformula_1_sat, subformula_2_sat)
        
        return key, jnp.array(sat_probs > self.prob).astype(jnp.float32)

class Truth:
    """Truth: always satisfied"""
    
    def __init__(self):
        pass

    def _bound(self):
        return 0
    
    def _exact_sat(self, **kwargs):
        """compute the satisfaction set on the entire state space"""
        num_states = kwargs['matrix'].shape[0]
        return jnp.ones(num_states).astype(jnp.float32)

    def _local_sat(self, states, **kwargs):
        """compute the satisfaction set for the for the local set of states"""
        return kwargs['key'], jnp.ones_like(states)

class Atom:
    """Atom: satisfied for each state that has atom in its set of labels"""
    
    def __init__(self, atom):
        self.atom = atom

    def _bound(self):
        return 0
        
    def _exact_sat(self, **kwargs):
        """compute the satisfaction set on the entire state space"""
        num_states = kwargs['matrix'].shape[0]
        labelling_fn = kwargs['labelling_fn']
        atomic_predicate_map = kwargs['atomic_predicate_map']
        # return the vectorized labelling function at the corresponding index in atomic_predicate_map
        return labelling_fn[atomic_predicate_map[self.atom]]

    def _local_sat(self, states, **kwargs):
        """compute the satisfaction set for the for the local set of states"""
        labelling_fn = kwargs['labelling_fn']
        atomic_predicate_map = kwargs['atomic_predicate_map']
        # return the vectorized labelling function at the corresponding index in atomic_predicate_map
        # index the vectorized labbeling function at the input states
        return kwargs['key'], labelling_fn[atomic_predicate_map[self.atom]][states]
    
class Neg:
    """Negation: satisfied by each state that doesn't satisfy the subformula"""
    
    def __init__(self, subformula):
        self.subformula = subformula

    def _bound(self):
        return self.subformula._bound()
        
    def _exact_sat(self, **kwargs):
        """compute the satisfaction set on the entire state space"""
        return 1.0 - self.subformula._exact_sat(**kwargs)

    def _local_sat(self, states, **kwargs):
        """compute the satisfaction set for the for the local set of states"""
        key, subformula_sat = self.subformula._local_sat(states, **kwargs)  
        return key, 1.0 - subformula_sat

class And:
    """And: satisfied by each state that satisfies both subformulae"""
    
    def __init__(self, subformula_1, subformula_2):
        self.subformula_1 = subformula_1
        self.subformula_2 = subformula_2

    def _bound(self):
        return max(self.subformula_1._bound(), self.subformula_2._bound())
        
    def _exact_sat(self, **kwargs):
        """compute the satisfaction set on the entire state space"""
        return self.subformula_1._exact_sat(**kwargs) * self.subformula_2._exact_sat(**kwargs)

    def _local_sat(self, states, **kwargs):
        """compute the satisfaction set for the for the local set of states"""
        key, subformula_1_sat = self.subformula_1._local_sat(states, **kwargs)
        kwargs['key'] = key
        key, subformula_2_sat = self.subformula_2._local_sat(states, **kwargs)
        return  key, subformula_1_sat * subformula_2_sat
    
class Or:
    """Or: satisifed by each state that satisfies either of the subformulae"""
    
    def __init__(self, subformula_1, subformula_2):
        self.subformula_1 = subformula_1
        self.subformula_2 = subformula_2

    def _bound(self):
        return max(self.subformula_1._bound(), self.subformula_2._bound())
        
    def _exact_sat(self, **kwargs):
        """compute the satisfaction set on the entire state space"""
        return Neg(And(Neg(self.subformula_1), Neg(self.subformula_2)))._exact_sat(**kwargs)

    def _local_sat(self, states, **kwargs):
        """compute the satisfaction set for the for the local set of states"""
        return Neg(And(Neg(self.subformula_1), Neg(self.subformula_2)))._local_sat(states, **kwargs)

class Implies:
    """Implies: satisfied by each state that satisfies subformula_2 when subformula_1 is satisfied"""

    def __init__(self, subformula_1, subformula_2):
        self.subformula_1 = subformula_1
        self.subformula_2 = subformula_2

    def _bound(self):
        return max(self.subformula_1._bound(), self.subformula_2._bound())

    def _exact_sat(self, **kwargs):
        """compute the satisfaction set on the entire state space"""
        return Or(Neg(self.subformula_1), self.subformula_2)._exact_sat(**kwargs)

    def _local_sat(self, states, **kwargs):
        """compute the satisfaction set for the for the local set of states"""
        return Or(Neg(self.subformula_1), self.subformula_2)._local_sat(states, **kwargs)

class Exact_Model_Checker(PCTL_Model_Checker):
    """Class decorator for exact model checking PCTL formula

    Input attributes:
        device: JAX device on which to do most of the computation
        formula: the PCTL formula to check
        labelling_fn: vectorized labelling function used for model checking
        atomic_predicate_map: dictionary mapping each atomic predicate to an index of the labelling_fn
    """

    def __init__(self, device, formula, labelling_fn, atomic_predicate_map):
        super().__init__(device, formula, labelling_fn, atomic_predicate_map)

    def check(self, key, policy, state_action_transition_matrix, state, action, **kwargs):
        """check if the formula is satisfied from the given state
        
        Inputs:
            policy: the policy to check
            state_action_transition_matrix: the transition matrix to check
            state: the state to check from
            action: the conditional action to be played from the state
            kwargs: any additional kwargs
        """
        with jax.default_device(self.device):
            policy, state, state_action_transition_matrix = to_jnp(policy), to_jnp(state), to_jnp(state_action_transition_matrix)
            if action is not None:
                action = to_jnp(action)
                state_transition_matrix, cond_action_matrix = compute_cond_action_matrix(state_action_transition_matrix, policy, state, action)
            else:
                state_transition_matrix = compute_state_transition_matrix(state_action_transition_matrix, policy)
                cond_action_matrix = None

            kwargs = {'matrix': state_transition_matrix,
                      'labelling_fn': self.labelling_fn,
                      'atomic_predicate_map': self.atomic_predicate_map,
                      'cond_action_matrix': cond_action_matrix}

            return key, self.formula._exact_sat(**kwargs)[state]

class Monte_Carlo_Model_Checker(PCTL_Model_Checker):

    """Class decorator for Monte Carlo model checking PCTL formula
    
    Input attributes:
        device: JAX device on which to do most of the computation
        formula: the PCTL formula to check
        labelling_fn: vectorized labelling function used for model checking
        atomic_predicate_map: dictionary mapping each atomic predicate to an index of the labelling_fn
    """

    def __init__(self, device, formula, labelling_fn, atomic_predicate_map):
        super().__init__(device, formula, labelling_fn, atomic_predicate_map)

    def check(self, key, policy, state_action_transition_matrix, state, action, num_samples=512, **kwargs):
        """check if the formula is satisfied from the given state
        
        Inputs:
            policy: the policy to check
            state_action_transition_matrix: the transition matrix to check
            state: the state to check from
            action: the conditional action to be played from the state
            num_samples: number of samples to take during model checking
            kwargs: any additional kwargs
        """
        with jax.default_device(self.device):
            policy, state, state_action_transition_matrix = to_jnp(policy), to_jnp([state]), to_jnp(state_action_transition_matrix)
            action = to_jnp([action]) if action is not None else None
            matrix_sampler = Matrix_Sampler(state_action_transition_matrix, policy)

            """check whether a given state satisfies the pctl property"""
            kwargs = {'model': matrix_sampler,
                      'num_samples' : num_samples,
                      'labelling_fn': self.labelling_fn,
                      'atomic_predicate_map': self.atomic_predicate_map,
                      'cond_actions': action,
                      'key': key}    
            return self.formula._local_sat(state, **kwargs)

        
        

        
        




        