import sys, os, time
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

t0 = time.time()
pid = str(os.getpid())
if "SLURM_JOB_ID" in os.environ:
    jobid = str(os.environ["SLURM_JOB_ID"])
else:
    jobid = pid
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = str(0.6) # instead of default 0.9 see:  https://github.com/google/jax/issues/13504
from brax import envs
from absl import flags, app
from ml_collections import config_flags
from ml_collections import ConfigDict
import pickle

import src.brax.svginf.train as svg
import src.brax.arm.train as arm
import src.brax.sac.train as sac
from src.misc.helper_methods import moving_avg

from src.brax.custom_envs.double_pendulum import DoublePendulum, get_goal_state

from src.utils import logger, system

FLAGS = flags.FLAGS

config_flags.DEFINE_config_file('config_env',  './configs/envs/exp4_toy_ca.py', 'file path to the environment configuration.', lock_config=False)
config_flags.DEFINE_config_file('config_rl', './configs/rl/arm_default.py', 'file path to the RL algorithm configuration.', lock_config=False)

flags.DEFINE_boolean('debug', False, 'debug')
flags.DEFINE_integer('seed', 42, 'seed int')

def remove_key(cfg, k):
    dct = cfg.to_dict()
    dct.pop(k)
    return ConfigDict(dct)

def progress_fn(step: int, metrics):
    logger.record_step('env_steps', step)
    for k, v in metrics.items():
        logger.record_tabular(k, v)
    logger.dump_tabular()

def main(argv):
    config_env = FLAGS.config_env
    config_rl = FLAGS.config_rl
    
    
    print('saving files')
    
    # logger
    env_name = "{0}-{1}".format(config_env.env_type, config_env.length)
    run_name = f"{env_name}/"
    run_name += f"{system.now_str()}+{jobid}-{pid}"
    format_strs = ['csv']
    if FLAGS.debug:
        format_strs.extend(['stdout', 'logs'])
    log_path = os.path.join('./logs/double-pend-final-final', run_name)
    logger.configure(dir=log_path, format_strs=format_strs)
    # write flags to a txt
    key_flags = FLAGS.get_key_flags_for_module(argv[0])
    with open(os.path.join(log_path, "flags.txt"), "w") as text_file:
        text_file.write("\n".join(f.serialize() for f in key_flags) + "\n")
    # write flags to pkl
    with open(os.path.join(log_path, "flags.pkl"), "wb") as f:
        pickle.dump(FLAGS.flag_values_dict(), f)
        
    print('creating env')
    
    ## creating env
    goal_action = [-0.4 * 180, 0.3 * 180]
    goal_state, goal_action = get_goal_state(config_env.length, config_env.dt, goal_action)
    env = DoublePendulum(goal_state, goal_action/180., config_env.length, config_env.dt)
    eval_env = DoublePendulum(goal_state, goal_action/180., config_env.length, config_env.dt)

    if config_rl.alg == 'arm':
        alg = arm
        # max unroll, no critics
        if config_rl.unroll_length == -1:
            config_rl.unroll_length = config_env.length
        if config_rl.sequence_model_name == 'gpt':
            config_rl.sequence_model_params = {'name': config_rl.sequence_model_name,
                                            "transformer_nlayers" : 2,
                                            "transformer_nheads": 3,
                                            "transformer_pdrop": 0.1}
        if config_rl.sequence_model_name == 'true_pg':
            config_rl.sequence_model_params = {'name': 'lstm'} #placeholder not used
            config_rl.true_pg = True
            config_rl.dynamics_update_every = 10000000
        else:
            config_rl.sequence_model_params = {'name': config_rl.sequence_model_name}
        
        config_rl = remove_key(config_rl, 'sequence_model_name')
            
    elif config_rl.alg == 'svg':
        alg = svg
        # Make sure imagined trajectories dont get to terminal states
        config_rl.chunk_length = config_rl.unroll_length + 1

        # full unroll, no critic
        if config_rl.unroll_length == -1:
            config_rl.unroll_length = config_env.length
            config_rl.chunk_length = config_env.length
            config_rl.bootstrap = 0.

    elif config_rl.alg == 'sac':
        alg = sac
    
    # for logging purposes, it only works when evaluation steps coincide with update steps
    if config_rl.eval_every == -1:
        config_rl.eval_every = config_rl.policy_update_every * 2
    
    # removing unnecessary key
    config_rl = remove_key(config_rl, 'alg')
    
    config_rl.seed = int(pid)

    _ = alg.train(env, eval_env, config_env.length,
               progress_fn=progress_fn,
               **config_rl)

if __name__ == '__main__':
    app.run(main)
