import numpy as np
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components, floyd_warshall

import sys
from pathlib import Path
parent_folder = Path(__file__).resolve().parent.parent
sys.path.append(str(parent_folder))

from utils import save_dict

sys.path.remove(str(parent_folder))

##############################################################################

def is_stochastic(P:np.ndarray)->bool:
    return np.allclose(P.sum(axis=1), 1)

def is_irreducible(P:np.ndarray)-> bool:
    graph = csr_matrix(P > 0)
    n_components = connected_components(csgraph=graph, directed=True, return_labels=False, connection="strong")
    return n_components == 1

def is_aperiodic(P: np.ndarray) -> bool:
    distance_states = floyd_warshall(P, directed=True, unweighted=True)
    loop_distances = distance_states + distance_states.T + np.diag(np.diag(P) > 0)
    # Use NumPy's gcd.reduce to compute GCDs along axis 1 (rows) and get the maximum GCD value
    period_max = np.max(np.gcd.reduce(loop_distances.astype(int), axis=1))
    return period_max == 1

def is_ergodic(P:np.ndarray)->bool:
    if not is_stochastic(P):
        raise ValueError("The input matrix is not a stochastic matrix.")
    return is_irreducible(P) and is_aperiodic(P)
  
####################################################################


def generate_c_graph_with_random_nb_neighbors(num_nodes: int, proba_max_add_nodes: float = 0.5, random_proba_add_nodes: bool = True) -> np.ndarray:
    if num_nodes < 2:
        raise ValueError("Number of nodes must be at least 2.")
    if not (0 <= proba_max_add_nodes <= 1):
        raise ValueError("proba_max_add_nodes must be between 0 and 1.")
    directed_graph = np.zeros((num_nodes, num_nodes), dtype=int)
    # Create a path connecting all nodes
    directed_graph[:-1, 1:] = np.eye(num_nodes - 1, dtype=int)
    # Determine the probability for additional edges for each node
    proba_edges = np.random.uniform(0, proba_max_add_nodes, num_nodes) if random_proba_add_nodes else np.full(num_nodes, proba_max_add_nodes)
    # Add additional edges based on determined probabilities
    for i in range(num_nodes):
        additional_edges = (np.random.rand(num_nodes) < proba_edges[i]).astype(int)
        directed_graph[i] = np.logical_or(directed_graph[i], additional_edges).astype(int)
    # Ensure the last node is connected
    if np.count_nonzero(directed_graph[-1]) == 0:
        directed_graph[-1, np.random.randint(num_nodes)] = 1
    return directed_graph


def generate_aperiodic_graph_with_random_nb_neighbors(num_nodes:int, proba_max_add_nodes:float=0.5, random_proba_add_nodes:bool=True, verbose:bool=True)->np.ndarray:
    graph=generate_c_graph_with_random_nb_neighbors(num_nodes=num_nodes, proba_max_add_nodes=proba_max_add_nodes, random_proba_add_nodes=random_proba_add_nodes)
    i=1
    while not(is_aperiodic(graph) and is_irreducible(graph)):
        if verbose:
            print("\rTry {}".format(i), end="")
        graph=generate_c_graph_with_random_nb_neighbors(num_nodes=num_nodes, proba_max_add_nodes=proba_max_add_nodes, random_proba_add_nodes=random_proba_add_nodes)
        i+=1
    if verbose:
        print("\n")
    return graph
        
####################################################################

def create_markov_chain_from_graph(graph: np.ndarray, proba_min_one_transition:float=0.8, random_proba_min_one_transition:bool=True) -> np.ndarray:
    if graph.shape[0] != graph.shape[1]:
        raise ValueError("The input adjacency matrix must be square.")
    num_nodes = graph.shape[0]
    # Initialize the transition matrix
    P = np.zeros((num_nodes, num_nodes), dtype=np.float32)
    for i in range(num_nodes):
        # Get the indices of the connected nodes
        connected_nodes = np.nonzero(graph[i])[0]
        if connected_nodes.size == 1:
            P[i, connected_nodes[0]] = 1.0
        else:
            proba_transition=np.zeros((connected_nodes.size,), dtype=np.float32)
            if proba_min_one_transition is None:
                proba_transition=np.random.rand(connected_nodes.size)
            else:
                # Assign probabilities to the connected nodes
                if random_proba_min_one_transition:
                    proba_transition[0] = np.random.uniform(low=proba_min_one_transition)
                else:
                    proba_transition[0] = proba_min_one_transition
                proba_transition[1:] = np.random.uniform(size=connected_nodes.size - 1)
                proba_transition[1:] = (1 - proba_transition[0]) * proba_transition[1:] / proba_transition[1:].sum()
            proba_transition/=proba_transition.sum()
            np.random.shuffle(proba_transition)
            P[i, connected_nodes] = proba_transition
    return P

def create_ergodic_mc_with_random_nb_neighbors_states(num_states:int, proba_max_add_states:float=0.5, random_proba_add_states:bool=True,
                                proba_min_one_transition:float=0.8, random_proba_min_one_transition:bool=True,  verbose:bool=True) -> np.ndarray:
    graph=generate_aperiodic_graph_with_random_nb_neighbors(num_nodes=num_states, proba_max_add_nodes=proba_max_add_states, random_proba_add_nodes=random_proba_add_states, verbose=verbose)
    return create_markov_chain_from_graph(graph=graph, proba_min_one_transition=proba_min_one_transition, random_proba_min_one_transition=random_proba_min_one_transition)
        
####################################################################
    
if __name__ == "__main__":
    
    num_states=500
    proba_max_add_states=0.5
    random_proba_min_one_transition=None
    P=create_ergodic_mc_with_random_nb_neighbors_states(num_states=num_states, 
                                                        proba_max_add_states=proba_max_add_states, 
                                                        random_proba_add_states=True,
                                                        proba_min_one_transition=random_proba_min_one_transition)
    print("Transition Probablity Matrix generated: \n", P)
    print("\nGraph: \n",(P>0).astype(int))
    print("\nIs irreductible ? ", is_irreducible(P=P))
    print("Is aperiodic ? ", is_aperiodic(P=P))
    print("Is ergodic ? ", is_ergodic(P=P))
    save_dict(path="env.json", data=dict(P=P))