#!/usr/bin/env python3

import argparse
import logging
import os
import sys

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn

import latency_env
import latency_env.training.utils as utils
import latency_env.training.bpql as bpql
from latency_env.modules.make import q_critic_mlp, gaussian_policy_mlp

import latency_env.misc.argparser_types as at
from latency_env.misc import Argument as Arg
from latency_env.misc import ArgumentList as ArgList
from latency_env.misc import ArgumentParser

from baseline_config import (
    CONTINUOUS_ENVS as ENVS,
)

parser = ArgumentParser(description="Train a baseline environment using Belief-Projection-Based Q-learning")
parser += ArgList(
    env = Arg(choices=ENVS.keys(),
              _auto_default=False,
              help="The environment to train BPQL on."),
    delay = Arg("-D", "--delay", type=latency_env.delayed_mdp.delay_from_string, default=None,
                  help=f"How much delay in terms of timesteps to use when training the environment."),
    horizon = Arg("--hzn", "--horizon", type=at.posint, default=1,
                  help=f"Horizon to use for the action buffer in the delayed environment."),
    net_dims = Arg("--netdims", type=at.nonnegint_commalist, default=[256,256],
                   help="Comma separated list of the hidden dimensions in the network."),
    pi_activation = Arg("--pi-activation", type=str.lower, default="relu",
                        help="Activation to use between hidden layers in the policy network."),
    q_activation = Arg("--q-activation", type=str.lower, default="relu",
                       help="Activation to use between hidden layers in the critic network."),
)
parser += bpql.arguments.with_defaults(
    initial_alpha = 0.2,
    lr_q = 3e-4,
    lr_pi = 3e-4,
    lr_alpha = 3e-4,
    entropy_threshold = None,
    optim_q = "adam",
    optim_pi = "adam",
    discount = 0.99,
    replay_size = 1_000_000,
    batch_size = 256,
    max_trajectory = 10000,
    polyak = 0.005,
).modify("entropy_threshold",
    _auto_default=False,
    help=lambda h: h + " (Default: -dim(A))",
)
parser += latency_env.delayed_mdp.training_arguments.with_defaults(
    trainer_iterations=100,
    trainer_itersteps=10000,
    trainer_eval_runs=10,
    trainer_save_interval=100000,
)
parser += utils.logging_arguments
args = parser.parse_args()

env_spec = ENVS[args.env]
def mk_delayed_mdp(delay):
    e = latency_env.SimulatedInteractionLayer(
        env_spec.make(),
        delay=delay,
        horizon=args.assumed_delay,
    )
    if args.assumed_delay != args.horizon:
        LOG.warning(f"args.assumed_delay != args.horizon: {args.assumed_delay} != {args.horizon}")
    e = latency_env.ActionMemorizer(e, horizon=args.assumed_delay, obs_passthrough=True)
    return e

NAME = f"baseline-bpql-{args.env}"

LOG = logging.getLogger(NAME)
LOG.addHandler(logging.NullHandler())

denv = mk_delayed_mdp(args.delay)
LOG.debug(f"Observation space: {denv.observation_space}")
LOG.debug(f"Action space: {denv.action_space}")

if args.entropy_threshold is None:
    # A good rule of thumb is -dim(A)
    from functools import reduce
    prod = lambda iterable: reduce(lambda x,y:x*y, iterable, 1)

    args.entropy_threshold = -prod(denv.action_space.shape)

LOG.debug(f"Arguments: {args}")

utils.seed_all(0)

def bpql_mkargs():
    def _prod(iterable): from functools import reduce; return reduce(lambda x,y:x*y,iterable,1)
    obs_dim = _prod(denv.observation_space.shape) + _prod(denv.action_space.shape) * args.assumed_delay

    pi = gaussian_policy_mlp(
        denv,
        hidden_sizes=args.net_dims,
        activation=args.pi_activation,
        override_obs_dim=obs_dim,
    )
    q1 = q_critic_mlp(denv, hidden_sizes=args.net_dims, activation=args.q_activation)
    q2 = q_critic_mlp(denv, hidden_sizes=args.net_dims, activation=args.q_activation)
    LOG.debug(f"pi network:\n{pi}")
    LOG.debug(f"q network:\n{q1}")
    return (pi, q1, q2, denv, args)

if args.assumed_delay != args.delay:
    LOG.warning("Training under the assumption of a different delay (--assumed-delay) to the one that is being used in the environment.")

LOG.debug("Starting training loop")

latency_env.delayed_mdp.training_loop(
    cls=bpql.BPQL, cls_argmaker=bpql_mkargs,
    denv=denv,
    mk_denv=mk_delayed_mdp,
    args=args,
    name=NAME,
)
