import tempfile
from pathlib import Path
from functools import partial
from typing import Optional, Union, Any
from collections.abc import MutableMapping

from einops import rearrange, pack, unpack
from beartype import beartype
from omegaconf import OmegaConf, DictConfig
import wandb
import numpy as np
import torch
from torch.optim import Adam
from torch.nn import functional as ff
from torch.nn.utils import clip_grad
from torch import autograd
from tensordict import TensorDict
from tensordict.nn import TensorDictModule
from torchrl.data import ReplayBuffer
from geomloss import SamplesLoss

from helpers import logger
from agents.nets import (
    log_module_info, Actor, TanhGaussActor, Critic, Discriminator, RandomPredictor)
from agents.losses import HLGaussLoss, compute_mmd_loss
from agents.pwil import PWILRewarder

from agents.diffusion import DiffusionDiscriminator


class Agent(object):

    @beartype
    def __init__(self,
                 net_shapes: dict[str, tuple[int, ...]],
                 min_ac: np.ndarray,
                 max_ac: np.ndarray,
                 device: torch.device,
                 hps: MutableMapping[Any, Any],
                 generator: torch.Generator,
                 rb: Optional[ReplayBuffer] = None,
                 expert_dataset: Optional[ReplayBuffer] = None,
                 expert_atoms: Optional[list[TensorDict]] = None,
                 all_expert_atoms: Optional[list[TensorDict]] = None):

        ob_shape = net_shapes["ob_shape"]
        ac_shape = net_shapes["ac_shape"]

        self.device = device

        self.min_ac = torch.tensor(min_ac, dtype=torch.float, device=self.device)
        self.max_ac = torch.tensor(max_ac, dtype=torch.float, device=self.device)

        assert isinstance(hps, DictConfig)
        self.hps = hps

        # for reward assembly
        self.constant_one = torch.ones((1,), device=self.device)

        # for gradient penalty
        self.grad_outputs = rearrange(
            torch.ones(self.hps.batch_size, device=self.device),
            "b -> b 1",
        )

        self.timesteps_so_far = 0
        self.actor_updates_so_far = 0
        self.qnet_updates_so_far = 0
        self.reward_updates_so_far = 0

        self.best_eval_ep_ret = -float("inf")  # updated in orchestrator

        assert self.hps.segment_len <= self.hps.batch_size
        if self.hps.clip_norm <= 0:
            logger.info("clip_norm <= 0, hence disabled")
        assert 0. <= float(self.hps.label_smooth) <= 1.

        # replay buffer
        self.rb = rb

        # expert dataset
        self.expert_dataset = expert_dataset

        # expert atoms
        if self.hps.method == "pwil":  # tools for training only
            assert expert_atoms is not None
            # pwil rewarder
            self.pwil_rewarder = [
                PWILRewarder(
                    self.vectorize_expert_atoms(expert_atoms),
                    self.device,
                    self.hps.input_mode,
                    ob_shape,
                    ac_shape,
                    horizon=1000,
                )
                for _ in range(self.hps.num_envs)
            ]

        if self.hps.wasserstein_two:  # tools for eval only
            assert all_expert_atoms is not None
            self.vectorized_expert_atoms = self.vectorize_expert_atoms(all_expert_atoms)  # all
            self.avg_ = self.vectorized_expert_atoms.mean(
                dim=0, keepdim=True)
            self.std_ = self.vectorized_expert_atoms.std(
                dim=0, keepdim=True) + 1e-6
            self.vectorized_expert_atoms -= self.avg_
            self.vectorized_expert_atoms /= self.std_

        # create online and target nets

        actor_net_args = [ob_shape, ac_shape, (256, 256), self.min_ac, self.max_ac]
        actor_net_kwargs = {"layer_norm": self.hps.layer_norm}
        if self.hps.prefer_td3_over_sac:
            actor_net_kwargs.update({"exploration_noise": self.hps.actor_noise_std})
        else:
            actor_net_kwargs.update({"generator": generator})

        self.actor = (Actor if self.hps.prefer_td3_over_sac else TanhGaussActor)(
            *actor_net_args, **actor_net_kwargs, device=self.device)
        self.actor_params = TensorDict.from_module(self.actor, as_module=True)
        self.actor_target = self.actor_params.data.clone()

        # discard params of net
        self.actor = (Actor if self.hps.prefer_td3_over_sac else TanhGaussActor)(
            *actor_net_args, **actor_net_kwargs, device="meta")
        self.actor_params.to_module(self.actor)
        self.actor_detach = (Actor if self.hps.prefer_td3_over_sac else TanhGaussActor)(
            *actor_net_args, **actor_net_kwargs, device=self.device)

        # copy params to actor_detach without grad
        TensorDict.from_module(self.actor).data.to_module(self.actor_detach)
        if self.hps.prefer_td3_over_sac:
            self.policy = TensorDictModule(
                self.actor_detach.exploit,
                in_keys=["observations"],
                out_keys=["action"],
            )
            self.policy_explore = TensorDictModule(
                self.actor_detach.explore,
                in_keys=["observations"],
                out_keys=["action"],
            )
        else:
            self.policy = TensorDictModule(
                self.actor_detach.get_action,
                in_keys=["observations"],
                out_keys=["mode"],
            )
            self.policy_explore = TensorDictModule(
                self.actor_detach.get_action,
                in_keys=["observations"],
                out_keys=["sample"],
            )

        if self.hps.compile:
            self.policy = torch.compile(self.policy, mode=None)
            self.policy_explore = torch.compile(self.policy_explore, mode=None)

        qnet_net_args = [ob_shape, ac_shape, (256, 256)]
        qnet_net_kwargs = {"layer_norm": self.hps.layer_norm}

        self.qnet1 = Critic(*qnet_net_args, **qnet_net_kwargs, device=self.device)
        self.qnet2 = Critic(*qnet_net_args, **qnet_net_kwargs, device=self.device)
        self.qnet_params = TensorDict.from_modules(self.qnet1, self.qnet2, as_module=True)
        self.qnet_target = self.qnet_params.data.clone()
        # discard params of net
        self.qnet = Critic(*qnet_net_args, **qnet_net_kwargs, device="meta")
        self.qnet_params.to_module(self.qnet)

        reward_net_args = [
            ob_shape, ac_shape, (256, 256), self.hps.input_mode, self.hps.activation]
        reward_net_kwargs_keys = ["spectral_norm", "dropout"]
        reward_net_kwargs = {k: getattr(self.hps, k) for k in reward_net_kwargs_keys}

        match self.hps.method:

            case "ngt":

                # loss
                out_scale, num_bins = None, 1

                match self.hps.ngt_loss:

                    case "l1":
                        self.rnet_criterion = {
                            "p": partial(ff.l1_loss, reduction="none"),
                            "e": partial(ff.l1_loss, reduction="none"),
                        }
                    case "ln-cosh":
                        self.rnet_criterion = {
                            "p": lambda x, y: torch.log(torch.cosh(x - y)),
                            "e": lambda x, y: torch.log(torch.cosh(x - y)),
                        }
                    case "huber":
                        self.rnet_criterion = {
                            "p": partial(ff.huber_loss, reduction="none"),
                            "e": partial(ff.huber_loss, reduction="none"),
                        }
                    case "huber-softmax":
                        self.rnet_criterion = {
                            "p": lambda x, y: partial(ff.huber_loss, reduction="none")(
                                    ff.softmax(x, dim=1), ff.softmax(y, dim=1)),
                            "e": lambda x, y: partial(ff.huber_loss, reduction="none")(
                                    ff.softmax(x, dim=1), ff.softmax(y, dim=1)),
                        }
                    case "mse":
                        self.rnet_criterion = {
                            "p": partial(ff.mse_loss, reduction="none"),
                            "e": partial(ff.mse_loss, reduction="none"),
                        }
                    case "mse-softmax":
                        self.rnet_criterion = {
                            "p": lambda x, y: 0.5 * partial(ff.mse_loss, reduction="none")(
                                    ff.softmax(x, dim=1), ff.softmax(y, dim=1)),
                            "e": lambda x, y: 0.5 * partial(ff.mse_loss, reduction="none")(
                                    ff.softmax(x, dim=1), ff.softmax(y, dim=1)),
                        }
                    case "hl-gauss":
                        out_scale = self.hps.hlgauss_minmax_value
                        num_bins = int(self.hps.hlgauss_num_bins)
                        hlgauss_loss = partial(
                            HLGaussLoss,
                            min_value=-out_scale,
                            max_value=+out_scale,
                            num_bins=num_bins,
                            device=self.device,
                        )
                        self.rnet_criterion = {
                            "p": hlgauss_loss(sigma=self.hps.hlgauss_sigma_p),
                            "e": hlgauss_loss(sigma=self.hps.hlgauss_sigma_e),
                        }
                    case _:
                        raise ValueError("invalid NGT loss")

                self.predictor = RandomPredictor(
                    *[*reward_net_args,
                      self.hps.out_size * num_bins,
                      out_scale],
                    **{**reward_net_kwargs,
                       "v2": self.hps.v2,
                       "dropout": self.hps.dropout},
                    device=self.device,
                )
                self.prior = RandomPredictor(
                    *[*reward_net_args,
                      self.hps.out_size,
                      out_scale],
                    **{**reward_net_kwargs,
                       "v2": self.hps.v2,
                       "dropout": False},
                    make_untrainable=True,
                    device=self.device,
                )

            case "samdac" | "mmd-samdac" | "w-samdac":
                self.discriminator = Discriminator(
                    *reward_net_args, **reward_net_kwargs, device=self.device)
            case "diffail":
                self.discriminator = DiffusionDiscriminator(
                    ob_shape, ac_shape, self.max_ac, self.hps.input_mode, device=self.device)
            case "pwil" | "bc" | "random":
                pass
            case _:
                raise ValueError("invalid method")

        self.reward = TensorDictModule(
            self.compute_reward,
            in_keys=["observations", "actions", "next_observations"],
            out_keys=["rewards"],
        )

        # set up the optimizers

        self.q_optimizer = Adam(
            self.qnet.parameters(),
            lr=self.hps.qnets_lr,
            capturable=self.hps.cudagraphs and not self.hps.compile,
        )
        self.actor_optimizer = Adam(
            self.actor.parameters(),
            lr=self.hps.actor_lr,
            capturable=self.hps.cudagraphs and not self.hps.compile,
        )

        if not self.hps.prefer_td3_over_sac:
            # setup log(alpha) if SAC is chosen
            self.log_alpha = torch.as_tensor(self.hps.alpha_init, device=self.device).log()

            if self.hps.autotune:
                # create learnable Lagrangian multiplier
                # common trick: learn log(alpha) instead of alpha directly
                self.log_alpha.requires_grad = True
                self.targ_ent = -ac_shape[-1]  # set target entropy to -|A|
                self.alpha_optimizer = Adam(
                    [self.log_alpha],
                    lr=self.hps.log_alpha_lr,
                    capturable=self.hps.cudagraphs and not self.hps.compile,
                )

        log_module_info(self.actor)
        log_module_info(self.qnet1)
        log_module_info(self.qnet2)

        if self.hps.method in {"pwil", "bc", "random"}:
            pass  # not a "learned reward" method
        else:
            self.reward_optimizer = Adam(
                (
                    rnet := (  # walrus for the module logger below
                        self.predictor
                        if self.hps.method == "ngt"
                        else self.discriminator
                    )
                ).parameters(),
                lr=self.hps.reward_lr,
                capturable=self.hps.cudagraphs and not self.hps.compile,
            )

            log_module_info(rnet)

    @beartype
    def vectorize_expert_atoms(self, expert_atoms: list[TensorDict]) -> torch.Tensor:
        """Aggregate all the transitions from the demos into one vector"""
        vector_list = []

        for td in expert_atoms:  # TensorDicts

            match self.hps.input_mode:
                case "sa":
                    vec = torch.cat(
                        [
                            td["observations"],
                            td["actions"],
                        ],
                        dim=1,
                    )  # shape: [T, ob_dim + ac_dim]
                case "ss":
                    vec = torch.cat(
                        [
                            td["observations"],
                            td["next_observations"],
                        ],
                        dim=1,
                    )  # shape: [T, 2 * ob_dim]
                case "s":
                    vec = td["observations"]
                case _:
                    raise ValueError("invalid input mode")

            vector_list.append(vec)

        return torch.cat(vector_list, dim=0)

    @beartype
    def distance_to_expert(self, trajectory: dict[str, np.ndarray]) -> Optional[torch.Tensor]:
        """Compute the Wasserstein-2 distance between the agent and the expert.
        N.B.: despite taking a trajectory as input, the boundaries between
        trajectories are not available, because the expert trajectories are merged
        into a single vector of transitions.
        """
        with torch.no_grad():  # only ever used for eval
            # prepare agent atom
            observations = torch.as_tensor(
                trajectory["observations"], dtype=torch.float, device=self.device)
            vectorized_agent_atoms = observations.squeeze(dim=1)
            match self.hps.input_mode:
                case "sa":
                    actions = torch.as_tensor(
                        trajectory["actions"], dtype=torch.float, device=self.device)
                    vectorized_agent_atoms = torch.cat(
                        [
                            vectorized_agent_atoms,
                            actions.squeeze(dim=1),
                        ],
                        dim=-1,
                    )
                case "ss":
                    next_observations = torch.as_tensor(
                        trajectory["next_observations"], dtype=torch.float, device=self.device)
                    vectorized_agent_atoms = torch.cat(
                        [
                            vectorized_agent_atoms,
                            next_observations.squeeze(dim=1),
                        ],
                        dim=-1,
                    )
                case "s":
                    pass
                case _:
                    raise ValueError("invalid input mode")

            # handle cases where we want to abandon computation
            if vectorized_agent_atoms.size(dim=0) == 0:
                return None
            if torch.isnan(vectorized_agent_atoms).any():
                return None
            if torch.isinf(vectorized_agent_atoms).any():
                return None

            # normalize vectorized agent atoms with the stats
            # computed from the vectorized expert atoms
            vectorized_agent_atoms -= self.avg_
            vectorized_agent_atoms /= self.std_

            # create the criterion object using Geomloss
            # we use the Wasserstein-2 distance (approx with the Sinkhorn algorithm)
            sinkhorn_ent_reg = 0.05
            criterion = SamplesLoss("sinkhorn", p=2, blur=sinkhorn_ent_reg)
            # doc: by default, uses constant (i.e. uniform) weights = 1/number of samples

            # compute the criterion estimate from samples
            return criterion(
                rearrange(
                    vectorized_agent_atoms,
                    "t d -> 1 t d",
                ),
                rearrange(
                    self.vectorized_expert_atoms,
                    "t d -> 1 t d",
                ),
            )

    @beartype
    def batched_qf(self,
                   params: TensorDict,
                   ob: torch.Tensor,
                   ac: torch.Tensor,
                   next_q_value: Optional[torch.Tensor] = None) -> torch.Tensor:
        """Use two qnet networks from params"""
        with params.to_module(self.qnet):
            vals = self.qnet(ob, ac)
            if next_q_value is not None:
                return ff.mse_loss(vals.view(-1), next_q_value)
            return vals

    @beartype
    def pi(self, params: TensorDict, ob: torch.Tensor) -> torch.Tensor:
        """Use an actor network from params"""
        with params.to_module(self.actor):
            return self.actor(ob)

    @property
    @beartype
    def alpha(self) -> Optional[torch.Tensor]:
        if not self.hps.prefer_td3_over_sac:
            return self.log_alpha.exp()
        return None

    @beartype
    def predict(self, in_td: TensorDict, *, explore: bool) -> np.ndarray:
        """Predict with policy, with or without perturbation"""

        out_td = self.policy_explore(in_td) if explore else self.policy(in_td)

        if self.hps.prefer_td3_over_sac:
            action = out_td["action"]
        else:
            action = out_td["sample" if explore else "mode"]

        action.clamp(self.min_ac, self.max_ac)
        return action.cpu().numpy()

    @beartype
    def update_qnets(self, batch: TensorDict) -> TensorDict:

        self.q_optimizer.zero_grad()

        with torch.no_grad():

            if self.hps.method == "pwil":
                batch["rewards"] = batch["pwil_rewards"]
            else:
                # populate batch with rewards
                batch = self.reward(batch)

            # compute target action
            if self.hps.prefer_td3_over_sac:
                # using TD3
                next_state_log_pi = None
                pi_next_target = self.pi(
                    self.actor_target, batch["next_observations"])  # target actor
                # why use `pi`: we only have a handle on the target actor parameters
                if self.hps.targ_actor_smoothing:
                    n_ = batch["actions"].clone().detach().normal_(0., self.hps.td3_std)
                    n_ = n_.clamp(-self.hps.td3_c, self.hps.td3_c)
                    next_action = (pi_next_target + n_).clamp(self.min_ac, self.max_ac)
                else:
                    next_action = pi_next_target
            else:
                # using SAC
                next_action, next_state_log_pi, _ = self.actor.get_action(
                    batch["next_observations"]).values()

            qf_next_target = torch.vmap(self.batched_qf, (0, None, None))(
                self.qnet_target, batch["next_observations"], next_action,
            )

            qf_min = qf_next_target.min(0).values
            if self.hps.bcq_style_targ_mix:
                # use BCQ style of target mixing: soft minimum
                qf_max = qf_next_target.max(0).values
                q_prime = ((0.75 * qf_min) + (0.25 * qf_max))
            else:
                # use TD3 style of target mixing: hard minimum
                q_prime = qf_min

            if not self.hps.prefer_td3_over_sac:  # only for SAC
                # add the causal entropy regularization term
                q_prime -= self.alpha * next_state_log_pi

            # assemble the Bellman target
            targ_q = batch["rewards"].flatten() + (
                ~batch["dones"].flatten()
            ).float() * self.hps.gamma * q_prime.view(-1)

        qf_a_values = torch.vmap(self.batched_qf, (0, None, None, None))(
            self.qnet_params, batch["observations"], batch["actions"], targ_q,
        )
        qf_loss = qf_a_values.sum(0)

        qf_loss.backward()
        self.q_optimizer.step()

        return TensorDict(
            {
                "loss/qf_loss": qf_loss.detach(),
                **{
                    f"reward/p{str(k).zfill(2)}": batch["rewards"].quantile(q=(k / 100.))
                    for k in [1, 5, 10, 50, 90, 95, 99]
                },
            },
        )

    @beartype
    def update_actor(self, batch: TensorDict) -> TensorDict:

        self.actor_optimizer.zero_grad()

        if self.hps.prefer_td3_over_sac:
            # using TD3
            action_from_actor = self.actor(batch["observations"])
        else:
            # using SAC
            action_from_actor, state_log_pi, _ = self.actor.get_action(
                batch["observations"]).values()
            # here, there are two gradient pathways: the reparam trick makes the sampling process
            # differentiable (pathwise derivative), and logp is a score function gradient estimator
            # intuition: aren't they competing and therefore messing up with each other's compute
            # graphs? to understand what happens, write down the closed form of the Normal's logp
            # (or see this formula in nets.py) and replace x by mean + eps * std
            # it shows that with both of these gradient pathways, the mean receives no gradient
            # only the std receives some (they cancel out)
            # moreover, if only the std receives gradient, we can expect subpar results if this std
            # is state independent
            # this can be observed here, and has been noted in openai/spinningup
            # in native PyTorch, it is equivalent to using `log_prob` on a sample from `rsample`
            # note also that detaching the action in the logp (using `sample`, and not `rsample`)
            # yields to poor results, showing how allowing for non-zero gradients for the mean
            # can have a destructive effect, and that is why SAC does not allow them to flow.

        if self.hps.prefer_td3_over_sac:
            qf_pi = torch.vmap(self.batched_qf, (0, None, None))(
                self.qnet_params.data, batch["observations"], action_from_actor)
            min_qf_pi = qf_pi[0]
            actor_loss = -min_qf_pi
        else:
            qf_pi = torch.vmap(self.batched_qf, (0, None, None))(
                self.qnet_params.data, batch["observations"], action_from_actor)
            min_qf_pi = qf_pi.min(0).values
            actor_loss = (self.alpha.detach() * state_log_pi) - min_qf_pi
        actor_loss = actor_loss.mean()

        actor_loss.backward()
        if self.hps.clip_norm > 0:
            clip_grad.clip_grad_norm_(self.actor.parameters(), self.hps.clip_norm)
        self.actor_optimizer.step()

        if self.hps.prefer_td3_over_sac:
            return TensorDict(
                {
                    "loss/actor_loss": actor_loss.detach(),
                },
            )

        if self.hps.autotune:
            self.alpha_optimizer.zero_grad()
            with torch.no_grad():
                _, state_log_pi, _ = self.actor.get_action(
                    batch["observations"]).values()
            alpha_loss = (self.alpha * (-state_log_pi - self.targ_ent).detach()).mean()  # alpha

            alpha_loss.backward()
            self.alpha_optimizer.step()

            return TensorDict(
                {
                    "loss/actor_loss": actor_loss.detach(),
                    "loss/alpha_loss": alpha_loss.detach(),
                    "vitals/alpha": self.alpha.detach(),
                },
            )

        return TensorDict(
            {
                "loss/actor_loss": actor_loss.detach(),
                "vitals/alpha": self.alpha.detach(),
            },
        )

    @beartype
    def update_reward(self, p_batch: TensorDict, e_batch: TensorDict) -> TensorDict:

        if self.hps.method != "pwil":
            self.reward_optimizer.zero_grad()

        p_input_a = p_batch["observations"]
        e_input_a = e_batch["observations"]
        match self.hps.input_mode:
            case "ss":
                p_input_b = p_batch["next_observations"]
                e_input_b = e_batch["next_observations"]
            case "sa":
                p_input_b = p_batch["actions"]
                e_input_b = e_batch["actions"]
            case "s":
                p_input_b = None
                e_input_b = None
            case _:
                raise ValueError("invalid input mode")

        match self.hps.method:

            case "ngt":

                p_inputs = (p_input_a, p_input_b)
                e_inputs = (e_input_a, e_input_b)

                p_loss_ = self.rnet_criterion["p"](
                    self.predictor(*p_inputs), self.prior(*p_inputs))
                e_loss_ = self.rnet_criterion["e"](
                    self.predictor(*e_inputs), self.prior(*e_inputs))
                # N.B.: the scores above are not reduced

                # use only the desired proportion of experience per update
                p_loss = p_loss_.mean(dim=-1)
                e_loss = e_loss_.mean(dim=-1)
                p_mask = p_loss.clone().detach().uniform_()
                e_mask = e_loss.clone().detach().uniform_()

                p_mask = (p_mask < self.hps.p_proportion_of_exp_per_update).float()
                e_mask = (e_mask < self.hps.e_proportion_of_exp_per_update).float()
                p_loss = (p_mask * p_loss).sum() / torch.max(self.constant_one, p_mask.sum())
                e_loss = (e_mask * e_loss).sum() / torch.max(self.constant_one, e_mask.sum())

                # squeeze to get 0-dim tensors
                p_loss = p_loss.squeeze()
                e_loss = e_loss.squeeze()

                losses = {
                    "loss/p_loss": p_loss.detach(),
                    "loss/e_loss": e_loss.detach(),
                }

                # gradient descent on expert data; gradient ascent on policy data
                reward_loss = e_loss - (self.hps.advers_p_ascent_scale * p_loss)

            case "samdac":
                # compute scores
                p_scores = self.discriminator(p_input_a, p_input_b)
                e_scores = self.discriminator(e_input_a, e_input_b)

                # entropy loss
                scores, _ = pack([p_scores, e_scores], "* d")  # concat along the batch dim, d is 1
                entropy = ff.binary_cross_entropy_with_logits(
                    input=scores,
                    target=torch.sigmoid(scores),
                )
                entropy_loss = -self.hps.ent_reg_scale * entropy

                # create labels
                fake_labels = 0. * torch.ones_like(p_scores)
                real_labels = 1. * torch.ones_like(e_scores)

                # apply label smoothing to real labels (one-sided label smoothing)
                if (offset := self.hps.label_smooth) != 0:
                    real_labels.uniform_(1. - offset, 1. + offset)
                    logger.debug("applied one-sided label smoothing")

                # binary classification
                p_loss = ff.binary_cross_entropy_with_logits(
                    input=p_scores,
                    target=fake_labels,
                )
                e_loss = ff.binary_cross_entropy_with_logits(
                    input=e_scores,
                    target=real_labels,
                )
                p_e_loss = p_loss + e_loss

                # sum losses
                reward_loss = p_e_loss + entropy_loss

                losses = {
                    "loss/entropy_loss": entropy_loss.detach(),
                    "loss/p_loss": p_loss.detach(),
                    "loss/e_loss": e_loss.detach(),
                    "loss/p_e_loss": p_e_loss.detach(),
                }

                if self.hps.grad_pen:

                    # add gradient penalty to loss
                    grad_pen = self.grad_pen(p_input_a, p_input_b, e_input_a, e_input_b)
                    reward_loss += (self.hps.grad_pen_scale * grad_pen)
                    losses.update(
                        {
                            "loss/grad_pen": grad_pen.detach(),
                        },
                    )

            case "diffail":
                p_inputs = torch.cat([p_input_a, p_input_b], dim=1)
                e_inputs = torch.cat([e_input_a, e_input_b], dim=1)
                p_scores = self.discriminator.loss(p_inputs, disc_ddpm=True)
                e_scores = self.discriminator.loss(e_inputs, disc_ddpm=True)
                p_scores = rearrange(p_scores, "b -> b 1")
                e_scores = rearrange(e_scores, "b -> b 1")

                # create labels
                fake_labels = 0. * torch.ones_like(p_scores)
                real_labels = 1. * torch.ones_like(e_scores)

                # binary classification
                p_loss = ff.binary_cross_entropy(
                    input=p_scores,
                    target=fake_labels,
                )
                e_loss = ff.binary_cross_entropy(
                    input=e_scores,
                    target=real_labels,
                )
                p_e_loss = p_loss + e_loss

                # compute gradient penalty
                grad_pen = self.grad_pen(p_input_a, p_input_b, e_input_a, e_input_b)

                reward_loss = p_e_loss + (self.hps.grad_pen_scale * grad_pen)

                losses = {
                    "loss/p_e_loss": p_e_loss.detach(),
                    "loss/grad_pen": grad_pen.detach(),
                }

            case "mmd-samdac":
                p_scores = self.discriminator(p_input_a, p_input_b)
                e_scores = self.discriminator(e_input_a, e_input_b)

                # compute MMD loss
                mmd_loss = compute_mmd_loss(p_scores, e_scores, self.hps.mmd_sigma)

                # compute gradient penalty
                grad_pen = self.grad_pen(p_input_a, p_input_b, e_input_a, e_input_b)

                reward_loss = -mmd_loss + (self.hps.grad_pen_scale * grad_pen)

                losses = {
                    "loss/mmd_loss": mmd_loss.detach(),
                    "loss/grad_pen": grad_pen.detach(),
                }

            case "w-samdac":
                # compute scores
                p_scores = self.discriminator(p_input_a, p_input_b)
                e_scores = self.discriminator(e_input_a, e_input_b)

                # compute the dual EMD distance (== Wasserstein-1)
                dual_emd = e_scores.mean() - p_scores.mean()

                # compute the reward loss as "minus" the W1 loss
                reward_loss = -dual_emd

                # compute gradient penalty
                grad_pen = self.grad_pen(p_input_a, p_input_b, e_input_a, e_input_b)

                reward_loss += (self.hps.grad_pen_scale * grad_pen)

                losses = {
                    "loss/dual_emd": dual_emd.detach(),
                    "loss/grad_pen": grad_pen.detach(),
                }

            case "pwil":
                # the reward is not learned in this method
                return TensorDict({})

            case _:
                raise ValueError("invalid method")

        reward_loss.backward()
        self.reward_optimizer.step()

        return TensorDict(
            {
                **losses,
            },
        )

    @beartype
    def grad_pen(self,
                 p_input_a: torch.Tensor,
                 p_input_b: torch.Tensor,
                 e_input_a: torch.Tensor,
                 e_input_b: torch.Tensor) -> torch.Tensor:
        """Compute the gradient penalty"""

        # concat the inputs along the last (2nd) dim
        p_input, ps = pack([p_input_a, p_input_b], "b *")
        e_input, ps = pack([e_input_a, e_input_b], "b *")
        # assemble interpolated inputs (point on segment)
        eps = torch.rand(p_input_a.size(0), 1, device=self.device)

        i_input = eps * p_input + ((1. - eps) * e_input)

        # unpack
        u_i_input = unpack(i_input, ps, "b *")
        for e in u_i_input:
            e.requires_grad_(requires_grad=True)

        match self.hps.method:
            case "ngt":
                ofunc = self.predictor
                outputs = ofunc(*u_i_input)
            case "diffail":
                ofunc = partial(self.discriminator.loss, disc_ddpm=True)
                u_i_input = torch.cat(u_i_input, dim=1)
                outputs = ofunc(u_i_input)
                outputs = rearrange(outputs, "b -> b 1")
            case _:
                ofunc = self.discriminator
                outputs = ofunc(*u_i_input)
        grads = autograd.grad(
            inputs=u_i_input,
            outputs=outputs,
            grad_outputs=self.grad_outputs,
            retain_graph=True,
            create_graph=True,
        )
        packed_grads, _ = pack(list(grads), "b *")
        grads_norm = packed_grads.norm(2, dim=-1)

        if self.hps.one_sided_pen:
            # penalize the gradient for having a norm GREATER than k
            grad_pen = torch.max(
                torch.zeros_like(grads_norm),
                grads_norm - self.hps.grad_pen_targ,
            )
        else:
            # penalize the gradient for having a norm LOWER OR GREATER than k
            grad_pen = grads_norm - self.hps.grad_pen_targ

        return grad_pen.pow(2).mean()

    @beartype
    def compute_reward(self,
                       state: torch.Tensor,
                       action: torch.Tensor,
                       next_state: torch.Tensor) -> torch.Tensor:

        input_a = state
        match self.hps.input_mode:
            case "ss":
                input_b = next_state
            case "sa":
                input_b = action
            case "s":
                input_b = None
            case _:
                raise ValueError("invalid input mode")

        match self.hps.method:

            case "ngt":
                self.predictor.eval()

                # compute reward
                inputs = (input_a, input_b)
                with torch.no_grad():
                    p_loss = self.rnet_criterion["e"](
                        self.predictor(*inputs), self.prior(*inputs))
                p_loss = p_loss.mean(dim=-1, keepdim=True)

                # post-processing 1: take negative + blow up to have non-teeny spread
                if self.hps.survivorship:
                    reward = -torch.sign(p_loss) * (
                        torch.exp(
                            torch.abs(p_loss) / self.hps.temperature,
                        ) - 1.
                    )
                else:
                    reward = (-p_loss / self.hps.temperature).exp()

                # post-processing 2: rescale by percentile range
                if self.hps.tighter_percentile_range:
                    perc_lo = reward.quantile(q=0.10)
                    perc_hi = reward.quantile(q=0.90)
                else:
                    perc_lo = reward.quantile(q=0.05)
                    perc_hi = reward.quantile(q=0.95)
                reward /= ((perc_hi - perc_lo) + 1e-8)  # robust to outliers

                # post-processing 3: recenter on percentile mark
                if self.hps.tighter_percentile_range:
                    reward -= reward.quantile(q=0.90 if self.hps.survivorship else 0.10)
                else:
                    reward -= reward.quantile(q=0.95 if self.hps.survivorship else 0.05)

                # (optional) post-processing 4: stretch reward values
                if self.hps.stretch_with_symexp:
                    reward = torch.sign(reward) * (
                        torch.exp(
                            torch.abs(reward) * torch.log(torch.tensor(2.0)),
                        ) - 1.
                    )

                self.predictor.train()
                return reward

            case "samdac":
                self.discriminator.eval()

                # compute score
                with torch.no_grad():
                    score = self.discriminator(input_a, input_b)

                # counterpart of GAN's minimax (also called "saturating") loss
                # numerics: 0 for non-expert-like states, goes to +inf for expert-like states
                # compatible with envs with traj cutoffs for bad (non-expert-like) behavior
                # e.g. walking simulations that get cut off when the robot falls over
                minimax_reward = -torch.log(1. - torch.sigmoid(score) + 1e-8)
                if self.hps.minimax_only:
                    reward = minimax_reward
                else:
                    # counterpart of GAN's non-saturating loss
                    # recommended in the original GAN paper and later in (Fedus et al. 2017)
                    # numerics: 0 for expert-like states, goes to -inf for non-expert-like states
                    # compatible with envs with traj cutoffs for good (expert-like) behavior
                    # e.g. mountain car, which gets cut off when the car reaches the destination
                    non_satur_reward = ff.logsigmoid(score)
                    # return the sum the two previous reward functions (as in AIRL, Fu et al. 2018)
                    # numerics: might be better might be way worse
                    reward = non_satur_reward + minimax_reward
                self.discriminator.train()
                return reward

            case "diffail":
                self.discriminator.eval()
                # concat inputs into one (code snippet used in nets.py)
                if input_b is not None:
                    x, _ = pack([input_a, input_b], "b *")  # concatenate along last dim
                else:
                    x = input_a
                # compute reward
                with torch.no_grad():
                    logits = self.discriminator.calc_reward(x)
                reward = - torch.log(1 - logits + 1e-8)
                self.discriminator.train()
                return reward

            case "mmd-samdac" | "w-samdac":
                self.discriminator.eval()

                # compute score
                with torch.no_grad():
                    reward = self.discriminator(input_a, input_b)

                self.discriminator.train()
                return reward

            case "pwil":
                raise ValueError("should not be here!")

            case _:
                raise ValueError("invalid method")

    @beartype
    def update_targ_nets(self):

        if (self.hps.prefer_td3_over_sac or (
            self.qnet_updates_so_far % self.hps.crit_targ_update_freq == 0)):

            # lerp is defined as x' = x + w (y-x), which is equivalent to x' = (1-w) x + w y

            self.qnet_target.lerp_(self.qnet_params.data, self.hps.polyak)
            if self.hps.prefer_td3_over_sac:
                # using TD3 (SAC does not use a target actor)
                self.actor_target.lerp_(self.actor_params.data, self.hps.polyak)

    @beartype
    def behavioral_cloning(self, batch: TensorDict) -> TensorDict:
        """Train actor with behavioral cloning"""

        self.actor_optimizer.zero_grad()

        if self.hps.prefer_td3_over_sac:
            # using deterministic actor
            action_from_actor = self.actor(batch["observations"])
        else:
            # using stochastic actor
            action_from_actor, _, _ = self.actor.get_action(
                batch["observations"]).values()

        actor_loss = ff.mse_loss(
            action_from_actor, batch["actions"])

        actor_loss.backward()
        self.actor_optimizer.step()

        return TensorDict(
            {
                "loss/bc_loss": actor_loss.detach(),
            },
        )

    @beartype
    def save(self, path: Path, sfx: Optional[str] = None):
        """Save the agent to disk and wandb servers"""
        # prep checkpoint
        fname = (f"ckpt_{sfx}"
                 if sfx is not None
                 else f".ckpt_{self.timesteps_so_far}ts")
        # design choice: hide the ckpt saved without an extra qualifier
        path = (parent := path) / f"{fname}.pth"
        checkpoint = {
            "hps": self.hps,  # handy for archeology
            "timesteps_so_far": self.timesteps_so_far,
            # and now the state_dict objects
            "actor": self.actor.state_dict(),
            "qnet1": self.qnet1.state_dict(),
            "qnet2": self.qnet2.state_dict(),
            "actor_optimizer": self.actor_optimizer.state_dict(),
            "q_optimizer": self.q_optimizer.state_dict(),
        }
        # save checkpoint to filesystem
        torch.save(checkpoint, path)
        logger.info(f"{sfx} model saved to disk")
        if sfx == "best":
            # upload the model to wandb servers
            wandb.save(str(path), base_path=parent)
            logger.warn("model saved to wandb")

    @beartype
    def load_from_disk(self, path: Path):
        """Load another agent into this one"""
        checkpoint = torch.load(path, weights_only=False)
        if "timesteps_so_far" in checkpoint:
            self.timesteps_so_far = checkpoint["timesteps_so_far"]
        # the "strict" argument of `load_state_dict` is True by default
        self.actor.load_state_dict(checkpoint["actor"])
        self.qnet1.load_state_dict(checkpoint["qnet1"])
        self.qnet2.load_state_dict(checkpoint["qnet2"])
        self.actor_optimizer.load_state_dict(checkpoint["actor_optimizer"])
        self.q_optimizer.load_state_dict(checkpoint["q_optimizer"])

    @staticmethod
    @beartype
    def compare_dictconfigs(
        dictconfig1: MutableMapping[Any, Any],
        dictconfig2: MutableMapping[Any, Any],
    ) -> dict[str, dict[str, Union[str, int, list[int], dict[str, Union[str, int, list[int]]]]]]:
        """Compare two DictConfig objects of depth=1 and return the differences.
        Returns a dictionary with keys "added", "removed", and "changed".
        """
        assert isinstance(dictconfig1, DictConfig)
        assert isinstance(dictconfig2, DictConfig)

        differences = {"added": {}, "removed": {}, "changed": {}}

        keys1 = set(dictconfig1.keys())
        keys2 = set(dictconfig2.keys())

        # added keys
        for key in keys2 - keys1:
            differences["added"][key] = dictconfig2[key]

        # removed keys
        for key in keys1 - keys2:
            differences["removed"][key] = dictconfig1[key]

        # changed keys
        for key in keys1 & keys2:
            if dictconfig1[key] != dictconfig2[key]:
                differences["changed"][key] = {
                    "from": dictconfig1[key], "to": dictconfig2[key]}

        return differences

    @beartype
    def load(self, wandb_run_path: str, model_name: str = "ckpt_best.pth"):
        """Download a model from wandb and load it"""
        api = wandb.Api()
        run = api.run(wandb_run_path)
        # compare the current cfg with the cfg of the loaded model
        wandb_cfg_dict: dict[str, Any] = run.config
        wandb_cfg: DictConfig = OmegaConf.create(wandb_cfg_dict)
        a, r, c = self.compare_dictconfigs(wandb_cfg, self.hps).values()
        # N.B.: in Python 3.7 and later, dicts preserve the insertion order
        logger.warn(f"added  : {a}")
        logger.warn(f"removed: {r}")
        logger.warn(f"changed: {c}")
        # create a temporary directory to download to
        with tempfile.TemporaryDirectory() as tmp_dir_name:
            file = run.file(model_name)
            # download the model file from wandb servers
            file.download(root=tmp_dir_name, replace=True)
            logger.warn("model downloaded from wandb to disk")
            tmp_file_path = Path(tmp_dir_name) / model_name
            # load the agent stored in this file
            self.load_from_disk(tmp_file_path)
            logger.warn("model loaded")
