import hydra
from stable_baselines3 import TD3
from stable_baselines3.common.vec_env import DummyVecEnv,  DummyVecEnv
import wandb
from wandb.integration.sb3 import WandbCallback
from stable_baselines3.common.callbacks import CallbackList
import numpy as np 
from stable_baselines3.common.env_util import make_vec_env
from utils import  SuperEvalCallback
from omegaconf import DictConfig
from run_test import evaluate

PATH="/Users/pierreclavier/Documents/These/Code/Expectile/Expectile_RL/test_expectile/configs/environment/Antv2-1_3.yaml"
CONFIG_NAME="Antv2-1_3.yaml"

@hydra.main(config_path="configs/environment/",config_name=CONFIG_NAME)
def main(args :DictConfig) :

    print(args)
    
    config = {
                "policy_type": "MlpPolicy",
                "total_timesteps": args.total_timesteps,
                "env": args.env_name,
                
            }
    
    run = wandb.init(
                        project=f"DR_{args.env_name}_{args.model}",
                        config=config,
                        sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
                        monitor_gym=False,  # auto-upload the videos of agents playing the game
                        save_code=False,  # optional
                    )
    wandbCallback=WandbCallback(
                            model_save_path=f"models/{run.id}",
                            verbose=2,)
                            
    vec_env = make_vec_env(config["env"], n_envs=args.n_envs,seed=args.seed ,vec_env_cls=DummyVecEnv)#vec_env_kwargs=dict(start_method='forkserver'), 

    supercallback = SuperEvalCallback(
                        eval_env=vec_env,
                        eval_freq=50,
                        total_timesteps=args.total_timesteps,
                        verbose=1,
                        best_model_save_path="./model/" + f"best_model_{args.env_name}_{args.model}_{run.id}/")
    
    callback = CallbackList([wandbCallback,supercallback])

    model = TD3(config["policy_type"], vec_env, verbose=1, tensorboard_log=f"runs/{run.id} "
                    , batch_size= 256    
                    ,learning_starts= 10000,
                    gamma= 0.98,
                    train_freq= 16,
                    tau=0.005,
                    #target_update_interval=1,
                    gradient_steps= 1,
                    learning_rate=3e-4,
                    buffer_size=300000,
                    expectile=0.5
                        )
    model.learn(
                        total_timesteps=args.total_timesteps,
                        callback=callback,
                        progress_bar=True
                    )
    
    best_model=TD3.load( "./model/" + f"best_model_{args.env_name}_{args.model}_{run.id}/"+"best_model.zip")

    evaluate(PATH+CONFIG_NAME,evaluate_num=args.evaluate_num, seed=args.seed ,best_model=model)
    evaluate(PATH+CONFIG_NAME,evaluate_num=args.evaluate_num, seed=args.seed ,best_model=best_model,best=True)

if __name__ == "__main__":
    main()
   
   
