'''
    In this script we introduce the Chi measure, which, given a point i,
    is the fraction of k-neighbors j (in the activations space) that
    share the final classification label with i (e.g. point i-th is the
    activation of a cat image, so its label is `cat`, if the j-th point
    that belongs in its k-neighbourood is the image of a `dog` it does
    not share the output label, and so the fraction diminishes, if it is
    again another `cat` it increases). Tracking the evolution of this
    measure across the visual stream can inform about whether some form
    of organization in the representation emerges.
'''

import numpy as np
import cvxpy as cp

from sklearn.neighbors import NearestNeighbors

class Chi():
    '''
        Compute, for each point i, the neighbor overlap to the ground
        truth classification label.
        
        Args:
        X (np.ndarray): 2-D Matrix of shape (N, D) representing the N
                        point in D dimensional space.
        
        L (np.ndarray): 1-D Matrix of shape (N, 1) representing the 
                        classification ground truth label for each point.
                        
        k (float): The number of nearest neighbors to consider when 
                evaluating the Chi fraction.
        
        Returns:
        Chi (float): The mean fraction (across all N points, normalized
                    in [0, 1]) of neighbors sharing the ground truth
                    classification label
    '''
    
    def __init__(self, k : int = 15, metric : str = 'euclidean') -> None:
        # Number of nearest neighbors to consider
        self.k = k

        # Which metric to use for nearest neighbors computation
        self.metric = metric


    def __call__(self, X : np.array, L : np.array) -> float:
        # Compute the k-Nearest Neighbors
        nn = NearestNeighbors (n_neighbors = self.k, metric = self.metric, n_jobs = -1)
        _, I = nn.fit (X).kneighbors (X)

        # Evaluate the Chi fraction for each point
        Chi = [np.sum (L[i[1:]] == L[i[0]]) / self.k for i in I]
        
        # Return the mean across all the points
        return np.mean (Chi)

def RChi(X1, X2):
    '''
        Compute the linear-explainable difference in (partial) coordinate
        between the two sets X1 and X2. We solve the optimization problem
        that seeks to find the optimal matrix A defines as:

                    A := argmin_A R = Sum [X2 - A @ X1]^2

        If the found optimum is such that R is close to zero, then this is
        interpreted as a sign that X1 and X2 differ only via a global linear
        transformation (example: they are different partial random project
        of points inhabiting the same manifold). This metric can be used in
        conjuntion with Chi to estimate what percentage of the neighbors
        change is actually due to (non linear) mixing and what is instead
        only apparend and trivially a consequence of a global linear
        transformation (which alters the local neighbors).  
    '''
    
    _, N1 = X1.shape
    _, N2 = X2.shape

    A = cp.Variable ((N2, N1))

    # Here we define the objective function
    # R = cp.Minimize (cp.sum (cp.sum ((X2.T - A @ X1.T)**2, axis = 0) / cp.sum ((X2.T)**2, axis = 0)))
    R = cp.Minimize (cp.sum (cp.sum ((X2.T - A @ X1.T)**2, axis = 1) / cp.sum ((X2.T)**2, axis = 1)))

    problem = cp.Problem (R)

    R_opt = problem.solve ()
    A_opt = A.value

    return R_opt, A_opt