

PATH="./configs/environment/"
CONFIG_NAME1="HalfCheetahv2-1_4.yaml"
CONFIG_NAME2="HalfCheetahv2-2_4_7.yaml"
CONFIG_NAME3="HalfCheetahv2-3_4_7_4.yaml"



import hydra
from stable_baselines3 import TD3, TD3_auto
from stable_baselines3.common.vec_env import DummyVecEnv,  DummyVecEnv, subproc_vec_env
from stable_baselines3.common.monitor import Monitor
import wandb
from wandb.integration.sb3 import WandbCallback
from stable_baselines3.common.callbacks import CallbackList
import numpy as np 
from stable_baselines3.common.env_util import make_vec_env, make_vec_env2 , unwrap_wrapper_dummy , unwrap_wrapper
from utils import  SuperEvalCallback
from omegaconf import DictConfig
from run_test import evaluate, Wrapper_DR, Wrapper_auto_DR, Wrapper_auto
import os


os.environ["WANDB_MODE"] = "offline"



@hydra.main(config_path="configs/environment/",config_name=CONFIG_NAME3)
def main(args :DictConfig) :

    print(args)
    run = wandb.init(
                        project=f"Auto_{args.auto}_DR_{args.dr}_{args.env_name}_{args.model}",
                        config=vars(args)['_content'],
                        sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
                        monitor_gym=False,  # auto-upload the videos of agents playing the game
                        save_code=False,  # optional
                    )
    wandbCallback=WandbCallback(
                            model_save_path=f"models/{run.id}",
                            verbose=2,)
    
    if (args.dr==True) & (args.auto==True): 

        vec_env = make_vec_env(args["env_name"], n_envs=args.n_envs,seed=args.seed ,vec_env_cls=DummyVecEnv,wrapper_class=Wrapper_auto_DR ,wrapper_kwargs={"arg" : args , "seed":args.seed}  )
        
    if (args.dr==True)  & (args.auto==False):
        
        vec_env = make_vec_env(args["env_name"], n_envs=args.n_envs,seed=args.seed ,vec_env_cls=DummyVecEnv,wrapper_class=Wrapper_DR ,wrapper_kwargs={"arg" : args , "seed":args.seed}  )
       
    if (args.dr==False)  & (args.auto==True):

        vec_env = make_vec_env(args["env_name"], n_envs=args.n_envs,seed=args.seed ,vec_env_cls=DummyVecEnv,wrapper_class=Wrapper_auto )
       
    if (args.dr==False)  & (args.auto==False):
        vec_env = make_vec_env2(args["env_name"], n_envs=args.n_envs,seed=args.seed ,vec_env_cls=DummyVecEnv)#vec_env_kwargs=dict(start_method='forkserver'), 
       

    supercallback = SuperEvalCallback(
                        eval_env=vec_env,
                        eval_freq=20000,
                        total_timesteps=args.total_timesteps,
                        verbose=1,
                        auto=args.auto,
                        expectile_max=args.expectile_max,
                        best_model_save_path="./model/" + f"best_model_{args.env_name}_{args.model}_{run.id}/")
    
    callback = CallbackList([wandbCallback,supercallback])
    if args.auto==False:
        model = TD3("MlpPolicy", vec_env, verbose=1, tensorboard_log=f"runs/{run.id} "
                        , batch_size= 100    
                        ,learning_starts= 70000,
                        gamma= 0.99,
                        train_freq= 100,
                        gradient_steps= 100,
                        learning_rate=args.learning_rate,
                        buffer_size=300000,
                        expectile=args.expectile,
                        
                     
                        
                            )
    else : model = TD3_auto("MlpPolicy", vec_env, verbose=1, tensorboard_log=f"runs/{run.id} "
                        , batch_size= 100    
                        ,learning_starts= 70000,
                        gamma= 0.99,
                        train_freq= 100,
                        gradient_steps= 100,
                        learning_rate=args.learning_rate,
                        buffer_size=300000,
                        expectile=args.expectile ,   bandit_lr=args.bandit_lr)
    model.learn(
                        total_timesteps=args.total_timesteps,
                        callback=callback,
                        progress_bar=True
                    )
    
    best_model=TD3.load( "./model/" + f"best_model_{args.env_name}_{args.model}_{run.id}/"+"best_model.zip")

    evaluate(args,evaluate_num=args.evaluate_num, seed=args.seed ,best_model=model, auto=args.auto, expectile_max=args.expectile_max)
    evaluate(args,evaluate_num=args.evaluate_num, seed=args.seed ,best_model=best_model,best=True, auto=args.auto, expectile_max=args.expectile_max)
if __name__ == "__main__":
    main()
   
