import torch, torch.nn as nn, torch.nn.functional as F
import numpy as np
from env import UnicycleEnv
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from sac import SAC, Replay

def train():
    env=UnicycleEnv()
    s_dim=len(env._get_obs()); a_dim=2
    agent=SAC(s_dim,a_dim)
    buf=Replay(s_dim,a_dim)
    epochs=500
    for ep in range(epochs):
        s=env.reset(); ep_ret=0; traj=[env.x.copy()]
        done=False; info={}
        while not done:
            a=agent.act(s)
            s2,r,done,info=env.step(a); ep_ret+=r
            buf.store(s,a,r,s2,float(done)); s=s2; traj.append(env.x.copy())
            if buf.size>400: agent.update(buf)
        print(f"Epoch {ep}, Return {ep_ret:.2f}, Done info {info}")
    torch.save(agent.actor.state_dict(),"trained_model/actor_sac_e2e.pt")



if __name__=="__main__":
    train()
