import numpy as np
from collections import defaultdict 
from model_checking import pctl
from .dynamics_models import Transition_Matrix, Tabular_Dynamics

from .jax_helpers import to_jnp

"""numpy helper functions"""

def construct_product_state_space(n_states, n_automaton_states):
    """return the product state space with shape (n_automaton_states, n_states,)"""
    return np.reshape(np.arange(n_states * n_automaton_states), (n_automaton_states, n_states,))

def compute_product_state_action_transition_matrix(state_action_transition_matrix, automaton, labelling_fn):
    """computes the product state action transition matrix from the state action transition matrix, and a given automaton and labelling function"""
    automaton_states = automaton.states
    n_automaton_states = len(automaton.states)
    n_states = state_action_transition_matrix.shape[0]
    n_actions = state_action_transition_matrix.shape[2]
    n_prod_state_space = n_states * n_automaton_states
    product_state_action_transition_matrix = np.zeros((n_prod_state_space, n_prod_state_space, n_actions))
    
    for i, i_state in enumerate(automaton_states):
        # initialise the loop relation for automaton state i
        i_loop_sat_relation = np.ones(n_states)
        for j, j_state in enumerate(automaton_states):
            if i == j:
                continue
            if automaton.has_edge(i_state, j_state):
                # compute the i to j transition relation for automaton states i and j
                i_j_sat_relation = np.array([automaton.edges[i_state][j_state].sat(labelling_fn[s]) for s in range(n_states)], dtype=int)
                # remove the corresponding states from the loop relation for automaton state i
                i_loop_sat_relation = np.clip(i_loop_sat_relation - i_j_sat_relation, 0, 1)
                # update the product state action transition matrix with the i to j transition relation
                product_state_action_transition_matrix[j *n_states:(j+1)* n_states, i *n_states:(i+1)*n_states, :] =\
                i_j_sat_relation[:, np.newaxis, np.newaxis] * state_action_transition_matrix

        # update the product state action transition matrix with what remains of the loop relation for automaton state i
        product_state_action_transition_matrix[i *n_states:(i+1)* n_states, i *n_states:(i+1)*n_states, :] = \
        i_loop_sat_relation[:, np.newaxis, np.newaxis] * state_action_transition_matrix
        
    return product_state_action_transition_matrix

def make_informative_prior(grid_size, n_states, n_actions):

    assert n_states == grid_size**2 

    # make grid
    grid = np.zeros((grid_size, grid_size), dtype=int)
    for y in range(grid_size):
        grid[y] = np.arange(grid_size) + y*grid_size

    # make action map
    action_map = {0: (0, -1), # left
                  1: (0, 1), # right
                  2: (1, 0), # up
                  3: (-1, 0), # down
                  4: (0, 0), # stay
                  5: (-1, -1), # left up
                  6: (-1, 1), # left down
                  7: (-1, 1), # right up
                  8: (1, 1), # right down
                  }

    assert n_actions < len(action_map.keys())

    # make prior
    prior = np.zeros((n_states, n_states, n_actions))

    for y in range(grid_size):
        for x in range(grid_size):
            for a in range(n_actions):
                state = grid[y][x]
                next_y = int(np.clip(y + action_map[a][0], 0, grid_size-1))
                next_x = int(np.clip(x + action_map[a][1], 0, grid_size-1))
                next_state = grid[next_y, next_x]
                if prior[next_state, state, a] == 0.0:
                    prior[next_state, state, a] = 1.0

    #prior[:, grid[-1][-1], :] = np.zeros_like(prior[:, grid[-1][-1], :])
    #prior[grid[0][0], grid[-1][-1], :] = np.ones_like(prior[grid[0][0], grid[-1][-1], :])

    return prior

def make_uninformative_prior(grid_size, n_states, n_actions):

    assert n_states == grid_size**2 

    # make grid
    grid = np.zeros((grid_size, grid_size), dtype=int)
    for y in range(grid_size):
        grid[y] = np.arange(grid_size) + y*grid_size

    # make action map
    action_map = {0: (0, -1), # left
                  1: (0, 1), # right
                  2: (1, 0), # up
                  3: (-1, 0), # down
                  4: (0, 0), # stay
                  5: (-1, -1), # left up
                  6: (-1, 1), # left down
                  7: (-1, 1), # right up
                  8: (1, 1), # right down
                  }

    assert n_actions < len(action_map.keys())

    # make prior
    prior = np.zeros((n_states, n_states, n_actions))

    for y in range(grid_size):
        for x in range(grid_size):
            for a in range(n_actions):
                state = grid[y][x]

                for rand_a in range(n_actions):
                    next_y = int(np.clip(y + action_map[rand_a][0], 0, grid_size-1))
                    next_x = int(np.clip(x + action_map[rand_a][1], 0, grid_size-1))
                    next_state = grid[next_y, next_x]

                    if prior[next_state, state, a] == 0.0:
                        prior[next_state, state, a] = 1.0

    #prior[:, grid[-1][-1], :] = np.zeros_like(prior[:, grid[-1][-1], :])
    #prior[grid[0][0], grid[-1][-1], :] = np.ones_like(prior[grid[0][0], grid[-1][-1], :])

    return prior

"""helper functions for setting up the model checker and dynamics model"""

def init_model_checker(env, properties, device, model_checking_type='exact', satisfaction_probability=0.9, shielding_type='task'):
    n_states = env.n_states
    n_actions = env.n_actions
    labelling_fn = env.labelling_fn
    atomic_predicates = env.atomic_predicates
    state_action_transition_matrix = env.transition_matrix
    automaton = properties.automaton
    prod_state_space = construct_product_state_space(n_states, len(automaton.states))

    if shielding_type in ['action_cond_safe', 'task_prod']:
        # redefine the labelling function on the product state space
        def empty_set():
            return {}
        labelling_fn = defaultdict(empty_set) 
        accepting = automaton.accepting
        for s in prod_state_space[accepting, :].flatten():
            labelling_fn[s.item()] = ({'accepting'})
        # define the vectorized labelling function on the product state space
        atomic_predictate_map = {'accepting' : 0}
        vec_labelling_fn = to_jnp([[1.0 if s in prod_state_space[accepting, :].flatten() else 0.0 for s in prod_state_space.flatten()]])
        if model_checking_type in ['exact', 'mc']:
            formula = properties.product_pctl_property(prob=satisfaction_probability)
        else:
            raise NotImplementedError
    else:
        raise NotImplementedError

    if model_checking_type == 'exact':
        return pctl.Exact_Model_Checker(device, formula, vec_labelling_fn, atomic_predictate_map)
    elif model_checking_type == 'mc':
        return pctl.Monte_Carlo_Model_Checker(device, formula, vec_labelling_fn, atomic_predictate_map)
    else:
        raise NotImplementedError

def init_dynamics_model(env, automaton, approximate_model=False, shielding_type="task", prior_type="none"):
    if approximate_model:
        if prior_type == "none":
            prior = None
        elif prior_type == "uninformative":
            prior = make_uninformative_prior(env.grid_size, env.n_states, env.n_actions)
        elif prior_type == "informative":
            prior = make_informative_prior(env.grid_size, env.n_states, env.n_actions)
        else:
            raise NotImplementedError
        if shielding_type in ['action_cond_safe', 'task_prod']:
            n_states = env.n_states * len(automaton.states)
            if prior is not None:
                product_prior = compute_product_state_action_transition_matrix(prior, automaton, env.labelling_fn)
            else:
                product_prior = None
            return Tabular_Dynamics(n_states, env.n_actions, prior=product_prior)
        else:
            raise NotImplementedError
    else:
        if shielding_type in ['action_cond_safe', 'task_prod']:
            labelling_fn = env.labelling_fn
            state_action_transition_matrix = env.transition_matrix
            # compute the product state action transition matrix
            product_state_action_transition_matrix = compute_product_state_action_transition_matrix(state_action_transition_matrix, automaton, labelling_fn)
            return Transition_Matrix(product_state_action_transition_matrix)
        else:
            raise NotImplementedError

    


    