import torch
import numpy as np
from env import UnicycleEnv
import sys, os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from sac import Actor

def test():
    # 初始化环境
    env = UnicycleEnv()
    s_dim = len(env._get_obs())
    a_dim = 2

    # 加载训练好的策略
    actor = Actor(s_dim, a_dim)
    actor.load_state_dict(torch.load("trained_model/actor_sac_e2e.pt"))
    actor.eval()

    # 测试一条轨迹
    s = env.reset()
    traj = [env.x.copy()]
    done, info = False, {}
    ep_ret = 0
    while not done:
        with torch.no_grad():
            a, logp = actor(torch.FloatTensor(s).unsqueeze(0))
            a = a[0].cpu().numpy()
        s, r, done, info = env.step(a)
        ep_ret += r
        traj.append(env.x.copy())

    print("Episode return:", ep_ret)
    print("Episode info:", info)

    # 可视化轨迹
    env.render(traj)

if __name__ == "__main__":
    test()