import argparse
from pathlib import Path
from types import SimpleNamespace

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import wandb

import train as train_module


class QNetwork(nn.Module):
	def __init__(self, obs_dim, act_dim, hidden=128):
		super().__init__()
		self.net = nn.Sequential(
			nn.Linear(obs_dim, hidden),
			nn.ReLU(),
			nn.Linear(hidden, hidden),
			nn.ReLU(),
			nn.Linear(hidden, act_dim),
		)

	def forward(self, x):
		return self.net(x)


def load_policy(ckpt_path, device="cpu"):
	ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
	args = ckpt.get("args", {})
	obs_dim = ckpt["obs_dim"]
	act_dim = ckpt["act_dim"]
	hidden = args.get("hidden", 128)
	q_net = QNetwork(obs_dim, act_dim, hidden=hidden).to(device)
	q_net.load_state_dict(ckpt["q_state_dict"])
	q_net.eval()
	action_set = ckpt["action_set"]
	return q_net, action_set


def evaluate(q_net, action_set, device, episodes=50, max_episode_steps=500, seed=0, perturb=0.0, xml_file=None):
	env_kwargs = {"max_episode_steps": max_episode_steps}
	if xml_file:
		env_kwargs["xml_file"] = xml_file
	env = gym.make("InvertedPendulum-v5", **env_kwargs)
	rewards = []
	for ep in range(episodes):
		obs, _ = env.reset(seed=seed + ep)
		ep_rew = 0.0
		for _ in range(max_episode_steps):
			with torch.no_grad():
				q_vals = q_net(torch.as_tensor(obs, dtype=torch.float32, device=device).unsqueeze(0))
				act_idx = int(torch.argmax(q_vals, dim=1).item())
			if np.random.rand() < perturb:
				act_idx = np.random.randint(len(action_set))
			action = np.array([action_set[act_idx]], dtype=np.float32)
			obs, reward, terminated, truncated, _ = env.step(action)
			ep_rew += reward
			if terminated or truncated:
				break
		rewards.append(ep_rew)
	env.close()
	return rewards


def main():
	parser = argparse.ArgumentParser(description="Train per seed then evaluate under action perturbations.")
	parser.add_argument("--seeds", type=int, default=10, help="Number of seeds (0..seeds-1).")
	parser.add_argument("--episodes", type=int, default=50, help="Episodes per perturbation per seed (eval).")
	parser.add_argument("--perturb_start", type=float, default=0.0)
	parser.add_argument("--perturb_end", type=float, default=0.05)
	parser.add_argument("--perturb_step", type=float, default=0.01)
	parser.add_argument("--max_episode_steps", type=int, default=500)
	parser.add_argument("--device", type=str, default="cpu")
	parser.add_argument("--out_path", type=str, default="res/multi_eval.npy")
	parser.add_argument("--xml_file", type=str, default=None, help="Optional custom MuJoCo XML for evaluation env.")
	parser.add_argument("--train_episodes", type=int, default=2000)
	parser.add_argument("--replay_size", type=int, default=50_000)
	parser.add_argument("--batch_size", type=int, default=64)
	parser.add_argument("--gamma", type=float, default=0.99)
	parser.add_argument("--lr", type=float, default=3e-4)
	parser.add_argument("--tau", type=float, default=0.005)
	parser.add_argument("--epsilon_start", type=float, default=1.0)
	parser.add_argument("--epsilon_end", type=float, default=0.05)
	parser.add_argument("--epsilon_decay_steps", type=int, default=20_000)
	parser.add_argument("--hidden", type=int, default=128)
	parser.add_argument("--num_actions", type=int, default=11)
	parser.add_argument("--lambda_val", type=float, default=500.0)
	parser.add_argument("--save_dir", type=str, default="checkpoints")
	parser.add_argument("--log_interval", type=int, default=10)
	parser.add_argument("--xml_dir", type=str, default=None, help="Directory containing custom XML files for eval. If set, eval over each XML.")
	parser.add_argument("--wandb", action="store_true", help="Enable wandb logging during training.")
	parser.add_argument("--wandb_project", type=str, default="inverted-pendulum", help="wandb project name for training runs.")
	parser.add_argument("--wandb_run_name", type=str, default=None, help="wandb run name for training runs.")
	args = parser.parse_args()

	lambda_tag = int(args.lambda_val)
	perturb_grid = np.arange(args.perturb_start, args.perturb_end + 1e-9, args.perturb_step)

	xml_list = [None]
	if args.xml_dir:
		xml_dir = Path(args.xml_dir)
		xml_list = sorted(xml_dir.glob("*.xml"))
		if not xml_list:
			raise FileNotFoundError(f"No XML files found in {xml_dir}")

	for xml_path in xml_list:
		per_seed_results = []
		per_seed_episode_rewards = []
		ckpt_root = Path(args.save_dir) / f"episodes_{args.train_episodes}_actions_{args.num_actions}"
		for seed in range(args.seeds):
			ckpt_path = ckpt_root / f"dqn_policy_lambda_{lambda_tag}_seed_{seed}.pt"
			if not ckpt_path.exists():
				# Train policy for this seed if checkpoint missing
				train_args = SimpleNamespace(
					episodes=args.train_episodes,
					max_episode_steps=args.max_episode_steps,
					replay_size=args.replay_size,
					batch_size=args.batch_size,
					gamma=args.gamma,
					lr=args.lr,
					tau=args.tau,
					epsilon_start=args.epsilon_start,
					epsilon_end=args.epsilon_end,
					epsilon_decay_steps=args.epsilon_decay_steps,
					hidden=args.hidden,
					num_actions=args.num_actions,
					seed=seed,
					device=args.device,
					log_interval=args.log_interval,
					wandb=args.wandb,
					wandb_project=args.wandb_project,
					wandb_run_name=None,
					eval_episodes=0,
					record_video=False,
					video_dir="",
					save_dir=str(ckpt_root),
					lambda_val=args.lambda_val,
				)
				train_module.train(train_args)
				if not ckpt_path.exists():
					raise FileNotFoundError(f"Checkpoint not found for seed {seed}: {ckpt_path}")

			q_net, action_set = load_policy(ckpt_path, device=args.device)
			seed_scores = []
			seed_episode_rewards = []
			for p in perturb_grid:
				rewards = evaluate(
					q_net,
					action_set,
					device=args.device,
					episodes=args.episodes,
					max_episode_steps=args.max_episode_steps,
					seed=seed * 1000,
					perturb=p,
					xml_file=str(xml_path) if xml_path else None,
				)
				seed_scores.append(np.mean(rewards))
				seed_episode_rewards.append(rewards)
			per_seed_results.append(seed_scores)
			per_seed_episode_rewards.append(seed_episode_rewards)

		per_seed_results = np.array(per_seed_results)
		mean_over_seeds = per_seed_results.mean(axis=0)
		std_over_seeds = per_seed_results.std(axis=0)
		per_seed_episode_rewards = np.array(per_seed_episode_rewards)  # shape (seeds, len(perturb_grid), episodes)

		out_path = Path(args.out_path)
		xml_tag = "" if xml_path is None else f"_{Path(xml_path).stem}"
		base_name = out_path.name
		for suffix in (".npy", ".npz"):
			if base_name.endswith(suffix):
				base_name = base_name[: -len(suffix)]
				break
		out_path = out_path.with_name(f"{base_name}_lambda_{lambda_tag}{xml_tag}.npz")
		out_path.parent.mkdir(parents=True, exist_ok=True)
		# Save mean/std and per-episode rewards
		np.savez(
			out_path,
			mean=mean_over_seeds,
			std=std_over_seeds,
			per_seed_episode_rewards=per_seed_episode_rewards,
			perturb_grid=perturb_grid,
		)

		print("XML:", xml_path if xml_path else "default")
		print("Perturb grid:", perturb_grid)
		print("Mean returns:", mean_over_seeds)
		print("Std returns:", std_over_seeds)
		print(f"Saved results to {out_path}")


if __name__ == "__main__":
	main()
