import math
import time
from math import sqrt
from typing import Any, Callable

import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from torch._prims_common import DeviceLikeType
from torch.nn.functional import log_softmax, one_hot, softmax
from torch.utils.data import DataLoader

from args import AdversarialTrainingConfig, NPGConfig, PPOConfig, QLearningConfig
from mdp.mdp_dataset import (
    MDPDataset,
    MDPDatasetImages,
    MDPDatasetImagesTorch,
    MDPDatasetTorch,
    process_miniworld_images,
)
from mdp.mdp_env import MDPController
from mdp.nn_utils import BatchedConv2D, BatchedLinear, Reshape, SwapDims
from mdp.utils import compute_fisher_matrix, gae, get_discounted_rewards, normalize
from net import ImageTransformer, Transformer, ValueAndPolicyNetwork


class MDPTransformerController(MDPController):
    model: Transformer
    sample: bool
    _dataset: MDPDataset
    _optimizer: torch.optim.Optimizer | None
    _frozen: bool

    def __init__(
        self,
        model: Transformer,
        n_envs: int,
        n_steps: int,
        n_states: int,
        state_dim: int,
        n_actions: int,
        sample: bool = False,
        *,
        frozen: bool = False,
        lr: float | None = None,
        device: DeviceLikeType | None = None,
    ):
        super().__init__(n_envs, n_steps, state_dim, n_states, n_actions, device)

        self.model = model
        self.sample = sample
        self._frozen = frozen

        self.init_dataset(n_envs, n_steps, n_states, state_dim, n_actions, device)

        self._optimizer = None
        if lr is not None:
            self._optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr, weight_decay=1e-4)

    def init_dataset(self, n_envs: int, n_steps: int, n_states: int, state_dim: int, n_actions: int, device: DeviceLikeType | None = None) -> None:
        self._dataset = MDPDataset(n_envs, n_steps, n_states, state_dim, n_actions, device)

    def clear_dataset(self) -> None:
        self._dataset.clear()

    def append(self, states: Tensor, actions: Tensor, rewards: Tensor, states_next: Tensor, rewards_original: Tensor, extras: dict[str, Any] = {}) -> None:
        self._dataset.append(states, actions, rewards, states_next, rewards_original)

    def sample_actions(self, states: Tensor) -> Tensor:
        context = self._dataset.get_context_for_transformer()
        if states.shape[1] == 1:
            states = one_hot(states.squeeze(-1).long(), self.model.state_dim)
        query_line = self.model.make_query_line(states, self.n_envs)

        outputs = self.model.predict_actions(context, query_line)
        outputs = outputs.detach()

        if self.sample:
            probs = torch.softmax(outputs, dim=-1)
            action_indices = torch.multinomial(probs, 1)[..., 0]
        else:
            action_indices = torch.argmax(outputs, dim=-1)

        return one_hot(action_indices, self.n_actions).to(torch.float32)

    def update(self, dataset: MDPDatasetTorch | MDPDatasetImagesTorch, adv_train_config: AdversarialTrainingConfig) -> tuple[list[dict[str, Any]], dict[str, Any]]:
        if self._frozen:
            return [], {}

        assert self._optimizer is not None

        metrics = []
        self.model.test = False

        loss_fn = torch.nn.CrossEntropyLoss(reduction="sum")
        train_loader = DataLoader(dataset, batch_size=64, shuffle=True)

        train_loss = []
        for _ in range(adv_train_config.victim_iters):
            start_time = time.time()
            epoch_train_loss = 0

            for batch, true_actions in train_loader:
                pred_actions = self.model(batch)
                del batch

                true_actions = true_actions.reshape(-1, self.n_actions)
                pred_actions = pred_actions.reshape(-1, self.n_actions)

                self._optimizer.zero_grad()
                loss: Tensor = loss_fn(pred_actions, true_actions)
                loss.backward()
                self._optimizer.step()
                epoch_train_loss += loss.item() / self.n_steps

            train_loss.append(epoch_train_loss)
            end_time = time.time()
            metrics.append(
                {
                    "train/victim_loss": train_loss[-1],
                    "train/victim_time": end_time - start_time,
                }
            )

        self.model.test = True

        return metrics, {}


class MDPImageTransformerController(MDPTransformerController):
    model: ImageTransformer
    _dataset: MDPDatasetImages

    def init_dataset(self, n_envs: int, n_steps: int, n_states: int, state_dim: int, n_actions: int, device: DeviceLikeType | None = None) -> None:
        self._dataset = MDPDatasetImages(n_envs, n_steps, n_states, state_dim, n_actions, device)

    def append(
        self, states: Tensor, images: Any, actions: Tensor, rewards: Tensor, states_next: Tensor, images_next: Any, rewards_original: Tensor, extras: dict[str, Any] = {}
    ) -> None:
        self._dataset.append(states, images, actions, rewards, states_next, images_next, rewards_original, extras=extras)

    def sample_actions(self, states: Tensor, images: Any) -> Tensor:
        query_images = process_miniworld_images(images).to(device=self.device)

        images, context = self._dataset.get_context_for_transformer()
        query_line = self.model.make_query_line(states, self.n_envs)

        outputs = self.model.predict_actions((images, context), (query_images, query_line))
        outputs = outputs.detach()

        if self.sample:
            probs = torch.softmax(outputs, dim=-1)
            action_indices = torch.multinomial(probs, 1)[..., 0]
        else:
            action_indices = torch.argmax(outputs, dim=-1)

        return one_hot(action_indices, self.n_actions).to(torch.float32)


class MDPOptimalController(MDPController):
    optimal_actions: Tensor

    def __init__(self, optimal_actions: Tensor, n_envs: int, n_steps: int, n_states: int, state_dim: int, n_actions: int, device: DeviceLikeType | None = None):
        super().__init__(n_envs, n_steps, state_dim, n_states, n_actions, device)

        # assert optimal_actions.shape == (n_envs, n_states, n_actions)
        self.optimal_actions = optimal_actions

    def sample_actions(self, states: Tensor) -> Tensor:
        assert states.ndim == 2

        if states.shape[1] == 1:
            return self.optimal_actions[torch.arange(self.n_envs), states.squeeze(1)].float()
        elif states.shape[1] == 2:
            return self.optimal_actions[torch.arange(self.n_envs), states[:, 0], states[:, 1]]
        else:
            raise NotImplementedError()


class MDPRandomController(MDPController):
    def sample_actions(self, states: Tensor) -> Tensor:
        # assert states.ndim == 1
        return one_hot(torch.randint(0, self.n_actions, (states.shape[0],), device=self.device), self.n_actions).float()


class PPOController(MDPController):
    config: PPOConfig
    model: ValueAndPolicyNetwork
    _optimizer: torch.optim.Optimizer

    def __init__(self, config: PPOConfig, n_envs: int, n_steps: int, n_states: int, state_dim: int, n_actions: int, device: DeviceLikeType | None = None):
        super().__init__(n_envs, n_steps, state_dim, n_states, n_actions, device)

        self.config = config
        self.reinitialize()

    def reinitialize(self) -> None:
        self.model = ValueAndPolicyNetwork(self.config, self.n_envs, self.state_dim, self.n_actions).to(self.device)
        self._optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.config.lr, weight_decay=1e-4)

    def sample_actions(self, states: Tensor) -> Tensor:
        pi, _ = self.model(states[:, None, :].float())
        a = pi.sample().squeeze(1)
        return one_hot(a, self.model.action_dim).float()

    def update(self, datasets: list[MDPDatasetTorch], adv_training_config: AdversarialTrainingConfig) -> tuple[list[dict[str, float]], dict[str, float]]:
        clip = self.config.clip
        clipvf = self.config.clipvf
        value_loss_coef = self.config.value_loss_coef
        entropy_bonus_coef = self.config.entropy_bonus_coef

        n_minibatches = 1

        if isinstance(datasets, MDPDatasetTorch):
            datasets = [datasets]

        all_states = []
        all_actions = []
        all_values = []
        all_advantages = []
        all_logprobs = []

        for dataset in datasets:
            all_states.append(dataset.states)
            all_actions.append(dataset.actions)
            with torch.no_grad():
                pi, values = self.model(dataset.states.float())
                a = pi.sample()
                logprobs: Tensor = pi.log_prob(a)
                advantages = gae(dataset, values, self.config.gamma, self.config.lam, self.n_steps)

            all_values.append(values)
            all_advantages.append(advantages)
            all_logprobs.append(logprobs)

        all_states = torch.concat(all_states, dim=1)
        all_actions = torch.concat(all_actions, dim=1)
        all_values = torch.concat(all_values, dim=1)
        all_advantages = torch.concat(all_advantages, dim=1)
        all_logprobs = torch.concat(all_logprobs, dim=1)

        mb_size = math.ceil((datasets[0].n_steps * len(datasets)) / n_minibatches)

        for i_mb in range(n_minibatches):
            mb_states = all_states[:, i_mb * mb_size : i_mb * mb_size + mb_size]
            mb_actions = all_actions[:, i_mb * mb_size : i_mb * mb_size + mb_size]
            mb_values_stored = all_values[:, i_mb * mb_size : i_mb * mb_size + mb_size]
            mb_advantages = all_advantages[:, i_mb * mb_size : i_mb * mb_size + mb_size]
            mb_logprobs_stored = all_logprobs[:, i_mb * mb_size : i_mb * mb_size + mb_size]

            pi, mb_values = self.model(mb_states.float())
            mb_logprobs: Tensor = pi.log_prob(mb_actions.argmax(-1))

            # advantages = gae(dataset, values_current, self.config.gamma, self.config.lam, self.n_steps)
            mb_advantages_normalized = normalize(mb_advantages)

            sampled_return = (mb_values_stored + mb_advantages).detach()

            clipped_value = mb_values_stored + ((mb_values - mb_values_stored).clamp(min=-clipvf, max=clipvf) if clipvf is not None else (mb_values - mb_values_stored))
            # clipped_value = torch.clamp(values_current, mb_values_stored - clipvf, mb_values_stored + clipvf)  # same thing
            value_loss_batch = torch.max((mb_values - sampled_return) ** 2, (clipped_value - sampled_return) ** 2)
            value_loss = 0.5 * value_loss_batch.mean()

            ratio = torch.exp(mb_logprobs - mb_logprobs_stored)
            clipped_ratio = ratio.clamp(min=1.0 - clip, max=1.0 + clip)
            policy_reward = torch.min(ratio * mb_advantages_normalized, clipped_ratio * mb_advantages_normalized)
            clip_fraction = (abs((ratio - 1.0)) > clip).float().mean()

            policy_loss = -policy_reward.mean()
            entropy_bonus: Tensor = pi.entropy().mean()

            loss = policy_loss + value_loss_coef * value_loss - entropy_bonus_coef * entropy_bonus

            approx_kl_divergence = 0.5 * ((mb_logprobs_stored - mb_logprobs) ** 2).mean()

            self._optimizer.zero_grad()
            loss.backward()
            # overclip = torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.5)
            self._optimizer.step()

        metrics = [
            {
                "train/loss": loss.item(),
                "train/policy_reward": -policy_loss.item(),
                "train/value_loss": value_loss.item(),
                # "train/clip_extra_norm": overclip.item(),
                "train/entropy_bonus": entropy_bonus.item(),
                "train/kl_divergence": approx_kl_divergence.item(),
                "train/clip_fraction": clip_fraction.item(),
            }
        ]
        aux = {}
        return metrics, aux


class MDPNPGController(MDPController):
    sample: bool
    config: NPGConfig

    _decode_state: Tensor
    _policy: Tensor
    _value_totals: Tensor
    _value_counts: Tensor

    def __init__(self, config: NPGConfig, n_envs: int, n_steps: int, n_states: int, state_dim: int, n_actions: int, sample: bool = False, device: DeviceLikeType | None = None):
        super().__init__(n_envs, n_steps, state_dim, n_states, n_actions, device)

        self.config = config
        self.sample = sample

        square_len = int(sqrt(self.n_states))
        self._decode_state = torch.tensor([square_len, 1.0], device=self.device)

        self.reinitialize()

    def reinitialize(self) -> None:
        self._policy = torch.rand((self.n_envs, self.n_states, self.n_actions), device=self.device, requires_grad=True)

        self._value_totals = torch.zeros((self.n_envs, self.n_states), device=self.device)
        self._value_counts = torch.zeros((self.n_envs, self.n_states), device=self.device, dtype=torch.int)

    def sample_actions(self, states: Tensor) -> Tensor:
        states = torch.einsum("es,s->e", states.float(), self._decode_state).int()
        probs = softmax(self._policy[torch.arange(self.n_envs), states], dim=-1)
        actions = torch.multinomial(probs, 1)[..., 0]

        return one_hot(actions, self.n_actions).float()

    def update(
        self, datasets: list[MDPDatasetTorch], adv_train_config: AdversarialTrainingConfig | None = None, *, damping: float = 1e-3
    ) -> tuple[list[dict[str, Tensor]], dict[str, Tensor]]:
        lr = self.config.lr
        gamma = self.config.gamma
        lam = self.config.lam

        if isinstance(datasets, MDPDatasetTorch):
            datasets = [datasets]

        values = torch.zeros((self.n_envs, self.n_states), device=datasets[0].device)
        value_counts = torch.zeros((self.n_envs, self.n_states), device=datasets[0].device, dtype=torch.int)

        for dataset in datasets:
            states = torch.einsum("eis,s->ei", dataset.states.detach(), self._decode_state).int()
            discounted_rewards = get_discounted_rewards(dataset, gamma)

            for i in range(dataset.n_steps):
                values[torch.arange(self.n_envs), states[:, i]] += discounted_rewards[:, i]
                value_counts[torch.arange(self.n_envs), states[:, i]] += 1

            self._value_totals += values
            self._value_counts += value_counts

        values = self._value_totals / self._value_counts
        values[value_counts == 0] = 0

        advantages = torch.zeros((self.n_envs, self.n_states, self.n_actions), device=datasets[0].device)
        advantages_counts = torch.zeros((self.n_envs, self.n_states, self.n_actions), device=dataset.device, dtype=torch.int)

        for dataset in datasets:
            states = torch.einsum("eis,s->ei", dataset.states.detach(), self._decode_state).int()
            actions = dataset.actions.detach().argmax(dim=2)

            step_values = values[torch.arange(self.n_envs)[:, None], states]

            step_advantages = gae(dataset, step_values, gamma, lam, dataset.context_len)
            step_advantages = normalize(step_advantages)
            for i in range(dataset.n_steps):
                advantages[torch.arange(self.n_envs)[:, None], states[:, i], actions[:, i]] += step_advantages[:, i]
                advantages_counts[torch.arange(self.n_envs)[:, None], states[:, i], actions[:, i]] += 1

        advantages = advantages / advantages_counts

        log_probs = log_softmax(self._policy[torch.arange(self.n_envs)[:, None], states, :], dim=-1)
        selected_log_probs = torch.sum(log_probs * dataset.actions.detach(), dim=-1)
        policy_gradient_loss = torch.sum(step_advantages * selected_log_probs)

        policy_gradient = torch.autograd.grad(policy_gradient_loss, self._policy, create_graph=True)[0]

        grad = torch.empty((self.n_envs, self.n_states, self.n_actions), device=dataset.device)
        for env in range(self.n_envs):
            fisher = compute_fisher_matrix(self._policy[env], states[env])
            fisher += damping * torch.eye(fisher.shape[0], device=self.device)
            grad[env] = torch.linalg.solve(fisher, policy_gradient[env].flatten()).reshape(self.n_states, self.n_actions)

        with torch.no_grad():
            policy_new = self._policy + lr * grad
            policy_diff = torch.norm(policy_new - self._policy, dim=(1, 2))
            self._policy = policy_new.requires_grad_(True)

        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        self._policy.grad = None

        metrics = [
            {
                "train/policy_difference": policy_diff.mean(),
            }
        ]
        aux = {
            "values": values,
            "advantages": advantages,
        }

        return metrics, aux


class MDPQLearningController(MDPController):
    config: QLearningConfig
    square_len: int

    _loss_fn: Callable[[Tensor, Tensor], Tensor]
    _optimizer: torch.optim.Optimizer
    _decode_state: Tensor
    _qvalues: Tensor

    current_step: int

    def __init__(self, config: QLearningConfig, n_envs: int, n_steps: int, n_states: int, state_dim: int, n_actions: int, device: DeviceLikeType | None = None):
        super().__init__(n_envs, n_steps, state_dim, n_states, n_actions, device)

        self.config = config

        self.square_len = int(sqrt(self.n_states))
        self._decode_state = torch.tensor([self.square_len, 1.0], device=self.device)

        self._loss_fn = torch.nn.SmoothL1Loss()

        self.reinitialize()

    def reinitialize(self) -> None:
        self._qvalues = torch.nn.Parameter(torch.rand((self.n_envs, self.square_len**2, self.n_actions), device=self.device, requires_grad=True))
        self._optimizer = torch.optim.AdamW([self._qvalues], lr=self.config.lr, weight_decay=1e-4)
        self.current_step = 0

    def get_current_frac_greedy(self) -> Tensor:
        start = self.config.frac_greedy_start
        end = self.config.frac_greedy_end
        decay_factor = self.config.frac_greedy_decay_factor

        threshold = end + (start - end) * math.exp(-1.0 * self.current_step / decay_factor)
        greedy_probability = torch.rand((self.n_envs,), device=self.device)

        return greedy_probability > threshold

    def sample_actions(self, states: Tensor) -> Tensor:
        actions = torch.randint(0, self.n_actions, (self.n_envs,), device=self.device)
        greedy = self.get_current_frac_greedy()

        states = torch.einsum("es,s->e", states.float(), self._decode_state).int()
        qvals = self._qvalues[torch.arange(self.n_envs), states]
        actions_greedy = qvals.argmax(-1)

        actions[greedy] = actions_greedy[greedy]

        self.current_step += 1

        return one_hot(actions, self.n_actions).float()

    def estimate_qvals(self, states_decoded: Tensor, actions: Tensor) -> Tensor:
        return torch.einsum("eia,eia->ei", self._qvalues[torch.arange(self.n_envs)[:, None], states_decoded, :], actions)

    def update(self, datasets: list[MDPDatasetTorch], adv_train_config: AdversarialTrainingConfig | None = None) -> tuple[list[dict[str, Any]], dict[str, Any]]:
        gamma = self.config.gamma

        if isinstance(datasets, MDPDatasetTorch):
            datasets = [datasets]

        total_loss = 0.0

        for dataset in datasets:
            states = dataset.states
            states = torch.einsum("eis,s->ei", states.float(), self._decode_state).int()
            next_states = states[:, 1:]
            states = states[:, :-1]

            actions = dataset.actions[:, :-1]
            rewards = dataset.rewards[:, :-1]

            qvals = self.estimate_qvals(states, actions)
            with torch.no_grad():
                qvals_next = self.estimate_qvals(next_states, actions)

            loss = self._loss_fn(qvals, rewards + gamma * qvals_next.detach())
            total_loss += loss.item()

            self._optimizer.zero_grad()
            loss.backward()
            self._optimizer.step()

        metrics = [
            {
                "train/loss": total_loss,
            }
        ]
        aux = {}

        return metrics, aux


class MDPImageRandController(MDPController):
    _dataset: MDPDatasetImages

    def __init__(self, n_envs: int, n_steps: int, state_dim: int, n_states: int, n_actions: int, device: str | torch.device | int | None = None):
        super().__init__(n_envs, n_steps, state_dim, n_states, n_actions, device)
        self._dataset = MDPDatasetImages(n_envs, n_steps, n_states, state_dim, n_actions, device)

    def clear_dataset(self) -> None:
        self._dataset.clear()

    def append(
        self, states: Tensor, images: Any, actions: Tensor, rewards: Tensor, states_next: Tensor, images_next: Any, rewards_original: Tensor, extras: dict[str, Any] = {}
    ) -> None:
        self._dataset.append(states, images, actions, rewards, states_next, images_next, rewards_original, extras=extras)

    def sample_actions(self, states: Tensor, images: Any) -> Tensor:
        action_indices = torch.randint(0, self.n_actions, (self.n_envs,), device=self.device)

        return one_hot(action_indices, self.n_actions).to(torch.float32)


class MDPImageQLearningFAController(MDPController):
    config: QLearningConfig

    _loss_fn: Callable[[Tensor, Tensor], Tensor]
    _optimizer: torch.optim.Optimizer

    current_step: int

    def init_dataset(self, n_envs: int, n_steps: int, n_states: int, state_dim: int, n_actions: int, device: DeviceLikeType | None = None) -> None:
        self._dataset = MDPDatasetImages(n_envs, n_steps, n_states, state_dim, n_actions, device)

    def __init__(self, config: QLearningConfig, n_envs: int, n_steps: int, n_actions: int, device: DeviceLikeType | None = None):
        super().__init__(n_envs, n_steps, 2, 0, n_actions, device)

        self.init_dataset(n_envs, n_steps, 0, 2, n_actions, device)

        self.config = config

        self._loss_fn = torch.nn.SmoothL1Loss()

        self.reinitialize()

    def append(
        self, states: Tensor, images: Any, actions: Tensor, rewards: Tensor, states_next: Tensor, images_next: Any, rewards_original: Tensor, extras: dict[str, Any] = {}
    ) -> None:
        self._dataset.append(states, images, actions, rewards, states_next, images_next, rewards_original, extras=extras)

    def reinitialize(self) -> None:
        image_size = 25
        size = image_size
        size = (size - 3) // 2 + 1
        size = (size - 3) // 1 + 1
        image_embedding_dim = 8

        self.nenv_image_encoder = nn.Sequential(
            BatchedConv2D(self.n_envs, 3, 16, kernel_size=3, stride=2),
            nn.ReLU(),
            BatchedConv2D(self.n_envs, 16, 16, kernel_size=3, stride=1),
            nn.ReLU(),
            Reshape((self.n_envs, -1, int(16 * size * size))),
            SwapDims(-1, -2),
            BatchedLinear(self.n_envs, int(16 * size * size), image_embedding_dim),
            nn.ReLU(),
        ).to(self.device)
        self.nenv_pred_qvals = BatchedLinear(self.n_envs, image_embedding_dim, self.n_actions).to(self.device)

        self._optimizer = torch.optim.AdamW([*self.nenv_image_encoder.parameters(), *self.nenv_pred_qvals.parameters()], lr=self.config.lr, weight_decay=1e-4)
        self.current_step = 0

    def get_current_frac_greedy(self) -> Tensor:
        start = self.config.frac_greedy_start
        end = self.config.frac_greedy_end
        decay_factor = self.config.frac_greedy_decay_factor

        threshold = end + (start - end) * math.exp(-1.0 * self.current_step / decay_factor)
        greedy_probability = torch.rand((self.n_envs,), device=self.device)

        return greedy_probability > threshold

    def get_qvals(self, images: list[np.ndarray], process=True) -> Tensor:
        # print(f"\n{len(images)}, {images[0].shape=}")
        if process:
            query_images = process_miniworld_images(images).to(device=self.device)
        else:
            query_images = images
        hidden_state = self.nenv_image_encoder(query_images)
        qvals = self.nenv_pred_qvals(hidden_state)

        return qvals.squeeze(-1)

    def sample_actions(self, states: Tensor, images: list[np.ndarray]) -> Tensor:
        actions = torch.randint(0, self.n_actions, (self.n_envs,), device=self.device)
        greedy = self.get_current_frac_greedy()

        qvals = self.get_qvals(images)
        actions_greedy = qvals.argmax(-1)

        actions[greedy] = actions_greedy[greedy]

        self.current_step += 1

        return one_hot(actions, self.n_actions).float()

    def estimate_qvals(self, images: Tensor, actions: Tensor) -> Tensor:
        qvals = self.get_qvals(images, process=False)
        qvals_actions = torch.einsum("...sa,...sa->...s", qvals.reshape_as(actions), actions)
        return qvals_actions

    def update(self, datasets: list[MDPDatasetImagesTorch], adv_train_config: AdversarialTrainingConfig | None = None) -> tuple[list[dict[str, Any]], dict[str, Any]]:
        gamma = self.config.gamma

        if isinstance(datasets, MDPDatasetImagesTorch):
            datasets = [datasets]

        total_loss = 0.0

        for dataset in datasets:
            states = dataset._images

            next_states = states[:, 1:]
            states = states[:, :-1]

            next_actions = dataset.actions[:, 1:]
            actions = dataset.actions[:, :-1]
            rewards = dataset.rewards[:, :-1]

            qvals = self.estimate_qvals(states, actions)
            with torch.no_grad():
                qvals_next = self.estimate_qvals(next_states, next_actions)

            loss = self._loss_fn(qvals, rewards + gamma * qvals_next.detach())
            total_loss += loss.item()

            self._optimizer.zero_grad()
            loss.backward()
            self._optimizer.step()

        metrics = [
            {
                "train/loss": total_loss,
            }
        ]
        aux = {}

        return metrics, aux
