from __future__ import annotations

import argparse
import os
import numpy as np
import torch

from gdc_project.utils.helpers import set_seed, make_env
from gdc_project.gdc.sac_agent import SACAgent, SACConfig
from gdc_project.gdc.gdc_sac_agent import GDCSACAgent, GDCSACConfig


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--env_id", type=str, default="OptimalTrap-v0")
    parser.add_argument("--episodes", type=int, default=10)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--agent", type=str, default="sac", choices=["sac", "a_gdc_sac"])
    parser.add_argument("--ckpt", type=str, default=None)
    args = parser.parse_args()

    set_seed(args.seed)
    env = make_env(args.env_id, args.seed)

    obs_dim = int(np.prod(env.observation_space.shape))
    act_dim = int(np.prod(env.action_space.shape))
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if args.agent == "a_gdc_sac":
        agent = GDCSACAgent(GDCSACConfig(obs_dim=obs_dim, act_dim=act_dim, hidden_dims=(256, 256)), device)
    else:
        agent = SACAgent(SACConfig(obs_dim=obs_dim, act_dim=act_dim, hidden_dims=(256, 256)), device)
    if args.ckpt and os.path.isfile(args.ckpt):
        sd = torch.load(args.ckpt, map_location=device)
        agent.load_state_dict(sd)

    returns, costs = [], []
    for ep in range(args.episodes):
        obs, _ = env.reset()
        done = False
        ep_ret, ep_cost = 0.0, 0.0
        while not done:
            act = agent.select_action(obs, deterministic=True)
            obs, rew, term, trunc, info = env.step(act)
            done = term or trunc
            ep_ret += rew
            ep_cost += info.get("cost", 0.0)
        returns.append(ep_ret)
        costs.append(ep_cost)
        print(f"Episode {ep}: return={ep_ret:.2f} cost={ep_cost:.2f}")

    print(f"Avg return={np.mean(returns):.2f}±{np.std(returns):.2f}, Avg cost={np.mean(costs):.2f}")


if __name__ == "__main__":
    main()
