from typing import Optional, List, Dict, Any
import numpy as np
from scipy.ndimage import gaussian_filter
from tqdm import trange

from general_utils import (
    compute_random_policy_transmat, 
    compute_sr_matrix, 
)


def intrinsic_reward_from_eigenvector(mdp, eigenvector: np.ndarray):
    r_e = np.zeros((mdp.num_states, mdp.num_actions))
    
    for s in range(mdp.num_states):
        for a in range(mdp.num_actions):
            next_s, _, _, _ = mdp.step(s, a)
            r_e[s, a] = eigenvector[next_s] - eigenvector[s]
    
    return r_e


def solve_option_specific_policy(
    mdp, 
    r_e: np.ndarray, 
    gamma: float = 0.99, 
    tol: float = 1e-6,
    max_iters: int = 5000, 
    soft_beta: bool = False,
    filter_sigma: float = 1.0,
):
    num_states, num_actions = mdp.num_states, mdp.num_actions
    V = np.zeros((num_states, ))
    
    for i in range(max_iters):
        V_old = V.copy()
        Q = np.zeros((num_states, num_actions))
        
        for s in range(num_states):
            for a in range(num_actions):
                next_s, _, _, _ = mdp.step(s, a)
                Q[s, a] = r_e[s, a] + gamma * V_old[next_s]
        
        V = np.max(Q, axis=1)
        if np.max(np.abs(V - V_old)) < tol:
            break
    
    pi = np.zeros((num_states, num_actions))
    for n in range(num_states):
        max_Q = np.max(Q[n, :])
        greedy_actions = np.where(Q[n, :] == max_Q)[0]
        pi[n, greedy_actions] = 1.0 / len(greedy_actions)

    beta = np.max(Q, axis=1) <= 0.0
    if soft_beta:
        beta = gaussian_filter(beta.astype(float), sigma=filter_sigma)
    
    return pi, beta


def compute_option_transition_matrix(
    mdp, 
    pi_option: np.ndarray, 
    beta_option: np.ndarray, 
    max_steps: Optional[int] = None, 
):
    num_states = mdp.num_states
    if max_steps is None:
        max_steps = 4 * num_states
    
    transmat_option = np.zeros((num_states, num_states))
    
    for s0 in range(num_states):
        if beta_option[s0]:
            transmat_option[s0, s0] = 1.0
            continue
        s = s0
        for _ in range(max_steps):
            if beta_option[s]:
                break
            a = np.random.choice(mdp.num_actions, p=pi_option[s])
            s, _, _, _ = mdp.step(s, a)
        
        transmat_option[s0, s] = 1.0
    
    return transmat_option


def build_augmented_transmat(
    transmat_pi: np.ndarray,
    transmats_option: List[np.ndarray], 
    alpha: float = 0.5, 
):
    if len(transmats_option) == 0:
        return transmat_pi.copy()
    
    num_states = transmat_pi.shape[0]
    num_options = len(transmats_option)
    transmat_augmented = np.zeros_like(transmat_pi)
    
    for transmat_option in transmats_option:
        transmat_augmented += transmat_option
    transmat_augmented /= num_options
    
    transmat_augmented = alpha * transmat_pi + (1 - alpha) * transmat_augmented
    
    return transmat_augmented


def diffusion_times_per_goal(transmat_augmented: np.ndarray):
    num_states = transmat_augmented.shape[0]
    I = np.eye(num_states)
    tau_list = []
    
    for g in range(num_states):
        mask = np.ones(num_states, dtype=bool)
        mask[g] = False
        transmat_augmented_sub = transmat_augmented[mask][:, mask]
        I_sub = I[mask][:, mask]
        
        A = I_sub - transmat_augmented_sub
        b = np.ones((num_states - 1, ))
        V_sub = np.linalg.solve(A, b)
        
        V = np.zeros(num_states)
        V[mask] = V_sub
        
        tau_g = V.mean()
        tau_list.append(tau_g)
    
    return np.array(tau_list)


def compute_mean_median_diffusion(
    mdp, 
    transmats_option: List[np.ndarray], 
    alpha: float = 0.5, 
):
    transmat_random = compute_random_policy_transmat(mdp)
    transmat_augmented = build_augmented_transmat(
        transmat_random,
        transmats_option,
        alpha=alpha,
    )
    tau_list = diffusion_times_per_goal(transmat_augmented)
    return np.mean(tau_list), np.median(tau_list), tau_list


def compute_eigenoptions(
    transmat_pi: np.ndarray, 
    k: Optional[int] = None, 
    gamma: float = 0.99, 
):
    L = compute_sr_matrix(transmat_pi, gamma=gamma)
    
    evals, evecs = np.linalg.eig(L)
    
    inds = np.argsort(evals)[::-1]
    evals, evecs = evals[inds], evecs[:, inds]
    
    if k is not None:
        return evals[:k], evecs[:, :k]
    return evals, evecs


def compute_eigenoptions_diffusion_time(
    mdp, 
    max_K: int = 30, 
    alpha: float = 0.5, 
    gamma_eigenoptions: float = 0.99, 
    tol: float = 1e-6, 
):
    transmat_random = compute_random_policy_transmat(mdp)
    evals, evecs = compute_eigenoptions(transmat_random)
    
    baseline_mean, baseline_median, _ = compute_mean_median_diffusion(
        mdp, 
        transmats_option=[], 
        alpha=alpha, 
    )
    
    transmats_option = []
    mean_diffusion_list = []
    median_diffusion_list = []
    all_diffusion_list = []
    
    for k in range(1, max_K + 1):
        v = evecs[:, k]
        r_e = intrinsic_reward_from_eigenvector(mdp, v)
        pi_option, beta_option = solve_option_specific_policy(
            mdp, 
            r_e, 
            gamma=gamma_eigenoptions, 
            tol=tol, 
        )
        transmat_option = compute_option_transition_matrix(
            mdp, 
            pi_option, 
            beta_option, 
        )
        transmats_option.append(transmat_option)
        
        mean_diffusion, median_diffusion, all_diffusion = compute_mean_median_diffusion(
            mdp, 
            transmats_option=transmats_option, 
            alpha=alpha, 
        )
        mean_diffusion_list.append(mean_diffusion)
        median_diffusion_list.append(median_diffusion)
        all_diffusion_list.append(all_diffusion)
        
    return (baseline_mean, baseline_median), mean_diffusion_list, median_diffusion_list, all_diffusion_list


def execute_action_or_option(
    env, 
    agent, 
    state: int, 
    action: int, 
    options: List[Any], 
    num_primitive_actions: int,
    max_steps_remaining: int, 
):
    gamma = agent.gamma
    
    if action < num_primitive_actions:
        next_state, reward, done, _ = env.step(state, action)
        return next_state, reward, 1, done
    
    option_ind = action - num_primitive_actions
    option = options[option_ind]
    pi_option, beta_option = option['pi'], option['beta']
    
    s = state
    total_reward = 0.0
    discount = 1.0
    steps = 0
    done = False
    
    while steps < max_steps_remaining and not done:
        a = np.random.choice(num_primitive_actions, p=pi_option[s])
        next_s, r, done, _ = env.step(s, a)
        
        total_reward += discount * r
        discount *= gamma
        
        s = next_s
        steps += 1
        
        if np.random.rand() < beta_option[s]:
            break
    
    return s, total_reward, max(steps, 1), done


def precompute_eigenoptions(
    mdp, 
    max_K: int, 
    gamma_sr: float = 0.99, 
    gamma_intrinsic: float = 0.9, 
    soft_beta: bool = False,
    filter_sigma: float = 1.0,
    soft_beta_alpha: float = 10.0,
    soft_beta_tau: float = 1.0,
):
    transmat_random = compute_random_policy_transmat(mdp)
    evals, evecs = compute_eigenoptions(transmat_random, gamma=gamma_sr)
    
    options = []
    for k in trange(1, max_K + 1):
        v = evecs[:, k]
        r_e = intrinsic_reward_from_eigenvector(mdp, v)
        pi_option, beta_option = solve_option_specific_policy(
            mdp, 
            r_e, 
            gamma=gamma_intrinsic, 
            soft_beta=soft_beta, 
            filter_sigma=filter_sigma, 
        )
        option = {
            'pi': pi_option,
            'beta': beta_option,
        }
        options.append(option)
    
    return options


def compute_option_specific_primitive_transmat_sr(
    mdp, 
    option: Dict[str, np.ndarray], 
    gamma_sr: float = 0.99, 
):
    num_states = mdp.num_states
    pi, beta = option['pi'], option['beta']
    
    mdp_transmat = mdp._transition_matrix
    option_transmat = np.sum(pi[..., None] * mdp_transmat, axis=1)
    
    option_transmat *= (1 - beta)[None, :]
    
    option_sr = np.linalg.inv(
        np.eye(num_states) - gamma_sr * option_transmat
    )
    
    return option_transmat, option_sr


def compute_primitive_option_transmat_sr(
    mdp, 
    action: int, 
    gamma_sr: float = 0.99,
):
    num_states, num_actions = mdp.num_states, mdp.num_actions
    pi = np.zeros((num_states, num_actions))
    pi[:, action] = 1.0
    
    beta = np.zeros((num_states, ))
    
    return compute_option_specific_primitive_transmat_sr(
        mdp, 
        {'pi': pi, 'beta': beta}, 
        gamma_sr=gamma_sr, 
    )


def compute_hierarchical_sr(
    mdp, 
    agent, 
    options: List[Dict[str, np.ndarray]],
    gamma_sr: float = 0.99,
    high_level_uniform: bool = False, 
):
    num_states, num_primitive_actions = mdp.num_states, mdp.num_actions
    num_options = len(options)
    
    assert agent.num_actions == num_primitive_actions + num_options
    
    option_transmat_list, option_sr_list, option_termination_list = [], [], []
    
    for a in range(num_primitive_actions):
        primitive_transmat, primitive_sr = compute_primitive_option_transmat_sr(
            mdp, 
            a, 
            gamma_sr=gamma_sr, 
        )
        option_transmat_list.append(primitive_transmat)
        option_sr_list.append(np.eye(num_states))
        option_termination_list.append(np.ones((num_states,)))
    
    for option in options:
        option_transmat, option_sr = compute_option_specific_primitive_transmat_sr(
            mdp, 
            option, 
            gamma_sr=gamma_sr, 
        )
        option_transmat_list.append(option_transmat)
        option_sr_list.append(option_sr)
        option_termination_list.append(option['beta'])
    
    if high_level_uniform:
        high_level_policy = np.ones((num_states, agent.num_actions)) / agent.num_actions
    else:
        high_level_policy = np.zeros((num_states, agent.num_actions))
        for s in range(num_states):
            q_values = agent.q_values[s]
            if agent.softmax:
                exp_q = np.exp(q_values / agent.softmax_temp)
                action_probs = exp_q / np.sum(exp_q)
                high_level_policy[s] = action_probs
            else:
                max_q = np.max(q_values)
                greedy_actions = np.where(q_values == max_q)[0]
                high_level_policy[s, greedy_actions] = 1.0 / len(greedy_actions)
    
    option_continuation_kernel_list = []
    for i in range(agent.num_actions):
        if i < num_primitive_actions:
            transmat_option = option_transmat_list[i]
            F = gamma_sr * transmat_option
        else:
            sr_option = option_sr_list[i]
            beta_option = option_termination_list[i]
            F = gamma_sr * (sr_option @ np.diag(beta_option))
        option_continuation_kernel_list.append(F)
    
    B_mu = np.zeros((num_states, num_states))
    G_mu = np.zeros((num_states, num_states))
    for s in range(num_states):
        for a in range(agent.num_actions):
            B_mu[s] += high_level_policy[s, a] * option_sr_list[a][s]
            G_mu[s] += high_level_policy[s, a] * option_continuation_kernel_list[a][s]
    
    hierarchical_sr = np.linalg.inv(np.eye(num_states) - G_mu) @ B_mu
    
    return hierarchical_sr


def execute_macro_action_with_trajectory(
    env,
    state,
    action,
    options,
    num_primitive_actions,
    max_steps_remaining=1000,
):
    if action < num_primitive_actions:
        s2, r, done, _ = env.step(state, action)
        return s2, float(r), 1, bool(done), [s2]

    option_ind = action - num_primitive_actions
    option = options[option_ind]

    total_r = 0.0
    traj = []
    s = state
    done = False
    k = 0

    while (k < max_steps_remaining) and (not done):
        pi_s = option["pi"][s]
        a_prim = int(np.random.choice(len(pi_s), p=pi_s))

        s2, r, done, _ = env.step(s, a_prim)
        total_r += float(r)
        traj.append(s2)
        k += 1

        beta = float(option["beta"][s2])
        if np.random.rand() < beta:
            s = s2
            break

        s = s2

    end_state = s
    return end_state, total_r, k, bool(done), traj
