import argparse
import collections
import random
import time

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import wandb
from pathlib import Path


class ReplayBuffer:
	"""Simple FIFO replay buffer."""

	def __init__(self, capacity, obs_dim):
		self.capacity = capacity
		self.obs_buf = np.zeros((capacity, obs_dim), dtype=np.float32)
		self.act_buf = np.zeros((capacity,), dtype=np.int64)
		self.rew_buf = np.zeros((capacity,), dtype=np.float32)
		self.next_obs_buf = np.zeros((capacity, obs_dim), dtype=np.float32)
		self.done_buf = np.zeros((capacity,), dtype=np.float32)
		self.ptr = 0
		self.size = 0

	def push(self, obs, act, rew, next_obs, done):
		self.obs_buf[self.ptr] = obs
		self.act_buf[self.ptr] = act
		self.rew_buf[self.ptr] = rew
		self.next_obs_buf[self.ptr] = next_obs
		self.done_buf[self.ptr] = done
		self.ptr = (self.ptr + 1) % self.capacity
		self.size = min(self.size + 1, self.capacity)

	def sample(self, batch_size):
		idxs = np.random.choice(self.size, batch_size, replace=True)
		return (
			self.obs_buf[idxs],
			self.act_buf[idxs],
			self.rew_buf[idxs],
			self.next_obs_buf[idxs],
			self.done_buf[idxs],
		)


class QNetwork(nn.Module):
	def __init__(self, obs_dim, act_dim, hidden=128):
		super().__init__()
		self.net = nn.Sequential(
			nn.Linear(obs_dim, hidden),
			nn.ReLU(),
			nn.Linear(hidden, hidden),
			nn.ReLU(),
			nn.Linear(hidden, act_dim),
		)

	def forward(self, x):
		return self.net(x)


def select_action(q_net, obs, action_set, epsilon):
	if random.random() < epsilon:
		return random.randrange(len(action_set))
	with torch.no_grad():
		device = next(q_net.parameters()).device
		q_values = q_net(torch.as_tensor(obs, dtype=torch.float32, device=device).unsqueeze(0))
		return int(torch.argmax(q_values, dim=1).item())


def train(args):
	env = gym.make("InvertedPendulum-v5", max_episode_steps=500)
	obs, _ = env.reset(seed=args.seed)
	obs_dim = obs.shape[0]

	# Discretize the continuous action space for simplicity
	action_set = np.linspace(-3.0, 3.0, args.num_actions, dtype=np.float32)
	act_dim = len(action_set)

	q1 = QNetwork(obs_dim, act_dim, hidden=args.hidden).to(args.device)
	q2 = QNetwork(obs_dim, act_dim, hidden=args.hidden).to(args.device)
	target_q1 = QNetwork(obs_dim, act_dim, hidden=args.hidden).to(args.device)
	target_q2 = QNetwork(obs_dim, act_dim, hidden=args.hidden).to(args.device)
	target_q1.load_state_dict(q1.state_dict())
	target_q2.load_state_dict(q2.state_dict())

	optimizer1 = optim.Adam(q1.parameters(), lr=args.lr)
	optimizer2 = optim.Adam(q2.parameters(), lr=args.lr)
	replay = ReplayBuffer(args.replay_size, obs_dim)

	epsilon = args.epsilon_start
	epsilon_decay = (args.epsilon_start - args.epsilon_end) / max(args.epsilon_decay_steps, 1)
	epsilon_min = args.epsilon_end

	episode_rewards = []
	total_steps = 0
	start_time = time.time()
	last_loss1 = None
	last_loss2 = None
	last_q1 = None
	last_q2 = None

	run = None
	if args.wandb:
		if wandb is None:
			print("wandb requested but unavailable; continuing without logging.")
		else:
			run = wandb.init(
			project=args.wandb_project,
			name=args.wandb_run_name,
			config=vars(args),
			)

	for ep in range(args.episodes):
		obs, _ = env.reset()
		ep_reward = 0.0
		for _ in range(args.max_episode_steps):
			act_idx = select_action(q1, obs, action_set, epsilon)
			action = np.array([action_set[act_idx]], dtype=np.float32)
			next_obs, reward, terminated, truncated, _ = env.step(action)
			done = terminated or truncated
			replay.push(obs, act_idx, reward, next_obs, float(done))

			obs = next_obs
			ep_reward += reward
			total_steps += 1

			# Epsilon decay
			if epsilon > epsilon_min:
				epsilon = max(epsilon_min, epsilon - epsilon_decay)

			# Update
			if replay.size >= args.batch_size:
				obs_b, act_b, rew_b, next_obs_b, done_b = replay.sample(args.batch_size)

				obs_b = torch.as_tensor(obs_b, dtype=torch.float32, device=args.device)
				act_b = torch.as_tensor(act_b, dtype=torch.int64, device=args.device)
				rew_b = torch.as_tensor(rew_b, dtype=torch.float32, device=args.device)
				next_obs_b = torch.as_tensor(next_obs_b, dtype=torch.float32, device=args.device)
				done_b = torch.as_tensor(done_b, dtype=torch.float32, device=args.device)

				# Double Q targets
				with torch.no_grad():
					next_q1 = q1(next_obs_b)
					next_actions = torch.argmax(next_q1, dim=1)
					next_q2 = target_q2(next_obs_b)
					next_q = next_q2.gather(1, next_actions.unsqueeze(1)).squeeze(1)
					next_q = torch.minimum(next_q, torch.tensor(args.lambda_val, device=args.device, dtype=torch.float32))
					target = rew_b + args.gamma * (1 - done_b) * next_q

				# Q1 update
				q1_pred = q1(obs_b).gather(1, act_b.unsqueeze(1)).squeeze(1)
				loss1 = nn.functional.mse_loss(q1_pred, target)
				optimizer1.zero_grad()
				loss1.backward()
				optimizer1.step()
				last_loss1 = loss1.item()
				last_q1 = q1_pred.mean().item()

				# Swap roles for Q2 (symmetry)
				with torch.no_grad():
					next_q2_main = q2(next_obs_b)
					next_actions2 = torch.argmax(next_q2_main, dim=1)
					next_q1_tgt = target_q1(next_obs_b)
					next_q_val2 = next_q1_tgt.gather(1, next_actions2.unsqueeze(1)).squeeze(1)
					next_q_val2 = torch.minimum(next_q_val2, torch.tensor(args.lambda_val, device=args.device, dtype=torch.float32))
					target2 = rew_b + args.gamma * (1 - done_b) * next_q_val2

				q2_pred = q2(obs_b).gather(1, act_b.unsqueeze(1)).squeeze(1)
				loss2 = nn.functional.mse_loss(q2_pred, target2)
				optimizer2.zero_grad()
				loss2.backward()
				optimizer2.step()
				last_loss2 = loss2.item()
				last_q2 = q2_pred.mean().item()

				# Soft update targets
				with torch.no_grad():
					for target_param, param in zip(target_q1.parameters(), q1.parameters()):
						target_param.data.mul_(1 - args.tau)
						target_param.data.add_(args.tau * param.data)
					for target_param, param in zip(target_q2.parameters(), q2.parameters()):
						target_param.data.mul_(1 - args.tau)
						target_param.data.add_(args.tau * param.data)

			if done:
				break

		episode_rewards.append(ep_reward)
		if (ep + 1) % args.log_interval == 0:
			elapsed = time.time() - start_time
			avg_last = np.mean(episode_rewards[-args.log_interval:])
			print(f"Episode {ep+1}/{args.episodes} | AvgReward(last {args.log_interval}): {avg_last:.2f} | Epsilon: {epsilon:.3f} | Time: {elapsed:.1f}s")
			if run:
				run.log(
					{
						"episode": ep + 1,
						"avg_reward_last": avg_last,
						"ep_reward": ep_reward,
						"epsilon": epsilon,
						"total_steps": total_steps,
						"q1_loss": last_loss1,
						"q2_loss": last_loss2,
						"q1_value": last_q1,
						"q2_value": last_q2,
					}
				)

	env.close()
	if run:
		run.finish()
	# Evaluation: deterministic greedy policy over the same discretized actions
	eval_rewards = []
	for _ in range(args.eval_episodes):
		obs, _ = env.reset()
		ep_rew = 0.0
		for _ in range(args.max_episode_steps):
			with torch.no_grad():
				q_vals = q1(torch.as_tensor(obs, dtype=torch.float32).unsqueeze(0))
				act_idx = int(torch.argmax(q_vals, dim=1).item())
			action = np.array([action_set[act_idx]], dtype=np.float32)
			obs, reward, terminated, truncated, _ = env.step(action)
			ep_rew += reward
			if terminated or truncated:
				break
		eval_rewards.append(ep_rew)

	# Optional video recording of one greedy rollout
	if args.record_video:
		video_env = gym.make("InvertedPendulum-v5", max_episode_steps=500, render_mode="rgb_array")
		video_env = gym.wrappers.RecordVideo(video_env, video_folder=args.video_dir, name_prefix="doubleq_eval")
		obs, _ = video_env.reset()
		vid_rew = 0.0
		for _ in range(args.max_episode_steps):
			with torch.no_grad():
				q_vals = q1(torch.as_tensor(obs, dtype=torch.float32).unsqueeze(0))
				act_idx = int(torch.argmax(q_vals, dim=1).item())
			action = np.array([action_set[act_idx]], dtype=np.float32)
			obs, reward, terminated, truncated, _ = video_env.step(action)
			vid_rew += reward
			if terminated or truncated:
				break
		video_env.close()
		if run:
			run.log({"video_reward": vid_rew})
		print(f"Recorded video to {args.video_dir}, reward: {vid_rew:.2f}")

	# Save trained policy and action set
	save_path = Path(args.save_dir)
	save_path.mkdir(parents=True, exist_ok=True)
	lambda_tag = int(args.lambda_val)
	ckpt_name = f"double_q_policy_lambda_{lambda_tag}_seed_{args.seed}.pt"
	torch.save(
		{
			"q1_state_dict": q1.state_dict(),
			"q2_state_dict": q2.state_dict(),
			"action_set": action_set,
			"obs_dim": obs_dim,
			"act_dim": act_dim,
			"args": vars(args),
		},
		save_path / ckpt_name,
	)
	print(f"Saved policy checkpoint to {save_path/ckpt_name}")

	return {
		"episode_rewards": episode_rewards,
		"total_steps": total_steps,
		"elapsed": time.time() - start_time,
		"eval_rewards": eval_rewards,
	}


def main():
	parser = argparse.ArgumentParser(description="Double Q-learning on InvertedPendulum-v5 (discretized actions).")
	parser.add_argument("--episodes", type=int, default=2000)
	parser.add_argument("--max_episode_steps", type=int, default=500)
	parser.add_argument("--replay_size", type=int, default=50_000)
	parser.add_argument("--batch_size", type=int, default=64)
	parser.add_argument("--gamma", type=float, default=0.99)
	parser.add_argument("--lr", type=float, default=3e-4)
	parser.add_argument("--tau", type=float, default=0.005)
	parser.add_argument("--epsilon_start", type=float, default=1.0)
	parser.add_argument("--epsilon_end", type=float, default=0.05)
	parser.add_argument("--epsilon_decay_steps", type=int, default=20_000)
	parser.add_argument("--lambda_val", type=float, default=500.0, help="Value clip for TD targets.")
	parser.add_argument("--hidden", type=int, default=128)
	parser.add_argument("--num_actions", type=int, default=11, help="Number of discrete actions across [-3, 3].")
	parser.add_argument("--seed", type=int, default=0)
	parser.add_argument("--device", type=str, default="cpu")
	parser.add_argument("--log_interval", type=int, default=10)
	parser.add_argument("--wandb", action="store_true", help="Enable wandb logging.")
	parser.add_argument("--wandb_project", type=str, default="inverted-pendulum", help="wandb project name.")
	parser.add_argument("--wandb_run_name", type=str, default=None, help="Optional wandb run name.")
	parser.add_argument("--eval_episodes", type=int, default=10, help="Number of eval episodes after training.")
	parser.add_argument("--record_video", action="store_true", help="Record a greedy rollout video after training.")
	parser.add_argument("--video_dir", type=str, default="videos", help="Directory to save recorded videos.")
	parser.add_argument("--save_dir", type=str, default="checkpoints", help="Directory to save trained policy.")
	args = parser.parse_args()

	random.seed(args.seed)
	np.random.seed(args.seed)
	torch.manual_seed(args.seed)

	stats = train(args)
	print("Training finished.")
	print(f"Steps: {stats['total_steps']} | Elapsed: {stats['elapsed']:.1f}s | Final avg reward (last 10): {np.mean(stats['episode_rewards'][-10:]):.2f}")
	print(f"Eval avg reward over {args.eval_episodes} eps: {np.mean(stats['eval_rewards']):.2f}")


if __name__ == "__main__":
	main()
