import argparse
import random
import time
from pathlib import Path

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import wandb


class ReplayBuffer:
	"""Simple FIFO replay buffer."""

	def __init__(self, capacity, obs_dim):
		self.capacity = capacity
		self.obs_buf = np.zeros((capacity, obs_dim), dtype=np.float32)
		self.act_buf = np.zeros((capacity,), dtype=np.int64)
		self.rew_buf = np.zeros((capacity,), dtype=np.float32)
		self.next_obs_buf = np.zeros((capacity, obs_dim), dtype=np.float32)
		self.done_buf = np.zeros((capacity,), dtype=np.float32)
		self.ptr = 0
		self.size = 0

	def push(self, obs, act, rew, next_obs, done):
		self.obs_buf[self.ptr] = obs
		self.act_buf[self.ptr] = act
		self.rew_buf[self.ptr] = rew
		self.next_obs_buf[self.ptr] = next_obs
		self.done_buf[self.ptr] = done
		self.ptr = (self.ptr + 1) % self.capacity
		self.size = min(self.size + 1, self.capacity)

	def sample(self, batch_size):
		idxs = np.random.choice(self.size, batch_size, replace=True)
		return (
			self.obs_buf[idxs],
			self.act_buf[idxs],
			self.rew_buf[idxs],
			self.next_obs_buf[idxs],
			self.done_buf[idxs],
		)


class QNetwork(nn.Module):
	def __init__(self, obs_dim, act_dim, hidden=128):
		super().__init__()
		self.net = nn.Sequential(
			nn.Linear(obs_dim, hidden),
			nn.ReLU(),
			nn.Linear(hidden, hidden),
			nn.ReLU(),
			nn.Linear(hidden, act_dim),
		)

	def forward(self, x):
		return self.net(x)


def select_action(q_net, obs, action_set, epsilon):
	if random.random() < epsilon:
		return random.randrange(len(action_set))
	with torch.no_grad():
		device = next(q_net.parameters()).device
		q_values = q_net(torch.as_tensor(obs, dtype=torch.float32, device=device).unsqueeze(0))
		return int(torch.argmax(q_values, dim=1).item())


def train(args):
	env = gym.make("InvertedPendulum-v5", max_episode_steps=500)
	obs, _ = env.reset(seed=args.seed)
	obs_dim = obs.shape[0]

	# Discretize the continuous action space for simplicity
	action_set = np.linspace(-3.0, 3.0, args.num_actions, dtype=np.float32)
	act_dim = len(action_set)

	q_net = QNetwork(obs_dim, act_dim, hidden=args.hidden).to(args.device)
	target_q = QNetwork(obs_dim, act_dim, hidden=args.hidden).to(args.device)
	target_q.load_state_dict(q_net.state_dict())

	optimizer = optim.Adam(q_net.parameters(), lr=args.lr)
	replay = ReplayBuffer(args.replay_size, obs_dim)

	epsilon = args.epsilon_start
	epsilon_decay = (args.epsilon_start - args.epsilon_end) / max(args.epsilon_decay_steps, 1)
	epsilon_min = args.epsilon_end

	episode_rewards = []
	total_steps = 0
	start_time = time.time()
	last_loss = None
	last_q = None

	run = None
	if args.wandb:
		if wandb is None:
			print("wandb requested but unavailable; continuing without logging.")
		else:
			run = wandb.init(
				project=args.wandb_project,
				name=args.wandb_run_name,
				config=vars(args),
			)

	for ep in range(args.episodes):
		obs, _ = env.reset()
		ep_reward = 0.0
		for _ in range(args.max_episode_steps):
			act_idx = select_action(q_net, obs, action_set, epsilon)
			action = np.array([action_set[act_idx]], dtype=np.float32)
			next_obs, reward, terminated, truncated, _ = env.step(action)
			done = terminated or truncated
			replay.push(obs, act_idx, reward, next_obs, float(done))

			obs = next_obs
			ep_reward += reward
			total_steps += 1

			# Epsilon decay
			if epsilon > epsilon_min:
				epsilon = max(epsilon_min, epsilon - epsilon_decay)

			# Update
			if replay.size >= args.batch_size:
				obs_b, act_b, rew_b, next_obs_b, done_b = replay.sample(args.batch_size)

				obs_b = torch.as_tensor(obs_b, dtype=torch.float32, device=args.device)
				act_b = torch.as_tensor(act_b, dtype=torch.int64, device=args.device)
				rew_b = torch.as_tensor(rew_b, dtype=torch.float32, device=args.device)
				next_obs_b = torch.as_tensor(next_obs_b, dtype=torch.float32, device=args.device)
				done_b = torch.as_tensor(done_b, dtype=torch.float32, device=args.device)

				# Clipped DQN targets
				with torch.no_grad():
					next_q = target_q(next_obs_b).max(dim=1).values
					next_q = torch.minimum(
						next_q, torch.tensor(args.lambda_val, device=args.device, dtype=torch.float32)
					)
					target = rew_b + args.gamma * (1 - done_b) * next_q

				q_pred = q_net(obs_b).gather(1, act_b.unsqueeze(1)).squeeze(1)
				loss = nn.functional.mse_loss(q_pred, target)
				optimizer.zero_grad()
				loss.backward()
				optimizer.step()
				last_loss = loss.item()
				last_q = q_pred.mean().item()

				# Soft update target network
				with torch.no_grad():
					for target_param, param in zip(target_q.parameters(), q_net.parameters()):
						target_param.data.mul_(1 - args.tau)
						target_param.data.add_(args.tau * param.data)

			if done:
				break

		episode_rewards.append(ep_reward)
		if (ep + 1) % args.log_interval == 0:
			elapsed = time.time() - start_time
			avg_last = np.mean(episode_rewards[-args.log_interval:])
			print(
				f"Episode {ep+1}/{args.episodes} | AvgReward(last {args.log_interval}): {avg_last:.2f} "
				f"| Epsilon: {epsilon:.3f} | Time: {elapsed:.1f}s"
			)
			if run:
				run.log(
					{
						"episode": ep + 1,
						"avg_reward_last": avg_last,
						"ep_reward": ep_reward,
						"epsilon": epsilon,
						"total_steps": total_steps,
						"q_loss": last_loss,
						"q_value": last_q,
					}
				)

	env.close()
	if run:
		run.finish()

	# Evaluation: deterministic greedy policy over the same discretized actions
	eval_rewards = []
	for _ in range(args.eval_episodes):
		obs, _ = env.reset()
		ep_rew = 0.0
		for _ in range(args.max_episode_steps):
			with torch.no_grad():
				q_vals = q_net(torch.as_tensor(obs, dtype=torch.float32).unsqueeze(0))
				act_idx = int(torch.argmax(q_vals, dim=1).item())
			action = np.array([action_set[act_idx]], dtype=np.float32)
			obs, reward, terminated, truncated, _ = env.step(action)
			ep_rew += reward
			if terminated or truncated:
				break
		eval_rewards.append(ep_rew)

	# Optional video recording of one greedy rollout
	if args.record_video:
		video_env = gym.make("InvertedPendulum-v5", max_episode_steps=500, render_mode="rgb_array")
		video_env = gym.wrappers.RecordVideo(video_env, video_folder=args.video_dir, name_prefix="dqn_eval")
		obs, _ = video_env.reset()
		vid_rew = 0.0
		for _ in range(args.max_episode_steps):
			with torch.no_grad():
				q_vals = q_net(torch.as_tensor(obs, dtype=torch.float32).unsqueeze(0))
				act_idx = int(torch.argmax(q_vals, dim=1).item())
			action = np.array([action_set[act_idx]], dtype=np.float32)
			obs, reward, terminated, truncated, _ = video_env.step(action)
			vid_rew += reward
			if terminated or truncated:
				break
		video_env.close()
		if run:
			run.log({"video_reward": vid_rew})
		print(f"Recorded video to {args.video_dir}, reward: {vid_rew:.2f}")

	# Save trained policy and action set
	save_path = Path(args.save_dir)
	save_path.mkdir(parents=True, exist_ok=True)
	lambda_tag = int(args.lambda_val)
	ckpt_name = f"dqn_policy_lambda_{lambda_tag}_seed_{args.seed}.pt"
	torch.save(
		{
			"q_state_dict": q_net.state_dict(),
			"action_set": action_set,
			"obs_dim": obs_dim,
			"act_dim": act_dim,
			"args": vars(args),
		},
		save_path / ckpt_name,
	)
	print(f"Saved policy checkpoint to {save_path/ckpt_name}")

	return {
		"episode_rewards": episode_rewards,
		"total_steps": total_steps,
		"elapsed": time.time() - start_time,
		"eval_rewards": eval_rewards,
	}


def main():
	parser = argparse.ArgumentParser(description="Clipped DQN on InvertedPendulum-v5 (discretized actions).")
	parser.add_argument("--episodes", type=int, default=2000)
	parser.add_argument("--max_episode_steps", type=int, default=500)
	parser.add_argument("--replay_size", type=int, default=50_000)
	parser.add_argument("--batch_size", type=int, default=64)
	parser.add_argument("--gamma", type=float, default=0.99)
	parser.add_argument("--lr", type=float, default=3e-4)
	parser.add_argument("--tau", type=float, default=0.005)
	parser.add_argument("--epsilon_start", type=float, default=1.0)
	parser.add_argument("--epsilon_end", type=float, default=0.05)
	parser.add_argument("--epsilon_decay_steps", type=int, default=20_000)
	parser.add_argument("--lambda_val", type=float, default=500.0, help="Value clip for TD targets.")
	parser.add_argument("--hidden", type=int, default=128)
	parser.add_argument("--num_actions", type=int, default=11, help="Number of discrete actions across [-3, 3].")
	parser.add_argument("--seed", type=int, default=0)
	parser.add_argument("--device", type=str, default="cpu")
	parser.add_argument("--log_interval", type=int, default=10)
	parser.add_argument("--wandb", action="store_true", help="Enable wandb logging.")
	parser.add_argument("--wandb_project", type=str, default="inverted-pendulum", help="wandb project name.")
	parser.add_argument("--wandb_run_name", type=str, default=None, help="Optional wandb run name.")
	parser.add_argument("--eval_episodes", type=int, default=10, help="Number of eval episodes after training.")
	parser.add_argument("--record_video", action="store_true", help="Record a greedy rollout video after training.")
	parser.add_argument("--video_dir", type=str, default="videos", help="Directory to save recorded videos.")
	parser.add_argument("--save_dir", type=str, default="checkpoints", help="Directory to save trained policy.")
	args = parser.parse_args()

	random.seed(args.seed)
	np.random.seed(args.seed)
	torch.manual_seed(args.seed)

	stats = train(args)
	print("Training finished.")
	print(
		f"Steps: {stats['total_steps']} | Elapsed: {stats['elapsed']:.1f}s | "
		f"Final avg reward (last 10): {np.mean(stats['episode_rewards'][-10:]):.2f}"
	)
	print(f"Eval avg reward over {args.eval_episodes} eps: {np.mean(stats['eval_rewards']):.2f}")


if __name__ == "__main__":
	main()
