import wandb
from rich import print

from harvest_sed import setup_config_and_run_dir, setup_experiment
from harvest_sed.config import Config
from harvest_sed.training.collection import run_training_episode
from harvest_sed.utils.logger import logger

def main():
    args: Config = setup_config_and_run_dir()
    print(args)

    logger.cfg(args)

    ctx, envs, agent_trajectory = setup_experiment(args)

    total_episodes = args.eps_per_tax_rate * args.total_principal_steps
    for _ in range(total_episodes):

        """ Save model parameters. """
        if args.save_model and ctx.episode_number % args.save_model_freq == 0:
            """Only saving model params, not trajectories. Add this in if needed."""
            ctx.principal.save_params(ctx.episode_number)

        """Set new tax rates and return them."""
        tax_vals_per_game = ctx.principal.set_tax_vals(ctx, envs)

        """ Run a training episode - collect trajectories and step agent nets in chunks of sampling_horizon. """
        episode_buffer = run_training_episode(
            ctx, envs, args, agent_trajectory, tax_vals_per_game, log_prefix=f"train/"
        )

        """ Principal-dependent post-episode logic. """
        ctx.principal.after_episode(ctx, envs, episode_buffer, tax_vals_per_game)


    # envs.close()
    wandb.finish()


if __name__ == "__main__":
    main()
