import argparse
import os
import time

import numpy as np

from rql.agents.ORBIT import ORBITAgent
from rql.agents.robust_q import RobustQAgent
from rql.envs.simulatedMDP import SimulatedMDP


def parse_args():
	parser = argparse.ArgumentParser(description="Train robust Q-learning on the SimulatedMDP.")
	parser.add_argument("--est_type", choices=["regularized"])
	parser.add_argument("--est_div", choices=["TV"])
	parser.add_argument("--lambda_val", type=float)
	parser.add_argument("--c", type=float)
	parser.add_argument("--episodes", type=int, default=1000)
	parser.add_argument("--seed", type=int, default=0)
	parser.add_argument("--out_dir", type=str, default="res")
	parser.add_argument("--agent", choices=["orbit", "robustq"], default="orbit")
	return parser.parse_args()


def main():
	args = parse_args()
	np.random.seed(args.seed)
	env = SimulatedMDP()
	if args.agent == "orbit":
		agent = ORBITAgent(
			env,
			est_type=args.est_type,
			est_div=args.est_div,
			lambda_val=args.lambda_val,
			step_size=args.c,
			n_episodes=args.episodes,
			seed=args.seed,
		)
	else:
		agent = RobustQAgent(
			env,
			lambda_val=args.lambda_val,
			c=args.c,
			n_episodes=args.episodes,
			seed=args.seed,
		)

	start = time.time()
	pi = agent.train()
	elapsed = time.time() - start

	os.makedirs(args.out_dir, exist_ok=True)
	out_path = os.path.join(args.out_dir, f"pi_seed_{args.seed}.npy")
	np.save(out_path, pi)

	print(f"Saved policy to {out_path}")
	print(f"Training time: {elapsed:.2f}s")


if __name__ == "__main__":
	main()
