# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import json
import os
import io
import sys
import time
from logging import getLogger
from collections import OrderedDict
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.nn.utils import clip_grad_norm_
from .optim import get_optimizer
from .utils import to_cuda, split_data
from collections import defaultdict
import torch.nn.functional as F
import seaborn as sns
import matplotlib.pyplot as plt
import copy
import contextlib
import wandb
from .DataGenKinetics import batch_generate, batch_generate_1, batch_generate_2, batch_generate_3, batch_generate_nv1,batch_generate_nv2,batch_generate_nv3, batch_generate_v3
import pickle

# if torch.cuda.is_available():
has_apex = True
try:
    import apex
except:
    has_apex - False

logger = getLogger()


class LoadParameters(object):
    def __init__(self, modules, params):
        self.modules = modules
        self.params = params
        self.set_parameters()

    def set_parameters(self):
        """
        Set parameters.
        """
        self.parameters = {}
        named_params = []
        for v in self.modules.values():
            named_params.extend(
                [(k, p) for k, p in v.named_parameters() if p.requires_grad]
            )
        self.parameters["model"] = [p for k, p in named_params]
        for k, v in self.parameters.items():
            logger.info("Found %i parameters in %s." % (len(v), k))
            assert len(v) >= 1

    def reload_checkpoint(self, path=None, root=None, requires_grad=True):
        """
        Reload a checkpoint if we find one.
        """
        if path is None:
            path = "checkpoint.pth"
        if root is None:
            root = self.params.dump_path
        checkpoint_path = os.path.join(root, path)

        if not os.path.isfile(checkpoint_path):
            if self.params.reload_checkpoint == "":
                return
            else:
                checkpoint_path = self.params.reload_checkpoint + "/checkpoint.pth"
                assert os.path.isfile(checkpoint_path)

        logger.warning(f"Reloading checkpoint from {checkpoint_path} ...")
        data = torch.load(checkpoint_path, map_location="cpu")

        # reload model parameters
        for k, v in self.modules.items():
            try:
                weights = data[k]
                v.load_state_dict(weights)
            except RuntimeError:
                weights = {name.partition(".")[2]: v for name, v in data[k].items()}
                v.load_state_dict(weights)
            v.requires_grad = requires_grad


class Trainer(object):
    def __init__(self, modules, env, params, path=None, root=None):
        """
        Initialize trainer.
        """
        # modules / params
        self.modules = modules
        self.params = params
        self.env = env
        self.logger = logger

        # epoch / iteration size
        self.n_steps_per_epoch = params.n_steps_per_epoch
        self.inner_epoch = self.total_samples = self.n_equations = 0
        self.infos_statistics = defaultdict(list)

        # data iterators
        self.iterators = {}

        # set parameters
        self.set_parameters()
        
        # float16 / distributed (no AMP)
        assert params.amp >= 1 or not params.fp16
        assert params.amp >= 0 or params.accumulate_gradients == 1
        # assert not params.multi_gpu or params.amp == -1 or params.nvidia_apex
        assert not params.nvidia_apex or has_apex
        if params.multi_gpu:  # and params.amp == -1:
            logger.info("Using nn.parallel.DistributedDataParallel ...")
            for k in self.modules.keys():
                self.modules[k] = nn.parallel.DistributedDataParallel(
                    self.modules[k],
                    device_ids=[params.local_rank],
                    output_device=params.local_rank,
                    broadcast_buffers=True,
                )

        # set optimizer
        self.set_optimizer()

        # float16 / distributed (AMP)
        self.scaler = None
        if params.amp >= 0:
            self.init_amp()

        # stopping criterion used for early stopping
        if params.stopping_criterion != "":
            split = params.stopping_criterion.split(",")
            assert len(split) == 2 and split[1].isdigit()
            self.decrease_counts_max = int(split[1])
            self.decrease_counts = 0
            if split[0][0] == "_":
                self.stopping_criterion = (split[0][1:], False)
            else:
                self.stopping_criterion = (split[0], True)
            self.best_stopping_criterion = -1e12 if self.stopping_criterion[1] else 1e12
        else:
            self.stopping_criterion = None
            self.best_stopping_criterion = None
        
        # validation metrics
        self.metrics = []
        metrics = [m for m in params.validation_metrics.split(",") if m != ""]
        for m in metrics:
            m = (m, True) if 'r2' in m else (m, False)
            self.metrics.append(m)
        self.best_metrics = {
            metric: (-np.infty if biggest else np.infty)
            for (metric, biggest) in self.metrics
        }

        # training statistics
        self.epoch = 0
        self.n_iter = 0
        self.n_total_iter = 0
        self.stats = OrderedDict(
            [("processed_e", 0)]
            + [("processed_w", 0)]
            + sum(
                [[(x, []), (f"{x}-AVG-STOP-PROBS", [])] for x in env.TRAINING_TASKS], []
            )
        )
        self.last_time = time.time()

        # reload potential checkpoints
        self.reload_checkpoint(path=path, root=root)

        # file handler to export data
        if params.export_data:
            assert params.reload_data == ""
            params.export_path_prefix = os.path.join(params.dump_path, "data.prefix")
            self.file_handler_prefix = io.open(
                params.export_path_prefix, mode="a", encoding="utf-8"
            )
            logger.info(
                f"Data will be stored in prefix in: {params.export_path_prefix} ..."
            )

        # reload exported data
        if params.reload_data != "":
            logger.info(f"Reloading data from {params.reload_data}")
            assert params.num_workers in [0, 1] 
            assert params.export_data is False
            s = [x.split(",") for x in params.reload_data.split(";") if len(x) > 0]
            assert (
                len(s)
                >= 1
            )
            train_path = os.path.join(params.reload_data,'data.prefix')
            test_path = os.path.join(params.reload_data,'data.prefix.test')
            # check number of lines in test_path
            if os.path.isfile(test_path):
                with open(test_path) as f:
                    for i, l in enumerate(f):
                        pass
                n_eqs = i + 1
            else:
                n_eqs = 0
            if not os.path.isfile(test_path) or n_eqs < params.eval_size:
                split_data(train_path, params.eval_size)
            self.data_path = {"functions": (train_path, test_path)}

            logger.info(self.data_path)

            for task in self.env.TRAINING_TASKS:
                assert (task in self.data_path) == (task in params.tasks)
        else:
            self.data_path = None

        # create data loaders
        if not params.eval_only:
            if params.env_base_seed < 0:
                params.env_base_seed = np.random.randint(1_000_000_000)
            self.dataloader = {
                task: iter(self.env.create_train_iterator(task, self.data_path, params))
                for task in params.tasks
            }

    def set_new_train_iterator_params(self, args={}):
        params = self.params
        if params.env_base_seed < 0:
            params.env_base_seed = np.random.randint(1_000_000_000)
        self.dataloader = {
            task: iter(
                self.env.create_train_iterator(task, self.data_path, params, args)
            )
            for task in params.tasks
        }
        logger.info(
            "Succesfully replaced training iterator with following args:{}".format(args)
        )
        return

    def set_parameters(self):
        """
        Set parameters.
        """
        self.parameters = {}
        named_params = []
        for v in self.modules.values():
            named_params.extend(
                [(k, p) for k, p in v.named_parameters() if p.requires_grad]
            )
        self.parameters["model"] = [p for k, p in named_params]
        for k, v in self.parameters.items():
            logger.info("Found %i parameters in %s." % (len(v), k))
            assert len(v) >= 1

    def set_optimizer(self):
        """
        Set optimizer.
        """
        params = self.params
        self.optimizer = get_optimizer(
            self.parameters["model"], params.lr, params.optimizer
        )
        logger.info("Optimizer: %s" % type(self.optimizer))

    def init_amp(self):
        """
        Initialize AMP optimizer.
        """
        params = self.params
        assert (
            params.amp == 0
            and params.fp16 is False
            or params.amp in [1, 2, 3]
            and params.fp16 is True
        )
        mod_names = sorted(self.modules.keys())
        if params.nvidia_apex is True:
            modules, optimizer = apex.amp.initialize(
                [self.modules[k] for k in mod_names],
                self.optimizer,
                opt_level=("O%i" % params.amp),
            )
            self.modules = {k: module for k, module in zip(mod_names, modules)}
            self.optimizer = optimizer
        else:
            self.scaler = torch.cuda.amp.GradScaler()

    def optimize(self, loss):
        """
        Optimize.
        """
        # check NaN
        if (loss != loss).data.any():
            logger.warning("NaN detected")
            # exit()

        params = self.params

        # optimizer
        optimizer = self.optimizer

        # regular optimization
        if params.amp == -1:
            optimizer.zero_grad()
            loss.backward()
            if params.clip_grad_norm > 0:
                clip_grad_norm_(self.parameters["model"], params.clip_grad_norm)
            optimizer.step()

        # AMP optimization
        elif params.nvidia_apex is True:
            if (self.n_iter + 1) % params.accumulate_gradients == 0:
                with apex.amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                if params.clip_grad_norm > 0:
                    clip_grad_norm_(
                        apex.amp.master_params(self.optimizer), params.clip_grad_norm
                    )
                optimizer.step()
                optimizer.zero_grad()
            else:
                with apex.amp.scale_loss(
                    loss, optimizer, delay_unscale=True
                ) as scaled_loss:
                    scaled_loss.backward()

        else:
            if params.accumulate_gradients > 1:
                loss = loss / params.accumulate_gradients
            self.scaler.scale(loss).backward()

            if (self.n_iter + 1) % params.accumulate_gradients == 0:
                if params.clip_grad_norm > 0:
                    self.scaler.unscale_(optimizer)
                    clip_grad_norm_(self.parameters["model"], params.clip_grad_norm)
                self.scaler.step(optimizer)
                self.scaler.update()
                optimizer.zero_grad()

    def iter(self):
        """
        End of iteration.
        """
        self.n_iter += 1
        self.n_total_iter += 1
        self.print_stats()

    def print_stats(self):
        """
        Print statistics about the training.
        """
        if self.params.use_wandb and self.params.is_master:
            if self.stats["functions"]:
                wandb.log({'loss':self.stats["functions"][-1],
                   'lr': self.optimizer.param_groups[0]["lr"],
                   })
        if self.n_total_iter % self.params.print_freq != 0:
            return

        s_total_eq = "- Total Eq: " + "{:.2e}".format(self.n_equations)
        s_iter = "%7i - " % self.n_total_iter
        s_stat = " || ".join(
            [
                "{}: {:7.4f}".format(k.upper().replace("_", "-"), np.mean(v))
                for k, v in self.stats.items()
                if type(v) is list and len(v) > 0
            ]
        )
        for k in self.stats.keys():
            if type(self.stats[k]) is list:
                del self.stats[k][:]

        # learning rates
        s_lr = (" - LR: ") + " / ".join(
            "{:.4e}".format(group["lr"]) for group in self.optimizer.param_groups
        )

        # processing speed
        new_time = time.time()
        diff = new_time - self.last_time
        s_speed = "{:7.2f} equations/s - {:8.2f} words/s - ".format(
            self.stats["processed_e"] * 1.0 / diff,
            self.stats["processed_w"] * 1.0 / diff,
        )
        max_mem = torch.cuda.max_memory_allocated() / 1024 ** 2
        s_mem = " MEM: {:.2f} MB - ".format(max_mem)
        self.stats["processed_e"] = 0
        self.stats["processed_w"] = 0
        self.last_time = new_time
        # log speed + stats + learning rate
        logger.info(s_iter + s_speed + s_mem + s_stat + s_lr + s_total_eq)

    def get_generation_statistics(self, task):

        total_eqs = sum(
            x.shape[0]
            for x in self.infos_statistics[list(self.infos_statistics.keys())[0]]
        )
        logger.info("Generation statistics (to generate {} eqs):".format(total_eqs))

        all_infos = defaultdict(list)
        for info_type, infos in self.infos_statistics.items():
            all_infos[info_type] = torch.cat(infos).tolist()
            infos = [torch.bincount(info) for info in infos]
            max_val = max([info.shape[0] for info in infos])
            aggregated_infos = torch.cat(
                [
                    F.pad(info, (0, max_val - info.shape[0])).unsqueeze(-1)
                    for info in infos
                ],
                -1,
            ).sum(-1)
            non_zeros = aggregated_infos.nonzero(as_tuple=True)[0]
            vals = [
                (
                    non_zero.item(),
                    "{:.2e}".format(
                        (aggregated_infos[non_zero] / aggregated_infos.sum()).item()
                    ),
                )
                for non_zero in non_zeros
            ]
            logger.info("{}: {}".format(info_type, vals))
        all_infos = pd.DataFrame(all_infos)
        g = sns.PairGrid(all_infos)
        g.map_upper(sns.scatterplot)
        g.map_lower(sns.kdeplot, fill=True)
        g.map_diag(sns.histplot, kde=True)
        plt.savefig(
            os.path.join(self.params.dump_path, "statistics_{}.png".format(self.epoch))
        )

        self.infos_statistics = defaultdict(list)

    def save_checkpoint(self, name, include_optimizer=True, include_stats=False):
        """
        Save the model / checkpoints.
        """
        if not self.params.is_master:
            return

        path = os.path.join(self.params.dump_path, "%s.pt" % name)
        logger.info("Saving %s to %s ..." % (name, path))

        data = {
            "epoch": self.epoch,
            "n_total_iter": self.n_total_iter,
            "best_metrics": self.best_metrics,
            "best_stopping_criterion": self.best_stopping_criterion,
            "params": {k: v for k, v in self.params.__dict__.items()},
        }

        for k, v in self.modules.items():
            logger.warning(f"Saving {k} parameters ...")
            data[k] = v.state_dict()

        if include_optimizer:
            logger.warning("Saving optimizer ...")
            data["optimizer"] = self.optimizer.state_dict()
            if self.scaler is not None:
                data["scaler"] = self.scaler.state_dict()

        if include_stats:
            data["stats"] = self.stats
        
        if not os.path.exists(self.params.dump_path):
            os.makedirs(self.params.dump_path)
            
        torch.save(data, path)

    def reload_checkpoint(self, path=None, root=None, requires_grad=True, checkpoint_path=None):
        """
        Reload a checkpoint if we find one.
        """
        if checkpoint_path is None:
            if path is None:
                path = "checkpoint.pt"

            if self.params.reload_checkpoint != "":
                checkpoint_path = os.path.join(self.params.reload_checkpoint, path)
                assert os.path.isfile(checkpoint_path)
            else:
                if root is not None:
                    checkpoint_path = os.path.join(root, path)
                else:
                    checkpoint_path = os.path.join(self.params.dump_path, path)
                if not os.path.isfile(checkpoint_path):
                    logger.warning(
                        "Checkpoint path does not exist, {}".format(checkpoint_path)
                    )
                    return
        elif not os.path.isfile(checkpoint_path):
            logger.warning(
                "Checkpoint path does not exist, {}".format(checkpoint_path)
            )
            return

        logger.warning(f"Reloading checkpoint from {checkpoint_path} ...")
        data = torch.load(checkpoint_path, map_location="cpu")

        # reload model parameters
        for k, v in self.modules.items():
            weights = data[k]
            try:
                weights = data[k]
                v.load_state_dict(weights)
            except RuntimeError:  # remove the 'module.'
                weights = {name.partition(".")[2]: v for name, v in data[k].items()}
                v.load_state_dict(weights)
            v.requires_grad = requires_grad

        if self.params.amp == -1 or not self.params.nvidia_apex:
            logger.warning("Reloading checkpoint optimizer ...")
            self.optimizer.load_state_dict(data["optimizer"])
        else:
            logger.warning("Not reloading checkpoint optimizer.")
            for group_id, param_group in enumerate(self.optimizer.param_groups):
                if "num_updates" not in param_group:
                    logger.warning("No 'num_updates' for optimizer.")
                    continue
                logger.warning("Reloading 'num_updates' and 'lr' for optimizer.")
                param_group["num_updates"] = data["optimizer"]["param_groups"][
                    group_id
                ]["num_updates"]
                param_group["lr"] = self.optimizer.get_lr_for_step(
                    param_group["num_updates"]
                )

        if self.params.fp16 and not self.params.nvidia_apex:
            logger.warning("Reloading gradient scaler ...")
            self.scaler.load_state_dict(data["scaler"])
        else:
            assert self.scaler is None and "scaler" not in data

        # reload main metrics
        self.epoch = data["epoch"] + 1
        self.n_total_iter = data["n_total_iter"]
        self.best_metrics = data["best_metrics"]
        self.best_stopping_criterion = data["best_stopping_criterion"]
        logger.warning(
            f"Checkpoint reloaded. Resuming at epoch {self.epoch} / iteration {self.n_total_iter} ..."
        )
        if "stats" in data.keys():
            self.stats = data["stats"]
            return data["stats"]
        else:
            return None

    def save_periodic(self):
        """
        Save the models periodically.
        """
        if not self.params.is_master:
            return
        if (
            self.params.save_periodic > 0
            and self.epoch % self.params.save_periodic == 0
        ):
            self.save_checkpoint("periodic-%i" % self.epoch)

    def save_best_model(self, scores, metric="r2_zero", prefix=None, suffix=None):
        """
        Save best models according to given validation metrics.
        """
        if not self.params.is_master:
            return
        
        if "r2" in metric: 
            factor = 1
        else:
            factor = -1

        best_so_far = factor * self.best_metrics[metric]

        if factor * scores[metric] > best_so_far:
            self.best_metrics[metric] = scores[metric]
            logger.info("New best score for %s: %.6f" % (metric, scores[metric]))
            self.save_checkpoint("best" % metric)

    def end_epoch(self, scores):
        """
        End the epoch.
        """
        # stop if the stopping criterion has not improved after a certain number of epochs
        if self.stopping_criterion is not None and (
            self.params.is_master or not self.stopping_criterion[0].endswith("_mt_bleu")
        ):
            metric, biggest = self.stopping_criterion
            assert metric in scores, metric
            factor = 1 if biggest else -1
            if factor * scores[metric] > factor * self.best_stopping_criterion:
                self.best_stopping_criterion = scores[metric]
                logger.info(
                    "New best validation score: %f" % self.best_stopping_criterion
                )
                self.decrease_counts = 0
            else:
                logger.info(
                    "Not a better validation score (%i / %i)."
                    % (self.decrease_counts, self.decrease_counts_max)
                )
                self.decrease_counts += 1
            if self.decrease_counts > self.decrease_counts_max:
                logger.info(
                    "Stopping criterion has been below its best value for more "
                    "than %i epochs. Ending the experiment..."
                    % self.decrease_counts_max
                )
                if self.params.multi_gpu and "SLURM_JOB_ID" in os.environ:
                    os.system("scancel " + os.environ["SLURM_JOB_ID"])
                exit()
        self.save_checkpoint("checkpoint")
        self.epoch += 1

    def get_batch(self, task):
        """
        Return a training batch for a specific task.
        """
        try:
            batch = next(self.dataloader[task])
        except Exception as e:
            print(e)
            logger.error(
                "An unknown exception of type {0} occurred in line {1} when fetching batch. "
                "Arguments:{2!r}. Restarting ...".format(
                    type(e).__name__, sys.exc_info()[-1].tb_lineno, e.args
                )
            )
            if self.params.is_slurm_job:
                if int(os.environ["SLURM_PROCID"]) == 0:
                    logger.warning("Requeuing job " + os.environ["SLURM_JOB_ID"])
                    os.system("scontrol requeue " + os.environ["SLURM_JOB_ID"])
                else:
                    logger.warning("Not the master process, no need to requeue.")
            raise
        return batch

    def export_data(self, task):
        """
        Export data to the disk.
        """
        samples = batch_generate(self.params.batch_size)

        for info in samples["infos"]:
            samples["infos"][info] = list(map(str, samples["infos"][info].tolist()))

        def get_dictionary_slice(idx, dico):
            x = {}
            for d in dico:
                x[d] = dico[d][idx]
            return x

        def float_list_to_str_lst(lst, float_precision):
            for i in range(len(lst)):
                if isinstance(lst[i], list):
                    lst[i] = float_list_to_str_lst(lst[i], float_precision)
                else:
                    str_float = f"%.{float_precision}e" % lst[i]
                    lst[i] = str_float
            return lst

        processed_e = len(samples)
        for i in range(processed_e):
            # prefix
            outputs = {**get_dictionary_slice(i, samples["infos"])}
            times = samples["times"][i].tolist()
            trajectory = samples["trajectory"][i].tolist()
            outputs["times"] = float_list_to_str_lst(
                times, self.params.float_precision
            )
            outputs["trajectory"] = float_list_to_str_lst(
                trajectory, self.params.float_precision
            )
            outputs["tree"] = samples["tree"][i].prefix()

            self.file_handler_prefix.write(json.dumps(outputs) + "\n")
            self.file_handler_prefix.flush()

        # number of processed sequences / words
        self.n_equations += self.params.batch_size
        self.total_samples += self.params.batch_size
        self.stats["processed_e"] += len(samples)

    def enc_dec_step(self, task):
        """
        Encoding / decoding step.
        """
        params = self.params
        embedder, encoder, decoder = (
            self.modules["embedder"],
            self.modules["encoder"],
            self.modules["decoder"],
        )
        embedder.train()
        encoder.train()
        decoder.train()
        env = self.env

        #params.batch_size=4

        samples = batch_generate_v3(params.batch_size)

        if self.params.debug_train_statistics:
            for info_type, info in samples["infos"].items():
                self.infos_statistics[info_type].append(info)

        times = samples["times"]
        trajectory = samples["trajectory"]
        infos = samples["infos"]
        condition = samples['condition']
        full_equtaion = samples['full_equation']


        if params.max_masked_variables:  # randomly mask some variables
            for seq_id in range(len(times)):
                n_masked_variables = min(np.random.randint(0, self.params.max_masked_variables + 1), infos["dimension"][seq_id]-1)
                trajectory[seq_id][:, -n_masked_variables:] = np.nan

        x1 = []
        for seq_id in range(len(times)):
            x1.append([])
            for seq_l in range(len(times[seq_id])):
                x1[seq_id].append([times[seq_id][seq_l], trajectory[seq_id][seq_l]])
                
        if self.params.masked_input: #randomly mask some inputs
            input_tokens, len1 = embedder.forward(x1, return_before_embed=True)
            # randomly mask a fraction x of the input tokens along seq dimension
            mask = np.random.rand(*input_tokens.shape[:2]) < self.params.masked_input
            input_tokens[mask][self.params.float_descriptor_length:] = encoder.word2id['<MASK>']
            x1 = embedder.compress(embedder.embed(input_tokens))
            predict_input_tokens = input_tokens[:,:,self.params.float_descriptor_length:]
        else:
            x1, len1 = embedder(x1)

        if self.params.use_skeleton:
            x2, len2 = self.env.batch_equations(
                self.env.word_to_idx(
                    samples["skeleton_tree_encoded"], float_input=False
                )
            )
        else:
            full_equtaion2, full_equtaion_len= self.env.batch_equations(
                self.env.word_to_idx(full_equtaion, float_input=False)
            )
            condition2, condition_len = self.env.batch_equations(
                self.env.word_to_idx(condition, float_input=False)
            )
            x2, len2 = self.env.batch_equations(
                self.env.word_to_idx(samples["tree_encoded"], float_input=False)
            )

        pred_alen = torch.arange(full_equtaion_len.max(), dtype=torch.long, device=full_equtaion_len.device)
        pred_mask_v1 = (
            pred_alen[:, None] < full_equtaion_len[None] - 1
        ) 
        for idx, every_condition_len in enumerate(condition_len):
            pred_mask_v1[:every_condition_len-1,idx] = False

        # target words to predict
        alen = torch.arange(len2.max(), dtype=torch.long, device=len2.device)
        pred_mask = (
            alen[:, None] < len2[None] - 1
        )  # do not predict anything given the last target word

        y = x2[1:].T.masked_select(pred_mask[:-1].T)
        assert len(y) == (len2 - 1).sum().item()
        assert len(y) == pred_mask_v1.sum().item()
        
        if params.use_two_hot:
            assert self.env.equation_encoder.constant_encoder is not None
            y = self.env.ids_to_two_hot(
                ids=y.reshape(-1, 1), 
                support_size=len(self.env.equation_words) + len(self.env.constant_words)
            )
        
        # cuda
        ## Revised
        if not self.params.cpu:
            x2, len2, y ,full_equtaion2, full_equtaion_len = to_cuda(x2, len2, y,full_equtaion2, full_equtaion_len)
            #full_equtaion2 = to_cuda(full_equtaion2)
        # forward / loss

        with autocast_wrapper(params):
            encoded = encoder("fwd", x=x1, lengths=len1, causal=False)
            decoded = decoder(
                "fwd",
                x=full_equtaion2,
                lengths=full_equtaion_len,
                causal=True,
                src_enc=encoded.transpose(0, 1),
                src_len=len1,
            )
            '''decoded = decoder(
                "fwd",
                x=x2,
                lengths=len2,
                causal=True,
                src_enc=encoded.transpose(0, 1),
                src_len=len1,
            )'''

            _scores, loss = decoder(
                "predict", tensor=decoded, pred_mask=pred_mask_v1, y=y, get_scores=False
            )
            # print(decoded.size(), pred_mask.size(), y.size())
            '''_scores, loss = decoder(
                "predict", tensor=decoded, pred_mask=pred_mask, y=y, get_scores=False
            )'''
            # print(loss)
            

            if self.params.masked_output:
                # randomly mask a fraction x of the input tokens along seq dimension
                output_mask = torch.rand(x2.shape[:2]) < self.params.masked_output
                output_mask *= pred_mask
                predict_output_tokens = x2[output_mask]
                x2_masked = x2.clone()
                x2_masked[output_mask] = encoder.word2id['<MASK>']

                decoded = decoder(
                    "fwd",
                    x=x2_masked,
                    lengths=len2,
                    causal=False,
                    src_enc=encoded.transpose(0, 1),
                    src_len=len1,
                )
                _scores, loss_mlm = decoder(
                    "predict", tensor=decoded, pred_mask=output_mask, y=predict_output_tokens, get_scores=False
                )
                loss = loss + loss_mlm

            if self.params.masked_input:
                targets = predict_input_tokens # len, bs, float_descriptor_length * max_dim
                targets = targets[mask]
                encoded = encoded[mask]
                targets = targets.flatten()
                scores = encoder.proj(encoded) # len, bs, float_descriptor_length * max_dim * n_words
                scores = scores.view(-1, len(encoder.word2id))
                loss_numeric = F.cross_entropy(scores, targets)
                loss = loss + loss_numeric

        self.stats[task].append(loss.item())

        # optimize
        self.optimize(loss)

        # number of processed sequences / words
        self.inner_epoch += 1
        self.n_equations += len1.size(0)
        self.stats["processed_e"] += len1.size(0)
        self.stats["processed_w"] += (len1 + len2 - 2).sum().item()

        return loss.item()

    def enc_dec_step_2(self, task):
        """
        Encoding / decoding step.
        """
        params = self.params
        embedder, encoder, decoder = (
            self.modules["embedder"],
            self.modules["encoder"],
            self.modules["decoder"],
        )
        embedder.train()
        encoder.train()
        decoder.train()
        env = self.env

        samples = batch_generate_2(params.batch_size)


        if self.params.debug_train_statistics:
            for info_type, info in samples["infos"].items():
                self.infos_statistics[info_type].append(info)

        times = samples["times"]
        trajectory = samples["trajectory"]
        infos = samples["infos"]

        if params.max_masked_variables:  # randomly mask some variables
            for seq_id in range(len(times)):
                n_masked_variables = min(np.random.randint(0, self.params.max_masked_variables + 1), infos["dimension"][seq_id]-1)
                trajectory[seq_id][:, -n_masked_variables:] = np.nan

        x1 = []
        for seq_id in range(len(times)):
            x1.append([])
            for seq_l in range(len(times[seq_id])):
                x1[seq_id].append([times[seq_id][seq_l], trajectory[seq_id][seq_l]])
                
        if self.params.masked_input: #randomly mask some inputs
            input_tokens, len1 = embedder.forward(x1, return_before_embed=True)
            # randomly mask a fraction x of the input tokens along seq dimension
            mask = np.random.rand(*input_tokens.shape[:2]) < self.params.masked_input
            input_tokens[mask][self.params.float_descriptor_length:] = encoder.word2id['<MASK>']
            x1 = embedder.compress(embedder.embed(input_tokens))
            predict_input_tokens = input_tokens[:,:,self.params.float_descriptor_length:]
        else:
            x1, len1 = embedder(x1)

        if self.params.use_skeleton:
            x2, len2 = self.env.batch_equations(
                self.env.word_to_idx(
                    samples["skeleton_tree_encoded"], float_input=False
                )
            )
        else:
            x2, len2 = self.env.batch_equations(
                self.env.word_to_idx(samples["tree_encoded"], float_input=False)
            )

        # target words to predict
        alen = torch.arange(len2.max(), dtype=torch.long, device=len2.device)
        pred_mask = (
            alen[:, None] < len2[None] - 1
        )  # do not predict anything given the last target word

        y = x2[1:].masked_select(pred_mask[:-1])
        assert len(y) == (len2 - 1).sum().item()
        
        if params.use_two_hot:
            assert self.env.equation_encoder.constant_encoder is not None
            y = self.env.ids_to_two_hot(
                ids=y.reshape(-1, 1), 
                support_size=len(self.env.equation_words) + len(self.env.constant_words)
            )
        
        # cuda
        ## Revised
        if not self.params.cpu:
            x2, len2, y = to_cuda(x2, len2, y)
        # forward / loss

        with autocast_wrapper(params):
            encoded = encoder("fwd", x=x1, lengths=len1, causal=False)
            decoded = decoder(
                "fwd",
                x=x2,
                lengths=len2,
                causal=True,
                src_enc=encoded.transpose(0, 1),
                src_len=len1,
            )
            # print(decoded.size(), pred_mask.size(), y.size())
            _scores, loss = decoder(
                "predict", tensor=decoded, pred_mask=pred_mask, y=y, get_scores=False
            )
            # print(loss)
            

            if self.params.masked_output:
                # randomly mask a fraction x of the input tokens along seq dimension
                output_mask = torch.rand(x2.shape[:2]) < self.params.masked_output
                output_mask *= pred_mask
                predict_output_tokens = x2[output_mask]
                x2_masked = x2.clone()
                x2_masked[output_mask] = encoder.word2id['<MASK>']

                decoded = decoder(
                    "fwd",
                    x=x2_masked,
                    lengths=len2,
                    causal=False,
                    src_enc=encoded.transpose(0, 1),
                    src_len=len1,
                )
                _scores, loss_mlm = decoder(
                    "predict", tensor=decoded, pred_mask=output_mask, y=predict_output_tokens, get_scores=False
                )
                loss = loss + loss_mlm

            if self.params.masked_input:
                targets = predict_input_tokens # len, bs, float_descriptor_length * max_dim
                targets = targets[mask]
                encoded = encoded[mask]
                targets = targets.flatten()
                scores = encoder.proj(encoded) # len, bs, float_descriptor_length * max_dim * n_words
                scores = scores.view(-1, len(encoder.word2id))
                loss_numeric = F.cross_entropy(scores, targets)
                loss = loss + loss_numeric

        self.stats[task].append(loss.item())

        # optimize
        self.optimize(loss)

        # number of processed sequences / words
        self.inner_epoch += 1
        self.n_equations += len1.size(0)
        self.stats["processed_e"] += len1.size(0)
        self.stats["processed_w"] += (len1 + len2 - 2).sum().item()

        return loss.item()

    def enc_dec_step_3(self, task):
        """
        Encoding / decoding step.
        """
        params = self.params
        embedder, encoder, decoder = (
            self.modules["embedder"],
            self.modules["encoder"],
            self.modules["decoder"],
        )
        embedder.train()
        encoder.train()
        decoder.train()
        env = self.env

        #params.batch_size=4

        samples = batch_generate_3(params.batch_size)

        if self.params.debug_train_statistics:
            for info_type, info in samples["infos"].items():
                self.infos_statistics[info_type].append(info)

        times = samples["times"]
        trajectory = samples["trajectory"]
        infos = samples["infos"]

        if params.max_masked_variables:  # randomly mask some variables
            for seq_id in range(len(times)):
                n_masked_variables = min(np.random.randint(0, self.params.max_masked_variables + 1), infos["dimension"][seq_id]-1)
                trajectory[seq_id][:, -n_masked_variables:] = np.nan

        x1 = []
        for seq_id in range(len(times)):
            x1.append([])
            for seq_l in range(len(times[seq_id])):
                x1[seq_id].append([times[seq_id][seq_l], trajectory[seq_id][seq_l]])
                
        if self.params.masked_input: #randomly mask some inputs
            input_tokens, len1 = embedder.forward(x1, return_before_embed=True)
            # randomly mask a fraction x of the input tokens along seq dimension
            mask = np.random.rand(*input_tokens.shape[:2]) < self.params.masked_input
            input_tokens[mask][self.params.float_descriptor_length:] = encoder.word2id['<MASK>']
            x1 = embedder.compress(embedder.embed(input_tokens))
            predict_input_tokens = input_tokens[:,:,self.params.float_descriptor_length:]
        else:
            x1, len1 = embedder(x1)

        if self.params.use_skeleton:
            x2, len2 = self.env.batch_equations(
                self.env.word_to_idx(
                    samples["skeleton_tree_encoded"], float_input=False
                )
            )
        else:
            x2, len2 = self.env.batch_equations(
                self.env.word_to_idx(samples["tree_encoded"], float_input=False)
            )

        # target words to predict
        alen = torch.arange(len2.max(), dtype=torch.long, device=len2.device)
        pred_mask = (
            alen[:, None] < len2[None] - 1
        )  # do not predict anything given the last target word

        y = x2[1:].masked_select(pred_mask[:-1])
        assert len(y) == (len2 - 1).sum().item()
        
        if params.use_two_hot:
            assert self.env.equation_encoder.constant_encoder is not None
            y = self.env.ids_to_two_hot(
                ids=y.reshape(-1, 1), 
                support_size=len(self.env.equation_words) + len(self.env.constant_words)
            )
        
        # cuda
        ## Revised
        if not self.params.cpu:
            x2, len2, y = to_cuda(x2, len2, y)
        # forward / loss

        with autocast_wrapper(params):
            encoded = encoder("fwd", x=x1, lengths=len1, causal=False)
            decoded = decoder(
                "fwd",
                x=x2,
                lengths=len2,
                causal=True,
                src_enc=encoded.transpose(0, 1),
                src_len=len1,
            )
            # print(decoded.size(), pred_mask.size(), y.size())
            _scores, loss = decoder(
                "predict", tensor=decoded, pred_mask=pred_mask, y=y, get_scores=False
            )
            # print(loss)
            

            if self.params.masked_output:
                # randomly mask a fraction x of the input tokens along seq dimension
                output_mask = torch.rand(x2.shape[:2]) < self.params.masked_output
                output_mask *= pred_mask
                predict_output_tokens = x2[output_mask]
                x2_masked = x2.clone()
                x2_masked[output_mask] = encoder.word2id['<MASK>']

                decoded = decoder(
                    "fwd",
                    x=x2_masked,
                    lengths=len2,
                    causal=False,
                    src_enc=encoded.transpose(0, 1),
                    src_len=len1,
                )
                _scores, loss_mlm = decoder(
                    "predict", tensor=decoded, pred_mask=output_mask, y=predict_output_tokens, get_scores=False
                )
                loss = loss + loss_mlm

            if self.params.masked_input:
                targets = predict_input_tokens # len, bs, float_descriptor_length * max_dim
                targets = targets[mask]
                encoded = encoded[mask]
                targets = targets.flatten()
                scores = encoder.proj(encoded) # len, bs, float_descriptor_length * max_dim * n_words
                scores = scores.view(-1, len(encoder.word2id))
                loss_numeric = F.cross_entropy(scores, targets)
                loss = loss + loss_numeric

        self.stats[task].append(loss.item())

        # optimize
        self.optimize(loss)

        # number of processed sequences / words
        self.inner_epoch += 1
        self.n_equations += len1.size(0)
        self.stats["processed_e"] += len1.size(0)
        self.stats["processed_w"] += (len1 + len2 - 2).sum().item()

        return loss.item()

    def enc_dec_step_nv1(self, task):
        """
        Encoding / decoding step.
        """
        params = self.params
        embedder, encoder, decoder = (
            self.modules["embedder"],
            self.modules["encoder"],
            self.modules["decoder"],
        )
        embedder.train()
        encoder.train()
        decoder.train()
        env = self.env

        #params.batch_size=4

        samples = batch_generate_nv1(params.batch_size)

        if self.params.debug_train_statistics:
            for info_type, info in samples["infos"].items():
                self.infos_statistics[info_type].append(info)

        times = samples["times"]
        trajectory = samples["trajectory"]
        infos = samples["infos"]

        if params.max_masked_variables:  # randomly mask some variables
            for seq_id in range(len(times)):
                n_masked_variables = min(np.random.randint(0, self.params.max_masked_variables + 1), infos["dimension"][seq_id]-1)
                trajectory[seq_id][:, -n_masked_variables:] = np.nan

        x1 = []
        for seq_id in range(len(times)):
            x1.append([])
            for seq_l in range(len(times[seq_id])):
                x1[seq_id].append([times[seq_id][seq_l], trajectory[seq_id][seq_l]])
            #x1[seq_id].insert(0,samples["prompt"][seq_id])

        if self.params.masked_input: #randomly mask some inputs
            input_tokens, len1 = embedder.forward(x1, return_before_embed=True)
            # randomly mask a fraction x of the input tokens along seq dimension
            mask = np.random.rand(*input_tokens.shape[:2]) < self.params.masked_input
            input_tokens[mask][self.params.float_descriptor_length:] = encoder.word2id['<MASK>']
            x1 = embedder.compress(embedder.embed(input_tokens))
            predict_input_tokens = input_tokens[:,:,self.params.float_descriptor_length:]
        else:
            if "prompt" in samples:
                for seq_id in range(len(times)):
                    x1[seq_id].insert(0,samples["prompt"][seq_id])
            x1, len1 = embedder(x1)

        if self.params.use_skeleton:
            x2, len2 = self.env.batch_equations(
                self.env.word_to_idx(
                    samples["skeleton_tree_encoded"], float_input=False
                )
            )
        else:
            x2, len2 = self.env.batch_equations(
                self.env.word_to_idx(samples["tree_encoded"], float_input=False)
            )

        # target words to predict
        alen = torch.arange(len2.max(), dtype=torch.long, device=len2.device)
        pred_mask = (
            alen[:, None] < len2[None] - 1
        )  # do not predict anything given the last target word

        y = x2[1:].masked_select(pred_mask[:-1])
        assert len(y) == (len2 - 1).sum().item()
        
        if params.use_two_hot:
            assert self.env.equation_encoder.constant_encoder is not None
            y = self.env.ids_to_two_hot(
                ids=y.reshape(-1, 1), 
                support_size=len(self.env.equation_words) + len(self.env.constant_words)
            )
        
        # cuda
        ## Revised
        if not self.params.cpu:
            x2, len2, y = to_cuda(x2, len2, y)
        # forward / loss

        with autocast_wrapper(params):
            encoded = encoder("fwd", x=x1, lengths=len1, causal=False)
            decoded = decoder(
                "fwd",
                x=x2,
                lengths=len2,
                causal=True,
                src_enc=encoded.transpose(0, 1),
                src_len=len1,
            )
            # print(decoded.size(), pred_mask.size(), y.size())
            _scores, loss = decoder(
                "predict", tensor=decoded, pred_mask=pred_mask, y=y, get_scores=False
            )
            # print(loss)
            

            if self.params.masked_output:
                # randomly mask a fraction x of the input tokens along seq dimension
                output_mask = torch.rand(x2.shape[:2]) < self.params.masked_output
                output_mask *= pred_mask
                predict_output_tokens = x2[output_mask]
                x2_masked = x2.clone()
                x2_masked[output_mask] = encoder.word2id['<MASK>']

                decoded = decoder(
                    "fwd",
                    x=x2_masked,
                    lengths=len2,
                    causal=False,
                    src_enc=encoded.transpose(0, 1),
                    src_len=len1,
                )
                _scores, loss_mlm = decoder(
                    "predict", tensor=decoded, pred_mask=output_mask, y=predict_output_tokens, get_scores=False
                )
                loss = loss + loss_mlm

            if self.params.masked_input:
                targets = predict_input_tokens # len, bs, float_descriptor_length * max_dim
                targets = targets[mask]
                encoded = encoded[mask]
                targets = targets.flatten()
                scores = encoder.proj(encoded) # len, bs, float_descriptor_length * max_dim * n_words
                scores = scores.view(-1, len(encoder.word2id))
                loss_numeric = F.cross_entropy(scores, targets)
                loss = loss + loss_numeric

        self.stats[task].append(loss.item())

        # optimize
        self.optimize(loss)

        # number of processed sequences / words
        self.inner_epoch += 1
        self.n_equations += len1.size(0)
        self.stats["processed_e"] += len1.size(0)
        self.stats["processed_w"] += (len1 + len2 - 2).sum().item()

        return loss.item()

    def enc_dec_step_nv2(self, task):
        """
        Encoding / decoding step.
        """
        params = self.params
        embedder, encoder, decoder = (
            self.modules["embedder"],
            self.modules["encoder"],
            self.modules["decoder"],
        )
        embedder.train()
        encoder.train()
        decoder.train()
        env = self.env

        samples = batch_generate_nv2(params.batch_size)


        if self.params.debug_train_statistics:
            for info_type, info in samples["infos"].items():
                self.infos_statistics[info_type].append(info)

        times = samples["times"]
        trajectory = samples["trajectory"]
        infos = samples["infos"]

        if params.max_masked_variables:  # randomly mask some variables
            for seq_id in range(len(times)):
                n_masked_variables = min(np.random.randint(0, self.params.max_masked_variables + 1), infos["dimension"][seq_id]-1)
                trajectory[seq_id][:, -n_masked_variables:] = np.nan

        x1 = []
        for seq_id in range(len(times)):
            x1.append([])
            for seq_l in range(len(times[seq_id])):
                x1[seq_id].append([times[seq_id][seq_l], trajectory[seq_id][seq_l]])
                
        if self.params.masked_input: #randomly mask some inputs
            input_tokens, len1 = embedder.forward(x1, return_before_embed=True)
            # randomly mask a fraction x of the input tokens along seq dimension
            mask = np.random.rand(*input_tokens.shape[:2]) < self.params.masked_input
            input_tokens[mask][self.params.float_descriptor_length:] = encoder.word2id['<MASK>']
            x1 = embedder.compress(embedder.embed(input_tokens))
            predict_input_tokens = input_tokens[:,:,self.params.float_descriptor_length:]
        else:
            if "prompt" in samples:
                for seq_id in range(len(times)):
                    x1[seq_id].insert(0,samples["prompt"][seq_id])
            x1, len1 = embedder(x1)

        if self.params.use_skeleton:
            x2, len2 = self.env.batch_equations(
                self.env.word_to_idx(
                    samples["skeleton_tree_encoded"], float_input=False
                )
            )
        else:
            x2, len2 = self.env.batch_equations(
                self.env.word_to_idx(samples["tree_encoded"], float_input=False)
            )

        # target words to predict
        alen = torch.arange(len2.max(), dtype=torch.long, device=len2.device)
        pred_mask = (
            alen[:, None] < len2[None] - 1
        )  # do not predict anything given the last target word

        y = x2[1:].masked_select(pred_mask[:-1])
        assert len(y) == (len2 - 1).sum().item()
        
        if params.use_two_hot:
            assert self.env.equation_encoder.constant_encoder is not None
            y = self.env.ids_to_two_hot(
                ids=y.reshape(-1, 1), 
                support_size=len(self.env.equation_words) + len(self.env.constant_words)
            )
        
        # cuda
        ## Revised
        if not self.params.cpu:
            x2, len2, y = to_cuda(x2, len2, y)
        # forward / loss

        with autocast_wrapper(params):
            encoded = encoder("fwd", x=x1, lengths=len1, causal=False)
            decoded = decoder(
                "fwd",
                x=x2,
                lengths=len2,
                causal=True,
                src_enc=encoded.transpose(0, 1),
                src_len=len1,
            )
            # print(decoded.size(), pred_mask.size(), y.size())
            _scores, loss = decoder(
                "predict", tensor=decoded, pred_mask=pred_mask, y=y, get_scores=False
            )
            # print(loss)
            

            if self.params.masked_output:
                # randomly mask a fraction x of the input tokens along seq dimension
                output_mask = torch.rand(x2.shape[:2]) < self.params.masked_output
                output_mask *= pred_mask
                predict_output_tokens = x2[output_mask]
                x2_masked = x2.clone()
                x2_masked[output_mask] = encoder.word2id['<MASK>']

                decoded = decoder(
                    "fwd",
                    x=x2_masked,
                    lengths=len2,
                    causal=False,
                    src_enc=encoded.transpose(0, 1),
                    src_len=len1,
                )
                _scores, loss_mlm = decoder(
                    "predict", tensor=decoded, pred_mask=output_mask, y=predict_output_tokens, get_scores=False
                )
                loss = loss + loss_mlm

            if self.params.masked_input:
                targets = predict_input_tokens # len, bs, float_descriptor_length * max_dim
                targets = targets[mask]
                encoded = encoded[mask]
                targets = targets.flatten()
                scores = encoder.proj(encoded) # len, bs, float_descriptor_length * max_dim * n_words
                scores = scores.view(-1, len(encoder.word2id))
                loss_numeric = F.cross_entropy(scores, targets)
                loss = loss + loss_numeric

        self.stats[task].append(loss.item())

        # optimize
        self.optimize(loss)

        # number of processed sequences / words
        self.inner_epoch += 1
        self.n_equations += len1.size(0)
        self.stats["processed_e"] += len1.size(0)
        self.stats["processed_w"] += (len1 + len2 - 2).sum().item()

        return loss.item()

    def enc_dec_step_nv3(self, task):
        """
        Encoding / decoding step.
        """
        params = self.params
        embedder, encoder, decoder = (
            self.modules["embedder"],
            self.modules["encoder"],
            self.modules["decoder"],
        )
        embedder.train()
        encoder.train()
        decoder.train()
        env = self.env

        samples = batch_generate_nv3(params.batch_size)

        if self.params.debug_train_statistics:
            for info_type, info in samples["infos"].items():
                self.infos_statistics[info_type].append(info)

        times = samples["times"]
        trajectory = samples["trajectory"]
        infos = samples["infos"]

        if params.max_masked_variables:  # randomly mask some variables
            for seq_id in range(len(times)):
                n_masked_variables = min(np.random.randint(0, self.params.max_masked_variables + 1), infos["dimension"][seq_id]-1)
                trajectory[seq_id][:, -n_masked_variables:] = np.nan

        x1 = []
        for seq_id in range(len(times)):
            x1.append([])
            for seq_l in range(len(times[seq_id])):
                x1[seq_id].append([times[seq_id][seq_l], trajectory[seq_id][seq_l]])
                
        if self.params.masked_input: #randomly mask some inputs
            input_tokens, len1 = embedder.forward(x1, return_before_embed=True)
            # randomly mask a fraction x of the input tokens along seq dimension
            mask = np.random.rand(*input_tokens.shape[:2]) < self.params.masked_input
            input_tokens[mask][self.params.float_descriptor_length:] = encoder.word2id['<MASK>']
            x1 = embedder.compress(embedder.embed(input_tokens))
            predict_input_tokens = input_tokens[:,:,self.params.float_descriptor_length:]
        else:
            if "prompt" in samples:
                for seq_id in range(len(times)):
                    x1[seq_id].insert(0,samples["prompt"][seq_id])
            x1, len1 = embedder(x1)

        if self.params.use_skeleton:
            x2, len2 = self.env.batch_equations(
                self.env.word_to_idx(
                    samples["skeleton_tree_encoded"], float_input=False
                )
            )
        else:
            x2, len2 = self.env.batch_equations(
                self.env.word_to_idx(samples["tree_encoded"], float_input=False)
            )

        # target words to predict
        alen = torch.arange(len2.max(), dtype=torch.long, device=len2.device)
        pred_mask = (
            alen[:, None] < len2[None] - 1
        )  # do not predict anything given the last target word

        y = x2[1:].masked_select(pred_mask[:-1])
        assert len(y) == (len2 - 1).sum().item()
        
        if params.use_two_hot:
            assert self.env.equation_encoder.constant_encoder is not None
            y = self.env.ids_to_two_hot(
                ids=y.reshape(-1, 1), 
                support_size=len(self.env.equation_words) + len(self.env.constant_words)
            )
        
        # cuda
        ## Revised
        if not self.params.cpu:
            x2, len2, y = to_cuda(x2, len2, y)
        # forward / loss

        with autocast_wrapper(params):
            encoded = encoder("fwd", x=x1, lengths=len1, causal=False)
            decoded = decoder(
                "fwd",
                x=x2,
                lengths=len2,
                causal=True,
                src_enc=encoded.transpose(0, 1),
                src_len=len1,
            )
            _scores, loss = decoder(
                "predict", tensor=decoded, pred_mask=pred_mask, y=y, get_scores=False
            )
            

            if self.params.masked_output:
                # randomly mask a fraction x of the input tokens along seq dimension
                output_mask = torch.rand(x2.shape[:2]) < self.params.masked_output
                output_mask *= pred_mask
                predict_output_tokens = x2[output_mask]
                x2_masked = x2.clone()
                x2_masked[output_mask] = encoder.word2id['<MASK>']

                decoded = decoder(
                    "fwd",
                    x=x2_masked,
                    lengths=len2,
                    causal=False,
                    src_enc=encoded.transpose(0, 1),
                    src_len=len1,
                )
                _scores, loss_mlm = decoder(
                    "predict", tensor=decoded, pred_mask=output_mask, y=predict_output_tokens, get_scores=False
                )
                loss = loss + loss_mlm

            if self.params.masked_input:
                targets = predict_input_tokens # len, bs, float_descriptor_length * max_dim
                targets = targets[mask]
                encoded = encoded[mask]
                targets = targets.flatten()
                scores = encoder.proj(encoded) # len, bs, float_descriptor_length * max_dim * n_words
                scores = scores.view(-1, len(encoder.word2id))
                loss_numeric = F.cross_entropy(scores, targets)
                loss = loss + loss_numeric

        self.stats[task].append(loss.item())

        # optimize
        self.optimize(loss)

        # number of processed sequences / words
        self.inner_epoch += 1
        self.n_equations += len1.size(0)
        self.stats["processed_e"] += len1.size(0)
        self.stats["processed_w"] += (len1 + len2 - 2).sum().item()

        return loss.item()


@contextlib.contextmanager
def autocast_wrapper(params):
    if params.amp == -1 or params.nvidia_apex:
        yield
    else:
        with torch.cuda.amp.autocast():
            yield
