import warnings
warnings.filterwarnings('ignore')
import os
import argparse

import metaworld
import numpy as np

#import sb3_jax
#from sb3_jax.common.norm_layers import RunningNormLayer
#from sb3_jax.common.evaluation import evaluate_policy
import stable_baselines3 as sb3
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback
from stable_baselines3.common.vec_env import VecNormalize, sync_envs_normalization

from utils import check_env, make_env, make_vec_env, evaluate_env, render_env, get_expert

SEED = 0

def main(args):
    # seeding
    benchmark = metaworld.MT10(seed=SEED)
    
    # check valid env
    print(f"<=======Env {args.env}========>")
    if not check_env(args, benchmark): 
        print(f"Not a valid environment {args.env}")
        return
    
    # make gym env 
    env = make_vec_env(args, benchmark, seed=SEED)
    eval_env = make_vec_env(args, benchmark, seed=SEED, training=False, norm_reward=False)
    obs_space, act_space = env.observation_space, env.action_space
    #print(f"Obs Space: {obs_space}, Act Space: {act_space}")

    # make path
    path = f'./models/sac/{args.env}/{args.tag}'
    os.makedirs(path, exist_ok=True)

    # Tranining
    if args.train:
        policy_kwargs = dict(net_arch=dict(pi=[256,256], qf=[256,256]))
        model = sb3.SAC( 
            policy='MlpPolicy',
            env=env,
            learning_starts=1000,
            learning_rate=3e-4,
            buffer_size=1000000,
            batch_size=128,
            gamma=0.99,
            tau=5e-3,
            policy_kwargs=policy_kwargs,
            verbose=1,
            seed=SEED,
        )

        eval_freq = int(args.timesteps/20)
        callbacks = [
            CheckpointCallback(eval_freq, path, verbose=1),
            EvalCallback(eval_env, n_eval_episodes=10, eval_freq=eval_freq, log_path=path),
        ]
        
        print("==== Before Training ====")
        avg_rew, avg_suc = evaluate_env(eval_env, model.policy, n_eval_episodes=10, verbose=True)
        model.learn(total_timesteps=args.timesteps, callback=callbacks, log_interval=10)
        print("==== After Training =====")
        avg_rew, avg_suc = evaluate_env(eval_env, model.policy, n_eval_episodes=10, verbose=True)
        os.makedirs(path, exist_ok=True)
        model.save(path=path+'/rl_model_final_steps.zip')
        env.save(path+'/env.zip')
    
    policy = None

    # If Expert?
    if args.expert:
        expert = get_expert(args)
        policy = expert
        print("Evaluating Expert !!!")
        evaluate_env(eval_env, expert, n_eval_episodes=10, verbose=True)

    # Testing
    model = None
    if args.eval:
        model = sb3.SAC.load(path + f'/rl_model_{args.checkpoint}_steps.zip')
        eval_env = eval_env.load(path+'/env.zip', eval_env)
        eval_env.training, eval_env.norm_reward = False, False
        policy = model.policy
        if not args.render:
            data = np.load(path + '/evaluations.npz')
            #print(data['timesteps'], data['results'], data['ep_lengths'])
            print("Evaluating Policy !!!")
            avg_rew, avg_suc = evaluate_env(eval_env, model.policy, n_eval_episodes=10, verbose=True)
            
    if args.render: # expert or policy
        render_env(args, eval_env, policy, args.fov, path, verbose=True)
    
    if args.collect: 
        render_env(args, eval_env, policy, args.fov, path, collect=True, verbose=True)
    

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # env
    parser.add_argument("--env", type=str, default='reach-v2')
    # training
    parser.add_argument("--timesteps", type=int, default=2_000_000)
    parser.add_argument("--tag", type=str, default='v0')
    # misc
    parser.add_argument("--max_step", type=str, default=500)
    parser.add_argument("--terminate", action="store_true", default=False)
    # domain factors
    parser.add_argument("--domain_factor", type=str, default='TEST')
    parser.add_argument("--fov", type=str, default='cam0-0')
    parser.add_argument("--xwind_id", type=int, default=0)
    parser.add_argument("--gravity_id", type=int, default=0)
    # 
    parser.add_argument("--seed", type=int, default=777)
    parser.add_argument("--train", action="store_true", default=False)
    parser.add_argument("--eval", action="store_true", default=False)
    parser.add_argument("--expert", action="store_true", default=False)
    parser.add_argument("--render", action="store_true", default=False)
    parser.add_argument("--collect", action="store_true", default=False)
    parser.add_argument("--checkpoint", type=str, default='final')
    args = parser.parse_args()

    main(args)
