import numpy as np
from emdp.gridworld import build_simple_grid

from ergodic_mc import is_ergodic, is_irreducible, is_aperiodic

import sys
from pathlib import Path
parent_folder = Path(__file__).resolve().parent.parent
sys.path.append(str(parent_folder))

from utils import save_dict

sys.path.remove(str(parent_folder))

########################################################################

def get_transition_matrix(Psas: np.ndarray, policy: np.ndarray) -> np.ndarray:
    return np.einsum('ik,ikj->ij', policy, Psas)

def generate_transition_probability_matrix(n:int, p_success:float=0.99)->np.ndarray:
    Psas = build_simple_grid(size=n, p_success=p_success)
    num_states, num_actions = Psas.shape[0], Psas.shape[1]
    policy = np.random.rand(num_states, num_actions)
    policy = policy/policy.sum(1, keepdims=True)
    return get_transition_matrix(Psas=Psas, policy=policy)


########################################################################


if __name__ == "__main__":
    n=12
    p_success=0.99
    print("Generating transition probabilities...")
    P=generate_transition_probability_matrix(n=n, p_success=p_success)
    print("Number of states: ", P.shape[0])
    print("Is Irreductible ? ", is_irreducible(P))
    print("Is Aperiodic ? ", is_aperiodic(P))
    print("Is Ergodic ? ", is_ergodic(P))
    save_dict(path="res/gridworld/env.json", data=dict(P=P))
    
    
     