import os
import torch
import torch.nn as nn
import util
import math

from itertools import chain
from models.real_nvp import RealNVP, RealNVPLoss

class VectorDiscriminator(nn.Module):
    def __init__(self, input_dim, hidden_dim=128):
        super(VectorDiscriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(0.2),  
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, 1),  
            nn.Sigmoid()  
        )

    def forward(self, x):
        return self.model(x)
    
class MLPforGraphPredict(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, output_dim=1):
        super(MLPforGraphPredict, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),           
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),            
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.net(x)

class Flow2Flow(nn.Module):
    """Flow2Flow Model

    Normalizing flows for unpaired image-to-image translation.
    Uses two normalizing flow models (RealNVP) for the generators,
    and two PatchGAN discriminators. The generators map to a shared
    intermediate latent space `Z` with simple prior `p_Z`, and the
    whole system optimizes a hybrid GAN-MLE objective.
    """
    def __init__(self, args):
        """
        Args:
            args: Configuration args passed in via the command line.
        """
        super(Flow2Flow, self).__init__()
        self.device = 'cuda' if len(args.gpu_ids) > 0 else 'cpu'
        self.gpu_ids = args.gpu_ids
        self.is_training = args.is_training

        self.in_channels = args.num_channels
        self.out_channels = 4 ** (args.num_scales - 1) * self.in_channels

        # Set up RealNVP generators (g_src: X <-> Z, g_tgt: Y <-> Z)
        self.g_src = RealNVP(input_dim=args.input_dim, 
                     hidden_dim=args.hidden_dim,
                     num_coupling_layers=args.num_coupling_layers)
        util.init_model(self.g_src, init_method=args.initializer)

        self.g_tgt = RealNVP(input_dim=args.input_dim, 
                            hidden_dim=args.hidden_dim,
                            num_coupling_layers=args.num_coupling_layers)
        util.init_model(self.g_tgt, init_method=args.initializer)

        self.predictor = MLPforGraphPredict(
            input_dim=args.input_dim,       
            hidden_dim=128,                
            output_dim=1                  
        ).to(self.device)

        self.labeled_data_a = []
        self.labeled_data_b = []
        self.labeled_data_a_val = []
        self.labeled_data_b_val = []

        if self.is_training:
            # Set up discriminators
            self.d_tgt = VectorDiscriminator(input_dim=args.num_channels, hidden_dim=128)
            self.d_src = VectorDiscriminator(input_dim=args.num_channels, hidden_dim=128)

            self._data_parallel()

            # Set up loss functions
            self.max_grad_norm = args.clip_gradient
            self.lambda_mle = args.lambda_mle
            self.mle_loss_fn = RealNVPLoss()
            self.gan_loss_fn = util.GANLoss(device=self.device, use_least_squares=True)

            self.clamp_jacobian = args.clamp_jacobian
            self.jc_loss_fn = util.JacobianClampingLoss(args.jc_lambda_min, args.jc_lambda_max)

            all_params = chain(self.g_src.parameters(), self.g_tgt.parameters(), self.predictor.parameters())
            self.opt_g = torch.optim.Adam(
                all_params, 
                lr=args.rnvp_lr,
                betas=(args.rnvp_beta_1, args.rnvp_beta_2)
            )

            self.opt_d = torch.optim.Adam(
                chain(self.d_tgt.parameters(), self.d_src.parameters()),
                lr=args.lr,
                betas=(args.beta_1, args.beta_2)
            )
            self.opt_pred = torch.optim.Adam(
                self.predictor.parameters(),
                lr=1e-3,  
                betas=(0.9, 0.999)  
            )
            self.optimizers = [self.opt_g, self.opt_d, self.opt_pred]
            self.schedulers = [util.get_lr_scheduler(opt, args) for opt in self.optimizers]

            # Setup image mixers
            buffer_capacity = 50 if args.use_mixer else 0
            self.src2tgt_buffer = util.ImageBuffer(buffer_capacity)  # Buffer of generated tgt images
            self.tgt2src_buffer = util.ImageBuffer(buffer_capacity)  # Buffer of generated src images
        else:
            self._data_parallel()

        # Images in flow src -> lat -> tgt
        self.src = None
        self.src2lat = None
        self.src2tgt = None

        # Images in flow tgt -> lat -> src
        self.tgt = None
        self.tgt2lat = None
        self.tgt2src = None

        # Jacobian clamping tensors
        self.src_jc = None
        self.tgt_jc = None
        self.src2tgt_jc = None
        self.tgt2src_jc = None

        # Discriminator loss
        self.loss_d_tgt = None
        self.loss_d_src = None
        self.loss_d = None

        # Generator GAN loss
        self.loss_gan_src = None
        self.loss_gan_tgt = None
        self.loss_gan = None

        # Generator MLE loss
        self.loss_mle_src = None
        self.loss_mle_tgt = None
        self.loss_mle = None

        # Jacobian Clamping loss
        self.loss_jc_src = None
        self.loss_jc_tgt = None
        self.loss_jc = None

        # Generator total loss
        self.loss_g = None

    def set_inputs(self, src_input, tgt_input=None):
        """Set the inputs prior to a forward pass through the network.

        Args:
            src_input: Tensor with src input
            tgt_input: Tensor with tgt input
        """
        self.src = src_input.to(self.device)
        if tgt_input is not None:
            self.tgt = tgt_input.to(self.device)

    def forward(self):
        """No-op. We do the forward pass in `backward_g`."""
        pass

    def test(self):
        """Run a forward pass through the generator for test inference.
        Used during test inference only, as this throws away forward-pass values,
        which would be needed for backprop.

        Important: Call `set_inputs` prior to each successive call to `test`.
        """
        # Disable auto-grad because we will not backprop
        with torch.no_grad():
            src2lat, _ = self.g_src(self.src, reverse=False)
            src2lat2tgt, _ = self.g_tgt(src2lat, reverse=True)
            # self.src2tgt = torch.tanh(src2lat2tgt)
            self.src2tgt = src2lat2tgt

            tgt2lat, _ = self.g_tgt(self.tgt, reverse=False)
            tgt2lat2src, _ = self.g_src(tgt2lat, reverse=True)
            # self.tgt2src = torch.tanh(tgt2lat2src)
            self.tgt2src = tgt2lat2src

    def _forward_d(self, d, real_img, fake_img):
        """Forward  pass for one discriminator."""

        # Forward on real and fake images (detach fake to avoid backprop into generators)
        loss_real = self.gan_loss_fn(d(real_img), is_tgt_real=True)
        loss_fake = self.gan_loss_fn(d(fake_img.detach()), is_tgt_real=False)
        loss_d = 0.5 * (loss_real + loss_fake)

        return loss_d

    def backward_d(self):
        # Forward tgt discriminator
        src2tgt = self.src2tgt_buffer.sample(self.src2tgt)
        self.loss_d_tgt = self._forward_d(self.d_tgt, self.tgt, src2tgt)

        # Forward src discriminator
        tgt2src = self.tgt2src_buffer.sample(self.tgt2src)
        self.loss_d_src = self._forward_d(self.d_src, self.src, tgt2src)

        # Backprop
        self.loss_d = self.loss_d_tgt + self.loss_d_src
        self.loss_d.backward()

    def backward_g(self):
        # ========== 1) Forward (Flow2Flow) ==========
        if self.clamp_jacobian:
            self._jc_preprocess()

        self.src2lat, sldj_src2lat = self.g_src(self.src, reverse=False)
        self.loss_mle_src = self.lambda_mle * self.mle_loss_fn(self.src2lat, sldj_src2lat)

        src2tgt, _ = self.g_tgt(self.src2lat, reverse=True)
        self.src2tgt = src2tgt

        self.tgt2lat, sldj_tgt2lat = self.g_tgt(self.tgt, reverse=False)
        self.loss_mle_tgt = self.lambda_mle * self.mle_loss_fn(self.tgt2lat, sldj_tgt2lat)

        tgt2src, _ = self.g_src(self.tgt2lat, reverse=True)
        self.tgt2src = tgt2src

        if self.clamp_jacobian:
            self._jc_postprocess()
            self.loss_jc_src = self.jc_loss_fn(self.src2tgt, self.src2tgt_jc, self.src, self.src_jc)
            self.loss_jc_tgt = self.jc_loss_fn(self.tgt2src, self.tgt2src_jc, self.tgt, self.tgt_jc)
            self.loss_jc = self.loss_jc_src + self.loss_jc_tgt
        else:
            self.loss_jc = 0.

        # GAN loss
        self.loss_gan_src = self.gan_loss_fn(self.d_tgt(self.src2tgt), is_tgt_real=True)
        self.loss_gan_tgt = self.gan_loss_fn(self.d_src(self.tgt2src), is_tgt_real=True)
        self.loss_gan = self.loss_gan_src + self.loss_gan_tgt

        self.loss_mle = self.loss_mle_src + self.loss_mle_tgt
        flow2flow_loss = self.loss_gan + self.loss_mle + self.loss_jc


        pred_loss = self._compute_predictor_loss_per_batch(sample_size=10)
       
        self.opt_g.zero_grad()
        self.opt_pred.zero_grad()  
        flow2flow_loss.backward(retain_graph=True)  
        self.opt_g.step()

        self.opt_g.zero_grad()
        self.opt_pred.zero_grad()
        pred_loss.backward()
        self.opt_g.step()
        self.opt_pred.step()

    def _compute_predictor_loss_per_batch(self, sample_size=10):
        device = self.device
        total_loss = 0.0
        mse = nn.MSELoss()  
        count = 0

        # ---- Domain A ----
        if len(self.labeled_data_a) > 0:
            n_a = min(sample_size, len(self.labeled_data_a))
            idx_a = torch.randperm(len(self.labeled_data_a))[:n_a]
            batch_a = [self.labeled_data_a[i] for i in idx_a]
            x_a = torch.stack([item[0] for item in batch_a]).to(device)
            y_a = torch.stack([item[1] for item in batch_a]).to(device)

            z_a, _ = self.g_src(x_a, reverse=False)
            pred_a = self.predictor(z_a)
            loss_a = mse(pred_a, y_a)
            total_loss += loss_a
            count += 1

        # ---- Domain B ----
        if len(self.labeled_data_b) > 0:
            n_b = min(sample_size, len(self.labeled_data_b))
            idx_b = torch.randperm(len(self.labeled_data_b))[:n_b]
            batch_b = [self.labeled_data_b[i] for i in idx_b]
            x_b = torch.stack([item[0] for item in batch_b]).to(device)
            y_b = torch.stack([item[1] for item in batch_b]).to(device)

            z_b, _ = self.g_tgt(x_b, reverse=False)
            pred_b = self.predictor(z_b)
            loss_b = mse(pred_b, y_b)
            total_loss += loss_b 
            count += 1

        if count > 0:
            return total_loss / count
        else:
            return torch.zeros(1, dtype=torch.float32, device=device, requires_grad=True)

    def add_labeled_data_a(self, data_x, data_y):
        """
        data_x: shape [D] or [batch, D]
        data_y: shape [1] or [batch, 1]
        可自行保證 data_x, data_y 已是 CPU tensor or numpy
        """
        if data_x.dim() == 1:
            self.labeled_data_a.append((data_x, data_y))
        else:
            for x_i, y_i in zip(data_x, data_y):
                self.labeled_data_a.append((x_i, y_i))
    
    def add_labeled_data_b(self, data_x, data_y):
        if data_x.dim() == 1:
            self.labeled_data_b.append((data_x, data_y))
        else:
            for x_i, y_i in zip(data_x, data_y):
                self.labeled_data_b.append((x_i, y_i))

    def add_labeled_data_a_val(self, data_x, data_y):
        if data_x.dim() == 1:
            self.labeled_data_a_val.append((data_x, data_y))
        else:
            for x_i, y_i in zip(data_x, data_y):
                self.labeled_data_a_val.append((x_i, y_i))

    def add_labeled_data_b_val(self, data_x, data_y):
        if data_x.dim() == 1:
            self.labeled_data_b_val.append((data_x, data_y))
        else:
            for x_i, y_i in zip(data_x, data_y):
                self.labeled_data_b_val.append((x_i, y_i))

    def validate_predictor_loss(self):
        device = self.device
        mse = nn.MSELoss()
        total_loss = 0.0
        count = 0

        self.eval()
        with torch.no_grad():
            # Domain A val
            if len(self.labeled_data_a_val) > 0:
                data_x_a = torch.stack([item[0] for item in self.labeled_data_a_val]).to(device)
                data_y_a = torch.stack([item[1] for item in self.labeled_data_a_val]).to(device)
                z_a, _ = self.g_src(data_x_a, reverse=False)
                pred_a = self.predictor(z_a)
                loss_a = mse(pred_a, data_y_a)
                total_loss += loss_a.item()
                count += 1

            # Domain B val
            if len(self.labeled_data_b_val) > 0:
                data_x_b = torch.stack([item[0] for item in self.labeled_data_b_val]).to(device)
                data_y_b = torch.stack([item[1] for item in self.labeled_data_b_val]).to(device)
                z_b, _ = self.g_tgt(data_x_b, reverse=False)
                pred_b = self.predictor(z_b)
                loss_b = mse(pred_b, data_y_b)
                total_loss += loss_b.item() 
                count += 1

        self.train()
        if count == 0:
            return 0.0
        return total_loss / count
    
    def backward_g_flow2flow(self):
        if self.clamp_jacobian:
            self._jc_preprocess()

        self.src2lat, sldj_src2lat = self.g_src(self.src, reverse=False)
        self.loss_mle_src = self.lambda_mle * self.mle_loss_fn(self.src2lat, sldj_src2lat)

        src2tgt, _ = self.g_tgt(self.src2lat, reverse=True)
        self.src2tgt = src2tgt

        self.tgt2lat, sldj_tgt2lat = self.g_tgt(self.tgt, reverse=False)
        self.loss_mle_tgt = self.lambda_mle * self.mle_loss_fn(self.tgt2lat, sldj_tgt2lat)

        tgt2src, _ = self.g_src(self.tgt2lat, reverse=True)
        self.tgt2src = tgt2src

        if self.clamp_jacobian:
            self._jc_postprocess()
            self.loss_jc_src = self.jc_loss_fn(self.src2tgt, self.src2tgt_jc, self.src, self.src_jc)
            self.loss_jc_tgt = self.jc_loss_fn(self.tgt2src, self.tgt2src_jc, self.tgt, self.tgt_jc)
            self.loss_jc = self.loss_jc_src + self.loss_jc_tgt
        else:
            self.loss_jc = 0.

        # GAN loss
        self.loss_gan_src = self.gan_loss_fn(self.d_tgt(self.src2tgt), is_tgt_real=True)
        self.loss_gan_tgt = self.gan_loss_fn(self.d_src(self.tgt2src), is_tgt_real=True)
        self.loss_gan = self.loss_gan_src + self.loss_gan_tgt

        self.loss_mle = self.loss_mle_src + self.loss_mle_tgt

        self.loss_flow2flow = self.loss_gan + self.loss_mle + self.loss_jc

        self.loss_flow2flow.backward(retain_graph=True)  


    def backward_g_predictor(self):
        pred_loss = self._compute_predictor_loss_per_batch(sample_size=10)
        pred_loss.backward() 


    def train_iter(self):
        self.forward()

        # -----------------------
        # (A) Flow2Flow backward
        # -----------------------
        self.opt_g.zero_grad()
        self.opt_pred.zero_grad()  
        self.backward_g_flow2flow()
        # clip
        util.clip_grad_norm(self.opt_g, self.max_grad_norm)
        self.opt_g.step()  

        # (B) Predictor backward
        self.opt_g.zero_grad()
        self.opt_pred.zero_grad()
        self.backward_g_predictor()
        util.clip_grad_norm(self.opt_g, self.max_grad_norm)
        util.clip_grad_norm(self.opt_pred, self.max_grad_norm)
        self.opt_g.step()
        self.opt_pred.step()

        # -----------------------
        # (C) Discriminator backward
        # -----------------------
        self.opt_d.zero_grad()
        self.backward_d()
        util.clip_grad_norm(self.opt_d, self.max_grad_norm)
        self.opt_d.step()

    def train_iter_alignflow(self):
        # 1) generator
        self.opt_g.zero_grad()
        self.backward_g_flow2flow()  
        util.clip_grad_norm(self.opt_g, self.max_grad_norm)
        self.opt_g.step()

        # 2) discriminator
        self.opt_d.zero_grad()
        self.backward_d()
        util.clip_grad_norm(self.opt_d, self.max_grad_norm)
        self.opt_d.step()

    def train_iter_all_one_epoch(self, batch_size=32):

        device = self.device
        self.train()  

        data_a = self.labeled_data_a  # [(x_a, y_a), ...]
        data_b = self.labeled_data_b  # [(x_b, y_b), ...]

        size_a = len(data_a)
        size_b = len(data_b)

        if size_a == 0 and size_b == 0:
            print("No labeled data in domain A / B, skip predictor training.")
            return

        max_size = max(size_a, size_b)
        num_batches = math.ceil(max_size / batch_size)

        idx_a_shuffled = torch.randperm(size_a) if size_a > 0 else []
        idx_b_shuffled = torch.randperm(size_b) if size_b > 0 else []

        mse = nn.MSELoss()

        for i in range(num_batches):
            start_idx = i * batch_size
            end_idx = min(start_idx + batch_size, max_size)

            # --- Domain A batch ---
            a_indices = []
            for idx in range(start_idx, end_idx):
                if size_a > 0:
                    real_idx = idx % size_a  # oversample if smaller
                    a_indices.append(idx_a_shuffled[real_idx].item())

            if len(a_indices) > 0:
                x_a_list = []
                y_a_list = []
                for j in a_indices:
                    x_a_list.append(data_a[j][0])
                    y_a_list.append(data_a[j][1])
                xa_tensor = torch.stack(x_a_list).to(device)
                ya_tensor = torch.stack(y_a_list).to(device)
            else:
                xa_tensor = None
                ya_tensor = None

            # --- Domain B batch ---
            b_indices = []
            for idx in range(start_idx, end_idx):
                if size_b > 0:
                    real_idx = idx % size_b
                    b_indices.append(idx_b_shuffled[real_idx].item())

            if len(b_indices) > 0:
                x_b_list = []
                y_b_list = []
                for j in b_indices:
                    x_b_list.append(data_b[j][0])
                    y_b_list.append(data_b[j][1])
                xb_tensor = torch.stack(x_b_list).to(device)
                yb_tensor = torch.stack(y_b_list).to(device)
            else:
                xb_tensor = None
                yb_tensor = None

            self.opt_g.zero_grad()
            if hasattr(self, "opt_pred"):
                self.opt_pred.zero_grad()

            total_loss = 0.0
            batch_count = 0

            # Domain A
            if xa_tensor is not None and xa_tensor.size(0) > 0:
                za, _ = self.g_src(xa_tensor, reverse=False)  # A->Z
                pred_a = self.predictor(za)
                loss_a = mse(pred_a, ya_tensor)
                total_loss += loss_a
                batch_count += 1

            # Domain B
            if xb_tensor is not None and xb_tensor.size(0) > 0:
                zb, _ = self.g_tgt(xb_tensor, reverse=False)  # B->Z
                pred_b = self.predictor(zb)
                loss_b = mse(pred_b, yb_tensor)
                total_loss += loss_b
                batch_count += 1

            if batch_count > 0:
                avg_loss = total_loss / batch_count
                avg_loss.backward()
                self.opt_g.step()
                if hasattr(self, "opt_pred"):
                    self.opt_pred.step()

        print(f"Finished 1 epoch of predictor training with {num_batches} batches (batch_size={batch_size}).")

    def compute_losses_no_grad(self):
        with torch.no_grad():
            if self.clamp_jacobian:
                self._jc_preprocess()
            
            self.src2lat, sldj_src2lat = self.g_src(self.src, reverse=False)
            self.loss_mle_src = self.lambda_mle * self.mle_loss_fn(self.src2lat, sldj_src2lat)
            
            src2tgt, _ = self.g_tgt(self.src2lat, reverse=True)
            self.src2tgt = src2tgt

            self.tgt2lat, sldj_tgt2lat = self.g_tgt(self.tgt, reverse=False)
            self.loss_mle_tgt = self.lambda_mle * self.mle_loss_fn(self.tgt2lat, sldj_tgt2lat)

            # tgt -> Z -> src
            tgt2src, _ = self.g_src(self.tgt2lat, reverse=True)
            self.tgt2src = tgt2src

            if self.clamp_jacobian:
                self._jc_postprocess()
                self.loss_jc_src = self.jc_loss_fn(self.src2tgt, self.src2tgt_jc, self.src, self.src_jc)
                self.loss_jc_tgt = self.jc_loss_fn(self.tgt2src, self.tgt2src_jc, self.tgt, self.tgt_jc)
                self.loss_jc = self.loss_jc_src + self.loss_jc_tgt
            else:
                self.loss_jc_src = self.loss_jc_tgt = self.loss_jc = 0.

            # GAN Loss
            self.loss_gan_src = self.gan_loss_fn(self.d_tgt(self.src2tgt), is_tgt_real=True)
            self.loss_gan_tgt = self.gan_loss_fn(self.d_src(self.tgt2src), is_tgt_real=True)
            self.loss_gan = self.loss_gan_src + self.loss_gan_tgt

            self.loss_mle = self.loss_mle_src + self.loss_mle_tgt
            self.loss_g = self.loss_gan + self.loss_mle + self.loss_jc

            # D_tgt
            src2tgt_detached = self.src2tgt.detach()
            loss_real_tgt = self.gan_loss_fn(self.d_tgt(self.tgt), is_tgt_real=True)
            loss_fake_tgt = self.gan_loss_fn(self.d_tgt(src2tgt_detached), is_tgt_real=False)
            self.loss_d_tgt = 0.5 * (loss_real_tgt + loss_fake_tgt)

            # D_src
            tgt2src_detached = self.tgt2src.detach()
            loss_real_src = self.gan_loss_fn(self.d_src(self.src), is_tgt_real=True)
            loss_fake_src = self.gan_loss_fn(self.d_src(tgt2src_detached), is_tgt_real=False)
            self.loss_d_src = 0.5 * (loss_real_src + loss_fake_src)

            self.loss_d = self.loss_d_tgt + self.loss_d_src

        return self.get_loss_dict()

    def get_loss_dict(self):
        """Get a dictionary of current errors for the model."""
        loss_dict = {
            # Generator loss
            'loss_gan': self.loss_gan,
            'loss_jc': self.loss_jc,
            'loss_mle': self.loss_mle,
            'loss_g': self.loss_g,

            # Discriminator loss
            'loss_d_src': self.loss_d_src,
            'loss_d_tgt': self.loss_d_tgt,
            'loss_d': self.loss_d
        }

        # Map scalars to floats for interpretation outside of the model
        loss_dict = {k: v.item() for k, v in loss_dict.items()
                     if isinstance(v, torch.Tensor)}

        return loss_dict

    def get_image_dict(self):
        """Get a dictionary of current images (src, tgt_real, tgt_fake) for the model.

        Returns: Dictionary containing numpy arrays of shape (batch_size, num_channels, height, width).
        Keys: {src, src2tgt, tgt2src}.
        """
        image_tensor_dict = {'src': self.src,
                             'src2tgt': self.src2tgt}

        if self.is_training:
            # When training, include full cycles
            image_tensor_dict.update({
                'tgt': self.tgt,
                'tgt2src': self.tgt2src
            })

        image_dict = {k: util.un_normalize(v) for k, v in image_tensor_dict.items()}

        return image_dict

    def on_epoch_end(self):
        """Callback for end of epoch.

        Update the learning rate by stepping the LR schedulers.
        """
        for scheduler in self.schedulers:
            scheduler.step()

    def get_learning_rate(self):
        """Get the current learning rate"""
        return self.optimizers[0].param_groups[0]['lr']

    def _data_parallel(self):
        self.g_src = nn.DataParallel(self.g_src, self.gpu_ids).to(self.device)
        self.g_tgt = nn.DataParallel(self.g_tgt, self.gpu_ids).to(self.device)
        if self.is_training:
            self.d_src = nn.DataParallel(self.d_src, self.gpu_ids).to(self.device)
            self.d_tgt = nn.DataParallel(self.d_tgt, self.gpu_ids).to(self.device)

    def _jc_preprocess(self):
        """Pre-process inputs for Jacobian Clamping. Doubles batch size.

        See Also:
            Algorithm 1 from https://arxiv.org/1802.08768v2
        """
        delta = torch.randn_like(self.src)
        src_jc = self.src + delta / delta.norm()
        src_jc.clamp_(-1, 1)
        self.src = torch.cat((self.src, src_jc), dim=0)

        delta = torch.randn_like(self.tgt)
        tgt_jc = self.tgt + delta / delta.norm()
        tgt_jc.clamp_(-1, 1)
        self.tgt = torch.cat((self.tgt, tgt_jc), dim=0)

    def _jc_preprocess(self):
        """Pre-process inputs for Jacobian Clamping. Doubles batch size."""
        delta = torch.randn_like(self.src)
        src_jc = self.src + delta / delta.norm(dim=1, keepdim=True)
        src_jc = src_jc.clamp(-1, 1) 
        self.src = torch.cat((self.src, src_jc), dim=0)

        delta = torch.randn_like(self.tgt)
        tgt_jc = self.tgt + delta / delta.norm(dim=1, keepdim=True)
        tgt_jc = tgt_jc.clamp(-1, 1)
        self.tgt = torch.cat((self.tgt, tgt_jc), dim=0)

    def _jc_postprocess(self):
    
        self.src, self.src_jc = self.src.chunk(2, dim=0)
        self.tgt, self.tgt_jc = self.tgt.chunk(2, dim=0)
        self.src2tgt, self.src2tgt_jc = self.src2tgt.chunk(2, dim=0)
        self.tgt2src, self.tgt2src_jc = self.tgt2src.chunk(2, dim=0)