import argparse
import os
import numpy as np

from rql.agents.ORBIT import ORBITAgent
from rql.agents.robust_q import RobustQAgent, evaluate_policy
from rql.envs.simulatedMDP import SimulatedMDP


def parse_args():
	parser = argparse.ArgumentParser(description="Train and evaluate robust Q-learning policies.")
	parser.add_argument("--est_type", choices=["regularized"], default="regularized")
	parser.add_argument("--est_div", choices=["TV"], default="TV")
	parser.add_argument("--lambda_val", type=float, default=0.1)
	parser.add_argument("--c", type=float, default=1)
	parser.add_argument("--episodes", type=int, default=1000, help="Training episodes per seed.")
	parser.add_argument("--seeds", type=int, default=20, help="Number of training seeds.")
	parser.add_argument("--perturb_end", type=float, default=1.0, help="Max perturbation to test.")
	parser.add_argument("--perturb_step", type=float, default=0.05, help="Perturbation grid step.")
	parser.add_argument("--eval_seed", type=int, default=0, help="Seed for evaluation rollouts.")
	parser.add_argument("--out_dir", type=str, default="res", help="Output directory.")
	parser.add_argument("--agent", choices=["orbit", "robustq"], default="orbit")
	return parser.parse_args()


def train_seeds(args):
	policies = []
	os.makedirs(args.out_dir, exist_ok=True)
	for seed in range(args.seeds):
		print(f"Training seed {seed}")
		np.random.seed(seed)
		env = SimulatedMDP()
		if args.agent == "orbit":
			agent = ORBITAgent(
				env,
				est_type=args.est_type,
				est_div=args.est_div,
				lambda_val=args.lambda_val,
				step_size=args.c,
				n_episodes=args.episodes,
				seed=seed,
			)
		else:
			agent = RobustQAgent(
				env,
				lambda_val=args.lambda_val,
				c=args.c,
				n_episodes=args.episodes,
				seed=seed,
			)
		pi = agent.train()
		policies.append(pi)
		np.save(os.path.join(args.out_dir, f"pi_seed_{seed}.npy"), pi)
	return policies


def evaluate_policies(policies, args):
	perturb_grid = np.arange(0, args.perturb_end + 1e-8, args.perturb_step)
	results_mean, results_std = [], []
	for perturb in perturb_grid:
		scores = []
		for pi in policies:
			env = SimulatedMDP()
			scores.append(evaluate_policy(env, pi, perturb=perturb))
		results_mean.append(np.mean(scores))
		results_std.append(np.std(scores))
	return perturb_grid, np.array(results_mean), np.array(results_std)


def main():
	args = parse_args()
	# Set base eval seed first so the training seeds start from a known state across runs.
	np.random.seed(args.eval_seed)
	policies = train_seeds(args)
	# Ensure evaluation randomness is deterministic across runs unless user changes eval_seed.
	np.random.seed(args.eval_seed)
	perturb_grid, means, stds = evaluate_policies(policies, args)

	out_path = os.path.join(args.out_dir, "res_summary.npy")
	np.save(out_path, np.vstack([means, stds]))
	print("Saved summary to", out_path)
	print("Perturb grid:", perturb_grid)
	print("Mean returns:", means)
	print("Std returns:", stds)


if __name__ == "__main__":
	main()
