"""
Experiment runner for preference-based reinforcement learning.

This module contains the main run_experiment function that coordinates
the entire experimental pipeline.
"""

import numpy as np
import random
import torch
from tqdm import tqdm
from datetime import datetime
import platform

from environments import StarMDP_with_random_flinging, StarMDP_with_random_staying, Gridworld
from models.transition_models import sanity_check_transitions
from utils.embeddings import create_embeddings
from core.offline_learning import offline_learning
from utils.offline_helpers import generate_offline_trajectories
from core.online_learning import online_learning
from utils.misc_helpers import short_env_name

from concurrent.futures import ProcessPoolExecutor, as_completed

import multiprocessing as mp

if platform.system() == "Windows":
    mp.set_start_method("spawn", force=True)
else:
    mp.set_start_method("fork", force=True)  # 'spawn' hangs

mp.set_start_method("spawn", force=True)  # MP_COMMENT: 'fork' sometimes hangs, try at your own risk


def run_experiment(params):
    """
    Run an experiment with provided parameters.

    Coordinates environment setup, offline learning (if enabled),
    online learning, and results collection across multiple seeds.

    Args:
        params (dict): Experiment parameters containing:
            - env: Environment name ('StarMDP_crazy', 'StarMDP', 'Gridworld')
            - N_experiments: Number of seeds to run
            - N_iterations: Number of online learning iterations
            - episode_length: Length of episodes
            - do_offline_BC: Whether to do offline behavioral cloning
            - ... (other experiment parameters)

    Returns:
        tuple: (metrics_per_seed, final_objs_per_seed, final_values_per_seed)
            - metrics_per_seed: List of dicts with metrics for each seed
            - final_objs_per_seed: Dict of objects from final iteration
            - final_values_per_seed: Dict of final values
    """
    env_true, which_transition_model, solution_pi_true, N_search_space_samples = _setup_environment(
        params
    )
    N_states = env_true.N_states
    N_actions = env_true.N_actions

    phi, d, B, kappa, lambda_param, eta = _setup_embeddings(params, N_states, N_actions)

    metrics_per_seed = []
    final_objs_per_seed = []
    final_values_per_seed = []

    for seed in tqdm(range(params["N_experiments"]), desc="Experiments"):
        metrics, final_objs, final_values = _run_single_seed(
            seed,
            params,
            env_true,
            which_transition_model,
            solution_pi_true,
            N_search_space_samples,
            phi,
            d,
            B,
            kappa,
            lambda_param,
            eta,
        )

        metrics_per_seed.append(metrics)
        final_objs_per_seed.append(final_objs)
        final_values_per_seed.append(final_values)

        if "loop-summary" in params["verbose"]:
            print(f" ##### EXPERIMENT {seed}/{params['N_experiments']} END #####")

    if params["N_experiments"] == 1:
        return metrics_per_seed[0], final_objs_per_seed[0], final_values_per_seed[0]
    else:
        return metrics_per_seed, final_objs_per_seed, final_values_per_seed


def run_experiment_multiprocessing(params):
    seeds = list(range(params["N_experiments"]))
    results = []
    finished = 0
    with ProcessPoolExecutor() as executor:
        futures = [executor.submit(_run_single_seed_mp, seed, params) for seed in seeds]
        for future in as_completed(futures):
            results.append(future.result())
            finished += 1
            print(
                f"finished {finished}/{params['N_experiments']} at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
            )
    # with mp.Pool() as pool:
    #     results = pool.starmap(_run_single_seed_mp, [(seed, params) for seed in seeds])
    # Sort results by seed order if needed
    results.sort(key=lambda x: seeds.index(x[0]) if isinstance(x, tuple) and len(x) > 1 else 0)
    seeds, metrics_per_seed, final_objs_per_seed, final_values_per_seed = zip(*results)

    if params["N_experiments"] == 1:
        return metrics_per_seed[0], final_objs_per_seed[0], final_values_per_seed[0]
    else:
        return list(metrics_per_seed), list(final_objs_per_seed), list(final_values_per_seed)


def _run_single_seed(
    seed,
    params,
    env_true,
    which_transition_model,
    solution_pi_true,
    N_search_space_samples,
    phi,
    d,
    B,
    kappa,
    lambda_param,
    eta,
):
    """this version is used for single-process runs.
    only difference to '_run_single_seed_mp' is that this one does not include seed setting, environment setup, and embedding setup
    (it assumes that is done before calling this function)"""
    _set_random_seeds(seed)

    # (optional) offline learning
    # if baseline uses augmented ball, we use this confset to generate the baseline search space by diluting it with random policies
    if params["do_offline_BC"] or params["baseline_search_space"] == "augmented_ball":
        offline_trajs, _ = generate_offline_trajectories(
            env_true, solution_pi_true, n_samples=params["N_offline_trajs"]
        )
        confset_offline, _, env_T_MLE = offline_learning(
            offline_trajs,
            env_true,
            params["episode_length"],
            params["delta_offline"],
            solution_pi_true,
            params["N_confset_size"],
            N_search_space_samples,  # only needed for which_confset_construction_method=="rejection-sampling-from-sample"
            params["which_confset_construction_method"],
            params["which_hellinger_calc"],
            which_transition_model,
            params["n_transition_model_epochs_offline"],
            params["offlineradius_formula"],
            params["offlineradius_override_value"],
            params["replace_mle_with_optimal_policy_in_offline_confset"],
            verbose=params["verbose"],
        )
    else:
        confset_offline = None
        offline_trajs = None
        env_T_MLE = None

    # (mandatory) online learning
    metrics, final_objs, final_values = online_learning(
        confset_offline,
        offline_trajs,
        env_T_MLE,
        solution_pi_true,
        env_true,
        params["N_rollouts"],
        params["N_iterations"],
        params["episode_length"],
        params["delta_online"],
        phi,
        B,
        params["W"],
        d,
        kappa,
        lambda_param,
        eta,
        params["w_MLE_epochs"],
        params["w_initialization"],
        params["w_sigmoid_slope"],
        params["xi_formula"],
        which_transition_model,
        params["n_transition_model_epochs_online"],
        params["online_confset_recalc_phi"],
        params["online_confset_bonus_multiplier"],
        params["use_true_T_in_online"],
        params["gamma_t_hardcoded_value"],
        params["do_offline_BC"],
        params["baseline_search_space"],
        params["N_confset_size"],
        short_env_name(params["env"]),
        params["verbose"],
    )
    return metrics, final_objs, final_values


def _run_single_seed_mp(seed, params):
    """this version is used for multiprocessing.
    only difference to '_run_single_seed' is that this one includes seed setting, environment setup, and embedding setup"""
    _set_random_seeds(seed)
    env_true, which_transition_model, solution_pi_true, N_search_space_samples = _setup_environment(
        params
    )
    N_states = env_true.N_states
    N_actions = env_true.N_actions
    phi, d, B, kappa, lambda_param, eta = _setup_embeddings(params, N_states, N_actions)

    # (optional) offline learning
    # if baseline uses augmented ball, we use this confset to generate the baseline search space by diluting it with random policies
    if params["do_offline_BC"] or params["baseline_search_space"] == "augmented_ball":
        offline_trajs, _ = generate_offline_trajectories(
            env_true, solution_pi_true, n_samples=params["N_offline_trajs"]
        )
        confset_offline, _, env_T_MLE = offline_learning(
            offline_trajs,
            env_true,
            params["episode_length"],
            params["delta_offline"],
            solution_pi_true,
            params["N_confset_size"],
            N_search_space_samples,  # only needed for which_confset_construction_method=="rejection-sampling-from-sample"
            params["which_confset_construction_method"],
            params["which_hellinger_calc"],
            which_transition_model,
            params["n_transition_model_epochs_offline"],
            params["offlineradius_formula"],
            params["offlineradius_override_value"],
            params["replace_mle_with_optimal_policy_in_offline_confset"],
            verbose=params["verbose"],
        )
    else:
        confset_offline = None
        offline_trajs = None
        env_T_MLE = None

    # (mandatory) online learning
    metrics, final_objs, final_values = online_learning(
        confset_offline,
        offline_trajs,
        env_T_MLE,
        solution_pi_true,
        env_true,
        params["N_rollouts"],
        params["N_iterations"],
        params["episode_length"],
        params["delta_online"],
        phi,
        B,
        params["W"],
        d,
        kappa,
        lambda_param,
        eta,
        params["w_MLE_epochs"],
        params["w_initialization"],
        params["w_sigmoid_slope"],
        params["xi_formula"],
        which_transition_model,
        params["n_transition_model_epochs_online"],
        params["online_confset_recalc_phi"],
        params["online_confset_bonus_multiplier"],
        params["use_true_T_in_online"],
        params["gamma_t_hardcoded_value"],
        params["do_offline_BC"],
        params["baseline_search_space"],
        params["N_confset_size"],
        short_env_name(params["env"]),
        params["verbose"],
    )
    return seed, metrics, final_objs, final_values


def _setup_environment(params):
    N_search_space_samples = None
    if params["env"] == "StarMDP_with_random_flinging":
        env_true = StarMDP_with_random_flinging(
            discount_factor=0.99,
            episode_length=params["episode_length"],
            move_prob=params["env_move_prob"],
        )
        which_transition_model = "MLE"
    elif params["env"] == "StarMDP_with_random_staying":
        env_true = StarMDP_with_random_staying(
            discount_factor=0.99,
            episode_length=params["episode_length"],
            move_prob=params["env_move_prob"],
        )
        which_transition_model = "MLE"
    elif params["env"] == "Gridworld":
        env_true = Gridworld(
            width=4,
            height=4,
            episode_length=params["episode_length"],
            discount_factor=0.99,
            random_action_prob=1 - params["env_move_prob"],
        )
        which_transition_model = "MLP"
        N_search_space_samples = int(5e5)  # for augmented ball confset
        if params["which_confset_construction_method"] == "rejection-sampling-from-all":
            raise ValueError(
                "rejection-sampling-from-all not supported for Gridworld (policy space too big!). Use rejection-sampling-from-sample, or random_sample instead."
            )
    elif params["env"] == "MiniGrid":
        raise ValueError("MiniGrid not yet implemented")
        # TODO: implement minigrid. needs to be tabular MDP.

    else:
        raise ValueError(f"Environment {params['env']} not supported")

    env_true = sanity_check_transitions(
        env_true, fix=True
    )  # ensure transitions are valid distribution
    solution_pi_true = env_true.get_lp_solution()  # optimal policy for true environment

    return env_true, which_transition_model, solution_pi_true, N_search_space_samples


def _setup_embeddings(params, N_states, N_actions):
    """Set up embedding functions and related parameters."""
    embedding_bounds = {
        "id_long": np.sqrt(2 * params["episode_length"]),
        "id_short": np.sqrt(2) * params["episode_length"],
        "state_counts": params["episode_length"],
        "final_state": 1,  # final_state is one-hot vector of size N_states, so norm=1
    }
    embedding_dims = {
        "id_long": params["episode_length"] * (N_states + N_actions),
        "id_short": N_states + N_actions,
        "state_counts": N_states,
        "final_state": N_states,
    }
    # embedding functions operating on trajectories (!)
    phi = create_embeddings(params["phi_name"], N_states, N_actions)
    d = embedding_dims[params["phi_name"]]
    B = embedding_bounds[params["phi_name"]]
    # compute kappa
    s = 1 / (1 + np.exp(-params["W"] * B))  # sigmoid(W*B)
    s_deriv = s * (1 - s)  # d/dx(sigmoid(W*B))
    kappa = 1 / s_deriv  # kappa = 1 / (d/dx(sigmoid(W*B)))
    lambda_param = B / kappa
    eta = 2 * params["W"] * B

    return phi, d, B, kappa, lambda_param, eta


def _set_random_seeds(seed):
    """Set random seeds for reproducibility."""
    random.seed(42 + seed)
    np.random.seed(42 + seed)
    torch.manual_seed(42 + seed)
