"""
Main training script for DQN agent using a custom replay buffer (ReaPER) on a specified environment.

This script:
- Initializes environment-specific hyperparameters.
- Sets up a DQN model with a custom replay buffer.
- Trains the model using Stable-Baselines3.
- Logs results to TensorBoard.
"""

from datetime import date
from stable_baselines3.common.callbacks import EvalCallback

# Custom modules
from custom_buffers.reaper import ReaPER
from utils.storage import get_tb_storage_file_path
from utils.environment import get_environment_specific_settings, seed_everything

# === Configuration === #
ENVIRONMENT_NAME = "CartPole-v1"
REPLAY_BUFFER_CLASS = ReaPER
MODEL_NAME = "CustomDDQN"
REPLAY_BUFFER_KWARGS = {}

# Seed everything for reproducibility
SEED = 0
seed_everything(SEED)

# Create export suffix for tracking experiments
trial_start_date = date.today().strftime("%Y%m%d")
export_suffix = f"{trial_start_date}_{REPLAY_BUFFER_CLASS.__name__}"

# Get TensorBoard log path
tb_log_path = get_tb_storage_file_path(ENVIRONMENT_NAME, REPLAY_BUFFER_CLASS, MODEL_NAME)

# === Load Environment-Specific Settings === #
(
    model_class,
    env,
    eval_env,
    dqn_policy,
    max_timesteps,
    reward_threshold,
    num_evals,
    callback_on_new_best,
    learning_rate,
    learning_starts,
    n_eval_episodes,
    exploration_initial_eps,
    exploration_fraction,
    exploration_final_eps,
    batch_size,
    buffer_size,
    policy_kwargs,
    max_grad_norm,
    train_freq,
    target_update_interval,
    gradient_steps,
    gamma,
    progress_bar,
) = get_environment_specific_settings(
    environment_name=ENVIRONMENT_NAME,
    n_envs=1,
    seed=SEED
)

# === Initialize Model === #
model = model_class(
    # General
    env=env,
    verbose=1,
    seed=SEED,

    # Replay Buffer
    replay_buffer_class=REPLAY_BUFFER_CLASS,
    replay_buffer_kwargs=REPLAY_BUFFER_KWARGS,

    # Learning Parameters
    buffer_size=buffer_size,
    batch_size=batch_size,
    learning_rate=learning_rate,
    learning_starts=learning_starts,
    gamma=gamma,
    gradient_steps=gradient_steps,
    max_grad_norm=max_grad_norm,
    train_freq=train_freq,
    target_update_interval=target_update_interval,

    # Exploration
    exploration_initial_eps=exploration_initial_eps,
    exploration_fraction=exploration_fraction,
    exploration_final_eps=exploration_final_eps,

    # Policy
    policy=dqn_policy,
    policy_kwargs=policy_kwargs,

    # Logging
    tensorboard_log=tb_log_path + export_suffix + "/",
)

# === Evaluation Callback === #
eval_callback = EvalCallback(
    eval_env=eval_env,
    eval_freq=max_timesteps // num_evals,
    callback_on_new_best=callback_on_new_best,
    n_eval_episodes=n_eval_episodes,
    verbose=1
)

# === Begin Training === #
model.learn(
    total_timesteps=max_timesteps,
    log_interval=max_timesteps // num_evals,
    progress_bar=progress_bar,
    callback=eval_callback
)
