'''
This file contains the implementation of the topological and geometric regularizers.
'''

from modules.loss_utils import *


def topo_loss(model, x):
    """
    Computes the topological loss for a given model and input data.
    The function encodes the input data using the provided model, optionally adds a time feature,
    and then computes the Euclidean distance matrices in both the original and latent spaces.
    It then calculates the topological signature distance between these distance matrices and
    normalizes the result according to the batch size.
    Args:
        model (torch.nn.Module): The model used to encode the input data. It should have an `encode` method
                                 and an `encode_args` attribute for encoding arguments. Additionally, it may
                                 have an `add_time_feature` attribute to indicate whether to add a time feature.
        x (torch.Tensor): The input data tensor of shape (B, N, D), where B is the batch size, N is the number
                          of instances, and D is the dimensionality of each instance.
    Returns:
        torch.Tensor: The normalized topological error, a scalar tensor representing the topological loss.
    Raises:
        ValueError: If the input data tensor `x` is not of the expected shape.
    """
    # encode using model
    latent = model.encode(x, **model.encode_args)
    if model.add_time_feature:
        # add a time feature to each instance in x
        x = add_time_feature(x)

    # compute and normalize distances in the original sapce and latent space
    x_distances = topo_euclidean_distance_matrix(x) # (B, N, N)
    x_distances = x_distances / x_distances.max()
    latent_distances = topo_euclidean_distance_matrix(latent) # (B, N, N)
    latent_distances = latent_distances / latent_distances.max()

    # compute topological signature distance
    topo_sig = TopologicalSignatureDistance()
    topo_error = topo_sig(x_distances, latent_distances)

    # normalize topo_error according to batch_size
    batch_size = x.size()[0]
    topo_error = topo_error / float(batch_size)

    return topo_error


def geo_loss(model, x, bandwidth):
    """
    Computes the geometric loss for a given model and input data.
    Parameters:
    model (torch.nn.Module): The model used to encode the input data.
    x (torch.Tensor): The input data tensor.
    bandwidth (float): The bandwidth parameter for the Laplacian computation.
    Returns:
    torch.Tensor: The computed geometric loss.
    Notes:
    - If the model's loader attribute is 'MicroTraffic', the input tensor x is permuted to match the shape 
      (n_samples, n_timesteps, n_agents, n_features).
    - If the model has an add_time_feature attribute set to True, a time feature is added to each instance in x.
    - The function computes the Laplacian matrix L of the input data x using the specified bandwidth.
    - The function then computes H_tilde using the JGinvJT transformation of L and the latent representation.
    - Finally, the function computes the relaxed distortion measure JGinvJT of H_tilde as the geometric loss.
    """
    # encode using model
    latent = model.encode(x)
    # Switch axes to (n_samples, n_timesteps, n_agents, n_features) for MicroTraffic data
    if model.loader == 'MicroTraffic':
        x = x.permute(0, 2, 1, 3)
    if model.add_time_feature:
        # add a time feature to each instance in x
        x = add_time_feature(x)
    
    L = get_laplacian(x, bandwidth=bandwidth)
    H_tilde = get_JGinvJT(L, latent)
    iso_loss = relaxed_distortion_measure_JGinvJT(H_tilde)

    return iso_loss