"""
DQN implementation with PyTorch.

Code based on: 
https://github.com/vwxyzjn/cleanrl
"""

import os
from collections import defaultdict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from agents.base import AbstractAgent


def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
    slope = (end_e - start_e) / duration
    return max(slope * t + start_e, end_e)


class QNetwork(nn.Module):
    def __init__(self, state_dim: int, n_actions: int, hidden_width: int = 256):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(np.array(state_dim).prod(), hidden_width),
            nn.ReLU(),
            nn.Linear(hidden_width, hidden_width),
            nn.ReLU(),
            nn.Linear(hidden_width, n_actions),
        )

    def forward(self, x):
        return self.network(x)


class DQN(AbstractAgent):

    def __init__(
        self,
        state_dim: int,
        n_actions: int,
        hidden_width: int = 256,
        learning_rate: float = 1e-3,
        final_learning_rate: float = 1e-4,
        batch_size: int = 128,
        gamma: float = 0.99,
        update_epochs: int = 4,
        target_network_update_freq: int = 10,
        tau: float = 1.0,
        use_anneal_lr: bool = True,
        device: torch.device = torch.device("cpu"),
    ):
        self.state_dim = state_dim
        self.n_actions = n_actions
        self.batch_size = batch_size
        self.update_epochs = update_epochs
        self.learning_rate = learning_rate
        self.final_learning_rate = final_learning_rate
        self.gamma = gamma
        self.tau = tau
        self.use_anneal_lr = use_anneal_lr
        self.device = device        
        self.target_network_update_freq = target_network_update_freq

        self.q_network = QNetwork(state_dim, n_actions, hidden_width).to(device)
        self.target_network = QNetwork(state_dim, n_actions, hidden_width).to(device)
        self.target_network.load_state_dict(self.q_network.state_dict())

        self.optimizer = optim.Adam(self.q_network.parameters(), lr=learning_rate)

    def act(self, state: np.ndarray) -> int:
        state = torch.Tensor(state).to(self.device)
        q_values = self.q_network(state)
        action = torch.argmax(q_values).detach().cpu().numpy().item()
        return action

    def anneal_lr(self, current_step: int, total_steps: int) -> None:
        if self.use_anneal_lr:
            lr = self.learning_rate - (self.learning_rate - self.final_learning_rate) * (current_step / total_steps)
            self.optimizer.param_groups[0]["lr"] = lr

    def update(self, buffer, step) -> dict:
        metrics = defaultdict(float)

        obs, next_obs, actions, rewards, dones = buffer.get_data()

        obs = obs.reshape(-1, self.state_dim)
        next_obs = next_obs.reshape(-1, self.state_dim)
        actions = actions.reshape(-1, 1)
        rewards = rewards.reshape(-1)
        dones = dones.reshape(-1)

        # if global_step % self.train_frequency == 0:
        for _ in range(self.update_epochs):
            with torch.no_grad():
                target_max, _ = self.target_network(next_obs).max(dim=1)
                td_target = rewards + self.gamma * target_max * (1 - dones)
            old_val = self.q_network(obs).gather(1, actions.long()).squeeze()
            loss = F.mse_loss(td_target, old_val)

            metrics["losses/td_loss"] += loss.item()
            metrics["losses/q_values"] += old_val.mean().item()

            # optimize the model
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

        # update target network
        if step % self.target_network_update_freq == 0:
            for target_network_param, q_network_param in zip(
                self.target_network.parameters(), self.q_network.parameters()
            ):
                target_network_param.data.copy_(
                    self.tau * q_network_param.data
                    + (1.0 - self.tau) * target_network_param.data
                )

        for key in metrics.keys():
            metrics[key] /= self.update_epochs
            
        metrics["charts/learning_rate"] = self.optimizer.param_groups[0]["lr"]

        return metrics

    def save(self, path: str) -> None:
        os.makedirs(path, exist_ok=True)
        torch.save(self.q_network.state_dict(), os.path.join(path, "q_network.pth"))
        torch.save(
            self.target_network.state_dict(), os.path.join(path, "target_network.pth")
        )

    def load(self, path: str) -> None:
        self.q_network.load_state_dict(torch.load(os.path.join(path, "q_network.pth")))
        self.target_network.load_state_dict(
            torch.load(os.path.join(path, "target_network.pth"))
        )
