"""
Main component: the trainer handles everything:
    * initializations
    * training
    * saving
"""
import inspect
import warnings
from copy import deepcopy
from pathlib import Path
from time import time

import numpy as np
from comet_ml import ExistingExperiment, Experiment

warnings.simplefilter("ignore", UserWarning)

import torch
import torch.nn as nn
from addict import Dict
from torch import autograd, sigmoid, softmax
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm

from climategan.data import get_all_loaders
from climategan.discriminator import Discriminator, create_discriminator
from climategan.eval_metrics import accuracy, mIOU
from climategan.fid import compute_val_fid
from climategan.generator import Generator, create_generator
from climategan.logger import Logger
from climategan.losses import get_losses
from climategan.optim import get_optimizer
from climategan.transforms import DiffTransforms
from climategan.tutils import (
    divide_pred,
    get_num_params,
    get_WGAN_gradient,
    print_num_parameters,
    shuffle_batch_tuple,
    vgg_preprocess,
    zero_grad,
)
from climategan.utils import (
    Timer,
    comet_kwargs,
    div_dict,
    find_target_size,
    flatten_opts,
    get_display_indices,
    get_existing_comet_id,
    get_latest_opts,
    merge,
    sum_dict,
)

try:
    import torch_xla.core.xla_model as xm  # type: ignore
except ImportError:
    pass


class Trainer:
    """Main trainer class"""

    def __init__(self, opts, comet_exp=None, verbose=0, device=None):
        """Trainer class to gather various model training procedures
        such as training evaluating saving and logging

        init:
        * creates an addict.Dict logger
        * creates logger.exp as a comet_exp experiment if `comet` arg is True
        * sets the device (1 GPU or CPU)

        Args:
            opts (addict.Dict): options to configure the trainer, the data, the models
            comet (bool, optional): whether to log the trainer with comet.ml.
                                    Defaults to False.
            verbose (int, optional): printing level to debug. Defaults to 0.
        """
        super().__init__()

        self.opts = opts
        self.verbose = verbose
        self.logger = Logger(self)

        self.losses = None
        self.G = self.D = None
        self.real_val_fid_stats = None
        self.use_pl4m = False
        self.is_setup = False
        self.loaders = self.all_loaders = None
        self.exp = None

        self.current_mode = "train"
        self.diff_transforms = None
        self.kitti_pretrain = self.opts.train.kitti.pretrain
        self.pseudo_training_tasks = set(self.opts.train.pseudo.tasks)

        self.lr_names = {}
        self.base_display_images = {}
        self.kitty_display_images = {}
        self.domain_labels = {"s": 0, "r": 1}

        self.device = device or torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu"
        )

        if isinstance(comet_exp, Experiment):
            self.exp = comet_exp

        if self.opts.train.amp:
            optimizers = [
                self.opts.gen.opt.optimizer.lower(),
                self.opts.dis.opt.optimizer.lower(),
            ]
            if "extraadam" in optimizers:
                raise ValueError(
                    "AMP does not work with ExtraAdam ({})".format(optimizers)
                )
            self.grad_scaler_d = GradScaler()
            self.grad_scaler_g = GradScaler()

        # -------------------------------
        # -----  Legacy Overwrites  -----
        # -------------------------------
        if (
            self.opts.gen.s.depth_feat_fusion is True
            or self.opts.gen.s.depth_dada_fusion is True
        ):
            self.opts.gen.s.use_dada = True

    @torch.no_grad()
    def paint_and_mask(self, image_batch, mask_batch=None, resolution="approx"):
        """
        Paints a batch of images (or a single image with a batch dim of 1). If
        masks are not provided, they are inferred from the masker.
        Resolution can either be the train-time resolution or the closest
        multiple of 2 ** spade_n_up

        Operations performed without gradient

        If resolution == "approx" then the output image has the shape:
            (dim // 2 ** spade_n_up) * 2 ** spade_n_up, for dim in [height, width]
            eg: (1000, 1300) => (896, 1280) for spade_n_up = 7
        If resolution == "exact" then the output image has the same shape:
            we first process in "approx" mode then upsample bilinear
        If resolution == "basic" image output shape is the train-time's
            (typically 640x640)
        If resolution == "upsample" image is inferred as "basic" and
            then upsampled to original size

        Args:
            image_batch (torch.Tensor): 4D batch of images to flood
            mask_batch (torch.Tensor, optional): Masks for the images.
                Defaults to None (infer with Masker).
            resolution (str, optional): "approx", "exact" or False

        Returns:
            torch.Tensor: N x C x H x W where H and W depend on `resolution`
        """
        assert resolution in {"approx", "exact", "basic", "upsample"}
        previous_mode = self.current_mode
        if previous_mode == "train":
            self.eval_mode()

        if mask_batch is None:
            mask_batch = self.G.mask(x=image_batch)
        else:
            assert len(image_batch) == len(mask_batch)
            assert image_batch.shape[-2:] == mask_batch.shape[-2:]

        if resolution not in {"approx", "exact"}:
            painted = self.G.paint(mask_batch, image_batch)

            if resolution == "upsample":
                painted = nn.functional.interpolate(
                    painted, size=image_batch.shape[-2:], mode="bilinear"
                )
        else:
            # save latent shape
            zh = self.G.painter.z_h
            zw = self.G.painter.z_w
            # adapt latent shape to approximately keep the resolution
            self.G.painter.z_h = (
                image_batch.shape[-2] // 2 ** self.opts.gen.p.spade_n_up
            )
            self.G.painter.z_w = (
                image_batch.shape[-1] // 2 ** self.opts.gen.p.spade_n_up
            )

            painted = self.G.paint(mask_batch, image_batch)

            self.G.painter.z_h = zh
            self.G.painter.z_w = zw
            if resolution == "exact":
                painted = nn.functional.interpolate(
                    painted, size=image_batch.shape[-2:], mode="bilinear"
                )

        if previous_mode == "train":
            self.train_mode()

        return painted

    def _p(self, *args, **kwargs):
        """
        verbose-dependant print util
        """
        if self.verbose > 0:
            print(*args, **kwargs)

    @classmethod
    def resume_from_path(
        cls,
        path,
        overrides={},
        setup=True,
        inference=False,
        new_exp=False,
        device=None,
        verbose=1,
    ):
        """
        Resume and optionally setup a trainer from a specific path,
        using the latest opts and checkpoint. Requires path to contain opts.yaml
        (or increased), url.txt (or increased) and checkpoints/

        Args:
            path (str | pathlib.Path): Trainer to resume
            overrides (dict, optional): Override loaded opts with those. Defaults to {}.
            setup (bool, optional): Wether or not to setup the trainer before
                returning it. Defaults to True.
            inference (bool, optional): Setup should be done in inference mode or not.
                Defaults to False.
            new_exp (bool, optional): Re-use existing comet exp in path or create
                a new one? Defaults to False.
            device (torch.device, optional): Device to use

        Returns:
            climategan.Trainer: Loaded and resumed trainer
        """
        p = Path(path).expanduser().resolve()
        assert p.exists()

        c = p / "checkpoints"
        assert c.exists() and c.is_dir()

        opts = get_latest_opts(p)
        opts = Dict(merge(overrides, opts))
        opts.train.resume = True

        if new_exp is None:
            exp = None
        elif new_exp is True:
            exp = Experiment(project_name="climategan", **comet_kwargs)
            exp.log_asset_folder(
                str(Path(__file__).parent), recursive=True, log_file_name=True
            )
            exp.log_parameters(flatten_opts(opts))
        else:
            comet_id = get_existing_comet_id(p)
            exp = ExistingExperiment(previous_experiment=comet_id, **comet_kwargs)

        trainer = cls(opts, comet_exp=exp, device=device, verbose=verbose)

        if setup:
            trainer.setup(inference=inference)
        return trainer

    def save(self):
        save_dir = Path(self.opts.output_path) / Path("checkpoints")
        save_dir.mkdir(exist_ok=True)
        save_path = save_dir / "latest_ckpt.pth"

        # Construct relevant state dicts / optims:
        # Save at least G
        save_dict = {
            "epoch": self.logger.epoch,
            "G": self.G.state_dict(),
            "g_opt": self.g_opt.state_dict(),
            "step": self.logger.global_step,
        }

        if self.D is not None and get_num_params(self.D) > 0:
            save_dict["D"] = self.D.state_dict()
            save_dict["d_opt"] = self.d_opt.state_dict()

        if (
            self.logger.epoch >= self.opts.train.min_save_epoch
            and self.logger.epoch % self.opts.train.save_n_epochs == 0
        ):
            torch.save(save_dict, save_dir / f"epoch_{self.logger.epoch}_ckpt.pth")
            if not self.opts.train.always_save_latest:
                torch.save(save_dict, save_path)

        if self.opts.train.always_save_latest:
            torch.save(save_dict, save_path)

    def resume(self, inference=False):
        tpu = "xla" in str(self.device)
        if tpu:
            print("Resuming on TPU:", self.device)

        m_path = Path(self.opts.load_paths.m)
        p_path = Path(self.opts.load_paths.p)
        pm_path = Path(self.opts.load_paths.pm)
        output_path = Path(self.opts.output_path)

        map_loc = self.device if not tpu else "cpu"

        if "m" in self.opts.tasks and "p" in self.opts.tasks:
            # ----------------------------------------
            # -----  Masker and Painter Loading  -----
            # ----------------------------------------

            # want to resume a pm model but no path was provided:
            # resume a single pm model from output_path
            if all([str(p) == "none" for p in [m_path, p_path, pm_path]]):
                checkpoint_path = output_path / "checkpoints/latest_ckpt.pth"
                print("Resuming P+M model from", str(checkpoint_path))
                checkpoint = torch.load(checkpoint_path, map_location=map_loc)

            # want to resume a pm model with a pm_path provided:
            # resume a single pm model from load_paths.pm
            # depending on whether a dir or a file is specified
            elif str(pm_path) != "none":
                assert pm_path.exists()

                if pm_path.is_dir():
                    checkpoint_path = pm_path / "checkpoints/latest_ckpt.pth"
                else:
                    assert pm_path.suffix == ".pth"
                    checkpoint_path = pm_path

                print("Resuming P+M model from", str(checkpoint_path))
                checkpoint = torch.load(checkpoint_path, map_location=map_loc)

            # want to resume a pm model, pm_path not provided:
            # m_path and p_path must be provided as dirs or pth files
            elif m_path != p_path:
                assert m_path.exists()
                assert p_path.exists()

                if m_path.is_dir():
                    m_path = m_path / "checkpoints/latest_ckpt.pth"

                if p_path.is_dir():
                    p_path = p_path / "checkpoints/latest_ckpt.pth"

                assert m_path.suffix == ".pth"
                assert p_path.suffix == ".pth"

                print(f"Resuming P+M model from \n  -{p_path} \nand \n  -{m_path}")
                m_checkpoint = torch.load(m_path, map_location=map_loc)
                p_checkpoint = torch.load(p_path, map_location=map_loc)
                checkpoint = merge(m_checkpoint, p_checkpoint)

            else:
                raise ValueError(
                    "Cannot resume a P+M model with provided load_paths:\n{}".format(
                        self.opts.load_paths
                    )
                )

        else:
            # ----------------------------------
            # -----  Single Model Loading  -----
            # ----------------------------------

            # cannot specify both paths
            if str(m_path) != "none" and str(p_path) != "none":
                raise ValueError(
                    "Opts tasks are {} but received 2 values for the load_paths".format(
                        self.opts.tasks
                    )
                )

            # specified m
            elif str(m_path) != "none":
                assert m_path.exists()
                assert "m" in self.opts.tasks
                model = "M"
                if m_path.is_dir():
                    m_path = m_path / "checkpoints/latest_ckpt.pth"
                checkpoint_path = m_path

            # specified m
            elif str(p_path) != "none":
                assert p_path.exists()
                assert "p" in self.opts.tasks
                model = "P"
                if p_path.is_dir():
                    p_path = p_path / "checkpoints/latest_ckpt.pth"
                checkpoint_path = p_path

            # specified neither p nor m: resume from output_path
            else:
                model = "P" if "p" in self.opts.tasks else "M"
                checkpoint_path = output_path / "checkpoints/latest_ckpt.pth"

            print(f"Resuming {model} model from {checkpoint_path}")
            checkpoint = torch.load(checkpoint_path, map_location=map_loc)

        # On TPUs must send the data to the xla device as it cannot be mapped
        # there directly from torch.load
        if tpu:
            checkpoint = xm.send_cpu_data_to_device(checkpoint, self.device)

        # -----------------------
        # -----  Restore G  -----
        # -----------------------
        if inference:
            incompatible_keys = self.G.load_state_dict(checkpoint["G"], strict=False)
            if incompatible_keys.missing_keys:
                print("WARNING: Missing keys in self.G.load_state_dict, keeping inits")
                print(incompatible_keys.missing_keys)
            if incompatible_keys.unexpected_keys:
                print("WARNING: Ignoring Unexpected keys in self.G.load_state_dict")
                print(incompatible_keys.unexpected_keys)
        else:
            self.G.load_state_dict(checkpoint["G"])

        if inference:
            # only G is needed to infer
            print("Done loading checkpoints.")
            return

        self.g_opt.load_state_dict(checkpoint["g_opt"])

        # ------------------------------
        # -----  Resume scheduler  -----
        # ------------------------------
        # https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822
        for _ in range(self.logger.epoch + 1):
            self.update_learning_rates()

        # -----------------------
        # -----  Restore D  -----
        # -----------------------
        if self.D is not None and get_num_params(self.D) > 0:
            self.D.load_state_dict(checkpoint["D"])
            self.d_opt.load_state_dict(checkpoint["d_opt"])

        # ---------------------------
        # -----  Resore logger  -----
        # ---------------------------
        self.logger.epoch = checkpoint["epoch"]
        self.logger.global_step = checkpoint["step"]
        self.exp.log_text(
            "Resuming from epoch {} & step {}".format(
                checkpoint["epoch"], checkpoint["step"]
            )
        )
        # Round step to even number for extraGradient
        if self.logger.global_step % 2 != 0:
            self.logger.global_step += 1

    def eval_mode(self):
        """
        Set trainer's models in eval mode
        """
        if self.G is not None:
            self.G.eval()
        if self.D is not None:
            self.D.eval()
        self.current_mode = "eval"

    def train_mode(self):
        """
        Set trainer's models in train mode
        """
        if self.G is not None:
            self.G.train()
        if self.D is not None:
            self.D.train()
        self.current_mode = "train"

    def assert_z_matches_x(self, x, z):
        assert x.shape[0] == (
            z.shape[0] if not isinstance(z, (list, tuple)) else z[0].shape[0]
        ), "x-> {}, z->{}".format(
            x.shape, z.shape if not isinstance(z, (list, tuple)) else z[0].shape
        )

    def batch_to_device(self, b):
        """sends the data in b to self.device

        Args:
            b (dict): the batch dictionnay

        Returns:
            dict: the batch dictionnary with its "data" field sent to self.device
        """
        for task, tensor in b["data"].items():
            b["data"][task] = tensor.to(self.device)
        return b

    def sample_painter_z(self, batch_size):
        return self.G.sample_painter_z(batch_size, self.device)

    @property
    def train_loaders(self):
        """Get a zip of all training loaders

        Returns:
            generator: zip generator yielding tuples:
                (batch_rf, batch_rn, batch_sf, batch_sn)
        """
        return zip(*list(self.loaders["train"].values()))

    @property
    def val_loaders(self):
        """Get a zip of all validation loaders

        Returns:
            generator: zip generator yielding tuples:
                (batch_rf, batch_rn, batch_sf, batch_sn)
        """
        return zip(*list(self.loaders["val"].values()))

    def compute_latent_shape(self):
        """Compute the latent shape, i.e. the Encoder's output shape,
        from a batch.

        Raises:
            ValueError: If no loader, the latent_shape cannot be inferred

        Returns:
            tuple: (c, h, w)
        """
        x = None
        for mode in self.all_loaders:
            for domain in self.all_loaders.loaders[mode]:
                x = (
                    self.all_loaders[mode][domain]
                    .dataset[0]["data"]["x"]
                    .to(self.device)
                )
                break
            if x is not None:
                break

        if x is None:
            raise ValueError("No batch found to compute_latent_shape")

        x = x.unsqueeze(0)
        z = self.G.encode(x)
        return z.shape[1:] if not isinstance(z, (list, tuple)) else z[0].shape[1:]

    def g_opt_step(self):
        """Run an optimizing step ; if using ExtraAdam, there needs to be an extrapolation
        step every other step
        """
        if "extra" in self.opts.gen.opt.optimizer.lower() and (
            self.logger.global_step % 2 == 0
        ):
            self.g_opt.extrapolation()
        else:
            self.g_opt.step()

    def d_opt_step(self):
        """Run an optimizing step ; if using ExtraAdam, there needs to be an extrapolation
        step every other step
        """
        if "extra" in self.opts.dis.opt.optimizer.lower() and (
            self.logger.global_step % 2 == 0
        ):
            self.d_opt.extrapolation()
        else:
            self.d_opt.step()

    def update_learning_rates(self):
        if self.g_scheduler is not None:
            self.g_scheduler.step()
        if self.d_scheduler is not None:
            self.d_scheduler.step()

    def setup(self, inference=False):
        """Prepare the trainer before it can be used to train the models:
        * initialize G and D
        * compute latent space dims
        * creates 3 optimizers
        """
        self.logger.global_step = 0
        start_time = time()
        self.logger.time.start_time = start_time
        verbose = self.verbose

        if not inference:
            self.all_loaders = get_all_loaders(self.opts)

        # -----------------------
        # -----  Generator  -----
        # -----------------------
        __t = time()
        print("Creating generator...")

        self.G: Generator = create_generator(
            self.opts, device=self.device, no_init=inference, verbose=verbose
        )

        self.has_painter = get_num_params(self.G.painter) or self.G.load_val_painter()

        if self.has_painter:
            self.G.painter.set_latent_shape(find_target_size(self.opts, "x"), True)

        print(f"Generator OK in {time() - __t:.1f}s.")

        if inference:  # Inference mode: no more than a Generator needed
            print("Inference mode: no Discriminator, no optimizers")
            print_num_parameters(self)
            self.switch_data(to="base")
            if self.opts.train.resume:
                self.resume(True)
            self.eval_mode()
            print("Trainer is in evaluation mode.")
            print("Setup done.")
            self.is_setup = True
            return

        # ---------------------------
        # -----  Discriminator  -----
        # ---------------------------

        self.D: Discriminator = create_discriminator(
            self.opts, self.device, verbose=verbose
        )
        print("Discriminator OK.")

        print_num_parameters(self)

        # --------------------------
        # -----  Optimization  -----
        # --------------------------
        # Get different optimizers for each task (different learning rates)
        self.g_opt, self.g_scheduler, self.lr_names["G"] = get_optimizer(
            self.G, self.opts.gen.opt, self.opts.tasks
        )

        if get_num_params(self.D) > 0:
            self.d_opt, self.d_scheduler, self.lr_names["D"] = get_optimizer(
                self.D, self.opts.dis.opt, self.opts.tasks, True
            )
        else:
            self.d_opt, self.d_scheduler = None, None

        self.losses = get_losses(self.opts, verbose, device=self.device)

        if "p" in self.opts.tasks and self.opts.gen.p.diff_aug.use:
            self.diff_transforms = DiffTransforms(self.opts.gen.p.diff_aug)

        if verbose > 0:
            for mode, mode_dict in self.all_loaders.items():
                for domain, domain_loader in mode_dict.items():
                    print(
                        "Loader {} {} : {}".format(
                            mode, domain, len(domain_loader.dataset)
                        )
                    )

        # ----------------------------
        # -----  Display images  -----
        # ----------------------------
        self.set_display_images()

        # -------------------------------
        # -----  Log Architectures  -----
        # -------------------------------
        self.logger.log_architecture()

        # -----------------------------
        # -----  Set data source  -----
        # -----------------------------
        if self.kitti_pretrain:
            self.switch_data(to="kitti")
        else:
            self.switch_data(to="base")

        # -------------------------
        # -----  Setup Done.  -----
        # -------------------------
        print(" " * 50, end="\r")
        print("Done creating display images")

        if self.opts.train.resume:
            print("Resuming Model (inference: False)")
            self.resume(False)
        else:
            print("Not resuming: starting a new model")

        print("Setup done.")
        self.is_setup = True

    def switch_data(self, to="kitti"):
        caller = inspect.stack()[1].function
        print(f"[{caller}] Switching data source to", to)
        self.data_source = to
        if to == "kitti":
            self.display_images = self.kitty_display_images
            if self.all_loaders is not None:
                self.loaders = {
                    mode: {"s": self.all_loaders[mode]["kitti"]}
                    for mode in self.all_loaders
                }
        else:
            self.display_images = self.base_display_images
            if self.all_loaders is not None:
                self.loaders = {
                    mode: {
                        domain: self.all_loaders[mode][domain]
                        for domain in self.all_loaders[mode]
                        if domain != "kitti"
                    }
                    for mode in self.all_loaders
                }
        if (
            self.logger.global_step % 2 != 0
            and "extra" in self.opts.dis.opt.optimizer.lower()
        ):
            print(
                "Warning: artificially bumping step to run an extrapolation step first."
            )
            self.logger.global_step += 1

    def set_display_images(self, use_all=False):
        for mode, mode_dict in self.all_loaders.items():

            if self.kitti_pretrain:
                self.kitty_display_images[mode] = {}
            self.base_display_images[mode] = {}

            for domain in mode_dict:

                if self.kitti_pretrain and domain == "kitti":
                    target_dict = self.kitty_display_images
                else:
                    if domain == "kitti":
                        continue
                    target_dict = self.base_display_images

                dataset = self.all_loaders[mode][domain].dataset
                display_indices = (
                    get_display_indices(self.opts, domain, len(dataset))
                    if not use_all
                    else list(range(len(dataset)))
                )
                ldis = len(display_indices)
                print(
                    f"       Creating {ldis} {mode} {domain} display images...",
                    end="\r",
                    flush=True,
                )
                target_dict[mode][domain] = [
                    Dict(dataset[i])
                    for i in display_indices
                    if (print(f"({i})", end="\r") is None and i < len(dataset))
                ]
                if self.exp is not None:
                    for im_id, d in enumerate(target_dict[mode][domain]):
                        self.exp.log_parameter(
                            "display_image_{}_{}_{}".format(mode, domain, im_id),
                            d["paths"],
                        )

    def train(self):
        """For each epoch:
        * train
        * eval
        * save
        """
        assert self.is_setup
        eval_epoch = self.opts.val.eval_every_epoch or 1

        for self.logger.epoch in range(
            self.logger.epoch, self.logger.epoch + self.opts.train.epochs
        ):
            # backprop painter's disc loss to masker
            if (
                self.logger.epoch == self.opts.gen.p.pl4m_epoch
                and get_num_params(self.G.painter) > 0
                and "p" in self.opts.tasks
                and self.opts.gen.m.use_pl4m
            ):
                print(
                    "\n\n >>> Enabling pl4m at epoch {}\n\n".format(self.logger.epoch)
                )
                self.use_pl4m = True

            self.run_epoch()
            if (self.logger.epoch + 1) % eval_epoch == 0:
                self.run_evaluation(verbose=1)
            self.save()

            # end vkitti2 pre-training
            if self.logger.epoch == self.opts.train.kitti.epochs - 1:
                self.switch_data(to="base")
                self.kitti_pretrain = False

            # end pseudo training
            if self.logger.epoch == self.opts.train.pseudo.epochs - 1:
                print(
                    "\n>>> Epoch {}: end of pseudo-training <<<".format(
                        self.logger.epoch + 1
                    )
                )
                self.pseudo_training_tasks = set()

    def run_epoch(self):
        """Runs an epoch:
        * checks trainer is setup
        * gets a tuple of batches per domain
        * sends batches to device
        * updates sequentially G and D
        """
        assert self.is_setup
        self.train_mode()
        if self.exp is not None:
            self.exp.log_parameter("epoch", self.logger.epoch)
        epoch_len = min(len(loader) for loader in self.loaders["train"].values())
        epoch_desc = "Epoch {}".format(self.logger.epoch)
        self.logger.time.epoch_start = time()

        for multi_batch_tuple in tqdm(
            self.train_loaders,
            desc=epoch_desc,
            total=epoch_len,
            mininterval=0.5,
            unit="batch",
        ):

            self.logger.time.step_start = time()
            multi_batch_tuple = shuffle_batch_tuple(multi_batch_tuple)

            # The `[0]` is because the domain is contained in a list
            multi_domain_batch = {
                batch["domain"][0]: self.batch_to_device(batch)
                for batch in multi_batch_tuple
            }
            # ------------------------------
            # -----  Update Generator  -----
            # ------------------------------

            # freeze params of the discriminator
            if self.d_opt is not None:
                for param in self.D.parameters():
                    param.requires_grad = False

            self.update_G(multi_domain_batch)

            # ----------------------------------
            # -----  Update Discriminator  -----
            # ----------------------------------

            # unfreeze params of the discriminator
            if self.d_opt is not None and not self.kitti_pretrain:
                for param in self.D.parameters():
                    param.requires_grad = True

                self.update_D(multi_domain_batch)

            # -------------------------
            # -----  Log Metrics  -----
            # -------------------------
            self.logger.global_step += 1
            self.logger.log_step_time(time())

        if not self.kitti_pretrain:
            self.update_learning_rates()

        self.logger.log_learning_rates()
        self.logger.log_epoch_time(time())

    def update_G(self, multi_domain_batch, verbose=0):
        """Perform an update on g from multi_domain_batch which is a dictionary
        domain => batch

        * automatic mixed precision according to self.opts.train.amp
        * compute loss for each task
        * loss.backward()
        * g_opt_step()
            * g_opt.step() or .extrapolation() depending on self.logger.global_step
        * logs losses on comet.ml with self.logger.log_losses(model_to_update="G")

        Args:
            multi_domain_batch (dict): dictionnary of domain batches
        """
        zero_grad(self.G)
        if self.opts.train.amp:
            with autocast():
                g_loss = self.get_G_loss(multi_domain_batch, verbose)
            self.grad_scaler_g.scale(g_loss).backward()
            self.grad_scaler_g.step(self.g_opt)
            self.grad_scaler_g.update()
        else:
            g_loss = self.get_G_loss(multi_domain_batch, verbose)
            g_loss.backward()
            self.g_opt_step()

        self.logger.log_losses(model_to_update="G", mode="train")

    def update_D(self, multi_domain_batch, verbose=0):
        zero_grad(self.D)

        if self.opts.train.amp:
            with autocast():
                d_loss = self.get_D_loss(multi_domain_batch, verbose)
            self.grad_scaler_d.scale(d_loss).backward()
            self.grad_scaler_d.step(self.d_opt)
            self.grad_scaler_d.update()
        else:
            d_loss = self.get_D_loss(multi_domain_batch, verbose)
            d_loss.backward()
            self.d_opt_step()

        self.logger.losses.disc.total_loss = d_loss.item()
        self.logger.log_losses(model_to_update="D", mode="train")

    def get_D_loss(self, multi_domain_batch, verbose=0):
        """Compute the discriminators' losses:

        * for each domain-specific batch:
        * encode the image
        * get the conditioning tensor if using spade
        * source domain is the data's domain, sequentially r|s then f|n
        * get the target domain accordingly
        * compute the translated image from the data
        * compute the source domain discriminator's loss on the data
        * compute the target domain discriminator's loss on the translated image

        # ? In this setting, each D[decoder][domain] is updated twice towards
        # real or fake data

        See readme's update d section for details

        Args:
            multi_domain_batch ([type]): [description]

        Returns:
            [type]: [description]
        """

        disc_loss = {"m": {"Advent": 0}, "s": {"Advent": 0}}
        if self.opts.dis.p.use_local_discriminator:
            disc_loss["p"] = {"global": 0, "local": 0}
        else:
            disc_loss["p"] = {"gan": 0}

        for domain, batch in multi_domain_batch.items():
            x = batch["data"]["x"]

            # ---------------------
            # -----  Painter  -----
            # ---------------------
            if domain == "rf" and self.has_painter:
                m = batch["data"]["m"]
                # sample vector
                with torch.no_grad():
                    # see spade compute_discriminator_loss
                    fake = self.G.paint(m, x)
                    if self.opts.gen.p.diff_aug.use:
                        fake = self.diff_transforms(fake)
                        x = self.diff_transforms(x)
                    fake = fake.detach()
                    fake.requires_grad_()

                if self.opts.dis.p.use_local_discriminator:
                    fake_d_global = self.D["p"]["global"](fake)
                    real_d_global = self.D["p"]["global"](x)

                    fake_d_local = self.D["p"]["local"](fake * m)
                    real_d_local = self.D["p"]["local"](x * m)

                    global_loss = self.losses["D"]["p"](fake_d_global, False, True)
                    global_loss += self.losses["D"]["p"](real_d_global, True, True)

                    local_loss = self.losses["D"]["p"](fake_d_local, False, True)
                    local_loss += self.losses["D"]["p"](real_d_local, True, True)

                    disc_loss["p"]["global"] += global_loss
                    disc_loss["p"]["local"] += local_loss
                else:
                    real_cat = torch.cat([m, x], axis=1)
                    fake_cat = torch.cat([m, fake], axis=1)
                    real_fake_cat = torch.cat([real_cat, fake_cat], dim=0)
                    real_fake_d = self.D["p"](real_fake_cat)
                    real_d, fake_d = divide_pred(real_fake_d)
                    disc_loss["p"]["gan"] = self.losses["D"]["p"](fake_d, False, True)
                    disc_loss["p"]["gan"] += self.losses["D"]["p"](real_d, True, True)

            # --------------------
            # -----  Masker  -----
            # --------------------
            else:
                z = self.G.encode(x)
                s_pred = d_pred = cond = z_depth = None

                if "s" in batch["data"]:
                    if "d" in self.opts.tasks and self.opts.gen.s.use_dada:
                        d_pred, z_depth = self.G.decoders["d"](z)

                    step_loss, s_pred = self.masker_s_loss(
                        x, z, d_pred, z_depth, None, domain, for_="D"
                    )
                    step_loss *= self.opts.train.lambdas.advent.adv_main
                    disc_loss["s"]["Advent"] += step_loss

                if "m" in batch["data"]:
                    if "d" in self.opts.tasks:
                        if self.opts.gen.m.use_spade:
                            if d_pred is None:
                                d_pred, z_depth = self.G.decoders["d"](z)
                            cond = self.G.make_m_cond(d_pred, s_pred, x)
                        elif self.opts.gen.m.use_dada:
                            if d_pred is None:
                                d_pred, z_depth = self.G.decoders["d"](z)

                    step_loss, _ = self.masker_m_loss(
                        x,
                        z,
                        None,
                        domain,
                        for_="D",
                        cond=cond,
                        z_depth=z_depth,
                        depth_preds=d_pred,
                    )
                    step_loss *= self.opts.train.lambdas.advent.adv_main
                    disc_loss["m"]["Advent"] += step_loss

        self.logger.losses.disc.update(
            {
                dom: {
                    k: v.item() if isinstance(v, torch.Tensor) else v
                    for k, v in d.items()
                }
                for dom, d in disc_loss.items()
            }
        )

        loss = sum(v for d in disc_loss.values() for k, v in d.items())
        return loss

    def get_G_loss(self, multi_domain_batch, verbose=0):
        m_loss = p_loss = None

        # For now, always compute "representation loss"
        g_loss = 0

        if any(t in self.opts.tasks for t in "msd"):
            m_loss = self.get_masker_loss(multi_domain_batch)
            self.logger.losses.gen.masker = m_loss.item()
            g_loss += m_loss

        if "p" in self.opts.tasks and not self.kitti_pretrain:
            p_loss = self.get_painter_loss(multi_domain_batch)
            self.logger.losses.gen.painter = p_loss.item()
            g_loss += p_loss

        assert g_loss != 0 and not isinstance(g_loss, int), "No update in get_G_loss!"

        self.logger.losses.gen.total_loss = g_loss.item()

        return g_loss

    def get_masker_loss(self, multi_domain_batch):  # TODO update docstrings
        """Only update the representation part of the model, meaning everything
        but the translation part

        Args:
            multi_domain_batch (dict): dictionnary mapping domain names to batches from
            the trainer's loaders

        Returns:
            torch.Tensor: scalar loss tensor, weighted according to opts.train.lambdas
        """
        m_loss = 0
        for domain, batch in multi_domain_batch.items():
            # We don't care about the flooded domain here
            if domain == "rf":
                continue

            x = batch["data"]["x"]
            z = self.G.encode(x)

            # --------------------------------------
            # -----  task-specific losses (2)  -----
            # --------------------------------------
            d_pred = s_pred = z_depth = None
            for task in ["d", "s", "m"]:
                if task not in batch["data"]:
                    continue

                target = batch["data"][task]

                if task == "d":
                    loss, d_pred, z_depth = self.masker_d_loss(
                        x, z, target, domain, "G"
                    )
                    m_loss += loss
                    self.logger.losses.gen.task["d"][domain] = loss.item()

                elif task == "s":
                    loss, s_pred = self.masker_s_loss(
                        x, z, d_pred, z_depth, target, domain, "G"
                    )
                    m_loss += loss
                    self.logger.losses.gen.task["s"][domain] = loss.item()

                elif task == "m":
                    cond = None
                    if self.opts.gen.m.use_spade:
                        if not self.opts.gen.m.detach:
                            d_pred = d_pred.clone()
                            s_pred = s_pred.clone()
                        cond = self.G.make_m_cond(d_pred, s_pred, x)

                    loss, _ = self.masker_m_loss(
                        x,
                        z,
                        target,
                        domain,
                        "G",
                        cond=cond,
                        z_depth=z_depth,
                        depth_preds=d_pred,
                    )
                    m_loss += loss
                    self.logger.losses.gen.task["m"][domain] = loss.item()

        return m_loss

    def get_painter_loss(self, multi_domain_batch):
        """Computes the translation loss when flooding/deflooding images

        Args:
            multi_domain_batch (dict): dictionnary mapping domain names to batches from
            the trainer's loaders

        Returns:
            torch.Tensor: scalar loss tensor, weighted according to opts.train.lambdas
        """
        step_loss = 0
        # self.g_opt.zero_grad()
        lambdas = self.opts.train.lambdas
        batch_domain = "rf"
        batch = multi_domain_batch[batch_domain]

        x = batch["data"]["x"]
        # ! different mask: hides water to be reconstructed
        # ! 1 for water, 0 otherwise
        m = batch["data"]["m"]
        fake_flooded = self.G.paint(m, x)

        # ----------------------
        # -----  VGG Loss  -----
        # ----------------------
        if lambdas.G.p.vgg != 0:
            loss = self.losses["G"]["p"]["vgg"](
                vgg_preprocess(fake_flooded * m), vgg_preprocess(x * m)
            )
            loss *= lambdas.G.p.vgg
            self.logger.losses.gen.p.vgg = loss.item()
            step_loss += loss

        # ---------------------
        # -----  TV Loss  -----
        # ---------------------
        if lambdas.G.p.tv != 0:
            loss = self.losses["G"]["p"]["tv"](fake_flooded * m)
            loss *= lambdas.G.p.tv
            self.logger.losses.gen.p.tv = loss.item()
            step_loss += loss

        # --------------------------
        # -----  Context Loss  -----
        # --------------------------
        if lambdas.G.p.context != 0:
            loss = self.losses["G"]["p"]["context"](fake_flooded, x, m)
            loss *= lambdas.G.p.context
            self.logger.losses.gen.p.context = loss.item()
            step_loss += loss

        # ---------------------------------
        # -----  Reconstruction Loss  -----
        # ---------------------------------
        if lambdas.G.p.reconstruction != 0:
            loss = self.losses["G"]["p"]["reconstruction"](fake_flooded, x, m)
            loss *= lambdas.G.p.reconstruction
            self.logger.losses.gen.p.reconstruction = loss.item()
            step_loss += loss

        # -------------------------------------
        # -----  Local & Global GAN Loss  -----
        # -------------------------------------
        if self.opts.gen.p.diff_aug.use:
            fake_flooded = self.diff_transforms(fake_flooded)
            x = self.diff_transforms(x)

        if self.opts.dis.p.use_local_discriminator:
            fake_d_global = self.D["p"]["global"](fake_flooded)
            fake_d_local = self.D["p"]["local"](fake_flooded * m)

            real_d_global = self.D["p"]["global"](x)

            # Note: discriminator returns [out_1,...,out_num_D] outputs
            # Each out_i is a list [feat1, feat2, ..., pred_i]

            self.logger.losses.gen.p.gan = 0

            loss = self.losses["G"]["p"]["gan"](fake_d_global, True, False)
            loss += self.losses["G"]["p"]["gan"](fake_d_local, True, False)
            loss *= lambdas.G["p"]["gan"]

            self.logger.losses.gen.p.gan = loss.item()

            step_loss += loss

            # -----------------------------------
            # -----  Feature Matching Loss  -----
            # -----------------------------------
            # (only on global discriminator)
            # Order must be real, fake
            if self.opts.dis.p.get_intermediate_features:
                loss = self.losses["G"]["p"]["featmatch"](real_d_global, fake_d_global)
                loss *= lambdas.G["p"]["featmatch"]

                if isinstance(loss, float):
                    self.logger.losses.gen.p.featmatch = loss
                else:
                    self.logger.losses.gen.p.featmatch = loss.item()

                step_loss += loss

        # -------------------------------------------
        # -----  Single Discriminator GAN Loss  -----
        # -------------------------------------------
        else:
            real_cat = torch.cat([m, x], axis=1)
            fake_cat = torch.cat([m, fake_flooded], axis=1)
            real_fake_cat = torch.cat([real_cat, fake_cat], dim=0)

            real_fake_d = self.D["p"](real_fake_cat)
            real_d, fake_d = divide_pred(real_fake_d)

            loss = self.losses["G"]["p"]["gan"](fake_d, True, False)
            self.logger.losses.gen.p.gan = loss.item()
            step_loss += loss

            # -----------------------------------
            # -----  Feature Matching Loss  -----
            # -----------------------------------
            if self.opts.dis.p.get_intermediate_features and lambdas.G.p.featmatch != 0:
                loss = self.losses["G"]["p"]["featmatch"](real_d, fake_d)
                loss *= lambdas.G.p.featmatch

                if isinstance(loss, float):
                    self.logger.losses.gen.p.featmatch = loss
                else:
                    self.logger.losses.gen.p.featmatch = loss.item()

                step_loss += loss

        return step_loss

    def masker_d_loss(self, x, z, target, domain, for_="G"):
        assert for_ in {"G", "D"}
        self.assert_z_matches_x(x, z)
        assert x.shape[0] == target.shape[0]
        zero_loss = torch.tensor(0.0, device=self.device)
        weight = self.opts.train.lambdas.G.d.main

        prediction, z_depth = self.G.decoders["d"](z)

        if self.opts.gen.d.classify.enable:
            target.squeeze_(1)

        full_loss = self.losses["G"]["tasks"]["d"](prediction, target)
        full_loss *= weight

        if weight == 0 or (domain == "r" and "d" not in self.pseudo_training_tasks):
            return zero_loss, prediction, z_depth

        return full_loss, prediction, z_depth

    def masker_s_loss(self, x, z, depth_preds, z_depth, target, domain, for_="G"):
        assert for_ in {"G", "D"}
        assert domain in {"r", "s"}
        self.assert_z_matches_x(x, z)
        assert x.shape[0] == target.shape[0] if target is not None else True
        full_loss = torch.tensor(0.0, device=self.device)
        softmax_preds = None
        # --------------------------
        # -----  Segmentation  -----
        # --------------------------
        pred = None
        if for_ == "G" or self.opts.gen.s.use_advent:
            pred = self.G.decoders["s"](z, z_depth)

        # Supervised segmentation loss: crossent for sim domain,
        # crossent_pseudo for real ; loss is crossent in any case
        if for_ == "G":
            if domain == "s" or "s" in self.pseudo_training_tasks:
                if domain == "s":
                    logger = self.logger.losses.gen.task["s"]["crossent"]
                    weight = self.opts.train.lambdas.G["s"]["crossent"]
                else:
                    logger = self.logger.losses.gen.task["s"]["crossent_pseudo"]
                    weight = self.opts.train.lambdas.G["s"]["crossent_pseudo"]

                if weight != 0:
                    # Cross-Entropy loss
                    loss_func = self.losses["G"]["tasks"]["s"]["crossent"]
                    loss = loss_func(pred, target.squeeze(1))
                    loss *= weight
                    full_loss += loss
                    logger[domain] = loss.item()

            if domain == "r":
                weight = self.opts.train.lambdas.G["s"]["minent"]
                if self.opts.gen.s.use_minent and weight != 0:
                    softmax_preds = softmax(pred, dim=1)
                    # Entropy minimization loss
                    loss = self.losses["G"]["tasks"]["s"]["minent"](softmax_preds)
                    loss *= weight
                    full_loss += loss

                    self.logger.losses.gen.task["s"]["minent"]["r"] = loss.item()

        # Fool ADVENT discriminator
        if self.opts.gen.s.use_advent:
            if self.opts.gen.s.use_dada and depth_preds is not None:
                depth_preds = depth_preds.detach()
            else:
                depth_preds = None

            if for_ == "D":
                domain_label = domain
                logger = {}
                loss_func = self.losses["D"]["advent"]
                pred = pred.detach()
                weight = self.opts.train.lambdas.advent.adv_main
            else:
                domain_label = "s"
                logger = self.logger.losses.gen.task["s"]["advent"]
                loss_func = self.losses["G"]["tasks"]["s"]["advent"]
                weight = self.opts.train.lambdas.G["s"]["advent"]

            if (for_ == "D" or domain == "r") and weight != 0:
                if softmax_preds is None:
                    softmax_preds = softmax(pred, dim=1)
                loss = loss_func(
                    softmax_preds,
                    self.domain_labels[domain_label],
                    self.D["s"]["Advent"],
                    depth_preds,
                )
                loss *= weight
                full_loss += loss
                logger[domain] = loss.item()

                if for_ == "D":
                    # WGAN: clipping or GP
                    if self.opts.dis.s.gan_type == "GAN" or "WGAN_norm":
                        pass
                    elif self.opts.dis.s.gan_type == "WGAN":
                        for p in self.D["s"]["Advent"].parameters():
                            p.data.clamp_(
                                self.opts.dis.s.wgan_clamp_lower,
                                self.opts.dis.s.wgan_clamp_upper,
                            )
                    elif self.opts.dis.s.gan_type == "WGAN_gp":
                        prob_need_grad = autograd.Variable(pred, requires_grad=True)
                        d_out = self.D["s"]["Advent"](prob_need_grad)
                        gp = get_WGAN_gradient(prob_need_grad, d_out)
                        gp_loss = gp * self.opts.train.lambdas.advent.WGAN_gp
                        full_loss += gp_loss
                    else:
                        raise NotImplementedError

        return full_loss, pred

    def masker_m_loss(
        self, x, z, target, domain, for_="G", cond=None, z_depth=None, depth_preds=None
    ):
        assert for_ in {"G", "D"}
        assert domain in {"r", "s"}
        self.assert_z_matches_x(x, z)
        assert x.shape[0] == target.shape[0] if target is not None else True
        full_loss = torch.tensor(0.0, device=self.device)

        pred_logits = self.G.decoders["m"](z, cond=cond, z_depth=z_depth)
        pred_prob = sigmoid(pred_logits)
        pred_prob_complementary = 1 - pred_prob
        prob = torch.cat([pred_prob, pred_prob_complementary], dim=1)

        if for_ == "G":
            # TV loss
            weight = self.opts.train.lambdas.G.m.tv
            if weight != 0:
                loss = self.losses["G"]["tasks"]["m"]["tv"](pred_prob)
                loss *= weight
                full_loss += loss

                self.logger.losses.gen.task["m"]["tv"][domain] = loss.item()

            weight = self.opts.train.lambdas.G.m.bce
            if domain == "s" and weight != 0:
                # CrossEnt Loss
                loss = self.losses["G"]["tasks"]["m"]["bce"](pred_logits, target)
                loss *= weight
                full_loss += loss
                self.logger.losses.gen.task["m"]["bce"]["s"] = loss.item()

            if domain == "r":

                weight = self.opts.train.lambdas.G["m"]["gi"]
                if self.opts.gen.m.use_ground_intersection and weight != 0:
                    # GroundIntersection loss
                    loss = self.losses["G"]["tasks"]["m"]["gi"](pred_prob, target)
                    loss *= weight
                    full_loss += loss
                    self.logger.losses.gen.task["m"]["gi"]["r"] = loss.item()

                weight = self.opts.train.lambdas.G.m.pl4m
                if self.use_pl4m and weight != 0:
                    # Painter loss
                    pl4m_loss = self.painter_loss_for_masker(x, pred_prob)
                    pl4m_loss *= weight
                    full_loss += pl4m_loss
                    self.logger.losses.gen.task.m.pl4m.r = pl4m_loss.item()

                weight = self.opts.train.lambdas.advent.ent_main
                if self.opts.gen.m.use_minent and weight != 0:
                    # MinEnt loss
                    loss = self.losses["G"]["tasks"]["m"]["minent"](prob)
                    loss *= weight
                    full_loss += loss
                    self.logger.losses.gen.task["m"]["minent"]["r"] = loss.item()

        if self.opts.gen.m.use_advent:
            # AdvEnt loss
            if self.opts.gen.m.use_dada and depth_preds is not None:
                depth_preds = depth_preds.detach()
                depth_preds = torch.nn.functional.interpolate(
                    depth_preds, size=x.shape[-2:], mode="nearest"
                )
            else:
                depth_preds = None

            if for_ == "D":
                domain_label = domain
                logger = {}
                loss_func = self.losses["D"]["advent"]
                prob = prob.detach()
                weight = self.opts.train.lambdas.advent.adv_main
            else:
                domain_label = "s"
                logger = self.logger.losses.gen.task["m"]["advent"]
                loss_func = self.losses["G"]["tasks"]["m"]["advent"]
                weight = self.opts.train.lambdas.advent.adv_main

            if (for_ == "D" or domain == "r") and weight != 0:
                loss = loss_func(
                    prob.to(self.device),
                    self.domain_labels[domain_label],
                    self.D["m"]["Advent"],
                    depth_preds,
                )
                loss *= weight
                full_loss += loss
                logger[domain] = loss.item()

            if for_ == "D":
                # WGAN: clipping or GP
                if self.opts.dis.m.gan_type == "GAN" or "WGAN_norm":
                    pass
                elif self.opts.dis.m.gan_type == "WGAN":
                    for p in self.D["s"]["Advent"].parameters():
                        p.data.clamp_(
                            self.opts.dis.m.wgan_clamp_lower,
                            self.opts.dis.m.wgan_clamp_upper,
                        )
                elif self.opts.dis.m.gan_type == "WGAN_gp":
                    prob_need_grad = autograd.Variable(prob, requires_grad=True)
                    d_out = self.D["s"]["Advent"](prob_need_grad)
                    gp = get_WGAN_gradient(prob_need_grad, d_out)
                    gp_loss = self.opts.train.lambdas.advent.WGAN_gp * gp
                    full_loss += gp_loss
                else:
                    raise NotImplementedError

        return full_loss, prob

    def painter_loss_for_masker(self, x, m):
        # pl4m loss
        # painter should not be updated
        for param in self.G.painter.parameters():
            param.requires_grad = False
        # TODO for param in self.D.painter.parameters():
        #     param.requires_grad = False

        fake_flooded = self.G.paint(m, x)

        if self.opts.dis.p.use_local_discriminator:
            fake_d_global = self.D["p"]["global"](fake_flooded)
            fake_d_local = self.D["p"]["local"](fake_flooded * m)

            # Note: discriminator returns [out_1,...,out_num_D] outputs
            # Each out_i is a list [feat1, feat2, ..., pred_i]

            pl4m_loss = self.losses["G"]["p"]["gan"](fake_d_global, True, False)
            pl4m_loss += self.losses["G"]["p"]["gan"](fake_d_local, True, False)
        else:
            real_cat = torch.cat([m, x], axis=1)
            fake_cat = torch.cat([m, fake_flooded], axis=1)
            real_fake_cat = torch.cat([real_cat, fake_cat], dim=0)

            real_fake_d = self.D["p"](real_fake_cat)
            _, fake_d = divide_pred(real_fake_d)

            pl4m_loss = self.losses["G"]["p"]["gan"](fake_d, True, False)

        if "p" in self.opts.tasks:
            for param in self.G.painter.parameters():
                param.requires_grad = True

        return pl4m_loss

    @torch.no_grad()
    def run_evaluation(self, verbose=0):
        print("******************* Running Evaluation ***********************")
        start_time = time()
        self.eval_mode()
        val_logger = None
        nb_of_batches = None
        for i, multi_batch_tuple in enumerate(self.val_loaders):
            # create a dictionnary (domain => batch) from tuple
            # (batch_domain_0, ..., batch_domain_i)
            # and send it to self.device
            nb_of_batches = i + 1
            multi_domain_batch = {
                batch["domain"][0]: self.batch_to_device(batch)
                for batch in multi_batch_tuple
            }
            self.get_G_loss(multi_domain_batch, verbose)

            if val_logger is None:
                val_logger = deepcopy(self.logger.losses.generator)
            else:
                val_logger = sum_dict(val_logger, self.logger.losses.generator)

        val_logger = div_dict(val_logger, nb_of_batches)
        self.logger.losses.generator = val_logger
        self.logger.log_losses(model_to_update="G", mode="val")

        for d in self.opts.domains:
            self.logger.log_comet_images("train", d)
            self.logger.log_comet_images("val", d)

        if "m" in self.opts.tasks and self.has_painter and not self.kitti_pretrain:
            self.logger.log_comet_combined_images("train", "r")
            self.logger.log_comet_combined_images("val", "r")

        if self.exp is not None:
            print()

        if "m" in self.opts.tasks or "s" in self.opts.tasks:
            self.eval_images("val", "r")
            self.eval_images("val", "s")

        if "p" in self.opts.tasks and not self.kitti_pretrain:
            val_fid = compute_val_fid(self)
            if self.exp is not None:
                self.exp.log_metric("val_fid", val_fid, step=self.logger.global_step)
            else:
                print("Validation FID Score", val_fid)

        self.train_mode()
        timing = int(time() - start_time)
        print("****************** Done in {}s *********************".format(timing))

    def eval_images(self, mode, domain):
        if domain == "s" and self.kitti_pretrain:
            domain = "kitti"
        if domain == "rf" or domain not in self.display_images[mode]:
            return

        metric_funcs = {"accuracy": accuracy, "mIOU": mIOU}
        metric_avg_scores = {"m": {}}
        if "s" in self.opts.tasks:
            metric_avg_scores["s"] = {}
        if "d" in self.opts.tasks and domain == "s" and self.opts.gen.d.classify.enable:
            metric_avg_scores["d"] = {}

        for key in metric_funcs:
            for task in metric_avg_scores:
                metric_avg_scores[task][key] = []

        for im_set in self.display_images[mode][domain]:
            x = im_set["data"]["x"].unsqueeze(0).to(self.device)
            z = self.G.encode(x)

            s_pred = d_pred = z_depth = None

            if "d" in metric_avg_scores:
                d_pred, z_depth = self.G.decoders["d"](z)
                d_pred = d_pred.detach().cpu()

                if domain == "s":
                    d = im_set["data"]["d"].unsqueeze(0).detach()

                    for metric in metric_funcs:
                        metric_score = metric_funcs[metric](d_pred, d)
                        metric_avg_scores["d"][metric].append(metric_score)

            if "s" in metric_avg_scores:
                if z_depth is None:
                    if self.opts.gen.s.use_dada and "d" in self.opts.tasks:
                        _, z_depth = self.G.decoders["d"](z)
                s_pred = self.G.decoders["s"](z, z_depth).detach().cpu()
                s = im_set["data"]["s"].unsqueeze(0).detach()

                for metric in metric_funcs:
                    metric_score = metric_funcs[metric](s_pred, s)
                    metric_avg_scores["s"][metric].append(metric_score)

            if "m" in self.opts:
                cond = None
                if s_pred is not None and d_pred is not None:
                    cond = self.G.make_m_cond(d_pred, s_pred, x)
                if z_depth is None:
                    if self.opts.gen.m.use_dada and "d" in self.opts.tasks:
                        _, z_depth = self.G.decoders["d"](z)

                pred_mask = (
                    (self.G.mask(z=z, cond=cond, z_depth=z_depth)).detach().cpu()
                )
                pred_mask = (pred_mask > 0.5).to(torch.float32)
                pred_prob = torch.cat([1 - pred_mask, pred_mask], dim=1)

                m = im_set["data"]["m"].unsqueeze(0).detach()

                for metric in metric_funcs:
                    if metric != "mIOU":
                        metric_score = metric_funcs[metric](pred_mask, m)
                    else:
                        metric_score = metric_funcs[metric](pred_prob, m)

                    metric_avg_scores["m"][metric].append(metric_score)

        metric_avg_scores = {
            task: {
                metric: np.mean(values) if values else float("nan")
                for metric, values in met_dict.items()
            }
            for task, met_dict in metric_avg_scores.items()
        }
        metric_avg_scores = {
            task: {
                metric: value if not np.isnan(value) else -1
                for metric, value in met_dict.items()
            }
            for task, met_dict in metric_avg_scores.items()
        }
        if self.exp is not None:
            self.exp.log_metrics(
                flatten_opts(metric_avg_scores),
                prefix=f"metrics_{mode}_{domain}",
                step=self.logger.global_step,
            )
        else:
            print(f"metrics_{mode}_{domain}")
            print(flatten_opts(metric_avg_scores))

        return 0

    def functional_test_mode(self):
        import atexit

        self.opts.output_path = (
            Path("~").expanduser() / "climategan" / "functional_tests"
        )
        Path(self.opts.output_path).mkdir(parents=True, exist_ok=True)
        with open(Path(self.opts.output_path) / "is_functional.test", "w") as f:
            f.write("trainer functional test - delete this dir")

        if self.exp is not None:
            self.exp.log_parameter("is_functional_test", True)
        atexit.register(self.del_output_path)

    def del_output_path(self, force=False):
        import shutil

        if not Path(self.opts.output_path).exists():
            return

        if (Path(self.opts.output_path) / "is_functional.test").exists() or force:
            shutil.rmtree(self.opts.output_path)

    def compute_flood(self, x, z=None, m=None, s=None, cloudy=None, bin_value=-1):
        """
        Applies a flood (mask + paint) to an input image, with optionally
        pre-computed masker z or mask

        Args:
            x (torch.Tensor): B x C x H x W -1:1 input image
            z (torch.Tensor, optional): B x C x H x W Masker latent vector.
                Defaults to None.
            m (torch.Tensor, optional): B x 1 x H x W Mask. Defaults to None.
            bin_value (float, optional): Mask binarization value.
                Set to -1 to use smooth masks (no binarization)

        Returns:
            torch.Tensor: B x 3 x H x W -1:1 flooded image
        """

        if m is None:
            z_depth = None
            if z is None:
                z = self.G.encode(x)
            if "d" in self.opts.tasks and self.opts.gen.m.use_dada:
                _, z_depth = self.G.decoders["d"](z)
            m = self.G.mask(x=x, z=z, z_depth=z_depth)

        if bin_value >= 0:
            m = (m > bin_value).to(m.dtype)

        if cloudy:
            assert s is not None
            return self.G.paint_cloudy(m, x, s)

        return self.G.paint(m, x)
