"""
Online learning module for preference-based reinforcement learning.

This module contains the online_learning function and all supporting
utilities for preference learning and policy optimization.
"""

import copy
import pickle
import os
import time
import numpy as np

from utils.online_helpers import (
    precompute_phi_Bhats,
    calc_gamma_t,
    calc_empirical_counts,
    calc_online_confset_t,
    get_policy_pair_that_maximizes_uncertainty,
    generate_policy_pair_rollouts,
    annotate_buffer,
    learn_w_MLE,
    project_w,
    find_most_preferred_policy,
    calc_regret,
    loop_iteration_logging,
    initial_loop_earlystop_logging,
    final_iteration_logging,
)
from utils.offline_helpers import (
    generate_offline_trajectories,
    _calculate_squared_hellinger_distance_bhattacharyya,
)

from models.policies import (
    TabularPolicy,
    generate_random_tabular_policies,
    generate_all_deterministic_stationary_policies,
    generate_random_tabular_policies_vectorized,
    generate_random_tabular_policies,
    train_tabular_BC_policy,
)
from models.transition_models import train_transition_model_wrapper, sanity_check_transitions


def online_learning(
    confset_offline,
    offline_trajs,
    env_BC_learned,
    solution_pi_true,
    env_true,
    N_rollouts,  # params
    N_iterations,  # params
    episode_length,  # params
    delta_online,  # params
    phi,
    B,
    W,  # params
    d,
    kappa,
    lambda_param,
    eta,
    w_MLE_epochs,  # params
    w_initialization,  # params
    w_sigmoid_slope,  # params
    xi_formula,  # params
    which_transition_model,
    n_transition_model_epochs,  # params
    online_confset_recalc_phi,  # params
    online_confset_bonus_multiplier,  # params
    use_true_T_in_online,  # params
    gamma_t_hardcoded_value,  # params
    do_offline_BC=False,  # params
    baseline_search_space="random_sample",  # params
    N_confset_size=1000,  # params
    env_name=None,
    verbose=[],  # params
):
    """
    Perform online learning using preference feedback and confidence sets.

    Args:
        confset_offline: Offline confidence set of policies
        offline_trajs: List of offline trajectories
        env_BC_learned: Environment with learned transitions
        solution_pi_true: True optimal policy
        env_true: True environment
        N_rollouts: Number of rollouts per iteration
        N_iterations: Number of online iterations
        episode_length: Length of episodes
        delta_online: Confidence parameter for online learning
        phi: Embedding function
        B: Embedding bound
        W: Weight bound
        d: Embedding dimension
        kappa: Kappa parameter
        lambda_param: Lambda parameter
        eta: Eta parameter
        w_MLE_epochs: Number of epochs for w_MLE training
        w_initialization: Weight initialization method
        w_sigmoid_slope: Sigmoid slope parameter
        overrides: Dict of parameter overrides
        which_transition_model: Transition model type
        n_transition_model_epochs: Number of epochs for transition model training
        online_confset_recalc_phi: Whether to recalculate phi inside online confset (or use precomputed one)
        online_confset_bonus_multiplier: Bonus multiplier for online confset (to manually scale B-term)
        do_offline_BC: Whether offline behavioral cloning was performed & offline data is available
        N_search_space: Size of search space if no offline BC
        baseline_search_space: Type of baseline search space ("random_sample", "all_policies", "augmented_ball")
          if == "augmented_ball": (builds H2 ball, augments with random policies to hit size N_search_space)
            baseline_bigsample_MLE_N_offline_trajs: Number of offline trajs for MLE transitions
            baseline_radius: Radius of the baseline ball: H2(MLE, pi) < radius^2
        verbose: List of verbosity options

    Returns:
        tuple: (metrics, final_objs, final_values)

    Main logic:
    for t=0, ..., N_iterations-1:
    - train w_MLE and project it to W-ball
    - precalculate phi(pi) and bhat(pi) values for all policies in confset_offline
    - calculate gamma_t
    - calculate online confset Pi_t (uses precalculated values)
    - select (pi1, pi2) policy pair that maximimizes uncertainty (uses precalculated values)
    - sample trajectories from (pi1, pi2), annotate and add to buffer
    - update V_t, N_ts and retrain transition model on all offline + online trajectories
    - calculate metrics: uncertainty_t, regret_t, etc.

    The start & end of the loop are slightly different. Start requires initialization and skipping some steps due to 0 data,
    while end requires extra metrics calculation.
    """
    N_states = env_true.N_states
    N_actions = env_true.N_actions

    if "loop-summary" in verbose or "full" in verbose:
        print(f"\n ----- starting ONLINE loop ----- ")

    online_start_time = time.time()

    ##### INITIALIZATION #####
    # If no offline behavioral cloning:
    # - No offline trajectories: ergo no visitation counts.
    # - No offline confidence set as search space. Options for search space:
    #   - sample random policies (N_search_space many)
    #   - use the full policy space (feasible for small MDPs only)
    if not do_offline_BC:
        # can either
        if baseline_search_space == "random_sample":
            confset_offline = [solution_pi_true]
            confset_offline.extend(
                generate_random_tabular_policies_vectorized(
                    N_states,
                    N_actions,
                    N_policies=N_confset_size,
                    make_deterministic=True,
                )
            )
        # avoid kernel crash by checking size of policy space
        elif baseline_search_space == "all_policies" and N_actions**N_states < 1e4:
            confset_offline = []
            # first look for them on disk in exps/all_policies/{env_name}_all_policies.pkl
            all_policies_path = f"exps/all_policies/{env_name}_all_policies.pkl"
            if os.path.exists(all_policies_path):
                all_policies = pickle.load(open(all_policies_path, "rb"))
                if "full" in verbose or "online-confset" in verbose:
                    print(f"loaded {len(all_policies)} policies from {all_policies_path}")
            else:
                # if not found on disk, generate all policies, then save to disk
                all_policies = generate_all_deterministic_stationary_policies(N_states, N_actions)
                if "full" in verbose or "online-confset" in verbose:
                    print(f"Generated all {N_actions**N_states} possible deterministic policies...")
                os.makedirs(os.path.dirname(all_policies_path), exist_ok=True)
                with open(all_policies_path, "wb") as f:
                    pickle.dump(all_policies, f)
            confset_offline.extend(all_policies)
        elif baseline_search_space == "augmented_ball":
            ## this branch: assumes we start with BRIDGE's ball from offline, then augment it with random policies
            ## to hit size N_search_space_size. Idea is that ball is decent, and random policies dilute it.
            num_to_be_filled = N_confset_size - len(confset_offline)
            if num_to_be_filled > 0:
                additional_random_policies = generate_random_tabular_policies_vectorized(
                    N_states,
                    N_actions,
                    N_policies=num_to_be_filled,
                    make_deterministic=True,
                )
                confset_offline.extend(additional_random_policies)
                if "full" in verbose or "online-confset" in verbose:
                    print(
                        f"Augmented ball with {num_to_be_filled} random policies to hit size {len(confset_offline)} (target {N_confset_size})"
                    )
            else:
                print(
                    f"Warning: Search space creation (augmented ball): N_search_space {N_confset_size} is smaller than H2-ball (current size {len(confset_offline)}). Decrease radius, or increase N_search_space."
                )
                # Remove excess policies to match N_confset_size
                confset_offline = confset_offline[:N_confset_size]
        else:
            raise ValueError("baseline_search_space invalid or MDP too big for all_policies")
        env_BC_learned = copy.deepcopy(env_true)
        env_BC_learned.transitions = np.ones((N_actions, N_states, N_states)) / N_states
        env_BC_learned = sanity_check_transitions(env_BC_learned, fix=True)
        offline_trajs = []

    # initialize metrics. when plotting, at iteration t {initial=0, 1, ..., N-1} we want to show:
    # - metrics related to "best policy prediction" that all use the w trained on dataset of loop t
    # --> log these metrics at loops {1, ..., N-1, final=N}
    # - metrics related to in-loop calculations of loop t
    # --> log these metrics at loops {initial=0, 1, ..., N-1}
    metrics = {
        # logged at loops {1, ..., N-1, final}
        "regrets": [],
        "best_iteration_policy": [],
        "scores_best_iteration_policy": [],
        "scores_true_opt": [],
        "avg_rewards_best_iteration_policy": [],
        "avg_rewards_true_opt": [],
        # logged at loops {initial, 1, ..., N-1}
        "pi_set_sizes": [],  # 0, 1...N-1
        "uncertainty_t": [],
        "iteration_times": [],
    }

    annotated_online_buffer = []
    online_traj_list = []
    online_policy_pairs = []
    t = 0  # use 0-based indexing [0...N_iterations-1]. any math formula that 't' appears in, I use t_ := t+1 (i.e. t=1, ..., N_iterations).
    ##### END INITIALIZATION #####

    ##### LOOP t=0 #####
    ## initialize data matrix V_t: V_1 is just a scaled identity matrix.
    V_t = kappa * lambda_param * np.eye(d)
    V_t_inv = np.linalg.inv(V_t)

    ## initialize empirical visitation counts N_t(s,a):
    # at each [t], the empirical counts of the state-action (s,a) pair seen in both offline + online trajectories
    # up to time t. it's a growing list of [N_offline, N_online0, N_online1, ...].
    N_ts = [calc_empirical_counts(offline_trajs, N_states, N_actions)]

    ## train w_MLE: no labeled data yet, so initialize random unif(-1,1).
    w_MLE_t = np.random.randn(d)
    w_proj_t = w_MLE_t / np.linalg.norm(w_MLE_t) * W  # project to boundary of W-ball

    ## precompute phi and Bhats for all policies in confset_offline. used for
    # - online confset calculation, in the condition that's checked (\forall pi)
    # - argmax_{pi1, pi2} of uncertainty score
    # Note: gamma_t uses separate values b/c it requires Bhats of only the policies sampled so far.
    precomputed_phi_bhats = precompute_phi_Bhats(
        N_states,
        N_actions,
        t,
        phi,
        confset_offline,
        env_BC_learned,
        eta,
        delta_online,
        episode_length,
        N_ts,
        xi_formula,
        n_samples=100,
    )

    if gamma_t_hardcoded_value:
        if "full" in verbose or "online-confset" in verbose:
            print(f"overriding gamma_{t}: {gamma_t_hardcoded_value}")
        gamma_t = gamma_t_hardcoded_value
    else:
        gamma_t = calc_gamma_t(
            t,
            kappa,
            lambda_param,
            B,
            W,
            N_iterations,
            d,
            delta_online,
            eta,
            episode_length,
            N_ts,
            xi_formula,
            policy_pairs=None,  # 1st iteration: no online policy pairs yet. set B-term to 0.
            env_learned=env_BC_learned,
            verbose=verbose,
        )

    Pi_t = confset_offline
    # Pi_t = calc_online_confset_t(
    #     confset_offline,
    #     precomputed_phi_bhats,
    #     w_proj_t,
    #     gamma_t,
    #     V_t_inv,
    #     online_confset_recalc_phi,
    #     online_confset_bonus_multiplier,
    #     phi,
    #     env_true,
    #     verbose,
    # )

    if "full" in verbose or "online-confset" in verbose:
        print(
            f"calculating online confset with gamma = {gamma_t}, B multiplier = {online_confset_bonus_multiplier}"
        )

    # policy pair selection, sampling, updates (if |Pi_t| = 1 we're done, skip & terminate)
    if len(Pi_t) > 1:
        # select policy pair that maximizes uncertainty
        policy_pair_t, uncertainty_t = get_policy_pair_that_maximizes_uncertainty(
            Pi_t, precomputed_phi_bhats, gamma_t, V_t_inv, verbose=verbose
        )
        online_policy_pairs.append(policy_pair_t)

        # optional: check if BC policy is in policy_pair_t via the hash
        if "full" in verbose or "warnings" in verbose:
            mle_hash = hash(confset_offline[0].matrix.tobytes())
            pair_hashes = [
                hash(policy_pair_t[0].matrix.tobytes()),
                hash(policy_pair_t[1].matrix.tobytes()),
            ]
            print(f"MLE/Opt policy is in policy_pair_t: {mle_hash in pair_hashes}")

        # gather, label, add online trajectories: sample traj1 ~ pi_1_t, traj2 ~ pi_2_t N_rollouts many times. first: generate trajs, [[t1, t2]_1, ...]
        traj_pairs_t = generate_policy_pair_rollouts(
            env_true, policy_pair_t[0], env_true, policy_pair_t[1], N_rollouts
        )
        for traj_pair in traj_pairs_t:
            online_traj_list.append(traj_pair[0])
            online_traj_list.append(traj_pair[1])
        annotated_online_buffer.extend(annotate_buffer(traj_pairs_t, env_true, N_rollouts))

        # update V_t, N_ts
        pi1_phi = precomputed_phi_bhats[hash(policy_pair_t[0].matrix.tobytes())].phi
        pi2_phi = precomputed_phi_bhats[hash(policy_pair_t[1].matrix.tobytes())].phi
        phi_diff = pi1_phi - pi2_phi
        V_t = V_t + np.outer(phi_diff, phi_diff)
        N_ts.append(calc_empirical_counts(online_traj_list, N_states, N_actions))

        # train new transition model: all offline + online trajectories
        all_trajs = offline_trajs + online_traj_list
        env_learned_t, _, _ = train_transition_model_wrapper(
            all_trajs, env_true, which_transition_model, n_transition_model_epochs, verbose
        )
        env_learned_t = sanity_check_transitions(env_learned_t, fix=True)

    # early stopping if |Pi_t| = 1: BC was enough to solve the problem (return pi_optimal)
    elif len(Pi_t) == 1:
        if do_offline_BC:
            # check if BC policy is in Pi_t
            BC_policy_hash = hash(confset_offline[0].matrix.tobytes())
            opt_policy_hash = hash(solution_pi_true.matrix.tobytes())
            pi_t_hashes = [hash(policy.matrix.tobytes()) for policy in Pi_t]
            print(
                f"len(Pi_t) = 1, terminating at t=0. Pi_t contains... BC_policy: {BC_policy_hash in pi_t_hashes}, opt_policy: {opt_policy_hash in pi_t_hashes}, BC == opt: {BC_policy_hash == opt_policy_hash}"
            )
        else:  # we didn't do offline BC, and yet still have Pi_t = 1? Something went wrong here
            raise ValueError("len(Pi_t) = 1, but we didn't do offline BC. Shouldn't happen!")

        # update metrics -- since we're returning early, pad metrics
        metrics, final_objs, final_values = initial_loop_earlystop_logging(
            metrics, Pi_t, w_MLE_t, w_proj_t, env_BC_learned, confset_offline, online_start_time
        )
        return metrics, final_objs, final_values
    else:  # len(Pi_t) == 0?? this should never happen
        raise ValueError(f"len(Pi_t) == 0. This should never happen.")

    # update metrics -- in loop 1, skip updating metrics that rely on w (in loop 0 it's random), those are logged in loops 1, ..., N-1, final
    metrics["pi_set_sizes"].append(len(Pi_t))
    metrics["uncertainty_t"].append(uncertainty_t)
    metrics["iteration_times"].append(time.time() - online_start_time)

    # optionally print loop summary
    if "loop-summary" in verbose or "full" in verbose:
        print(f"-- summary loop {t}:")
        print(f"  size of Pi_t: {len(Pi_t)}, uncertainty: {uncertainty_t}")
        print(f" ----- ending 0-th loop ----- ")
    ##### END OF LOOP 0 #####

    ##### LOOP t=1, ..., N_iterations #####
    for t in range(1, N_iterations):
        loop_start_time = time.time()
        if "loop-summary" in verbose or "full" in verbose:
            print(f"\n ----- starting loop {t} ----- ")

        # train and project w
        w_MLE_t = learn_w_MLE(
            annotated_online_buffer,
            phi,
            d,
            w=None,  # TODO: don't change this
            w_initialization=w_initialization,
            sigmoid_slope=w_sigmoid_slope,
            W_norm=W,
            lambda_param=lambda_param,
            n_epochs=w_MLE_epochs,
            lr=0.01,
            verbose=verbose,
        )
        V_t_inv = np.linalg.inv(V_t)
        w_proj_t = project_w(
            w_MLE_t, W, V_t_inv, annotated_online_buffer, phi, lambda_param, verbose
        )

        # optionally, use true T in online computations
        if use_true_T_in_online:
            env_online_computations = env_true
        else:
            env_online_computations = env_learned_t

        # precompute phi and Bhats, for online confset calculation & max-uncertainty policy pair selection
        precomputed_phi_bhats = precompute_phi_Bhats(
            N_states,
            N_actions,
            t,
            phi,
            confset_offline,
            env_online_computations,
            eta,
            delta_online,
            episode_length,
            N_ts,
            xi_formula,
            n_samples=100,
        )

        # calculate gamma_t (optionally just use hardcoded value)
        if gamma_t_hardcoded_value:
            gamma_t = gamma_t_hardcoded_value
            if "full" in verbose or "online-confset" in verbose:
                print(f"overriding gamma_{t}: {gamma_t}")
        else:
            gamma_t = calc_gamma_t(
                t,
                kappa,
                lambda_param,
                B,
                W,
                N_iterations,
                d,
                delta_online,
                eta,
                episode_length,
                N_ts,
                xi_formula,
                online_policy_pairs,
                env_online_computations,
                verbose=verbose,
            )

        # calculate online confset Pi_t
        Pi_t = calc_online_confset_t(
            confset_offline,
            precomputed_phi_bhats,
            w_proj_t,
            gamma_t,
            V_t_inv,
            online_confset_recalc_phi,
            online_confset_bonus_multiplier,
            phi,
            env_true,
            verbose,
        )

        # if |Pi_t| > 1, select policy pair that maximizes uncertainty. else terminate online algo.
        if len(Pi_t) > 1:
            # select policy pair that maximizes uncertainty
            policy_pair_t, uncertainty_t = get_policy_pair_that_maximizes_uncertainty(
                Pi_t, precomputed_phi_bhats, gamma_t, V_t_inv, verbose=verbose
            )
            online_policy_pairs.append(policy_pair_t)

            # optional: check if BC policy is in policy_pair_t via the hash
            if "full" in verbose or "warnings" in verbose:
                mle_hash = hash(confset_offline[0].matrix.tobytes())
                pair_hashes = [
                    hash(policy_pair_t[0].matrix.tobytes()),
                    hash(policy_pair_t[1].matrix.tobytes()),
                ]
                print(f"MLE/Opt policy is in policy_pair_t: {mle_hash in pair_hashes}")

            # gather, label, add online trajectories: sample traj1 ~ pi_1_t, traj2 ~ pi_2_t N_rollouts many times. first: generate trajs, [[t1, t2]_1, ...]
            traj_pairs_t = generate_policy_pair_rollouts(
                env_true, policy_pair_t[0], env_true, policy_pair_t[1], N_rollouts
            )
            for traj_pair in traj_pairs_t:
                online_traj_list.append(traj_pair[0])
                online_traj_list.append(traj_pair[1])
            annotated_online_buffer.extend(annotate_buffer(traj_pairs_t, env_true, N_rollouts))

            # update V_t, N_ts
            pi1_phi = precomputed_phi_bhats[hash(policy_pair_t[0].matrix.tobytes())].phi
            pi2_phi = precomputed_phi_bhats[hash(policy_pair_t[1].matrix.tobytes())].phi
            phi_diff = pi1_phi - pi2_phi
            V_t = V_t + np.outer(phi_diff, phi_diff)
            N_ts.append(calc_empirical_counts(online_traj_list, N_states, N_actions))

            # train new transition model: all offline + online trajectories
            all_trajs = offline_trajs + online_traj_list
            env_learned_t, _, _ = train_transition_model_wrapper(
                all_trajs, env_true, which_transition_model, n_transition_model_epochs, verbose
            )
            env_learned_t = sanity_check_transitions(env_learned_t, fix=True)

        ## finishing loop iteration: calc some metrics
        # compute current best policy:
        best_policy_t, _ = find_most_preferred_policy(
            w_proj_t, confset_offline, phi, env_true, verbose=[]
        )

        # compute regret of current best vs theoretically optimal policy (difference in average rewards over N_samples trajectories)
        (
            regret,
            score_test,
            score_trueopt,
            avg_reward_test,
            avg_reward_trueopt,
        ) = calc_regret(
            w_proj_t,
            best_policy_t,
            solution_pi_true,
            phi,
            env_true,
            N_samples=1000,
        )

        # TODO: proper logging..
        metrics = loop_iteration_logging(
            metrics,
            regret,
            uncertainty_t,
            best_policy_t,
            score_test,
            score_trueopt,
            avg_reward_test,
            avg_reward_trueopt,
            Pi_t,
            solution_pi_true,
            loop_start_time,
            t,
            verbose,
        )
        if len(Pi_t) == 1:
            print(f"Pi_t contains only 1 policy, terminating at t={t}/{N_iterations - 1}")
            break
        ##### END OF LOOP t #####

    ##### final loop iteration is special: #####
    ## need to train w_MLE -> w_proj one last time, and calc some metrics

    # train w_MLE
    w_MLE_final = learn_w_MLE(
        annotated_online_buffer,
        phi,
        d,
        w=None,
        w_initialization=w_initialization,
        sigmoid_slope=w_sigmoid_slope,
        W_norm=W,
        lambda_param=lambda_param,
        n_epochs=w_MLE_epochs,
        lr=0.01,
        verbose=verbose,
    )
    V_T_inv = np.linalg.inv(V_t)
    w_proj_final = project_w(
        w_MLE_final, W, V_T_inv, annotated_online_buffer, phi, lambda_param, verbose
    )

    # find best final policy
    final_best_policy, _ = find_most_preferred_policy(
        w_proj_final, confset_offline, phi, env_true, verbose=verbose
    )

    # compute regret of final policy vs theoretically optimal policy
    (
        regret,
        score_test,
        score_trueopt,
        avg_reward_test,
        avg_reward_trueopt,
    ) = calc_regret(
        w_proj_final,
        final_best_policy,
        solution_pi_true,
        phi,
        env_true,
        N_samples=1000,
    )

    metrics, final_objs, final_values = final_iteration_logging(
        metrics,
        regret,
        final_best_policy,
        score_test,
        score_trueopt,
        avg_reward_test,
        avg_reward_trueopt,
        w_MLE_final,
        w_proj_final,
        env_learned_t,
        Pi_t,
        confset_offline,
        online_start_time,
        verbose,
    )

    return metrics, final_objs, final_values
