"""
DQN network and agent implementations.
"""

from __future__ import annotations

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from .constants import EPSILON_ACTIONS
from .core import _default_device
from .replay import ReplayBuffer


class DQNNet(nn.Module):
    """Simple MLP for Q(s,a)."""

    def __init__(self, state_dim, num_actions, hidden_sizes=(256, 128)):
        super().__init__()
        h1, h2 = hidden_sizes
        self.fc1 = nn.Linear(state_dim, h1)
        self.fc2 = nn.Linear(h1, h2)
        self.out = nn.Linear(h2, num_actions)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.out(x)


class DQNEpsilonAgent:
    """
    PyTorch DQN agent selecting among discrete actions mapped to EPSILON_ACTIONS.
    """

    def __init__(
        self,
        state_dim,
        epsilon_actions=EPSILON_ACTIONS,
        lr=3e-4,
        gamma=1.0,
        buffer_capacity=50000,
        batch_size_default=64,
        target_update_interval=1000,
        tau=None,
        grad_clip=10.0,
        hidden_sizes=(256, 128),
        device=None,
        seed=0,
        actor_update_interval=100,
        double_dqn=True,
    ):
        self.state_dim = int(state_dim)
        self.epsilon_actions = np.asarray(epsilon_actions, dtype=np.float32)
        self.num_actions = len(self.epsilon_actions)

        self.lr = float(lr)
        self.gamma = float(gamma)
        self.batch_size_default = int(batch_size_default)

        self.target_update_interval = int(target_update_interval)
        self.tau = tau
        self.grad_clip = grad_clip
        self.hidden_sizes = tuple(hidden_sizes)
        self.double_dqn = bool(double_dqn)

        self.device = device if device is not None else _default_device()

        self.rng = np.random.default_rng(seed)

        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)

        self.policy_net = DQNNet(self.state_dim, self.num_actions, hidden_sizes=hidden_sizes).to(self.device)
        self.target_net = DQNNet(self.state_dim, self.num_actions, hidden_sizes=hidden_sizes).to(self.device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()

        self.actor_update_interval = int(actor_update_interval)
        if self.device.type == "cpu":
            self.actor_net = self.policy_net
        else:
            self.actor_net = DQNNet(self.state_dim, self.num_actions, hidden_sizes=hidden_sizes).to("cpu")
            self.actor_net.eval()
            self.sync_actor()

        self.optimizer = optim.AdamW(
            self.policy_net.parameters(),
            lr=self.lr,
            weight_decay=0.0,
            amsgrad=True,
        )

        self.criterion = nn.SmoothL1Loss()
        self.buffer = ReplayBuffer(buffer_capacity, state_dim, device=self.device)
        self.train_steps = 0

    def sync_actor(self):
        """Copy policy_net weights (GPU) -> actor_net (CPU)."""
        if self.actor_net is self.policy_net:
            return
        sd = {k: v.detach().cpu() for k, v in self.policy_net.state_dict().items()}
        self.actor_net.load_state_dict(sd)
        self.actor_net.eval()

    def q_values(self, s_np):
        """For debugging: returns Q(s,·) as a numpy array (num_actions,)."""
        with torch.no_grad():
            s = torch.tensor(s_np, dtype=torch.float32, device=self.device).unsqueeze(0)
            q = self.policy_net(s).squeeze(0)
        return q.detach().cpu().numpy()

    def select_action(self, s, explore_eps):
        """
        ε-greedy over actions. If explore_eps<=0 no RNG is touched.
        """
        eps = float(explore_eps)

        if eps <= 0.0:
            s_np = np.asarray(s, dtype=np.float32)
            with torch.inference_mode():
                s_t = torch.from_numpy(s_np).unsqueeze(0)
                q = self.actor_net(s_t)
                return int(q.argmax(dim=1).item())

        if eps >= 1.0:
            return int(self.rng.integers(0, self.num_actions))

        if self.rng.random() < eps:
            return int(self.rng.integers(0, self.num_actions))

        s_np = np.asarray(s, dtype=np.float32)
        with torch.inference_mode():
            s_t = torch.from_numpy(s_np).unsqueeze(0)
            q = self.actor_net(s_t)
            return int(q.argmax(dim=1).item())

    def select_action_batch(self, S, explore_eps):
        """
        Batch ε-greedy over actions.
        """
        S = np.asarray(S, dtype=np.float32)
        B = int(S.shape[0])

        if np.isscalar(explore_eps):
            eps = float(explore_eps)
            if eps <= 0.0:
                with torch.inference_mode():
                    s_t = torch.from_numpy(S)
                    q = self.actor_net(s_t)
                    return q.argmax(dim=1).cpu().numpy().astype(np.int64)
            if eps >= 1.0:
                return self.rng.integers(0, self.num_actions, size=B).astype(np.int64)
            explore_mask = (self.rng.random(B) < eps)
        else:
            eps = np.asarray(explore_eps, dtype=np.float32).reshape(-1)
            if eps.shape[0] != B:
                raise ValueError("explore_eps must be scalar or shape (B,)")
            if np.all(eps <= 0.0):
                with torch.inference_mode():
                    s_t = torch.from_numpy(S)
                    q = self.actor_net(s_t)
                    return q.argmax(dim=1).cpu().numpy().astype(np.int64)
            if np.all(eps >= 1.0):
                return self.rng.integers(0, self.num_actions, size=B).astype(np.int64)
            explore_mask = (self.rng.random(B) < eps)

        A = np.empty(B, dtype=np.int64)

        n_rand = int(explore_mask.sum())
        if n_rand > 0:
            A[explore_mask] = self.rng.integers(0, self.num_actions, size=n_rand)

        greedy_mask = ~explore_mask
        if greedy_mask.any():
            with torch.inference_mode():
                s_t = torch.from_numpy(S[greedy_mask])
                q = self.actor_net(s_t)
                A[greedy_mask] = q.argmax(dim=1).cpu().numpy()

        return A

    def _update_target(self):
        if self.tau is not None:
            tau = float(self.tau)
            with torch.no_grad():
                for p, p_t in zip(self.policy_net.parameters(), self.target_net.parameters()):
                    p_t.data.mul_(1.0 - tau)
                    p_t.data.add_(tau * p.data)
        else:
            if self.train_steps % self.target_update_interval == 0:
                self.target_net.load_state_dict(self.policy_net.state_dict())

    def train_step(self, batch_size=None):
        """
        One DQN optimization step. Returns loss tensor or None if not enough data.
        """
        if batch_size is None:
            batch_size = self.batch_size_default
        batch_size = int(batch_size)

        if self.buffer.size < batch_size:
            return None

        S, A, R, S_next, D = self.buffer.sample(batch_size)

        q_all = self.policy_net(S)
        q_sa = q_all.gather(1, A.unsqueeze(1)).squeeze(1)

        with torch.no_grad():
            not_done = (~D).float()
            if self.double_dqn:
                a_next = self.policy_net(S_next).argmax(dim=1, keepdim=True)
                q_next = self.target_net(S_next).gather(1, a_next).squeeze(1)
            else:
                q_next = self.target_net(S_next).max(dim=1).values
            target = R + self.gamma * q_next * not_done

        loss = self.criterion(q_sa, target)

        self.optimizer.zero_grad(set_to_none=True)
        loss.backward()
        if self.grad_clip is not None:
            torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), float(self.grad_clip))
        self.optimizer.step()

        self.train_steps += 1
        self._update_target()

        if (self.actor_net is not self.policy_net) and (self.train_steps % self.actor_update_interval == 0):
            self.sync_actor()

        return loss.detach()


__all__ = ["DQNNet", "DQNEpsilonAgent"]
