import argparse
from pathlib import Path

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn


class QNetwork(nn.Module):
	def __init__(self, obs_dim, act_dim, hidden=128):
		super().__init__()
		self.net = nn.Sequential(
			nn.Linear(obs_dim, hidden),
			nn.ReLU(),
			nn.Linear(hidden, hidden),
			nn.ReLU(),
			nn.Linear(hidden, act_dim),
		)

	def forward(self, x):
		return self.net(x)


def load_policy(ckpt_path, device="cpu"):
	# weights_only=False is needed because the checkpoint includes numpy arrays (action_set).
	ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
	args = ckpt.get("args", {})
	obs_dim = ckpt["obs_dim"]
	act_dim = ckpt["act_dim"]
	hidden = args.get("hidden", 128)
	q_net = QNetwork(obs_dim, act_dim, hidden=hidden).to(device)
	q_net.load_state_dict(ckpt["q_state_dict"])
	q_net.eval()
	action_set = ckpt["action_set"]
	return q_net, action_set, args


def evaluate(
	q_net,
	action_set,
	device,
	episodes=10,
	max_episode_steps=500,
	seed=0,
	perturb=0.0,
	xml_file=None,
	record_video=False,
	video_dir="videos_eval",
	lambda_val=None,
):
	env_kwargs = {"max_episode_steps": max_episode_steps}
	if xml_file:
		env_kwargs["xml_file"] = xml_file
	env = gym.make("InvertedPendulum-v5", **env_kwargs)
	rewards = []
	for ep in range(episodes):
		# For the last episode, optionally record video
		if ep == episodes - 1 and record_video:
			video_env = gym.make("InvertedPendulum-v5", **{**env_kwargs, "render_mode": "rgb_array"})
			len_tag = "default"
			if xml_file:
				len_tag = Path(xml_file).stem.rsplit("_", 1)[-1]
			lambda_tag = "na" if lambda_val is None else str(int(lambda_val))
			name_prefix = f"eval_len-{len_tag}_pert-{perturb}_lambda-{lambda_tag}"
			video_env = gym.wrappers.RecordVideo(video_env, video_folder=video_dir, name_prefix=name_prefix)
			obs, _ = video_env.reset(seed=seed + ep)
		else:
			video_env = None
			obs, _ = env.reset(seed=seed + ep)
		ep_rew = 0.0
		for _ in range(max_episode_steps):
			with torch.no_grad():
				q_vals = q_net(torch.as_tensor(obs, dtype=torch.float32, device=device).unsqueeze(0))
				act_idx = int(torch.argmax(q_vals, dim=1).item())
			if np.random.rand() < perturb:
				act_idx = np.random.randint(len(action_set))
			action = np.array([action_set[act_idx]], dtype=np.float32)
			if video_env is not None:
				obs, reward, terminated, truncated, _ = video_env.step(action)
			else:
				obs, reward, terminated, truncated, _ = env.step(action)
			ep_rew += reward
			if terminated or truncated:
				break
		rewards.append(ep_rew)
		if video_env is not None:
			video_env.close()
	env.close()
	return rewards


def main():
	parser = argparse.ArgumentParser(description="Evaluate a saved clipped DQN policy on InvertedPendulum-v5.")
	parser.add_argument("checkpoint", type=str, help="Path to saved checkpoint (dqn_policy.pt).")
	parser.add_argument("--episodes", type=int, default=10, help="Number of evaluation episodes.")
	parser.add_argument("--max_episode_steps", type=int, default=500, help="Max steps per episode.")
	parser.add_argument("--device", type=str, default="cpu", help="cpu or cuda.")
	parser.add_argument("--seed", type=int, default=0, help="Base seed for eval episodes.")
	parser.add_argument("--perturb", type=float, default=0.0, help="Probability of random action instead of greedy.")
	parser.add_argument("--xml_file", type=str, default=None, help="Optional custom MuJoCo XML to override pole length, etc.")
	parser.add_argument("--record_video", action="store_true", help="Record the last evaluation episode.")
	parser.add_argument("--video_dir", type=str, default="videos", help="Directory to save evaluation video.")
	args = parser.parse_args()

	q_net, action_set, saved_args = load_policy(args.checkpoint, device=args.device)
	lambda_val = saved_args.get("lambda_val", None)
	rewards = evaluate(
		q_net,
		action_set,
		device=args.device,
		episodes=args.episodes,
		max_episode_steps=args.max_episode_steps,
		seed=args.seed,
		perturb=args.perturb,
		xml_file=args.xml_file,
		record_video=args.record_video,
		video_dir=args.video_dir,
		lambda_val=lambda_val,
	)
	print("Per-episode cumulative returns:", rewards)
	print(f"Eval rewards over {args.episodes} eps: mean {np.mean(rewards):.2f}, std {np.std(rewards):.2f}")


if __name__ == "__main__":
	main()
