import numpy as np
from models.policies import (
    TabularPolicy,
    generate_random_tabular_policies,
    generate_random_tabular_policies_vectorized,
)
from utils.misc_helpers import rollout_policy_in_env
import itertools
import time


###### Confidence set generation (offline-specific variants) ######
## main confset generation function ##
def calc_offlineradius(
    offline_trajs,
    N_states,
    N_actions,
    episode_length,
    delta,
    formula_version="full",
    aux_input=None,
    verbose=[],
):
    """
    Calculates the offline radius for the confidence set when using behavioral cloning.
    Formula: alpha/sqrt(N) + beta/sqrt(N) * offlineradius_outerbracket. Alpha and beta are MDP-specific,
    and the bracket term represents uncertainty in the transition estimate.

    Args:
        offline_trajs: List of offline trajectories
        N_states: Number of states
        N_actions: Number of actions
        episode_length: Length of episodes
        delta: Confidence parameter for offline radius
        formula_version: Version of the formula to use
        aux_input: Auxiliary input for the formula
        verbose: Verbosity options

    Returns:
        Offline radius

    Formula versions:
    - 'full': full formula -> alpha/sqrt(N) + beta/sqrt(N) * outer_bracket_term
    - 'ignore_bracket': ignore the bracket term (set=1) -> (alpha+beta)/sqrt(N)
    - 'only_alpha': only use alpha term -> alpha/sqrt(N)
    - 'hardcode_radius_scaled': -> (aux_input)/sqrt(N)
    - 'hardcode_radius': -> aux_input
    """
    N_offline_trajs = len(offline_trajs)

    # calculate state-action visitation frequencies
    dhat_t = np.zeros((episode_length, N_states, N_actions))
    ## calculate this: dhat_t(s,a) = 1/N_offline_trajs * sum_{i=1}^N_offline_trajs * (1_{s_t = s and a_t = a})
    for t in range(episode_length):
        for s in range(N_states):
            for a in range(N_actions):
                dhat_t[t, s, a] = (
                    1
                    / N_offline_trajs
                    * sum(
                        [
                            1
                            for traj in offline_trajs
                            if traj[t * 3] == s and traj[t * 3 + 1] == a  # 1_{s_t = s and a_t = a}
                        ]
                    )
                )

    # calculate smallest nonzero visitation probability under optimal policy pi* and true model P*. here we approximate it with offline trajectories.
    gamma_min = np.min(dhat_t[dhat_t > 0])  # if np.any(dhat_t > 0) else 1e-6

    if "full" in verbose or "radius-calc" in verbose:
        print(f"gamma_min (coverage): {gamma_min:.6f}")

    # Calculate radius components
    deltahalf = delta / 2
    alpha = 2 * np.sqrt(N_states * np.log(N_actions / deltahalf))  # TODO: was sqrt(6)

    beta = 2 * np.sqrt(
        (N_states**2) * N_actions * np.log(N_offline_trajs * episode_length / deltahalf)
    )  # TODO: was sqrt(6)

    inner_bracket_term = 1 + (2 * alpha / (gamma_min * np.sqrt(N_offline_trajs)))
    outer_bracket_term = 1 + np.sqrt(episode_length * inner_bracket_term)

    alphaterm = alpha / np.sqrt(N_offline_trajs)
    betaterm = beta / np.sqrt(N_offline_trajs)

    if "full" in verbose or "radius-calc" in verbose or "offline-confset" in verbose:
        print(f"alpha/sqrt(N): {alphaterm:.3f}, beta/sqrt(N): {betaterm:.3f}")
        print(
            f"outer_bracket_term: {outer_bracket_term:.3f}, inner_bracket_term: {inner_bracket_term:.3f}"
        )

    if formula_version == "full":
        return alphaterm + betaterm * outer_bracket_term
    elif formula_version == "ignore_bracket":
        return alphaterm + betaterm
    elif formula_version == "only_alpha":
        return alphaterm
    elif formula_version == "hardcode_radius_scaled":
        return aux_input / np.sqrt(N_offline_trajs)
    elif formula_version == "hardcode_radius":
        return aux_input
    else:
        raise ValueError(f"Unknown formula version: {formula_version}. See docstring for options.")


## confset generation: noise matrices ##
def generate_confidence_set_deterministic_via_noise_matrices(
    policy_BC,
    d_pi_BC,
    radius,
    N_states,
    N_actions,
    replace_mle_with_optimal_policy_in_offline_confset,  # from overrides/params
    solution_pi_true,
    method="knapsack-sampling",
    sample_func="proportional",
    N_conf=100,
    max_attempts=1000,
    verbose=[],
):
    """generate a confidence set of deterministic policies, using the
    constructive method of adding noise matrices to the BC policy
    such that the resulting policy has a average local squared Hellinger distance from the BC
    of at most radius^2.

    Note 1: this method relies on an APPROXIMATION of the squared Hellinger dist!

    Note 2: policy_BC is defined to be part of the resulting confidence set.

    Returns:
        confidence_set: list of policies (TabularPolicy)
    """
    # Ensure policy_BC.matrix is a numpy array, not torch parameter
    if hasattr(policy_BC.matrix, "detach"):
        policy_BC = TabularPolicy(policy_BC.matrix.detach().numpy())
    if replace_mle_with_optimal_policy_in_offline_confset:
        confidence_set = [solution_pi_true]
    else:
        confidence_set = [policy_BC]
    generated_hashes = set()
    generated_matrix_hashes = set()
    attempts = 0
    policy_BC_matrix = policy_BC.matrix
    while len(confidence_set) < N_conf and attempts < max_attempts:
        attempts += 1
        noise_matrix, noise_hash = _generate_deterministic_noise_in_confidence_set(
            policy_BC, d_pi_BC, radius, N_states, N_actions, method, sample_func
        )
        if noise_hash is None:
            if "full" in verbose or "offline-confset" in verbose:
                print("couldn't generate a valid policy (radius too small or N_actions=1)")
            break

        if noise_hash not in generated_hashes:  # create & add new policy to confidence set
            generated_hashes.add(noise_hash)
            new_policy_matrix = policy_BC_matrix + noise_matrix
            new_policy = TabularPolicy(new_policy_matrix)
            _sanity_check_policy_matrix(new_policy)  # TODO: remove for speed
            confidence_set.append(new_policy)
            matrix_hash = hash(new_policy_matrix.tobytes())
            if matrix_hash not in generated_matrix_hashes:
                generated_matrix_hashes.add(matrix_hash)
                # print(f"{attempts}: ok")
            else:  # weird duplicate cases that shouldn't happen (ensure the hashed tuple contains only np.int!)
                if "full" in verbose or "offline-confset" in verbose:
                    print(f"{attempts}: issue, matrix colliding but noise not!")
                break
        else:  # catch duplicates via noise hash
            if "full" in verbose:
                print(f"{attempts}: duplicate noise hash")
            continue

        # sanity check: the generated policy has a avg Hellinger distance to policy_BC of at most radius^2
        avg_squared_hellinger_dist = _calculate_squared_hellinger_distance_local_avg(
            policy_BC, new_policy, d_pi_BC
        )

        if "full" in verbose or "offline-confset" in verbose:
            print(
                f"adding policy\n{new_policy.matrix}\nwith avg Hellinger dist ({avg_squared_hellinger_dist:.5f}) vs radius^2 {radius**2}"
            )
        if avg_squared_hellinger_dist > radius**2:
            if "full" in verbose or "offline-confset" in verbose:
                print(f"avg Hellinger dist {avg_squared_hellinger_dist} > radius^2 {radius**2}")
            raise ValueError("avg Hellinger dist > radius^2")

    if "full" in verbose or "offline-confset" in verbose:
        print(f"generated {len(confidence_set)} policies in {attempts} attempts")

    return confidence_set


def _generate_deterministic_noise_in_confidence_set(
    policy_BC,  # expecting N_states x N_actions matrix, each row one-hot for deterministic policy
    d_pi_BC,  # stationary distribution of pi_MLE
    radius,  # Hellinger distance radius of confidence set
    N_states,
    N_actions,
    method="knapsack-sampling",  # sampling method
    sample_func="proportional",  # how to sample states w.r.t. pi_MLE's stationary distribution of that state
):
    """Generates one noise matrix s.t. squared hellinger distance of pi_MLE + noise stays within `radius` of pi_MLE.
    Assumes deterministic distributions, finite state/actions
    Uses the average squared hellinger distance, so we have condition
    H2_avg(pi_MLE, pi) = sum_{s: pi(s) != pi_MLE(s)} d_pi_BC(s)<= radius^2
    Methods:
        knapsack-sampling: iteratively randomly selects states s to change,
            keeping track of the added d_pi_BC(s), until the 'budget' of R^2 is reached.
            choice of sampling:
                uniform (each state same prob),
                greedy (select state w/ lowest d_pi_BC(s)),
                proportional (select state w/ p ~ d_pi_BC(s))
        greedy: iteratively selects state with lowest d_pi_BC(s) to change, until the 'budget' of R^2 is reached.

    Returns:
        noise_matrix: N_states x N_actions matrix, each row one-hot. add this to pi_MLE to get a new policy
        noise_matrix_hash: hashable representation, for easy comparison
    """
    original_actions = np.argmax(policy_BC.matrix, axis=1)

    # find the set of states & actions that are changed
    if method == "knapsack-sampling":
        S_candidates = list(range(N_states))  # candidates for states to change action
        S_diff = []  # tracks the states where actions are changed
        change_details = []  # store (state, old action, new action) tuples
        cost = 0  # tracks the sum of d_pi_BC(s) for states that have been changed (must < R^2)

        while S_candidates:
            remaining_budget = radius**2 - cost
            S_eligible = [s for s in S_candidates if d_pi_BC[s] <= remaining_budget]

            if not S_eligible:
                break  # no more states can fit into budget

            # sample next state to change
            if sample_func == "uniform":
                s_next = np.random.choice(S_eligible)
            elif sample_func == "greedy":
                s_next = min(S_eligible, key=lambda s: d_pi_BC[s])
            elif sample_func == "proportional":
                eligible_probs = d_pi_BC[S_eligible]
                prob_sum = np.sum(eligible_probs)
                if prob_sum > 1e-9:  # avoid division by zero if all eligible probs tiny
                    normalized_eligible_probs = eligible_probs / prob_sum
                    s_next = np.random.choice(S_eligible, p=normalized_eligible_probs)
                elif S_eligible:  # if sum=0 but list not yet empty, pick random one
                    s_next = np.random.choice(S_eligible)
                else:  # should not be reachable if S_eligible check passed
                    break
            else:
                raise ValueError(
                    f"sample_func {sample_func} not supported. use 'uniform', 'greedy', or 'proportional'."
                )

            # add selected state and update cost
            S_diff.append(s_next)
            S_candidates.remove(s_next)
            cost += d_pi_BC[s_next]
            if cost > radius**2:
                break  # this should not happen, but just in case

            # choose new action for selected state
            original_action = original_actions[s_next]
            possible_new_actions = [a for a in range(N_actions) if a != original_action]
            if not possible_new_actions:
                # this happens if N_actions=1, policy can't be changed
                S_diff.remove(s_next)
                cost -= d_pi_BC[s_next]  # revert cost
                continue

            new_action = np.random.choice(possible_new_actions)
            change_details.append((s_next, original_action, new_action))

        if not S_diff:
            return None, None  # no valid policy could be generated

    else:
        raise ValueError(
            f"method {method} not supported. use 'knapsack-sampling' (greedy not yet implemented)"
        )

    noise_matrix = np.zeros((N_states, N_actions))
    for s, old_a, new_a in change_details:
        noise_matrix[s, old_a] = -1
        noise_matrix[s, new_a] = 1

    # create hashable representation of noise_matrix
    sorted_S_diff = tuple(sorted(S_diff))
    sorted_change_details = tuple(sorted(change_details))  # sort by state index
    noise_matrix_hash = hash((sorted_S_diff, sorted_change_details))
    return noise_matrix, noise_matrix_hash


def _sanity_check_policy_matrix(policy):
    """checks if policy matrix has rows that sum to 1"""
    if not isinstance(policy, TabularPolicy):
        raise ValueError("Input must be a TabularPolicy")
    matrix = policy.matrix
    if not isinstance(matrix, np.ndarray):
        matrix = matrix.detach().numpy()  # casting to numpy if torch param
    if not np.allclose(np.sum(matrix, axis=1), 1):
        raise ValueError("Policy matrix rows must sum to 1")
    return True


## confset generation: rejection sampling ##
def generate_confidence_set_deterministic_via_rejection_sampling(
    policy_BC,
    d_pi_BC,
    env_emp,
    episode_length,
    offlineradius,
    solution_pi_true,
    which_hellinger_calc="exact",
    N_search_space_samples=None,
    verbose=[],
):
    """generates an offline confidence set of deterministic policies by rejection sampling.
    first generates ALL possible policies (this is intractable for anything != StarMDP!!!)
    then calculate hellinger distance for each policy (choice of approximation via 'local-avg'
    or exact via 'bhattacharyya'), only include candidate policies with H2-distance <= radius^2."""
    start_time = time.time()

    if hasattr(policy_BC.matrix, "detach"):
        policy_BC = TabularPolicy(policy_BC.matrix.detach().numpy())

    confidence_set = []
    N_states, N_actions = env_emp.N_states, env_emp.N_actions
    ## generate ALL possible policies.
    ## policies are matrices of size N_states x N_actions, each row one-hot for deterministic policy
    ## so to generate ALL of them:

    # generate all possible action combinations (one action per state)
    # this gives us N_actions**N_states total combinations.
    # then check if H2-distance is <= radius^2.

    true_dist_opt_MLE = _calculate_squared_hellinger_distance_bhattacharyya(
        policy_BC,
        env_emp.transitions,
        solution_pi_true,
        env_emp.transitions,
        env_emp.initial_state_distribution,
        episode_length,
    )
    if true_dist_opt_MLE > offlineradius**2:
        print(
            f"WARNING: (confset gen.) optimal policy not in confset! true H2-dist: {true_dist_opt_MLE} > rad^2: {offlineradius**2}"
        )

    # if N_samples is None, generate all possible policies
    if N_search_space_samples is None:
        total_policies = N_actions**N_states
        if "offline-confset" in verbose:
            print(f"Confset: generating all {total_policies} possible deterministic policies...")
        action_combinations = itertools.product(range(N_actions), repeat=N_states)
        for actions in action_combinations:
            candidate_policy_matrix = np.zeros((N_states, N_actions))
            for state, action in enumerate(actions):
                candidate_policy_matrix[state, action] = 1  # one-hot encoded action
            candidate_policy = TabularPolicy(candidate_policy_matrix)

            # calculate H2-distance
            if which_hellinger_calc == "approx":
                squared_hellinger_dist = _calculate_squared_hellinger_distance_local_avg(
                    policy_BC, candidate_policy, d_pi_BC
                )
            elif which_hellinger_calc == "exact":
                squared_hellinger_dist = _calculate_squared_hellinger_distance_bhattacharyya(
                    policy_BC,
                    env_emp.transitions,
                    candidate_policy,
                    env_emp.transitions,
                    env_emp.initial_state_distribution,
                    episode_length,
                )
            else:
                raise ValueError(f"which_hellinger_calc {which_hellinger_calc} not supported")

            if squared_hellinger_dist <= offlineradius**2:
                confidence_set.append(candidate_policy)
    elif offlineradius < 1:
        if "offline-confset" in verbose:
            print(
                f"Confset: generating {N_search_space_samples} random policies, then rejection sampling them."
            )
        # generate N_samples many random policies, then rejection sample them (exactly)
        policies_sample = generate_random_tabular_policies_vectorized(
            N_states,
            N_actions,
            N_policies=N_search_space_samples,
            make_deterministic=True,
        )
        for policy in policies_sample:
            squared_hellinger_dist = _calculate_squared_hellinger_distance_bhattacharyya(
                policy_BC,
                env_emp.transitions,
                policy,
                env_emp.transitions,
                env_emp.initial_state_distribution,
                episode_length,
            )
            if squared_hellinger_dist <= offlineradius**2:
                confidence_set.append(policy)

        # add the MLE policy, because it's in the confset by definition.
        confidence_set.append(policy_BC)
        # add the optimal policy, if its bhatta dist is <= radius^2.
        if true_dist_opt_MLE <= offlineradius**2:
            confidence_set.append(solution_pi_true)

    else:
        raise ValueError(
            f"Trying to rejection sample from large sample (suggesting big MDP) but radius >= 1. Set radius < 1 to filter large sample to something manageable."
        )

    if "offline-confset" in verbose:
        print(
            f"Confset: generation success, {len(confidence_set)} deterministic policies in {time.time() - start_time:.2f} seconds"
        )

    return confidence_set


## misc helpers for confset generation ##
def calc_d_pi_BC(offline_trajs, N_states):
    """calculates the stationary distribution of policy_BC via visitation count of offline trajs.

    Trajectories are of form [s0, a0, r0, s1, a1, r1, ..., sH, aH, rH].
    """
    if not offline_trajs:
        raise ValueError("Offline trajectories are empty")

    d_pi_BC = np.zeros(N_states)
    total_visits = len(offline_trajs) * (len(offline_trajs[0]) // 3)  # N_trajs * episode_length

    for traj in offline_trajs:
        for state in traj[::3]:
            d_pi_BC[state] += 1

    d_pi_BC /= total_visits  # normalize to get distribution
    return d_pi_BC


###### Offline trajectory generation ######
def generate_offline_trajectories(env, policy, n_samples, verbose=[]):
    """
    Generate offline trajectories by rolling out a policy in an environment.

    Args:
        env: The environment to roll out in
        policy: The policy to roll out
        n_samples: Number of trajectories to generate

    Returns:
        offline_trajs: List of trajectories, each traj [s0,a0,r0,s1,a1,r1,...]
        unique_trajs: Set of unique trajectories (converted to int)
    """
    offline_trajs = []
    for _ in range(n_samples):
        offline_trajs.append(rollout_policy_in_env(env, policy))

    def traj_to_int(traj):
        return [int(x) if not isinstance(x, int) else x for x in traj]

    # get only unique trajectories
    unique_trajs = set(tuple(traj_to_int(traj)) for traj in offline_trajs)

    if "full" in verbose or "offline-trajs" in verbose:
        print(f"Number of unique trajectories: {len(set(tuple(traj) for traj in offline_trajs))}")
        print(f"Unique trajectories: {unique_trajs}")

    return offline_trajs, unique_trajs


###### Hellinger distance calculations ######
## exact H2-distance via Bhattacharyya coefficient ##
def _calculate_squared_hellinger_distance_bhattacharyya(
    policy1,
    t_matrix1,
    policy2,
    t_matrix2,
    d0,  # initial dist, 1-d vec of N_states
    H,  # episode length
    verbose=[],
) -> float:
    """
    Calculates the squared Hellinger distance between two trajectory distributions using a trick with Bhattacharyya coefficient.

    ASSUMES: finite MDP. deterministic, stationary policies.

    The distributions are defined by (policy, transition_model) pairs. This function
    avoids creating the full trajectory distribution vector by using a dynamic
    programming approach:

    Args:
        policy1, policy2: First/second policy objects. Must have a .matrix attribute, which is a Tensor of shape (N_states, N_actions).
        t_matrix1, t_matrix2: The first/second transition model's matrix. Tensor of shape (N_actions, N_states, N_states).
        d0: The initial state distribution vector. Tensor of shape (N_states,).
        H: The episode length (int).

    Algo:
        B_0 = d0
        for t in range(H):
            M_t = sqrt(T1(s'|s,pi1(s)) * T2(s'|s,pi2(s))) * IS_TRUE(pi1(s)==pi2(s))
            B_t+1 = M_t @ B_t
        NOTE: if T, pi deterministic then M_t is constant and we just have
        B_H = (M**H) @ B_0.

    Returns:
        Squared Hellinger distance, a float between 0 and 1
    """
    p1_matrix = policy1.matrix
    p2_matrix = policy2.matrix

    # validate inputs
    if not (
        p1_matrix.shape == p2_matrix.shape
        and t_matrix1.shape == t_matrix2.shape
        and p1_matrix.shape[0] == t_matrix1.shape[1]
        and p1_matrix.shape[1] == t_matrix1.shape[0]
    ):
        raise ValueError("Dimension mismatch between policies and transition models.")
    if not (len(d0.shape) == 1 and d0.shape[0] == p1_matrix.shape[0]):
        raise ValueError("Dimension mismatch for initial state distribution d0.")

    N_states = p1_matrix.shape[0]

    # pre-calc deterministic actions for each policy
    actions1 = np.argmax(p1_matrix, axis=1)
    actions2 = np.argmax(p2_matrix, axis=1)

    # --- dynamic programming for Bhattacharyya coefficient ---
    # Initialize B_0(s) = d0(s), initial state dist vector
    # B_t(s) is the sum of sqrt-probabilities of all trajectories of length t that end in state s
    B = np.copy(d0)

    if "full" in verbose:
        print(f"B_0: {B}")

    ## since have deterministic policies and transitions: the M_t matrix is constant!
    M_t = np.zeros((N_states, N_states))
    # for each starting state 's_from', calculate the transitions.
    # this corresponds to filling the columns of M_t. if pi1(s) != pi2(s),
    # the entire column of M_t stays zero, M[:,s]==0.
    for s_from in range(N_states):
        a1 = actions1[s_from]
        a2 = actions2[s_from]
        if a1 == a2:
            t1_probs = t_matrix1[a1, s_from, :]
            t2_probs = t_matrix2[a1, s_from, :]
            sqrt_joint_probs = np.sqrt(t1_probs * t2_probs)
            M_t[:, s_from] = sqrt_joint_probs

    # iterate for H timesteps. if we didn't have deterministic policies & transitions,
    # this loop would redefine M_t.
    # for _ in range(H):
    #     print(f"Step {_}->{_ + 1}: B_{_ + 1} = M_{_} * B_{_} =\n{M_t} * {B} =\n{M_t @ B}")
    #     B = M_t @ B  # B_t+1 = M_t * B_t

    M_H = np.linalg.matrix_power(M_t, H)
    B = M_H @ B

    if "full" in verbose:
        print(f"matrix M:\n{M_t}")
        print(f"M^H:\n{np.linalg.matrix_power(M_t, H)}")
        print(f"B = M^H @ B_0:\n{B}")

    bhattacharyya_coefficient = np.sum(B)  # sum over all states
    if "full" in verbose:
        print(f"Bhattacharyya coefficient: {bhattacharyya_coefficient}")
        print(f"returning max(0, 1-BhatC) = {max(0.0, 1.0 - bhattacharyya_coefficient)}")
    squared_h_distance = max(0.0, 1.0 - bhattacharyya_coefficient)  # clip to handle FP inaccuracies

    return squared_h_distance


## approximate H2-distance via local average ##
def _calculate_squared_hellinger_distance_local_avg(policy_BC, policy, d_pi_BC):
    """calculates the average squared Hellinger distance between policy_BC and policy, weighted by d_pi_BC
    Formula: H2_avg(pi_MLE, pi) = sum_{s: pi(s) != pi_MLE(s)} d_pi_BC(s)
    where policies are given as matrices of [N_states, N_actions] and d_pi_BC is a vector of length N_states
    """
    # convert policies to numpy arrays
    policy_BC = policy_BC.matrix  # .detach().numpy()
    policy = policy.matrix  # .detach().numpy()
    # calculate the average squared Hellinger distance
    total_dist = 0
    for s in range(len(d_pi_BC)):
        if not np.array_equal(policy_BC[s], policy[s]):
            total_dist += d_pi_BC[s]
    return total_dist
