"""
This module provides computational tools for exploring the geometric properties 
of embeddings generated by neural networks. 
It contains functions to calculate the Jacobian matrix, compute pullback metrics, 
and explore equivalence classes in embedding spaces. 
This module leverages PyTorch for tensor operations and gradient computations.
"""

import os
import pathlib
import time
from typing import List
from collections import defaultdict
import pickle
from numpy import around
import torch


def save_object(obj, filename):
    with open(filename, "wb") as outp:  # Overwrites any existing file.
        pickle.dump(obj, outp, pickle.HIGHEST_PROTOCOL)


def jacobian(nn_output: torch.Tensor, nn_input: torch.Tensor):
    """
    Computes the full Jacobian matrix of the neural network output with respect
    to its input.

    Args:
        nn_output (torch.Tensor): The output tensor of a neural network where
        each element depends on the input tensor and has gradients enabled.
        nn_input (torch.Tensor): The input tensor to the neural network with
        gradients enabled.

    Returns:
        torch.Tensor: A tensor representing the Jacobian matrix. The dimensions
        of the matrix will be [len(nn_output), len(nn_input)], reflecting the
        partial derivatives of each output element with respect to each input element.
    """

    return torch.stack(
        [
            torch.autograd.grad([nn_output[i]], nn_input, retain_graph=True)[0]
            for i in range(nn_output.size(0))
        ],
        dim=-1,
    )[0].detach()


def pullback(
    input_simec: torch.Tensor,
    output_simec: torch.Tensor,
    g: torch.Tensor,
    eq_class_emb_ids: List[int] = None,
):
    """
    Computes the pullback metric tensor using the given input and output embeddings and a metric tensor g.

    Args:
        input_simec (torch.Tensor): Input embeddings tensor.
        output_simec (torch.Tensor): Output embeddings tensor derived from the
        input embeddings.
        g (torch.Tensor): Metric tensor g used as the Riemannian metric in the
        output space.
        eq_class_emb_ids (List[int], optional): Indices of embeddings to be
        considered for the pullback. If provided, restricts the computation to
        these embeddings.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: Eigenvalues and eigenvectors of the
        pullback metric tensor.
    """
    jac = jacobian(output_simec, input_simec)
    if eq_class_emb_ids:
        jac = jac[eq_class_emb_ids]
    jac_t = torch.transpose(jac, -2, -1)
    tmp = torch.bmm(jac, g)
    pullback_metric = torch.bmm(tmp, jac_t).type(torch.double)
    return torch.linalg.eigh(pullback_metric, UPLO="U")


def pullback_eigenvalues(
    input_embedding: torch.Tensor,
    model: torch.nn.Module,
    pred_id: int,
    device: torch.device,
    keep_timing: bool = True,
    out_dir: str = ".",
):
    """
    Calculates the eigenvalues of the pullback metric tensor derived from a given
    model's embeddings.

    Args:
        input_embedding (torch.Tensor): The input embedding tensor.
        model (torch.nn.Module): Neural network model that produces output
        embeddings from the input embeddings.
        pred_id (int): Index of the prediction to be considered for the pullback
        calculation.
        device (torch.device): Device to perform the computation on.
        keep_timing (bool, optional): Flag to determine whether timing data
        should be saved. Defaults to False.
        out_dir (str, optional): Directory where timing and eigenvalues data
        will be saved if `keep_timing` is True.

    Returns:
        torch.Tensor: Eigenvalues of the pullback metric.
    """

    if keep_timing:
        tic = time.time()

    # Clone and require gradient of the embedded input and prepare for the first iteration
    input_emb = input_embedding.clone().to(device).requires_grad_(True)
    output_emb = model(input_emb)[0].to(device)

    # Build the identity matrix that we use as standard Riemannain metric of the output embedding space.
    g = (
        torch.eye(input_embedding.size(-1))
        .unsqueeze(0)
        .repeat(
            output_emb.size(1),
            1,
            1,
        )
    ).to(device)

    # Compute the pullback metric and its eigenvalues and eigenvectors
    eigenvalues, _ = pullback(
        output_simec=output_emb[0, pred_id].squeeze(),
        input_simec=input_emb,
        g=g,
    )

    pathlib.Path(out_dir).mkdir(parents=True, exist_ok=True)

    if keep_timing:
        save_object(
            {
                "input_embedding": input_emb.cpu(),
                "output_embedding": output_emb.cpu(),
                "eigenvalues": eigenvalues.cpu(),
                "time": time.time() - tic,
            },
            os.path.join(out_dir, "pullback_eigenvalues.pkl"),
        )
    else:
        save_object(
            {
                "input_embedding": input_emb.cpu(),
                "output_embedding": output_emb.cpu(),
                "eigenvalues": eigenvalues.cpu(),
            },
            os.path.join(out_dir, "pullback_eigenvalues.pkl"),
        )


def explore(
    same_equivalence_class: bool,
    input_embedding: torch.Tensor,
    model: torch.nn.Module,
    threshold: float,
    n_iterations: int,
    pred_id: int,
    device: torch.device,
    eq_class_emb_ids: List[int] = None,
    keep_timing: bool = True,
    save_each: int = 10,
    out_dir: str = ".",
):
    """
    Explore the manifold defined by the model's embedding space to analyze
    transitions within or between equivalence classes.

    Args:
        same_equivalence_class (bool): Flag indicating whether to stay within
        the same equivalence class.
        input_embedding (torch.Tensor): Batch of embeddings.
        model (torch.Module): The neural network model.
        delta (float): Step size for moving along the eigenvectors.
        threshold (float): Threshold for considering an eigenvalue as zero
        (or as a critical value for class change).
        n_iterations (int): Number of iterations to perform.
        pred_id (int): Index to select specific embeddings for prediction.
        device (torch.device): Device to run the computations on.
        eq_class_emb_ids (list, optional): Indices of embeddings belonging to
        the same equivalence class.
        keep_timing (bool, optional): Whether to keep timing for profiling.
        save_each (int, optional): Frequency of saving the state.
        out_dir (str, optional): Directory to save the outputs.

    Returns:
        None: Saves intermediate results to files.
    """

    # Clone and require gradient of the embedded input and prepare for the first iteration
    input_emb = input_embedding.clone().to(device).requires_grad_(True)
    output_emb = model(input_emb)[0].to(device)

    # Build the identity matrix that we use as standard Riemannain metric of the output embedding space.
    g = (
        torch.eye(input_embedding.size(-1))
        .unsqueeze(0)
        .repeat(
            input_emb.size(1) if not eq_class_emb_ids else len(eq_class_emb_ids),
            1,
            1,
        )
    ).to(device)

    # Keep track of the length of the polygonal
    distance = torch.zeros(
        input_emb.size(1) if not eq_class_emb_ids else len(eq_class_emb_ids)
    ).to(device)
    if keep_timing:
        times = defaultdict(float)
        times["n_iterations"] = n_iterations

    for i in range(n_iterations):
        if keep_timing:
            tic = time.time()
        # Compute the pullback metric and its eigenvalues and eigenvectors
        eigenvalues, eigenvectors = pullback(
            output_simec=output_emb[0, pred_id].squeeze(),
            input_simec=input_emb,
            g=g,
            eq_class_emb_ids=None if not eq_class_emb_ids else eq_class_emb_ids,
        )

        # Select a random eigenvectors corresponding to a null eigenvalue.
        # We consider an eigenvalue null if it is below a threshold value.
        if same_equivalence_class:
            number_eigenvalues = torch.count_nonzero(eigenvalues < threshold, dim=1)
        else:
            number_eigenvalues = torch.count_nonzero(eigenvalues > threshold, dim=1)
        eigenvecs, eigenvals = [], []
        for emb in range(eigenvalues.size(0)):
            if number_eigenvalues[emb]:
                id_eigen = torch.randint(0, number_eigenvalues[emb], (1,)).item()
                eigenvecs.append(
                    eigenvectors[emb, :, id_eigen].type(torch.float).to(device)
                )
                eigenvals.append(
                    eigenvalues[emb, id_eigen].type(torch.float).to(device)
                )
            else:
                eigenvecs.append(
                    torch.zeros(eigenvectors.size(-1)).type(torch.float).to(device)
                )
                eigenvals.append(torch.tensor(0).type(torch.float).to(device))
        eigenvecs = torch.stack(eigenvecs, dim=0).to(device)
        eigenvals = torch.stack(eigenvals, dim=0).to(device)

        with torch.no_grad():
            # Proceeed along a null direction
            if same_equivalence_class:
                delta = (torch.tensor(1) / torch.sqrt(torch.max(eigenvalues))).to(
                    device
                )
            else:
                delta = (torch.tensor(2) / torch.sqrt(torch.max(eigenvalues))).to(
                    device
                )
            if eq_class_emb_ids:
                input_emb[0, eq_class_emb_ids] = (
                    input_emb[0, eq_class_emb_ids] + eigenvecs * delta
                )
            else:
                input_emb[0] = input_emb[0] + eigenvecs * delta
            distance += eigenvals * delta

        # Prepare for next iteration
        input_emb = input_emb.to(device).requires_grad_(True)
        output_emb = model(input_emb)[0].to(device)
        if i % save_each == 0:
            if keep_timing:
                tic_save = time.time()
            print(f"Iteration: {i}\tDelta: {around(delta.cpu().numpy(), 5)}")
            if not os.path.exists(out_dir):
                os.makedirs(out_dir)
            save_object(
                {
                    "input_embedding": input_emb.cpu(),
                    "output_embedding": output_emb.cpu(),
                    "distance": distance.cpu(),
                    "iteration": i,
                    "delta": delta.cpu(),
                },
                os.path.join(out_dir, f"{i}.pkl"),
            )
            if keep_timing:
                diff = time.time() - tic_save
        if keep_timing:
            times["time"] += time.time() - tic
            if i % save_each == 0:
                times["time"] -= diff

    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    if keep_timing:
        save_object(
            {
                "input_embedding": input_emb.cpu(),
                "output_embedding": output_emb.cpu(),
                "distance": distance.cpu(),
                "iteration": n_iterations,
                "time": times["time"],
                "delta": delta.cpu(),
            },
            os.path.join(out_dir, f"{n_iterations}.pkl"),
        )
    else:
        save_object(
            {
                "input_embedding": input_emb.cpu(),
                "output_embedding": output_emb.cpu(),
                "distance": distance.cpu(),
                "iteration": n_iterations,
                "delta": delta.cpu(),
            },
            os.path.join(out_dir, f"{n_iterations}.pkl"),
        )
