"""
Offline learning module for preference-based reinforcement learning.

This module contains the offline_learning function and all supporting
utilities for behavioral cloning and confidence set generation.
"""

import time

from models.transition_models import train_transition_model_wrapper, sanity_check_transitions
from models.policies import train_tabular_BC_policy
from utils.offline_helpers import (
    calc_offlineradius,
    calc_d_pi_BC,
    generate_confidence_set_deterministic_via_noise_matrices,
    generate_confidence_set_deterministic_via_rejection_sampling,
)


def offline_learning(
    offline_trajs,  # exp setup
    env_true,  # env setup
    episode_length,  # params
    delta_offline,  # params
    solution_pi_true,  # env setup
    N_confset_size,  # params
    N_search_space_samples,  # params
    which_confset_construction_method,  # params
    which_hellinger_calc,  # params
    which_transition_model,  # env setup
    n_transition_model_epochs,  # params
    offlineradius_formula,  # params
    offlineradius_override_value,  # params
    replace_mle_with_optimal_policy_in_offline_confset,  # params
    verbose=[],
):
    """
    Perform offline learning using behavioral cloning and confidence set generation.

    Args:
        offline_trajs: List of offline trajectories
        env_true: True environment
        episode_length: Length of episodes
        delta_offline: Confidence parameter for offline radius
        overrides: Dict of parameter overrides
        solution_pi_true: True optimal policy
        N_sampled_initial_policies: Number of policies to sample for confidence set
        which_confset_construction_method: 'noise-matrices' or 'rejection-sampling'
        which_hellinger_calc: 'local-avg' or 'bhattacharyya'
        which_transition_model: 'MLE', 'linear_classifier', or 'MLP'
        n_transition_model_epochs: Number of epochs to train transition model
        verbose: List of verbosity options

    Returns tuple:
        confset_offline: confidence set of offline policies
        policy_BC_offline: policy obtained via behavioral cloning
        env_T_learned: learned transition model
    """
    N_states = env_true.N_states
    N_actions = env_true.N_actions

    if "loop-summary" in verbose or "full" in verbose:
        print(f"\n ----- starting OFFLINE part ----- ")

    offline_start_time = time.time()

    # learn transition model, sanity check it's a valid distribution (& fix if needed)
    env_T_learned, _, __ = train_transition_model_wrapper(
        offline_trajs, env_true, which_transition_model, n_transition_model_epochs, verbose
    )
    env_T_learned = sanity_check_transitions(
        env_T_learned, fix=True, verbose=verbose
    )  # sum(rows) must be 1
    if "loop-summary" in verbose or "full" in verbose:
        print(f"Transition model: {which_transition_model}")
    if "full" in verbose:
        print(f"MLE transitions (offline):\n{env_T_learned.transitions}")

    # learn behavioral cloning policy
    policy_BC_offline = train_tabular_BC_policy(
        offline_trajs,
        N_states,
        N_actions,
        init="random",
        n_epochs=10,
        lr=0.01,
        make_deterministic=True,
        verbose=verbose,
    )

    # calculate the radius of the offline confidence set
    offlineradius = calc_offlineradius(
        offline_trajs,
        N_states,
        N_actions,
        episode_length,
        delta_offline,
        formula_version=offlineradius_formula,
        aux_input=offlineradius_override_value,
        verbose=verbose,
    )
    if "full" in verbose:
        print(f"offlineradius: {offlineradius:.3f}")

    # get stationary distribution of BC policy
    d_pi_BC = calc_d_pi_BC(offline_trajs, N_states)

    # generate confidence set
    if which_confset_construction_method == "noise-matrices":
        confset_offline = generate_confidence_set_deterministic_via_noise_matrices(
            policy_BC_offline,
            d_pi_BC,
            offlineradius,
            N_states,
            N_actions,
            replace_mle_with_optimal_policy_in_offline_confset,
            solution_pi_true,
            method="knapsack-sampling",
            sample_func="proportional",
            N_conf=N_confset_size,
            max_attempts=1000,
            verbose=verbose,
        )
    elif which_confset_construction_method == "rejection-sampling-from-all":
        confset_offline = generate_confidence_set_deterministic_via_rejection_sampling(
            policy_BC_offline,
            d_pi_BC,
            env_T_learned,
            episode_length,
            offlineradius,
            solution_pi_true,
            which_hellinger_calc,
            N_search_space_samples=None,
            verbose=verbose,
        )
    elif which_confset_construction_method == "rejection-sampling-from-sample":
        confset_offline = generate_confidence_set_deterministic_via_rejection_sampling(
            policy_BC_offline,
            d_pi_BC,
            env_T_learned,
            episode_length,
            offlineradius,
            solution_pi_true,
            which_hellinger_calc,
            N_search_space_samples=N_search_space_samples,
            verbose=verbose,
        )
    else:
        raise ValueError(
            f"Unknown confset construction method: {which_confset_construction_method}"
        )

    offline_runtime = time.time() - offline_start_time
    if "loop-summary" in verbose or "full" in verbose:
        print(f"--- offline runtime: {offline_runtime:.3f} seconds ---")
        print(f"Generated offline confset of size {len(confset_offline) if confset_offline else 0}")

    return confset_offline, policy_BC_offline, env_T_learned
