import time
import numpy as np
from scipy.spatial.distance import mahalanobis
from scipy.optimize import minimize
import itertools
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from .misc_helpers import rollout_policy_in_env

EPS = np.finfo(np.float32).eps


# Online-specific computations
class PrecomputedPhiBhats:
    def __init__(self, phi, bhat_online_confset, bhat_uncertainty_score):
        self.phi = phi
        self.bhat_online_confset = bhat_online_confset
        self.bhat_uncertainty_score = bhat_uncertainty_score


def precompute_phi_Bhats(
    N_states,
    N_actions,
    t,
    phi_func,
    confset_offline,
    env,
    eta,
    delta_online,
    episode_length,
    N_ts,
    xi_formula,
    n_samples=100,
):
    """returns precomputed values for all policies in the offline confidence set,
    as a dictionary with policy hash as key.
    delta1: for online confset Pi_t
    delta2: for the uncertainty score (=delta_param)
    Access e.g. policy_values[policy_hash].phi"""
    policy_values = {}
    delta_online_confset = max(delta_online / (2 * (N_actions**N_states)), 1e-10)

    for policy in confset_offline:
        policy_hash = hash(policy.matrix.tobytes())  # assumes policy matrix is np.ndarray

        sampled_trajs = []
        for _ in range(n_samples):
            sampled_trajs.append(rollout_policy_in_env(env, policy))

        values = PrecomputedPhiBhats(
            # phi(policy)
            _calc_policy_embedding(phi_func, policy, env, sampled_trajs, n_trajectories=100),
            # Bhat(policy, eta, delta/(2A^S)) for online confset condition
            _calc_Bhat(
                t,
                policy,
                env,
                eta,
                delta_online_confset,
                episode_length,
                N_ts,
                xi_formula,
                sampled_trajs,
                n_trajectories=None,
            ),
            # Bhat(policy, eta, delta) for uncertainty score
            _calc_Bhat(
                t,
                policy,
                env,
                eta,
                delta_online,
                episode_length,
                N_ts,
                xi_formula,
                sampled_trajs,
                n_trajectories=None,
            ),
        )
        policy_values[policy_hash] = values
    return policy_values


def _calc_policy_embedding(phi_func, policy, env, provided_trajs=None, n_trajectories=1000):
    """computes the embedding of a policy. this is defined as:
    Expectation_{trajectories sampled in env under policy} (phi_func(trajectory))"""
    if provided_trajs is not None:
        return np.mean([phi_func(traj) for traj in provided_trajs], axis=0)
    else:
        sample_traj = rollout_policy_in_env(env, policy)
        embedding_sum = np.zeros(phi_func(sample_traj).shape)
        embedding_sum += phi_func(sample_traj)
        for _ in range(n_trajectories - 1):
            traj = rollout_policy_in_env(env, policy)
            embedding_sum += phi_func(traj)
        return embedding_sum / n_trajectories


def _calc_Bhat(
    t,  # outer loop idx: 0, ..., N_iterations-1
    policy,
    env,
    eta,
    delta,
    episode_length,
    N_ts,
    xi_formula,
    provided_trajs=None,
    n_trajectories=100,
):
    """calculates following formula:
    Bhat = Expectation_{traj [s_0, a_0, s_1, a_1, ...] under policy and transition_model, and initial state determined from env.set}
    of sum_{h=1}^{episode_length-1} xi(t, s_h, a_h, eta, delta, episode_length, N_states, N_actions)
    where:
    - transition_model: learned model at time t

    If trajectories are provided in a list, use those to approx. the expectation.
    Otherwise, generate n_trajectories."""
    N_states, N_actions = env.N_states, env.N_actions
    total_xi_sum = 0
    # process provided trajectories if available
    if provided_trajs is not None:
        for traj in provided_trajs:
            xi_sum = 0
            for h in range(0, len(traj) - 3, 3):
                s_h = traj[h]  # state at time h
                a_h = traj[h + 1]  # action at time h
                xi_sum += _calc_xi(
                    t,
                    s_h,
                    a_h,
                    eta,
                    delta,
                    episode_length,
                    N_states,
                    N_actions,
                    N_ts,
                    xi_formula,
                )
            total_xi_sum += xi_sum
        return total_xi_sum / len(provided_trajs)

    # generate and process trajectories on the fly if none provided
    else:
        for _ in range(n_trajectories):
            traj = rollout_policy_in_env(env, policy)
            xi_sum = 0
            for h in range(0, len(traj) - 3, 3):
                s_h = traj[h]  # state at time h
                a_h = traj[h + 1]  # action at time h
                xi_sum += _calc_xi(
                    t,
                    s_h,
                    a_h,
                    eta,
                    delta,
                    episode_length,
                    N_states,
                    N_actions,
                    N_ts,
                    xi_formula,
                )
            total_xi_sum += xi_sum
        return total_xi_sum / n_trajectories


def _calc_xi(t, s, a, eta, delta, episode_length, N_states, N_actions, N_ts, xi_formula):
    """Calculates xi_t,s,a(eta,delta), the state-action level uncertainty term that flows into the empirical bonus Bhat
    according to either of the following formulas (choose via xi_formula):
    'full': min(2*eta, 4*eta*sqrt(U/(N_ts[0][s,a] + N_ts[t][s,a])))
    'smaller_start': min(log(N_states*N_actions), sqrt(log(N_states*N_actions)/visitations))

    where U = episode_length*log(N_states*N_actions)+log(6*log(N_ts[0][s,a] + N_ts[t][s,a]))-log(delta).

    Note:
    - The 'smaller_start' formula is meant for low N_offline. It has the same scaling in N as 'full', but
    starts off smaller.
    - N_ts is a list, each entry contains a matrix of counts N_t(s,a)
    - N_1 contains the offline counts. N_ts[t] contains the ONLINE counts up to time t.

    """
    safe_delta = max(delta, EPS * 10)  # delta could be VERY small
    # splitting the outer log into two for numerical stability
    if t == 0:
        visitation_sum = N_ts[0][s, a]
    else:
        visitation_sum = N_ts[0][s, a] + N_ts[t][s, a]

    if (
        xi_formula == "smaller_start"
    ):  # replace formula by something with same scaling in N, but starts smaller (in low N).
        if visitation_sum == 0 or visitation_sum == 1:
            return np.log(N_states * N_actions)
        else:
            res = min(
                np.log(N_states * N_actions), np.sqrt(np.log(N_states * N_actions) / visitation_sum)
            )
            return res  # line above used to be: / np.log(visitation_sum)

    elif xi_formula == "full":
        if visitation_sum == 0:  # numerator->-infty, denominator->0, so res->2*eta
            res = 2 * eta
            return res
        else:
            U = (
                episode_length * np.log(N_states * N_actions)
                + np.log(6 * np.log(visitation_sum))
                - np.log(safe_delta)
            )
            res = min(2 * eta, 4 * eta * np.sqrt(U / visitation_sum))
            return res

    else:
        raise ValueError(f"Invalid xi_formula: {xi_formula}. options: 'full', 'smaller_start'")


def calc_gamma_t(
    t,  # outer loop idx: 0, ..., N_iterations-1
    kappa,
    lambda_param,
    B,
    W,
    N_iterations,
    d,
    delta_param,
    eta,
    episode_length,
    N_ts,
    xi_formula,
    policy_pairs,
    env_learned,
    verbose=[],
):
    """computes the gamma_t according to the following formula:
    gamma_t = sqrt(2)*(4*kappa*beta_t(delta) + alpha_{d,N_iterations}(delta)) + 1/t + 2*sqrt(U)
    where
        U = sum_{i=1}^{t-1} Bhat(t, pi1_i, env, eta, delta, episode_length)**2 + Bhat(t, pi2_i, env, eta, delta, episode_length)**2
        pi1_i, pi2_i are the policies picked in loop iteration i, which we obtain as a list [pi1_i, pi2_i] from policy_pairs[i] input (usually online_policy_pairs but in the 1st iteration it's offline_policy_pairs)

    Note: env must use learned transitions at t (not true!)
    Note: has to sample new trajs, not precomputing those b/c amount of policy pairs could be <<< confidence set size so unsure if worth (calc'ing extra Bhats for those in confset but not policy_pairs).
    """
    N_states, N_actions = env_learned.N_states, env_learned.N_actions
    t_formula = t + 1  # for formulas, t=1, ..., N_iterations.
    epsilon = 1 / (t_formula**2 * kappa * lambda_param + 4 * B**2 * t_formula**3)
    if epsilon < 1e-10 and ("full" in verbose or "warnings" in verbose):
        print("gamma_calc: epsilon is too small, hitting safety floor")
    delta_prime = delta_param / ((1 + 4 * W) / max(epsilon, 1e-10)) ** d  # add safety to prevent /0
    alpha_d_T = (
        20 * B * W * np.sqrt(d * np.log((N_iterations * (1 + 2 * N_iterations)) / delta_param))
    )
    beta_t = np.sqrt(lambda_param) * W + np.sqrt(
        np.log(1 / delta_param)
        + 2 * d * np.log(1 + (t_formula * B) / (kappa * lambda_param * delta_param))
    )
    Bhat1s = []
    Bhat2s = []
    if policy_pairs is not None:
        for l, pair in enumerate(policy_pairs):
            pi1, pi2 = pair[0], pair[1]
            delta_primeprime = max(
                delta_prime / (8 * (t_formula**3) * (N_actions**N_states)), 1e-10
            )
            Bhat1s.append(
                _calc_Bhat(
                    l,
                    pi1,
                    env_learned,
                    eta,
                    delta_primeprime,
                    episode_length,
                    N_ts,
                    xi_formula,
                    provided_trajs=None,
                    n_trajectories=100,
                )
                ** 2
            )
            Bhat2s.append(
                _calc_Bhat(
                    l,
                    pi2,
                    env_learned,
                    eta,
                    delta_primeprime,
                    episode_length,
                    N_ts,
                    xi_formula,
                    provided_trajs=None,
                    n_trajectories=100,
                )
                ** 2
            )
        B_term = 2 * np.sqrt(np.sum(Bhat1s) + np.sum(Bhat2s))
    else:  # iteration t=0
        B_term = 0
    alpha_beta_term = np.sqrt(2) * (4 * kappa * beta_t + alpha_d_T)
    gamma_t = alpha_beta_term + 1 / t_formula + B_term
    if "full" in verbose:
        print(f"calc gamma_t:")
        print(f"  epsilon: {epsilon:.5f}")
        print(f"  alpha_d_T: {alpha_d_T:.3f}")
        print(f"  beta_t: {beta_t:.3f}, 4*kappa*beta_t: {4 * kappa * beta_t:.3f}")
        print(f"  B_term: {B_term:.3f}")
        print(
            f"  gamma_t: {gamma_t:.3f} = alpha&beta term {alpha_beta_term:.3f} + 1/t {1 / t_formula:.3f} + B_term {B_term:.3f}"
        )
    return gamma_t


def calc_empirical_counts(traj_list, N_states, N_actions):
    """
    Get empirical state-action counts from trajectories: N_t(s,a) = number of times we've seen (s,a) up to time t.

    Each trajectory in traj_list is [s0, a0, r0, s1, a1, r1, ...].
    """
    N_t = np.zeros((N_states, N_actions), dtype=np.int32)
    for traj in traj_list:
        for i in range(0, len(traj), 3):
            s, a, _ = traj[i], traj[i + 1], traj[i + 2]
            N_t[s, a] += 1
    return N_t


def calc_online_confset_t(
    confset_offline,
    precomputed_phi_bhats,
    w_proj_t,
    gamma_t,
    V_t_inv,
    online_confset_recalc_phi,  # default should be False
    online_confset_bonus_multiplier,  # default should be 1
    phi_func=None,
    env=None,
    verbose=[],
):
    """calculates Pi_t, the online confset at time t
    it's a subset of the offline confset, containing all policies that fulfil the condition:
    for all other policies pi' in the offline confset, it holds that
    <(phi(pi)-phi(pi')), w> + gamma_t * ||phi(pi)-phi(pi')||_V^(-1) + bhat_online_confset(pi) + bhat_online_confset(pi') >= 0
    """
    calc_Pi_t_time = time.time()
    Pi_t = []

    min_estimate_win_probability_all_pi = np.inf
    avg_estimate_winprob_size = 0
    avg_B_term_size = 0
    avg_norm_size = 0
    avg_sum_term = 0
    tot_num_calcs = 0

    for i, pi in enumerate(confset_offline):
        pi_hash = hash(pi.matrix.tobytes())
        pi_phi = precomputed_phi_bhats[pi_hash].phi

        if online_confset_recalc_phi:
            pi_phi = _calc_policy_embedding(
                phi_func, pi, env, provided_trajs=None, n_trajectories=100
            )

        pi_bhat_online_confset = precomputed_phi_bhats[pi_hash].bhat_online_confset
        # if i % 100 == 0:
        #     print(f"TMP!! pi1: {pi_bhat_online_confset:.3f} (i={i})")
        all_conditions_satisfied = True

        for j, pi2 in enumerate(confset_offline):
            pi2_hash = hash(pi2.matrix.tobytes())
            pi2_phi = precomputed_phi_bhats[pi2_hash].phi
            if online_confset_recalc_phi:
                pi2_phi = _calc_policy_embedding(
                    phi_func, pi2, env, provided_trajs=None, n_trajectories=100
                )
            pi2_bhat_online_confset = precomputed_phi_bhats[pi2_hash].bhat_online_confset
            # if j % 100 == 0:
            #     print(f"TMP!! pi2: {pi2_bhat_online_confset:.3f} (j={j})")
            estimate_winprob = np.dot(pi_phi - pi2_phi, w_proj_t)
            norm_term = mahalanobis(pi_phi, pi2_phi, V_t_inv)
            uncertainty_w_estimate = gamma_t * norm_term
            sum_term = (
                estimate_winprob
                + uncertainty_w_estimate
                + online_confset_bonus_multiplier
                * (pi_bhat_online_confset + pi2_bhat_online_confset)
            )

            if estimate_winprob < min_estimate_win_probability_all_pi:
                min_estimate_win_probability_all_pi = estimate_winprob
            B_term_size = pi_bhat_online_confset + pi2_bhat_online_confset
            avg_estimate_winprob_size += estimate_winprob
            avg_norm_size += norm_term
            avg_B_term_size += B_term_size
            avg_sum_term += sum_term
            tot_num_calcs += 1
            if sum_term < 0:
                all_conditions_satisfied = False
                # print(
                #     f"TMP!! i={i}, j={j}, have sum = {estimate_winprob} + {uncertainty_w_estimate} + {online_confset_bonus_multiplier} * ({pi_bhat_online_confset} + {pi2_bhat_online_confset}) = {sum_term:.3f} < 0"
                # )
                break

        if all_conditions_satisfied:
            Pi_t.append(pi)

    avg_estimate_winprob_size /= tot_num_calcs
    avg_norm_size /= tot_num_calcs
    avg_B_term_size /= tot_num_calcs
    avg_sum_term /= tot_num_calcs
    if "full" in verbose or "online-confset" in verbose:
        print(f" -- confset --")
        print(
            f"  formula: <estimate_winprob> ({avg_estimate_winprob_size:.3f}) + gamma_t ({gamma_t:.2f}) * norm-term ({avg_norm_size:.2f}) + bonus_multiplier ({online_confset_bonus_multiplier:.3f}) * B_term ({avg_B_term_size:.3f})"
        )
        print(f"  policy deleted if above <0")
        print(f"  avg estimate winprob size: {avg_estimate_winprob_size:.3f}")
        print(f"  avg norm size: {avg_norm_size:.3f}")
        print(f"  avg B term size: {avg_B_term_size:.3f}")
        print(f"  avg sum term: {avg_sum_term:.3f}")
        print(f"  min win prob: {min_estimate_win_probability_all_pi:.3f} (all pi)")

    # Check if MLE (or optimal, if override or no BC done) policy is in Pi_t
    policy_BC_or_opt = confset_offline[0]
    bc_or_opt_in_pi_t = policy_BC_or_opt in Pi_t
    if not bc_or_opt_in_pi_t and ("full" in verbose or "warnings" in verbose):
        print("WARNING: BC/optimal policy is not in Pi_t!")

    if "full" in verbose or "online-confset" in verbose:
        print(f"calc Pi_t: success, {len(Pi_t)} policies, in {time.time() - calc_Pi_t_time:.2f}s")
        print(f"  BC policy in Pi_t: {bc_or_opt_in_pi_t}")

    if len(Pi_t) > 0:
        return Pi_t
    else:
        raise ValueError("Pi_t is empty")  # at least pi_MLE should be in Pi_t.


def get_policy_pair_that_maximizes_uncertainty(
    Pi_t, precomputed_phi_bhats, gamma_t, V_t_inv, verbose=[]
):
    """
    Get the policy pair that maximizes uncertainty from the Pi_t set.
    uncertainty defined as, for a pair pi1, pi2 in Pi_t:
    gamma_t * ||phi(pi1)-phi(pi2)||_V^(-1) + 2*bhat_uncertainty_score(pi1) + 2*bhat_uncertainty_score(pi2)
    """
    calc_uncertainty_time = time.time()
    max_uncertainty = -float("inf")

    # Generate unique pairs to avoid redundant calculations
    policy_pairs = list(itertools.combinations(Pi_t, 2))

    if len(policy_pairs) == 0:
        raise ValueError("Pi_t is empty")

    for pi1, pi2 in policy_pairs:
        pi1_hash = hash(pi1.matrix.tobytes())
        pi2_hash = hash(pi2.matrix.tobytes())
        pi1_phi = precomputed_phi_bhats[pi1_hash].phi
        pi2_phi = precomputed_phi_bhats[pi2_hash].phi
        pi1_bhat_uncertainty_score = precomputed_phi_bhats[pi1_hash].bhat_uncertainty_score
        pi2_bhat_uncertainty_score = precomputed_phi_bhats[pi2_hash].bhat_uncertainty_score
        mahalanobis_dist = mahalanobis(pi1_phi, pi2_phi, V_t_inv)
        uncertainty = (
            gamma_t * mahalanobis_dist
            + 2 * pi1_bhat_uncertainty_score
            + 2 * pi2_bhat_uncertainty_score
        )

        if uncertainty > max_uncertainty:
            max_uncertainty = uncertainty
            best_pair = (pi1, pi2)

    if "full" in verbose:
        print(
            f"max. uncertainty pair calc'd: uncertainty {max_uncertainty:.3f}, in {time.time() - calc_uncertainty_time:.2f}s"
        )

    return best_pair, uncertainty


def generate_policy_pair_rollouts(
    rollout_env1,
    policy_1,
    rollout_env2,
    policy_2,
    num_rollouts=10,
):
    """For a pair of 2 policies & envs, generate num_rollouts many pairs of trajectories.
    Returns:
        traj_pairs: list of lists [[traj1, traj2]_1, [traj1, traj2]_2, ..., []_Nrollouts]
        where each traj is [s1,a1,r1,s2,...,sN,aN,rN]
    """
    traj_pairs = []
    for _ in range(num_rollouts):
        traj_1 = rollout_policy_in_env(rollout_env1, policy_1)
        traj_2 = rollout_policy_in_env(rollout_env2, policy_2)

        traj_pairs.append([traj_1, traj_2])
    return traj_pairs  # [[traj1, traj2]_1, [traj1, traj2]_2, ..., []_Nrollouts] where each traj [s1,a1,r1,s2,...,sN,aN,rN]


def annotate_buffer(traj_pairs_buffer, oracle_env, N_rollouts):
    """Annotate buffer of offline trajectories using an Oracle represented by oracle_env (usually true env).

    Only the first N_rollouts pairs are annotated.

    IN: buffer (list) of trajectory pairs [[traj1, traj2]_1, [traj1, traj2]_2, ..., [.,.]_bufferlength]
    OUT: augmented list of traj pairs, for every pair add
    y_T: 1 if log-likelihood(traj1) > log-likelihood(traj2), else 0
    y_R: 1 if disc.reward(traj1) > disc.reward(traj2), else 0
    returns: [[traj1, traj2, y_T, y_R]_1, ..., [.,.,.,.]_Nrollouts]

    Note: uses transitions, rewards, discount factor of annotation_env. if bufferlength < N_rollouts: duplicate trajs
    """
    list_trajs = []
    for i in range(min(len(traj_pairs_buffer), N_rollouts)):
        traj_1, traj_2 = traj_pairs_buffer[i]
        P_1 = _compute_log_likelihood_traj(traj_1, oracle_env.transitions)
        R_1 = compute_rewards_traj(traj_1, oracle_env.rewards, oracle_env.discount_factor)
        P_2 = _compute_log_likelihood_traj(traj_2, oracle_env.transitions)
        R_2 = compute_rewards_traj(traj_2, oracle_env.rewards, oracle_env.discount_factor)

        y_t = int(P_1 >= P_2)  # 1 if log-likelihood(traj1) > log-likelihood(traj2), else 0
        if R_1 == R_2:
            y_r = np.random.randint(0, 1)
        else:
            y_r = int(R_1 > R_2)  # 1 if disc.reward(traj1) > disc.reward(traj2), else 0
        list_trajs.append([traj_1, traj_2, y_t, y_r])

    if len(traj_pairs_buffer) < N_rollouts:
        print("Warning: not enough samples in buffer")
        while len(list_trajs) < N_rollouts:
            list_trajs.append(np.random.choice(list_trajs))
    return list_trajs


def _compute_log_likelihood_traj(traj, transition_matrix):
    """Compute log-likelihood of a trajectory under a transition matrix T.
    First gets flattened (1D-vector) counts of (s,a,s') in traj. using unique flat index:
    For 3D array [D_x, D_y, D_z], the unique flat index for (x, y, z) is [x *D_y*D_z + y*D_z + z].
    Here: (x,y,z) = (a,s,s') so index is [a *N_states*N_states + s*N_states + s']

    IN:
    - traj: [s0,a0,r0, ..., sH,aH,rH] all ints (s,a) or floats (r)
    - transition_matrix: [action, state, next_state] np.ndarray

    OUT:
    log(P(traj | T)) = log( PROD_{s,a,s'} T(s,a,s') ^ count(s,a,s') )
                     = SUM_{s,a,s'} log(T(s,a,s')) * count(s,a,s')
    """
    N_actions, N_states, _ = transition_matrix.shape

    # Inline get_triple_counts
    triple_counts = np.zeros((N_actions * N_states * N_states))
    for i in range(0, len(traj) - 3, 3):
        state, action, next_state = traj[i], traj[i + 1], traj[i + 3]  # skipping reward at i+2
        index = action * (N_states * N_states) + state * N_states + next_state
        triple_counts[index] += 1

    return np.log(transition_matrix + EPS).flatten().dot(triple_counts)


def compute_rewards_traj(traj, reward_vector, discount_factor):
    """Computes discounted rewards of a trajectory.
    Args:
        traj: list of lists, each containing a trajectory
        reward_vector: reward per state, np.ndarray of shape (N_states,)
        discount_factor: float
    """
    N_states = reward_vector.shape
    state_counts = np.zeros((N_states))
    for i in range(0, len(traj), 3):
        state = traj[i]
        state_counts[state] += discount_factor**i
    return reward_vector.dot(state_counts)


def learn_w_MLE(
    annotated_traj_buffer,
    phi_func,
    dim,
    w=None,
    w_initialization="random",
    sigmoid_slope=1,
    W_norm=1,
    lambda_param=0.01,
    n_epochs=10,
    lr=0.01,
    verbose=[],
):
    """trains a w_MLE parameter on some set of annotated trajectory pairs,
    i.e. [[[s0,a0,r0,s1,...]_traj1, [s0,a0,r0,s1,...]_traj2, y_T, y_R]_pair1, [..]_pair2, ..., [..]_pairbufferlength].
    e.g. warm-starting from offline dataset, or later on using the collected online preferences.
    """
    assert len(phi_func(annotated_traj_buffer[0][0])) == dim
    if w is None:
        if w_initialization == "uniform":
            w = torch.ones(dim)  # initialize uniform vector
            w = w / torch.norm(w, p=2) * W_norm  # normalize to have 2-norm W
        elif w_initialization == "random":
            w = torch.randn(dim)  # initialize random vector
            w = w / torch.norm(w, p=2) * W_norm  # normalize to have 2-norm W
        else:
            raise ValueError(f"w_initialization {w_initialization} not supported")
        w.requires_grad = True

    else:  # if w is provided, ensure it's trainable
        if not w.requires_grad:
            w.requires_grad = True
    optimizer = torch.optim.Adam([w], lr=lr)  # initialize optimizer

    for epoch in range(n_epochs):
        total_loss = 0.0
        optimizer.zero_grad()

        for traj_pair in annotated_traj_buffer:
            t1 = traj_pair[0]  # first trajectory
            t2 = traj_pair[1]  # second trajectory
            y_R = traj_pair[3]  # binary pref label (1 if t1>t2 based on rewards)

            phi_t1 = torch.tensor(phi_func(t1), dtype=torch.float32)
            phi_t2 = torch.tensor(phi_func(t2), dtype=torch.float32)
            phi_diff = phi_t1 - phi_t2  # embedded difference phi(t1) - phi(t2)

            logit = torch.dot(phi_diff, w)  # logit = w^T * (phi(t1) - phi(t2))

            # Binary cross-entropy loss with regularization
            if y_R == 1:
                loss = -torch.log(torch.sigmoid(sigmoid_slope * logit))
            else:
                loss = -torch.log(1 - torch.sigmoid(sigmoid_slope * logit))

            # Add regularization term
            reg_term = (lambda_param / 2) * torch.norm(w, p=2) ** 2
            loss += reg_term

            total_loss += loss

        total_loss.backward()
        optimizer.step()

        if ("full" in verbose or "losses" in verbose) and epoch % 2 == 0:
            print(f"  Epoch {epoch}, Loss: {total_loss.item():.4f}")
    w = w.detach()
    if "full" in verbose:
        print(f"  w_MLE_t: {w}, 2-norm: {np.linalg.norm(w):.3f}")
    return w


def project_w(w_MLE, W, V_t_inv, annotated_traj_buffer, phi_func, lambda_param, verbose=[]):
    """Projects the learned w_MLE vector to a ball, resulting in w_proj.

    Formula: w_proj_t = argmin_{w in d-dim ball of radius W} ||g_t(w) - g_t(w_MLE)||_{V_t^{-1}}
    where the matrix norm is the Mahalanobis norm with respect to V_t^{-1},
    and where g_t(w) = sum_{i=0}^{t-1} sigmoid(np.dot(phi(t1_i)-phi(t2_i), w))*(phi(t1_i)-phi(t2_i)) + lambda_param * w,
    and where t1_i, t2_i are the i-th pair of trajectories from the dataset we're training on,
    dataset of the form [[[s0,a0,r0,s1,...]_traj1, [s0,a0,r0,s1,...]_traj2, y_T, y_R]_pair1, [..]_pair2, ..., [..]_pairbufferlength].
    """
    if isinstance(w_MLE, torch.Tensor):  # if w is torch tensor, convert to np
        w_MLE = w_MLE.detach().numpy().astype(np.float64)
    else:
        w_MLE = np.array(w_MLE, dtype=np.float64)

    if np.linalg.norm(w_MLE) <= W:
        if "full" in verbose:
            print("project w_MLE: success (was already ||.||<W)")
        return w_MLE

    # Ensure V_t_inv is float64
    V_t_inv = np.array(V_t_inv, dtype=np.float64)

    def g_t(w_):
        """calculates g_t(w) for any w, used in obj. function"""
        w_ = np.array(w_, dtype=np.float64)  # Ensure w_ is float64
        result = np.zeros_like(w_, dtype=np.float64)
        for traj_pair in annotated_traj_buffer:
            t1 = traj_pair[0]
            t2 = traj_pair[1]

            phi_t1 = np.array(phi_func(t1), dtype=np.float64)
            phi_t2 = np.array(phi_func(t2), dtype=np.float64)
            phi_diff = phi_t1 - phi_t2

            logit = np.dot(phi_diff, w_)
            sigmoid_val = _sigmoid(logit)

            result += sigmoid_val * phi_diff

        result += lambda_param * w_
        return result

    g_t_w_MLE = g_t(w_MLE)

    def objective_func(w_):
        w_ = np.array(w_, dtype=np.float64)  # Ensure w_ is float64
        return mahalanobis(g_t(w_), g_t_w_MLE, V_t_inv)

    # Constraint: ||w||_2 <= W
    constraint = {"type": "ineq", "fun": lambda w: float(W - np.linalg.norm(w))}  # >= 0

    # Initial guess: use w_MLE if within the ball, otherwise project it
    initial_w = W * (w_MLE / np.linalg.norm(w_MLE))
    initial_w = np.array(initial_w, dtype=np.float64)

    # Solve the optimization problem
    try:
        result = minimize(objective_func, initial_w, constraints=constraint, method="SLSQP")

        if result.success:
            if "full" in verbose:
                print(f"  w_proj_t: {result.x}, 2-norm: {np.linalg.norm(result.x):.3f} <= W: {W}")
            return result.x
        else:
            # If optimization fails, fall back to simple projection
            w_norm = np.linalg.norm(w_MLE)
            if w_norm <= W:
                if "full" in verbose:
                    print("failed to project w_MLE but ||w_MLE||_2 <= W, returning it")
                return w_MLE
            else:
                if "full" in verbose:
                    print("failed to project w_MLE, returning normalized w_MLE")
                return W * (w_MLE / w_norm)
    except Exception as e:
        if "full" in verbose:
            print("Error during optimization: {e}")
        # Fall back to simple projection
        w_norm = np.linalg.norm(w_MLE)
        if w_norm <= W:
            return w_MLE
        else:
            return W * (w_MLE / w_norm)


def _sigmoid(x, slope=1):
    """slope: >1 steeper, [0,1] shallower, <1 flips
    for numerical stability, separate cases"""
    if x >= 0:
        return 1 / (1 + np.exp(-slope * x))
    else:
        exp_x = np.exp(slope * x)
        return exp_x / (1 + exp_x)


def find_most_preferred_policy(w, policy_list, phi_traj_func, env, verbose=[]):
    """finds best policy in a list of policies by maximizing score func

    this list should be the offline confset?"""
    best_policy = None
    best_score = -np.inf
    policy_scores = {}

    for policy in policy_list:
        policy_hash = hash(policy.matrix.tobytes())
        score, reward = _score_func_policy(phi_traj_func, policy, w, env, verbose=verbose)
        policy_scores[policy_hash] = {
            "policy": policy,
            "score": score,
            "reward": reward,
        }
        if "full" in verbose:
            print(f"  score: {score:.3f}, avg reward: {reward:.3f}")
        if score > best_score:
            best_score = score
            best_policy = policy
    if "full" in verbose:
        print(f"  best policy score: {best_score:.3f}")
    return best_policy, policy_scores


def calc_avg_reward(policy, env, N_samples=1000):
    """simply calculates expected reward of a policy in an env, over N_samples trajectories"""
    reward_sum = 0
    for _ in range(N_samples):
        traj = rollout_policy_in_env(env, policy)
        traj_reward = compute_rewards_traj(traj, env.rewards, env.discount_factor)
        reward_sum += traj_reward
    return reward_sum / N_samples


def _score_func_policy(phi_traj_func, policy, w, env, N_samples=100, verbose=[]):
    """calculates the "score" function: <phi(policy), w>
    which is defined as the Expectation over trajectories from that policy, of <phi(traj), w>

    also calculates average reward over N_samples trajectories
    """
    # sampled_trajs = []
    score = 0
    reward_sum = 0
    for _ in range(N_samples):
        traj = rollout_policy_in_env(env, policy)
        traj_reward = compute_rewards_traj(traj, env.rewards, env.discount_factor)
        score += np.dot(phi_traj_func(traj), w)
        reward_sum += traj_reward
    if "full" in verbose:
        print(f"phi(traj): {phi_traj_func(traj)}, reward: {traj_reward:.3f}")  # print one
    return score / N_samples, reward_sum / N_samples


def calc_regret(w, policy_test, policy_true_opt, phi_traj_func, env, N_samples=1000, verbose=[]):
    """compares average performance (in terms of discounted rewards) on the environment (via sampling N_samples trajectories) of
    a policy to the true optimal policy"""
    avg_reward_test, score_test, avg_reward_trueopt, score_trueopt = _calc_avg_rewards_scores(
        w, policy_test, policy_true_opt, phi_traj_func, env, N_samples
    )
    regret = avg_reward_trueopt - avg_reward_test
    comparison_glyph = "<" if regret > 0 else ">"
    if "full" in verbose:
        print(
            f"  avg rewards: TEST {avg_reward_test:.3f} {comparison_glyph} TRUE {avg_reward_trueopt:.3f},\n  regret: {regret:.3f}\n  scores test {score_test:.3f} vs true opt {score_trueopt:.3f}"
        )
    return (
        regret,
        score_test,
        score_trueopt,
        avg_reward_test,
        avg_reward_trueopt,
    )


def _calc_avg_rewards_scores(w, policy_test, policy_true_opt, phi_traj_func, env, N_samples=1000):
    """compares average performance (in terms of discounted rewards) on the environment (via sampling N_samples trajectories) of
    a policy to the true optimal policy"""
    score_testpolicy, avg_reward_testpolicy = _score_func_policy(
        phi_traj_func, policy_test, w, env, N_samples
    )
    score_trueopt, avg_reward_trueopt = _score_func_policy(
        phi_traj_func, policy_true_opt, w, env, N_samples
    )
    return avg_reward_testpolicy, score_testpolicy, avg_reward_trueopt, score_trueopt


def loop_iteration_logging(
    metrics,
    regret,
    uncertainty_t,
    best_policy_t,
    score_test,
    score_trueopt,
    avg_reward_test,
    avg_reward_trueopt,
    Pi_t,
    solution_pi_true,
    loop_start_time,
    t,
    verbose=[],
):
    """logs metrics for one iteration"""
    metrics["regrets"].append(regret)
    metrics["best_iteration_policy"].append(best_policy_t)
    metrics["scores_best_iteration_policy"].append(score_test)
    metrics["scores_true_opt"].append(score_trueopt)
    metrics["avg_rewards_best_iteration_policy"].append(avg_reward_test)
    metrics["avg_rewards_true_opt"].append(avg_reward_trueopt)
    metrics["uncertainty_t"].append(uncertainty_t)
    metrics["pi_set_sizes"].append(len(Pi_t))
    metrics["iteration_times"].append(time.time() - loop_start_time)

    if "loop-summary" in verbose or "full" in verbose:
        opt_hash = hash(solution_pi_true.matrix.tobytes())
        Pi_t_hashes = [hash(policy.matrix.tobytes()) for policy in Pi_t]

        print(f"-- summary loop {t}:")
        print(
            f"  size of Pi_t: {len(Pi_t)}, opt in Pi_t: {opt_hash in Pi_t_hashes}, uncertainty: {uncertainty_t}"
        )
        print(f"  rewards best policy: {avg_reward_test:.3f} vs true opt {avg_reward_trueopt:.3f}")
        print(f"  scores best policy: {score_test:.3f} vs true opt {score_trueopt:.3f}")
        print(f"  => regret: {regret:.3f}")
        print(f" ----- ending loop {t} (in {time.time() - loop_start_time:.1f} seconds) ----- ")

    return metrics


def initial_loop_earlystop_logging(
    metrics, Pi_t, w_MLE_t, w_proj_t, env_BC_learned, confset_offline, online_start_time
):
    """produces metrics & return objects for loop 0 when we early-stop because |Pi_t| = 1"""
    metrics["uncertainty_t"].append(0)
    metrics["pi_set_sizes"].append(1)
    metrics["regrets"].append(0)
    metrics["best_iteration_policy"].append(Pi_t[0])
    metrics["scores_best_iteration_policy"].append(0)
    metrics["scores_true_opt"].append(0)
    metrics["avg_rewards_best_iteration_policy"].append(0)
    metrics["avg_rewards_true_opt"].append(0)
    metrics["iteration_times"].append(time.time() - online_start_time)
    final_objs = {
        "w_MLE_final": w_MLE_t,
        "w_proj_final": w_proj_t,
        "final_best_policy": Pi_t[0],
        "env_learned_t": env_BC_learned,
        "Pi_t": Pi_t,
        "confset": confset_offline,
    }
    final_values = {
        "regret": 0,
        "avg_reward_final_best_policy": 0,
        "score_final_best_policy": 0,
        "avg_reward_true_opt": 0,
        "score_true_opt": 0,
    }
    return metrics, final_objs, final_values


def final_iteration_logging(
    metrics,
    regret,
    final_best_policy,
    score_test,
    score_trueopt,
    avg_reward_test,
    avg_reward_trueopt,
    w_MLE_final,
    w_proj_final,
    env_learned_t,
    Pi_t,
    confset_offline,
    online_start_time,
    verbose=[],
):
    metrics["regrets"].append(regret)
    metrics["best_iteration_policy"].append(final_best_policy)
    metrics["scores_best_iteration_policy"].append(score_test)
    metrics["scores_true_opt"].append(score_trueopt)
    metrics["avg_rewards_best_iteration_policy"].append(avg_reward_test)
    metrics["avg_rewards_true_opt"].append(avg_reward_trueopt)

    final_objs = {
        "w_MLE_final": w_MLE_final,
        "w_proj_final": w_proj_final,
        "final_best_policy": final_best_policy,
        "env_learned_t": env_learned_t,
        "Pi_t": Pi_t,
        "confset": confset_offline,
    }

    final_values = {
        "regret": regret,
        "avg_reward_final_best_policy": avg_reward_test,
        "score_final_best_policy": score_test,
        "avg_reward_true_opt": avg_reward_trueopt,
        "score_true_opt": score_trueopt,
    }

    online_runtime = time.time() - online_start_time

    if "full" in verbose:
        print(f"  w_MLE_final: {w_MLE_final}, 2-norm: {np.linalg.norm(w_MLE_final):.3f}")
        print(f"  w_proj_final: {w_proj_final}, 2-norm: {np.linalg.norm(w_proj_final):.3f}")
        print(f"  final best policy: {final_best_policy}, regret: {regret:.3f}")
        print(f"  final best policy score: {score_test:.3f}, true opt score: {score_trueopt:.3f}")
        print(
            f"  final best policy avg reward: {avg_reward_test:.3f}, true opt avg reward: {avg_reward_trueopt:.3f}"
        )
        print(f"  online runtime: {online_runtime:.1f} seconds")

    return metrics, final_objs, final_values
