import argparse
import copy
import gymnasium as gym
import logging
import numpy as np
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import namedtuple, deque
from typing import Tuple, List


from .utils import ReplayMemoryBase, ReplayMemory, StochasticTrainingAlgorithm
from .sac import arguments as sac_arguments
from . import utils
from ..delayed_mdp import SimulatedInteractionLayer, StationaryAction, ExtendedState
from latency_env.modules.model import NNPredictiveModel

import latency_env.misc.argparser_types as at
from latency_env.misc.argparser import Argument as Arg
from latency_env.misc.argparser import ArgumentList as ArgList

LOG = logging.getLogger(__name__)
LOG.addHandler(logging.NullHandler())


@torch.no_grad()
def _pad_dim(dim, data, target_len):
    """
    Pads the first dimension such of data such that it is never less than
    the specified target length.
      data: [L*, _]
    """
    assert dim < len(data.shape)
    if data.shape[dim] == target_len:
        return data

    # Note that the F.pad is specified backwards, so that the first two
    # elements of pad actually specify how much that the last dimension
    # should be padded by. Since we only want to pad the first dimension,
    # we first need to say that all dimensions below it should not be
    # padded.
    padpfx = tuple(0 for _ in range(2 * (len(data.shape) - (dim + 1))))
    padlen = max(0, target_len - data.shape[dim])
    padsfx = tuple(0 for _ in range(2 * dim))

    padding = padpfx + (0, padlen) + padsfx

    return F.pad(data, pad=padding)


class MBPACLatentReplayMemory(ReplayMemory):
    """A standard replay buffer."""
    ParameterOrder = ReplayMemory.ParameterOrder + [
        "obs_state",      "mem_actions",      "delay",      "delayshift",      "mem_actions_len",
        "next_obs_state", "next_mem_actions", "next_delay", "next_delayshift", "next_mem_actions_len",
    ]

    def __init__(self, env: SimulatedInteractionLayer, *args, **kwargs):
        super().__init__(*args, env=env, **kwargs)
        self.rb["obs_state"] = torch.zeros((self.maxsize,) + self.s_shape, requires_grad=False, dtype=self.s_type)
        self.rb["delay"] = torch.zeros((self.maxsize,), requires_grad=False, dtype=torch.long)
        self.rb["delayshift"] = torch.zeros((self.maxsize,), requires_grad=False, dtype=torch.long)
        self.rb["next_obs_state"] = torch.zeros((self.maxsize,) + self.s_shape, requires_grad=False, dtype=self.s_type)
        self.rb["next_delay"] = torch.zeros((self.maxsize,), requires_grad=False, dtype=torch.long)
        self.rb["next_delayshift"] = torch.zeros((self.maxsize,), requires_grad=False, dtype=torch.long)

        # These have varying sizes, we need to store these separately as raw objects.
        self.rb["mem_actions"] = np.empty((self.maxsize,), dtype=object)
        self.rb["next_mem_actions"] = np.empty((self.maxsize,), dtype=object)

        self.rb["mem_actions_len"] = torch.zeros((self.maxsize,), requires_grad=False, dtype=torch.long)
        self.rb["next_mem_actions_len"] = torch.zeros((self.maxsize,), requires_grad=False, dtype=torch.long)

        self.get_hooks["mem_actions"] = lambda idxs: self._get_memact_rb("mem_actions", idxs)
        self.get_hooks["next_mem_actions"] = lambda idxs: self._get_memact_rb("next_mem_actions", idxs)
        self.add_hooks["mem_actions"] = lambda data, idxs: self._add_memact_rb("mem_actions", data, idxs)
        self.add_hooks["next_mem_actions"] = lambda data, idxs: self._add_memact_rb("next_mem_actions", data, idxs)

    def _get_memact_rb(self, name, idxs):
        # Note that idxs are usually sampled from np.random.randint
        if not isinstance(idxs, np.ndarray):
            return self.rb[name][idxs]

        sel = self.rb[name][idxs]
        maxlen = np.vectorize(lambda t: t.shape[0])(sel).max()

        return torch.stack([_pad_dim(0, t, maxlen) for t in sel])

    def _add_memact_rb(self, name, data, idxs):
        # Just make sure that it is a tensor, then we're happy
        data = torch.as_tensor(data, dtype=self.a_type) # [L*, A]
        assert data.shape[1:] == self.a_shape
        self.rb[name][idxs] = data


def arguments_hook(args):
    if args.model_batch_size is None:
        args.model_batch_size = args.batch_size

arguments = sac_arguments + ArgList(
    _hooks=[arguments_hook],
    lr_model = Arg("--lr-model", default=1e-5, type=at.posfloat,
                   help="Model learning rate."),
    wd_model = Arg("--wd-model", default=0.0, type=at.nonnegfloat,
                   help="Model weight decay (L2 loss)."),
    optim_model = Arg("--optim-model", default="adam", type=str.lower,
                      help="Model network optimizer.", choices=utils.OPTIMIZERS),
    model_train_window = Arg("--mtw", "--model-train-window", default=16, type=at.posint,
                             help="Model training window, aka the horizon."),
    model_batch_size = Arg("--model-batch-size", default=None, type=at.posint,
                           _auto_default=False,
                           help="Size of minibatches when performing model network "
                                "updates (will default to regular --batch-size "
                                "if unset)."),
    model_clamp_output = Arg("--model-clamp-output", action=argparse.BooleanOptionalAction, default=False,
                             help="Clamp the outputs from the model and the environment."),
    predict_deterministically = Arg("--predict-deterministically", action=argparse.BooleanOptionalAction, default=False,
                                    help="Use a simplified surrogate objective for the policy."),
    delay_memorization = Arg("--delay-memorization", action=argparse.BooleanOptionalAction, default=False,
                             help="Memorize previously dispatched actions to be used for predictions, rather than the reported contents of the action buffer."),
    stall_interaction = Arg("--stall-interaction", action=argparse.BooleanOptionalAction, default=False,
                            help="Stall the interaction with the interaction layer, waiting for acknowledgment that our action has been received before sending the next one."),
)


class MBPAC(StochasticTrainingAlgorithm):
    """
    Model-Based Predictive Actor-Critic (MBPAC)

    (Would like to rename this MDAPH: Model-based Distribution Agent on a Predictive Horizon
     Pronounced as "emdaff". But since MBPAC is in the article we will keep it.)

    Learns a predictive model to estimate future state information, which is
    fed as input to the policy.
    """
    def __init__(self, pi: nn.Module,
                       q1: nn.Module,
                       q2: nn.Module,
                       env: SimulatedInteractionLayer,
                       model: NNPredictiveModel,
                       obs_metric: "pytorch differentiable function",
                       args: "arguments"):
        """
        obs_metric : function
          A function that computes between two observations.
        """
        super().__init__(env=env, args=args)
        self.env = env
        self.p = args
        assert self.p.use_double_q

        if sum([int(self.p.delay_memorization), int(self.p.stall_interaction)]) > 1:
            raise ValueError("Can only specify one of delay memorization, interaction layer stallling, and constant delay assumption.")

        assert isinstance(self.env, SimulatedInteractionLayer), "Expecting back observation packets from an interaction layer"

        # Setup networks
        self.pi = pi
        self.pi_targ = copy.deepcopy(self.pi)
        self.q1 = q1
        self.q1_targ = copy.deepcopy(self.q1)
        self.q2 = q2
        self.q2_targ = copy.deepcopy(self.q2)
        self.model = model
        self.obs_metric = obs_metric

        # Alpha parameter
        self.log_alpha = nn.Parameter(torch.tensor(
            np.log(self.p.initial_alpha),
            dtype=torch.float32,
            device=self.p.device,
        ))

        # Move everything to proper device before creating optimizers
        self.pi = self.pi.to(self.device)
        self.pi_targ = self.pi_targ.to(self.device)
        self.q1 = self.q1.to(self.device)
        self.q1_targ = self.q1_targ.to(self.device)
        self.q2 = self.q2.to(self.device)
        self.q2_targ = self.q2_targ.to(self.device)
        self.model = self.model.to(self.device)

        # Setup optimizers and LR schedulers
        self.pi_opt = utils.make_optimizer(self.pi, optim=args.optim_pi, lr=args.lr_pi)
        self.q_opt = utils.make_optimizer(self.q1, self.q2, optim=args.optim_q, lr=args.lr_q)
        self.alpha_opt = utils.make_optimizer(self.log_alpha, optim=args.optim_alpha, lr=args.lr_alpha)
        self.model_opt = utils.make_optimizer(self.model, optim=args.optim_model, lr=args.lr_model, weight_decay=args.wd_model)

        # Setup replay buffer
        self.replay = MBPACLatentReplayMemory(
            maxsize=self.p.replay_size,
            env=self.env,
        )

        # State on how many steps since we last performed the polyak updates
        self.steps_since_update = self.p.update_delay
        self.updates_since_polyak = self.p.polyak_delay

    def evalcopy(self):
        blob = super().evalcopy()
        blob["pi"] = copy.deepcopy(self.pi).to("cpu")
        blob["q1"] = copy.deepcopy(self.q1).to("cpu")
        blob["q2"] = copy.deepcopy(self.q2).to("cpu")
        blob["env"] = self.env
        blob["model"] = copy.deepcopy(self.model).to("cpu")
        blob["obs_metric"] = self.obs_metric
        blob["p"] = self.p
        return blob

    @classmethod
    def from_evalcopy(obj, blob, device=None):
        if device is not None:
            blob = copy.copy(blob)
            blob["p"] = copy.copy(blob["p"])
            blob["p"].device = device
        return obj(
            pi=blob["pi"],
            q1=blob["q1"],
            q2=blob["q2"],
            env=blob["env"],
            model=blob["model"],
            obs_metric=blob["obs_metric"],
            args=blob["p"],
        )

    def modules(self):
        return (self.pi, self.pi_targ, self.q1, self.q1_targ, self.q2, self.q2_targ, self.model)

    @torch.no_grad()
    def clamp(self, x):
        """Clamps the extended state x according to the metric."""
        if self.p.model_clamp_output:
            cs = self.obs_metric.clamp(torch.as_tensor(x.s_obs)).numpy()
        else:
            cs = x.s_obs
        return x._replace(s_obs=cs)

    @torch.no_grad()
    def memorized_action_predictor(
            self,
            actpkt_stack: List[Tuple["t", "Actions"]],
            delay: int,
            ):
        """
        Constructs the memorized trajectory of guesses. This is a direct
        implementation of the memorized action predictor from the paper.
        """
        assert delay > 0

        actlist = []
        for i in reversed(range(delay)):
            if i <= len(actpkt_stack):
                # i+1 rationale: [-delay,-(delay-1),...,-1]
                (_, M) = actpkt_stack[max(-(i+1), -1)]
            else:
                # This uses the default action with t = -1
                (t, M) = actpkt_stack[0]
                assert t == -1

            # delay-1 and 0 is because of 0-indexing. Would be delay,1 if 1-indexing as in the paper.
            actlist.append(M[delay-1][0])

        return np.stack(actlist) # [delay, A]

    @torch.no_grad()
    def sample_action(self, state_action_info: Tuple["ExtendedState", "ActionTrace"],
                      random=False,
                      deterministic=False,
                      with_raw_sample=False,
                      with_logprob=False,
                      with_info=False,
                      **kwargs):
        """
        Convert the delayed state into a sequence of observed states and actions.
        """
        x, actpkt_stack = state_action_info
        s = x.s_obs

        # Action trace is not used for Naive PAID

        s_obs = torch.as_tensor(x.s_obs, dtype=torch.float32, device=self.device).unsqueeze(0) # [1, |S|]
        a_mem = torch.as_tensor(x.a_mem, dtype=torch.float32, device=self.device).unsqueeze(0) # [1, L, |A|]

        (_, S) = s_obs.shape
        (_, L, A) = a_mem.shape
        assert L > 0

        assert not with_raw_sample, "Raw sample not yet supported for PAID"
        assert not with_logprob, "Log prob not yet supported for PAID"
        assert not with_info, "With info not yet supported for PAID"

        if self.p.stall_interaction:
            # If we stall the interaction, then we are only generating actions
            # for when we know that the action buffer has fresh information,
            # i.e. when it hasn't shifted.
            if x.delayshift != 0:
                return StationaryAction

        if random:
            return (
                x.t,
                np.array([
                    [
                        StochasticTrainingAlgorithm.sample_action(self, s,
                                                                  random=True,
                                                                  with_raw_sample=False,
                                                                  with_logprob=False,
                                                                  with_info=False)
                        for h in range(L)
                    ]
                    for d in range(L)
                ]) # [L, L, A]
            )

        if self.p.delay_memorization:
            # Collect the memorized predictions for each delay in a matrix
            a_pred = torch.stack([
                _pad_dim(0,
                    torch.as_tensor(
                        self.memorized_action_predictor(actpkt_stack, d),
                        dtype=torch.float32,
                        device=self.device,
                    ), # [d, A]
                    target_len=L,
                ) # [L, A]
                for d in range(1, L+1)
            ]) # [L, L, A]

            # We use the same observation
            s_obs = s_obs.repeat(L, 1)

            assert s_obs.shape == (L, S)
            assert a_pred.shape == (L, L, A)
            embed = self.model.embed_inputs(s_obs, torch.nan_to_num(a_pred), omit_h0=True) # [L, L, H]
            embed = torch.stack([embed[l, l] for l in range(L)]) # [L, H]
        else:
            embed = self.model.embed_inputs(s_obs, torch.nan_to_num(a_mem), omit_h0=True) # [1, L, H]
            embed = embed.squeeze(0) # (L, H)

        (_, H) = embed.shape
        assert embed.shape == (L, H)

        # (delay, horizon, A)
        M = torch.zeros((L, L, A), dtype=torch.float32, device=embed.device)
        for l in range(L):
            a_gen, _, _, _ = self.pi(embed, deterministic=self.p.predict_deterministically)
            M[:,l] = a_gen
            embed = self.model.embed_onestep(embed, a_gen)

        action_packet = (x.t, M.cpu().numpy())
        return action_packet

    def evaluate(self, trajectory_length=None, env=None, **kwargs):
        """
        Evaluates the PAID-SAC policy.

        Note: Naive only for now!
        """
        if env is None:
            env = self.env
        if trajectory_length is None:
            trajectory_length = self.p.max_trajectory

        # See train() for more details on this
        actpkt_stack = []

        def fn_reset():
            nonlocal actpkt_stack
            x, info = env.extended_reset()
            x = self.clamp(x)

            ap0 = np.stack([x.a_mem] * self.env.horizon)
            actpkt_stack = [(-1, ap0)]

            return ((x, actpkt_stack), info)

        def fn_step(a):
            nonlocal actpkt_stack

            x_next, r, term, trunc, info = env.extended_step(a)
            x_next = self.clamp(x_next)

            actpkt_stack.append(a)

            return ((x_next, actpkt_stack), r, term, trunc, info)

        return super().evaluate(trajectory_length=trajectory_length, env=env,
                                fn_reset=fn_reset, fn_step=fn_step, **kwargs)

    def train(self, n_episodes=None, n_steps=None, follow_policy=True):
        """
        Trains the networks contained within with PAID SAC.
        """
        for ep in self.progress(n_episodes=n_episodes, n_steps=n_steps):
            # Collect a trajectory
            with torch.no_grad():
                self.replay.new_trajectory()
                self.set_mode("eval")

                r_total = 0.0
                (x, _), done = self.env.extended_reset(), False
                x = self.clamp(x)
                assert x.a_mem.shape[0] == self.env.horizon

                # The basic action packet to use to represent a_{-1}, a_{-2}, etc.
                # Just full of the default actions.
                ap0 = (-1, np.stack([x.a_mem] * self.env.horizon)) # (Int, [L, L, A])

                # Memorization stacks
                obs_stack = [] # List[X]
                actpkt_stack = [ap0] # List[(Int, Tensor[L, L, A])]

                taken_steps = 0
                for step in self.progress_steps(limit=self.p.max_trajectory):
                    taken_steps += 1

                    if done:
                        break

                    # Take a step
                    ap = self.sample_action((x, actpkt_stack), random=(not follow_policy))

                    x_next, r, terminated, truncated, _ = self.env.extended_step(ap)
                    x_next = self.clamp(x_next)
                    done = terminated or truncated

                    obs_stack.append((x, r, x_next, terminated))
                    actpkt_stack.append(ap)
                    x = x_next
                    r_total += r

                self.env.close()
                self.set_mode("train")

                def extract_a_mem(x_orig, x_app):
                    # Note: actpkt_stack[0] has t = -1
                    assert actpkt_stack[x_app.t_origin+1][0] == x_app.t_origin

                    if self.p.delay_memorization:
                        a_mem = self.memorized_action_predictor(actpkt_stack[:x_app.t_origin+1], x_app.delay)
                    else:
                        # Naive predictor: Assumes the actions reported by the obs packet to be executed
                        a_mem = x_orig.a_mem[:x_app.delay]

                    (_, M) = actpkt_stack[x_app.t_origin+1]
                    assert M[x_app.delay - 1].shape[0] == self.env.horizon

                    # Add the shifted actions. Note that we cannot generate
                    # actions beyond the horizon, so in the worst case we are
                    # generating the last action in the packet. That is the
                    # reason of the horizon-1 being the upper limit of the
                    # memorized actions being used to predict the next action.
                    a_shift = M[x_app.delay-1][:min(x_app.delayshift, self.env.horizon - 1)]
                    return np.concatenate((a_mem, a_shift))

                # Collect trajectory from evaluation
                for i, (x, r, x_next, terminated) in enumerate(obs_stack):
                    assert i == x.t
                    if x.delay is None:
                        continue # This is an initial action

                    assert x.t_origin < i
                    origin_x = obs_stack[x.t_origin][0]
                    origin_x_next = obs_stack[x_next.t_origin][0]

                    a_mem = extract_a_mem(origin_x, x)
                    a_mem_next = extract_a_mem(origin_x_next, x_next)

                    self.replay.add(
                        # Regular SAC experience
                        x.s_obs, x.a_mem[0], r, x_next.s_obs, terminated,
                        # The information used to regenerate the action that was applied to the environment
                        #  1. The observed state
                        #  2. The assumed sequence of actions between origin_x.s_obs and x.s_obs
                        #  3. The delay that the action was generated for
                        #  4. How much the actions had shifted in the buffer (i.e. which index to select)
                        origin_x.s_obs,      a_mem,      x.delay,      x.delayshift,      a_mem.shape[0],
                        # Same but now for the next state
                        origin_x_next.s_obs, a_mem_next, x_next.delay, x_next.delayshift, a_mem_next.shape[0],
                    )

            # Update for the same number of steps that we took in the environment
            losses_q = np.zeros((taken_steps * self.p.update_iterations,), dtype=np.float32)
            losses_pi = np.zeros((taken_steps * self.p.update_iterations,), dtype=np.float32)
            losses_alpha = np.zeros((taken_steps * self.p.update_iterations,), dtype=np.float32)
            losses_model = np.zeros((taken_steps * self.p.update_iterations,), dtype=np.float32)
            model_error = np.zeros((taken_steps * self.p.update_iterations, self.env.env.observation_space.shape[0]), dtype=np.float32)
            model_std = np.zeros((taken_steps * self.p.update_iterations, self.env.env.observation_space.shape[0]), dtype=np.float32)
            for step in range(taken_steps):
                # Minibatch network updates
                self.steps_since_update -= 1
                if len(self.replay) > self.p.batch_size and (self.steps_since_update <= 0):
                    self.steps_since_update = self.p.update_delay
                    for update_step in range(step * self.p.update_iterations, (step + 1) * self.p.update_iterations):
                        #LOG.debug(f"update_step: {update_step}")
                        (
                            # Regular SAC sample
                            bs, ba, br, bsn, bd,
                            # Info for regenerating ba   ([N, S], [N, L*, A], [N], [N], [N])
                            bs_o,  bs_am,  bs_d,  bs_sh,  bs_am_len,
                            # Same but for next state's action
                            bsn_o, bsn_am, bsn_d, bsn_sh, bsn_am_len,
                        ) = self.replay.sample(self.p.batch_size, device=self.device)

                        # Pre-fetch the alpha for later
                        with torch.no_grad():
                            alpha = torch.exp(self.log_alpha)

                        # Compute latent representations for ŝ and ŝ_next predictions
                        with torch.no_grad():
                            N = self.p.batch_size
                            A = self.replay.a_shape[0]
                            S = self.replay.s_shape[0]

                            # Make sure that they have the same length
                            bs_am  = _pad_dim(1, bs_am,  max(bs_am.shape[1], bsn_am.shape[1]))
                            bsn_am = _pad_dim(1, bsn_am, max(bs_am.shape[1], bsn_am.shape[1]))

                            # We concatenate predictions for this state and the
                            # next, to allow for better efficiency. (bc: batch concatenated)
                            # (torch.cat along batch dimension)
                            bc_o =      torch.cat((bs_o,       bsn_o),      0) # [2N, S]
                            bc_am =     torch.cat((bs_am,      bsn_am),     0) # [2N, L, A]
                            bc_am_len = torch.cat((bs_am_len,  bsn_am_len), 0) # [2N]

                            embed = self.model.embed_inputs(
                                bc_o,
                                torch.nan_to_num(bc_am),
                                omit_h0=True,
                            ) # [2N, L, H]

                            # Extract the prediction used for each sample
                            embed = torch.stack([
                                embed[n, bc_am_len[n].item() - 1] # Doing the -1 since we omit h0
                                for n in range(2 * N)
                            ]) # [2N, H]

                            # Split the embeddings. For pi objective, we want
                            # to differentiate on the horizon.
                            bs_emb, bsn_emb = torch.split(embed, N)

                        # Q update
                        self.q_opt.zero_grad()
                        q1val = self.q1(bs, ba)
                        q2val = self.q2(bs, ba)

                        with torch.no_grad():
                            an, _, logp_an, _ = self.pi(bsn_emb)
                            qval_targ = torch.min(self.q1_targ(bsn, an), self.q2_targ(bsn, an))

                            ys = br + self.p.discount * (1 - bd) * (qval_targ - alpha * logp_an)

                        loss_q1 = ((q1val - ys)**2).mean()
                        loss_q2 = ((q2val - ys)**2).mean()
                        loss_q = loss_q1 + loss_q2
                        loss_q.backward()
                        self.q_opt.step()

                        # pi (policy) update
                        self.pi_opt.zero_grad()

                        a, _, logp_a, _ = self.pi(bs_emb)

                        qval = torch.min(self.q1(bs, a), self.q2(bs, a))

                        loss_pi = -(qval - alpha * logp_a).mean()
                        loss_pi.backward()
                        self.pi_opt.step()

                        # alpha (temperature) update
                        self.alpha_opt.zero_grad()

                        # J(a) = E[-a logp(a|s) - a*H]
                        loss_alpha = -1.0 * (
                            torch.exp(self.log_alpha) * (
                                logp_a.mean() + # (Reusing the logp computed from the policy update)
                                self.p.entropy_threshold
                            ).clone().detach().requires_grad_(False)
                        )
                        loss_alpha.backward()

                        self.alpha_opt.step()

                        # model update (if it can be optimized)
                        if self.model.optimizable:
                            lengths, (tbs, tba, tbsn) = self.replay.sample_trajectory(
                                n=self.p.model_batch_size,
                                k=self.p.model_train_window,
                                keys=["state", "action", "next_state"],
                                device=self.device,
                            )

                            self.model_opt.zero_grad()

                            loss_model, _ = self.model.loss(
                                state=tbs[:,0],
                                actions=tba,
                                next_states=tbsn,
                                lengths=lengths,
                                distance=self.obs_metric,
                            )
                            loss_model.backward()
                            self.model_opt.step()
                        else:
                            loss_model = torch.tensor(0.0, requires_grad=False)

                        losses_q[update_step] = loss_q.cpu().detach().numpy()
                        losses_pi[update_step] = loss_pi.cpu().detach().numpy()
                        losses_alpha[update_step] = loss_alpha.cpu().detach().numpy()
                        losses_model[update_step] = loss_model.cpu().detach().numpy()

                        # Polyak (soft target) updates
                        self.updates_since_polyak -= 1
                        if self.updates_since_polyak <= 0:
                            self.updates_since_polyak = self.p.polyak_delay
                            with torch.no_grad():
                                for train, target in [(self.pi, self.pi_targ), (self.q1, self.q1_targ), (self.q2, self.q2_targ)]:
                                    for p_train, p_targ in zip(train.parameters(), target.parameters()):
                                        p_targ.data.mul_(1 - self.p.polyak)
                                        p_targ.data.add_(self.p.polyak * p_train.data)

            self.record_iteration_value("critic/loss", losses_q.mean())
            self.record_iteration_value("actor/loss", losses_pi.mean())
            self.record_iteration_value("temperature/loss", losses_alpha.mean())
            self.record_iteration_value("temperature/value", torch.exp(self.log_alpha).item())
            self.record_iteration_value("model/loss", losses_model.mean())
            for j in range(self.env.env.observation_space.shape[0]):
                self.record_iteration_value(f"model/err{j}", model_error[:,j].mean())
                self.record_iteration_value(f"model/std{j}", model_std[:,j].mean())

            self.record_episode_total_reward(r_total)

        LOG.debug("train() done")
