#!/usr/bin/env python3

import argparse
import logging
import os
import sys

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn

import latency_env
import latency_env.training.utils as utils
import latency_env.training.sac as sac
from latency_env.modules.make import q_critic_mlp, gaussian_policy_mlp

import latency_env.misc.argparser_types as at
from latency_env.misc import Argument as Arg
from latency_env.misc import ArgumentList as ArgList
from latency_env.misc import ArgumentParser

from baseline_config import (
    CONTINUOUS_ENVS as ENVS,
)

parser = ArgumentParser(description="Train a baseline environment using Soft Actor-Critic")
parser += ArgList(
    env = Arg(choices=ENVS.keys(),
              _auto_default=False,
              help="The environment to train SAC on."),
    delay = Arg("-D", "--delay", type=latency_env.delayed_mdp.delay_from_string, default=None,
                  help=f"How much delay in terms of timesteps to use when training the environment."),
    horizon = Arg("--hzn", "--horizon", type=at.posint, default=1,
                  help=f"Horizon to use for the action buffer in the delayed environment."),
    net_dims = Arg("--netdims", type=at.nonnegint_commalist, default=[256,256],
                   help="Comma separated list of the hidden dimensions in the network."),
    pi_activation = Arg("--pi-activation", type=str.lower, default="relu",
                        help="Activation to use between hidden layers in the policy network."),
    q_activation = Arg("--q-activation", type=str.lower, default="relu",
                       help="Activation to use between hidden layers in the critic network."),
    memorize_actions = Arg("--memorize", metavar="N", type=at.posint, default=None,
                           help="Memorize the last N sent actions."),
)
parser += sac.arguments.with_defaults(
    initial_alpha = 0.2,
    lr_q = 3e-4,
    lr_pi = 3e-4,
    lr_alpha = 3e-4,
    entropy_threshold = None,
    optim_q = "adam",
    optim_pi = "adam",
    discount = 0.99,
    replay_size = 1_000_000,
    batch_size = 256,
    max_trajectory = 10000,
    polyak = 0.005,
).modify("entropy_threshold",
    _auto_default=False,
    help=lambda h: h + " (Default: -dim(A))",
)
parser += latency_env.delayed_mdp.training_arguments.with_defaults(
    trainer_iterations=100,
    trainer_itersteps=10000,
    trainer_eval_runs=10,
    trainer_save_interval=100000,
)
parser += utils.logging_arguments
args = parser.parse_args()

env_spec = ENVS[args.env]
def mk_delayed_mdp(delay):
    horizon = args.horizon
    if args.memorize_actions is not None:
        horizon = args.memorize_actions
        if args.memorize_actions != args.horizon:
            LOG.warning(f"args.memorize_actions != args.horizon: {args.memorize_actions} != {args.horizon}")

    e = latency_env.SimulatedInteractionLayer(
        env_spec.make(),
        delay=delay,
        horizon=horizon,
    )
    if args.memorize_actions is not None:
        e = latency_env.ActionMemorizer(e, horizon=horizon)
    return e

NAME = f"baseline-sac-{args.env}"

utils.seed_all(0)
LOG = logging.getLogger(NAME)
LOG.addHandler(logging.NullHandler())

denv = mk_delayed_mdp(args.delay)
LOG.debug(f"Observation space: {denv.observation_space}")
LOG.debug(f"Action space: {denv.action_space}")

if args.entropy_threshold is None:
    # A good rule of thumb is -dim(A)
    from functools import reduce
    prod = lambda iterable: reduce(lambda x,y:x*y, iterable, 1)

    args.entropy_threshold = -prod(denv.action_space.shape)

LOG.debug(f"Arguments: {args}")

def sac_mkargs():
    pi = gaussian_policy_mlp(denv, hidden_sizes=args.net_dims, activation=args.pi_activation)
    q1 = q_critic_mlp(denv, hidden_sizes=args.net_dims, activation=args.q_activation)
    q2 = q_critic_mlp(denv, hidden_sizes=args.net_dims, activation=args.q_activation)
    LOG.debug(f"pi network:\n{pi}")
    LOG.debug(f"q network:\n{q1}")
    return (pi, q1, q2, denv, args)

LOG.debug("Starting training loop")

latency_env.delayed_mdp.training_loop(
    cls=sac.SAC, cls_argmaker=sac_mkargs,
    denv=denv,
    mk_denv=mk_delayed_mdp,
    args=args,
    name=NAME,
)
