import numpy as np
from sklearn.decomposition import NMF


def compute_random_walk_transmat_sr(n_states: int, gamma: float):
    a = np.zeros((n_states, n_states))
    a[np.arange(n_states), np.arange(n_states)] = .5
    a[np.arange(n_states-1), np.arange(n_states-1)+1] = .25
    a[np.arange(n_states-1)+1, np.arange(n_states-1)] = .25
    transmat = a / a.sum(axis=-1, keepdims=True)
    sr = np.linalg.inv(np.eye(n_states) - gamma * transmat)
    sr_evals, sr_evecs = np.linalg.eig(sr)
    sr_evals = sr_evals* (1 - gamma)
    return transmat, sr, sr_evals, sr_evecs


def compute_random_policy_transmat(mdp):
    transmat = mdp.get_transition_matrix()
    transmat_pi = transmat.mean(axis=1)
    
    return transmat_pi


def compute_sr_matrix(transmat_pi: np.ndarray, gamma: float = 0.99):
    return np.linalg.inv(np.eye(transmat_pi.shape[0]) - gamma * transmat_pi)


def compute_policy_sr(
    env, 
    q_agent, 
    gamma_sr: float = 0.99, 
    softmax: bool = False, 
    softmax_temp: float = 1.0, 
):
    num_states, num_actions = env.num_states, env.num_actions
    P_pi = np.zeros((num_states, num_actions))
    for s in range(num_states):
        try:
            q_values = q_agent.q_values[s]
        except:
            q_values = q_agent.compute_q_values(s)
        if softmax:
            q_values = q_values - np.max(q_values)
            exp_q = np.exp(q_values / softmax_temp)
            action_probs = exp_q / np.sum(exp_q)
        else:
            max_val = np.nanmax(q_values)
            max_actions = np.where(q_values == max_val)[0]
            action_probs = np.zeros(num_actions)
            action_probs[max_actions] = 1.0 / len(max_actions)
        P_pi[s] = action_probs
    
    transmat = env._transition_matrix
    transmat_pi = np.sum(P_pi[..., None] * transmat, axis=1)
    
    sr_pi = compute_sr_matrix(transmat_pi, gamma=gamma_sr)
    
    return sr_pi


def compute_nmf_basis(C_pred, num_cells=32, alpha_W=0.0, lmd_l1=0.0, init='nndsvd', return_H=False):
    C_pred_non_neg = np.clip(C_pred, a_min=0, a_max=None)

    nmf = NMF(n_components=num_cells, init=init, random_state=42, max_iter=5000, alpha_W=alpha_W, l1_ratio=lmd_l1)
    
    W = nmf.fit_transform(C_pred_non_neg)
    H = nmf.components_
    
    if return_H:
        return W, H
    else:
        return W