"""TD3 implementation."""

import copy
import logging
import pickle as pkl
from pathlib import Path
from typing import List

import cloudpickle
import hydra
import numpy as np
from omegaconf import DictConfig
import torch
import torch.nn.functional as F

from src.qd.td3_utils import Experience, MLPCritic, ReplayBuffer


logger = logging.getLogger(__name__)


class TD3:
    """TD3 implementation modified for DQD-RL.

    Trains a greedy actor and critics for the objective.

    Adapted from TD3 code by Scott Fujimoto:
    https://github.com/sfujim/TD3/blob/master/TD3.py
    """

    def __init__(
        self,
        config: DictConfig,
        agent_cfg: DictConfig,
        seed: int = None,
    ):
        self.config = config
        self.agent_cfg = agent_cfg
        # Technically, the actions could have different bounds for different
        # dims, but this works most of the time.
        self.max_action = 1.0

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        logger.info(f"TD3 Device: {self.device}")

        #
        # Attributes below need to be saved in save(), and they get replaced by
        # load().
        #

        self.rng = np.random.default_rng(seed)
        self.buffer = ReplayBuffer(
            config.buffer_size,
            config.state_dim,
            config.action_dim,
            seed,
        )

        # Total training iters so far.
        self.total_it = 0

        # actor/critic train to maximize the reward
        self.actor = hydra.utils.instantiate(self.agent_cfg).to(self.device)
        self.actor_target = copy.deepcopy(self.actor)
        self.actor_opt = torch.optim.Adam(
            self.actor.parameters(), lr=self.config.adam_learning_rate
        )
        self.critic = MLPCritic(
            config.state_dim,
            config.action_dim,
        ).to(self.device)
        self.critic_target = copy.deepcopy(self.critic)
        self.critic_opt = torch.optim.Adam(
            self.critic.parameters(), lr=self.config.adam_learning_rate
        )

        logger.info("Created TD3 actor and critics")

    def get_actor(self) -> np.ndarray:
        """Returns the TD3 actor that optimizes for performance"""
        return self.actor.to_numpy()

    def train_td3(self):
        """Trains the actor and critics with td3."""
        logger.info(f"Training critics for {self.config.train_critics_itrs} itrs")
        logger.info(f"Replay buffer size: {len(self.buffer)}")

        for _ in range(self.config.train_critics_itrs):
            self.total_it += 1

            # Sample replay buffer.
            batch = self.buffer.sample_tensors(self.config.batch_size)

            with torch.no_grad():
                # Select action according to policy and add clipped noise.
                # We use numpy's rng since it is harder to reproduce PyTorch
                # randomness.
                noise = (
                    self.rng.standard_normal(size=batch.action.shape, dtype=np.float32)
                    * self.config.smoothing_noise_variance
                )
                noise = (
                    torch.from_numpy(noise)
                    .clamp(
                        -self.config.smoothing_noise_clip,
                        self.config.smoothing_noise_clip,
                    )
                    .to(self.device)
                )

                next_action = (self.actor_target(batch.next_obs) + noise).clamp(
                    -self.max_action, self.max_action
                )

                # Compute the target Q value.
                target_q1, target_q2 = self.critic_target(batch.next_obs, next_action)
                target_q = torch.min(target_q1, target_q2)

                target_q = (
                    batch.reward[:, None]
                    + (1.0 - batch.done[:, None]) * self.config.discount * target_q
                )

            # Get current Q estimates.
            current_q1, current_q2 = self.critic(batch.obs, batch.action)

            # Compute critic loss.
            critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(
                current_q2, target_q
            )

            # Optimize the critic.
            self.critic_opt.zero_grad()
            critic_loss.backward()
            self.critic_opt.step()

            # Delayed policy updates.
            if self.total_it % self.config.target_update_freq == 0:

                # Compute actor losses.
                actor_loss = -self.critic.q1(batch.obs, self.actor(batch.obs)).mean()

                # Optimize the actor.
                self.actor_opt.zero_grad()
                actor_loss.backward()
                self.actor_opt.step()

                # Update the frozen target models.
                tau = self.config.target_update_rate
                for param, target_param in zip(
                    self.critic.parameters(), self.critic_target.parameters()
                ):
                    target_param.data.copy_(
                        tau * param.data + (1 - tau) * target_param.data
                    )

                for param, target_param in zip(
                    self.actor.parameters(), self.actor_target.parameters()
                ):
                    target_param.data.copy_(
                        tau * param.data + (1 - tau) * target_param.data
                    )

        logger.info(f"Finished training critics - total itrs: {self.total_it}")

    def gradient_ascent(self, sol: np.ndarray) -> np.ndarray:
        """Performs config.gradient_steps steps of gradient ascent on sol.

        The critic used is the first critic (usually the objective critic).

        Adapted from PGA-ME:
        https://github.com/ollenilsson19/PGA-MAP-Elites/blob/master/variational_operators.py#L24
        """
        actor = hydra.utils.instantiate(self.agent_cfg).from_numpy(sol).to(self.device)
        actor_opt = torch.optim.Adam(
            actor.parameters(), lr=self.config.gradient_learning_rate
        )

        for i in range(self.config.gradient_steps):
            batch = self.buffer.sample_tensors(self.config.pg_batch_size)
            obs = batch.obs
            actor_loss = -self.critic.q1(obs, actor(obs)).mean()
            actor_opt.zero_grad()
            actor_loss.backward()
            actor_opt.step()

        return actor.to_numpy()

    def save(self, pickle_path: Path, pytorch_path: Path):
        """Saves data to a pickle file and a PyTorch file.

        The PyTorch file holds the actor and critics, and the pickle file holds
        all the other attributes.

        See here for more info:
        https://pytorch.org/tutorials/beginner/saving_loading_models.html#save
        """
        logger.info("Saving TD3 pickle data")
        with pickle_path.open("wb") as file:
            cloudpickle.dump(
                {
                    "rng": self.rng,
                    "buffer": self.buffer,
                    "total_it": self.total_it,
                },
                file,
            )

        logger.info("Saving TD3 PyTorch data")
        torch.save(
            {
                "actor": (
                    self.actor.state_dict(),
                    self.actor_target.state_dict(),
                    self.actor_opt.state_dict(),
                ),
                "critic": (
                    self.critic.state_dict(),
                    self.critic_target.state_dict(),
                    self.critic_opt.state_dict(),
                ),
            },
            pytorch_path,
        )

    def load(self, pickle_path: Path, pytorch_path: Path):
        """Loads data from files saved by save()."""
        with open(pickle_path, "rb") as file:
            pickle_data = pkl.load(file)
            self.rng = pickle_data["rng"]
            self.buffer = pickle_data["buffer"]
            self.total_it = pickle_data["total_it"]

        pytorch_data = torch.load(pytorch_path)

        self.actor.load_state_dict(pytorch_data["actor"][0])
        self.actor_target.load_state_dict(pytorch_data["actor"][1])
        self.actor_opt.load_state_dict(pytorch_data["actor"][2])

        self.critic.load_state_dict(pytorch_data["critic"][0])
        self.critic_target.load_state_dict(pytorch_data["critic"][1])
        self.critic_opt.load_state_dict(pytorch_data["critic"][2])

        return self
