#!/usr/bin/env python3

import argparse
import logging
import os
import sys

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn

import latency_env
import latency_env.misc.argparser_types as at
import latency_env.training.utils as utils
import latency_env.training.mbpac as mbpac
from latency_env.misc import Argument as Arg
from latency_env.misc import ArgumentList as ArgList
from latency_env.misc import ArgumentParser
from latency_env.modules.make import q_critic_mlp, gaussian_policy_mlp

from baseline_config import (
    CONTINUOUS_ENVS as ENVS,
)

parser = ArgumentParser(description="Train a baseline environment using Model-Based Predictive Actor-Critic (MBPAC)")
parser += ArgList(
    env = Arg(choices=ENVS.keys(),
              _auto_default=False,
              help="The environment to train MBPAC on."),
    delay = Arg("-D", "--delay", type=latency_env.delayed_mdp.delay_from_string, default=None,
                  help=f"How much delay in terms of timesteps to use when training the environment."),
    horizon = Arg("--hzn", "--horizon", type=at.posint, default=1,
                  help=f"Horizon to use for the action buffer in the delayed environment."),
    net_dims = Arg("--netdims", type=at.nonnegint_commalist, default=[256,256],
                   help="Comma separated list of the hidden dimensions in the network."),
    model_state_embed_layers = Arg("--msel", "--model-state-embed-layers", type=at.nonnegint, default=2,
                                   help="Number of hidden embedding layers for the input state (s0 -> h0)."),
    model_action_embed_layers = Arg("--mael", "--model-action-embed-layers", type=at.nonnegint, default=2,
                                    help="Number of hidden embedding layers for the input actions."),
    model_recurrent_layers = Arg("--mrecl", "--model-recurrent-layers", type=at.posint, default=1,
                                 help="Number of layers for the recurrence."),
    model_activation = Arg("--mact", "--model-activation", type=str, default="clipsilu",
                           help="Which activation to use for the model."),
    model_ignore_h0 = Arg("--mih0", "--model-ignore-h0", action=argparse.BooleanOptionalAction, default=False,
                          help="Ignore emissions from the first hidden state in the model, instead output directly the input state. (Only for rec models!)"),
    model_include_emit_in_loss = Arg("--mieil", "--model-include-emit-in-loss", action=argparse.BooleanOptionalAction, default=True,
                          help="For RSSMLite models, include emissions in loss."),
    model_recursive_size = Arg("--msz-rec", "--model-recursive-size", type=at.posint, default=384,
                               help="Size of the recursive layer for the model."),
    model_hidden_size = Arg("--msz-hid", "--model-hidden-size", type=at.posint, default=256,
                            help="Size of the hidden layers in the model (excluding the recursive layer)."),
    model_emit_dropout = Arg("--me-dropout", "--model-emit-dropout", type=at.floatr_c0_1c, default=0.0,
                             help="Dropout for model emissions during training. The probability that the loss from an emission is ignored during training."),
)
parser += mbpac.arguments.with_defaults(
    initial_alpha = 0.2,
    lr_q = 3e-4,
    lr_pi = 3e-4,
    lr_alpha = 3e-4,
    lr_model = 1e-4,
    optim_q = "adam",
    optim_pi = "adam",
    optim_alpha = "adam",
    optim_model = "adam",
    discount = 0.99,
    replay_size = 1_000_000,
    batch_size = 256,
    max_trajectory = 10000,
    polyak = 0.005,
)
parser += latency_env.delayed_mdp.training_arguments.with_defaults(
    trainer_iterations=100,
    trainer_itersteps=10000,
    trainer_eval_runs=10,
    trainer_save_interval=100000,
)
parser += utils.logging_arguments
args = parser.parse_args()

env_spec = ENVS[args.env]
def mk_delayed_mdp(delay):
    base_env = env_spec.make()
    return latency_env.SimulatedInteractionLayer(base_env, delay=delay, horizon=args.horizon)

NAME = f"baseline-mbpac-{args.env}"

LOG = logging.getLogger(NAME)
LOG.addHandler(logging.NullHandler())

denv = mk_delayed_mdp(args.delay)
LOG.debug(f"Observation space: {denv.observation_space}")
LOG.debug(f"Action space: {denv.action_space}")

if args.entropy_threshold is None:
    # A good rule of thumb is -dim(A)
    from functools import reduce
    prod = lambda iterable: reduce(lambda x,y:x*y, iterable, 1)

    args.entropy_threshold = -prod(env.action_space.shape)

LOG.debug(f"Arguments: {args}")

utils.seed_all(0)

def mbpac_mkargs():
    pi = gaussian_policy_mlp(denv, hidden_sizes=args.net_dims, override_obs_dim=args.model_recursive_size)
    q1 = q_critic_mlp(denv, hidden_sizes=args.net_dims)
    q2 = q_critic_mlp(denv, hidden_sizes=args.net_dims)
    LOG.debug(f"pi network:\n{pi}")
    LOG.debug(f"q network:\n{q1}")

    assert len(denv.observation_space.shape) == 1
    assert len(denv.action_space.shape) == 1

    emitter = latency_env.modules.model.NNGaussianProbabilisticEmitter(
        state_size=denv.observation_space.shape[0],
        embed_size=args.model_recursive_size,
        hidden_size=args.model_hidden_size,
        common_layers=2,
        head_layers=1,
        dropout=args.model_emit_dropout,
        activation=args.model_activation,
    )

    model = latency_env.modules.model.NNPredictiveRecurrent(
        state_size=denv.observation_space.shape[0],
        action_size=denv.action_space.shape[0],
        emitter=emitter,
        hidden_size=args.model_hidden_size,
        hidden_rec_size=args.model_recursive_size,
        dropout=args.model_emit_dropout,
        emit_dropout=args.model_emit_dropout,
        state_embed_layers=args.model_state_embed_layers,
        action_embed_layers=args.model_action_embed_layers,
        recurrent_layers=args.model_recurrent_layers,
        ignore_h0=args.model_ignore_h0,
        activation=args.model_activation,
    )
    LOG.debug(f"Model structure:\n{model}")

    obs_metric = latency_env.distance.DistanceMetric(
        env_spec.metrics,
        denv.observation_space,
    )
    return (pi, q1, q2, denv, model, obs_metric, args)

LOG.debug("Starting training loop")

latency_env.delayed_mdp.training_loop(
    cls=mbpac.MBPAC, cls_argmaker=mbpac_mkargs,
    denv=denv,
    mk_denv=mk_delayed_mdp,
    args=args,
    name=NAME,
)
