import os
import argparse
import gymnasium as gym
from stable_baselines3 import SAC, PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback
from envs import * 

BASELINE_DIR = os.getenv('BASELINE_DIR')  

def parse_args():
    parser = argparse.ArgumentParser(description='Train a RL agent on a specified environment')
    parser.add_argument('--env', type=str, default='Pendulum-v1', 
                        help='Environment ID (default: Pendulum-v1)')
    parser.add_argument('--algo', type=str, choices=['sac', 'ppo'], default='sac',
                        help='Algorithm to use: sac or ppo (default: sac)')
    parser.add_argument('--exp_name', type=str, default=None,
                        help='Experiment name (default: {algo}_model)')
    parser.add_argument('--total_timesteps', type=int, default=1_000_000,
                        help='Total timesteps for training (default: 1,000,000)')
    parser.add_argument('--eval_episodes', type=int, default=10,
                        help='Number of episodes for evaluation (default: 5)')
    parser.add_argument('--eval_freq', type=int, default=1000,
                        help='Frequency of evaluation during training (default: 10000 steps)')
    return parser.parse_args()

def create_agent(algo, env, tensorboard_path):
    """Create either a SAC or PPO agent based on the algorithm choice."""
    if algo.lower() == 'sac':
        return SAC(
            policy="MlpPolicy",
            # policy="MultiInputPolicy",
            env=env,
            verbose=1,
            tensorboard_log=tensorboard_path
        )
    elif algo.lower() == 'ppo':
        return PPO(
            # policy="MlpPolicy",
            policy="MultiInputPolicy",
            env=env,
            verbose=1,
            tensorboard_log=tensorboard_path
        )
    else:
        raise ValueError(f"Unsupported algorithm: {algo}. Choose 'sac' or 'ppo'.")
 
if __name__ == "__main__":
    # Parse command line arguments
    args = parse_args()
    
    # Set default experiment name if not provided
    if args.exp_name is None:
        args.exp_name = f"{args.algo}_model"
    
    # Create directories for saving model, logs, etc.
    exp_path = os.path.join(BASELINE_DIR, f"{args.env}_{args.algo}_{args.exp_name}")
    os.makedirs(exp_path, exist_ok=True)
    
    model_path = os.path.join(exp_path, "model")
    tensorboard_path = os.path.join(exp_path, "tensorboard")
    os.makedirs(os.path.dirname(tensorboard_path), exist_ok=True)
     
    # Create and wrap environment
    env = gym.make(args.env)
    env = Monitor(env)  # logs training metrics
    env = DummyVecEnv([lambda: env])  # wraps env for vectorization
    
    # Create separate evaluation environment
    eval_env = gym.make(args.env)
    eval_env = Monitor(eval_env)
    eval_env = DummyVecEnv([lambda: eval_env])
     
    # Initialize model based on algorithm choice
    model = create_agent(args.algo, env, tensorboard_path)
    
    # Evaluation callback
    eval_callback = EvalCallback(
        eval_env,
        best_model_save_path=os.path.join(exp_path, "best_model"),
        log_path=os.path.join(exp_path, "eval_results"),
        eval_freq=args.eval_freq,
        n_eval_episodes=5,
        deterministic=True,
        verbose=1
    )
    
    # Checkpoint callback for regular model saving
    checkpoint_callback = CheckpointCallback(
        save_freq=args.eval_freq,
        save_path=os.path.join(exp_path, "checkpoints"),
        name_prefix=f"{args.algo}_model",
        verbose=1
    )
    
    print(f"Training {args.algo.upper()} on {args.env} for {args.total_timesteps} timesteps")
    
    # Train model with callbacks
    model.learn(
        total_timesteps=args.total_timesteps,
        callback=[eval_callback, checkpoint_callback]
    )
    
    # Save final model
    model.save(model_path)
    print(f"Final model saved to {model_path}")
    
    # Evaluate model
    print(f"Evaluating for {args.eval_episodes} episodes...")
    mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=args.eval_episodes)
    print(f"Mean reward: {mean_reward:.2f} ± {std_reward:.2f}")
    
    

# python baseline.py --env=ant_maze --algo=sac --exp_name="sac" --total_timesteps=3000000 --eval_freq=10000