"""
ELSA: Embedding Local Spatial Autocorrelation

Implementation of the ELSA metric as proposed in the paper
"ELSA: LOCAL SPATIAL AUTOCORRELATION OF EMBEDDINGS".

This module provides a class-based implementation of ELSA that is inspired
by the structure of esda.moran.Moran_Local from the PySAL ecosystem.
"""

import numpy as np
from libpysal.weights import W
from libpysal.weights.spatial_lag import lag_spatial
from sklearn.preprocessing import minmax_scale

class ELSA:
    """
    Embedding Local Spatial Autocorrelation (ELSA).

    ELSA is a measure of local spatial autocorrelation for high-dimensional
    embeddings. It adapts the local Moran's I statistic to work with vector
    data by using cosine similarity as the core measure of attribute similarity.

    Parameters
    ----------
    x : numpy.ndarray
        A numpy array of shape (n, d) where n is the number of observations
        and d is the dimension of the embeddings. It is assumed that the
        embeddings are L2-normalized.
    w : libpysal.weights.W
        A spatial weights object.
    permutations : int, optional
        The number of random permutations for calculating pseudo-p-values for
        statistical significance. Defaults to 999. If 0, no permutations are run.

    Attributes
    ----------
    e : numpy.ndarray
        The ELSA statistic for each observation.
    z : numpy.ndarray
        The standardized cosine similarity for each observation. This is the
        embedding's similarity to the global mean embedding, standardized.
    wz : numpy.ndarray
        The spatial lag of z.
    q : numpy.ndarray
        The quadrant for each observation. Takes values 1-4:
        1: High-High (High z, High wz) - Hotspot
        2: Low-Low (Low z, Low wz)   - Coldspot
        3: Low-High (Low z, High wz)  - Spatial Outlier
        4: High-Low (High z, Low wz)  - Spatial Outlier
    p_sim : numpy.ndarray
        The pseudo-p-value for each observation based on permutations.
        Only calculated if permutations > 0.
    sim : numpy.ndarray
        An array of shape (n, permutations) storing the simulated ELSA values
        for each observation. Only calculated if permutations > 0.
    """

    def __init__(self, x, w, permutations=999):
        if not isinstance(w, W):
            raise ValueError("w must be a libpysal.weights.W object.")
        if x.shape[0] != w.n:
            raise ValueError(
                "The number of observations in x does not match the "
                "number of observations in the weights object w."
            )

        self.w = w
        self.x = x
        self.permutations = permutations
        self.n, self.d = self.x.shape

        self._run()

        if self.permutations:
            self._run_permutations()
            self._calculate_p_values()

    def _run(self):
        """Calculate the ELSA statistic."""
        # Calculate the mean embedding vector
        x_bar = self.x.mean(axis=0)

        # Calculate cosine similarity between each embedding and the mean embedding
        # Assumes x and x_bar are L2 normalized, so dot product is cosine similarity
        cos_sim = self.x @ x_bar

        # Standardize the cosine similarities
        mu_cos = cos_sim.mean()
        sigma_cos = cos_sim.std()

        if sigma_cos == 0:
            # All embeddings are identical or perfectly aligned with the mean
            self.z = np.zeros(self.n)
        else:
            self.z = (cos_sim - mu_cos) / sigma_cos

        # Calculate the spatial lag of z
        self.wz = lag_spatial(self.w, self.z)

        # Calculate sum of squares of z
        z2ss = np.sum(self.z * self.z)

        # Calculate the ELSA statistic `e`
        self.e = ((self.n - 1) * self.z * self.wz) / z2ss

        # Determine quadrant classification
        self.q = np.ones(self.n, dtype='int')
        self.q[(self.z < 0) & (self.wz < 0)] = 2  # Low-Low
        self.q[(self.z < 0) & (self.wz > 0)] = 3  # Low-High
        self.q[(self.z > 0) & (self.wz < 0)] = 4  # High-Low

    def _run_permutations(self):
        """Run the permutation test for significance."""
        z_shuffled = np.copy(self.z)
        self.sim = np.zeros((self.n, self.permutations))
        z2ss = np.sum(self.z * self.z)

        for i in range(self.permutations):
            np.random.shuffle(z_shuffled)
            wz_p = lag_spatial(self.w, z_shuffled)
            self.sim[:, i] = ((self.n - 1) * self.z * wz_p) / z2ss

    def _calculate_p_values(self):
        """Calculate pseudo p-values from permutation results."""
        # Count how many simulated absolute values are >= observed absolute values
        larger = np.abs(self.sim) >= np.abs(self.e)[:, np.newaxis]
        count = larger.sum(axis=1)

        # Calculate p-value with correction for self-inclusion
        self.p_sim = (count + 1) / (self.permutations + 1)

    @classmethod
    def by_col(cls, df, cols, w=None, **kwargs):
        """
        Function to compute ELSA for a set of columns in a dataframe.

        Parameters
        ----------
        df : pandas.DataFrame
            Dataframe containing the embedding vectors.
        cols : list
            List of column names to be used as embeddings.
        w : libpysal.weights.W, optional
            A spatial weights object. If not provided, it must be passed as
            a keyword argument.
        **kwargs : dict
            Additional keyword arguments to be passed to the ELSA constructor.

        Returns
        -------
        ELSA
            An ELSA object.
        """
        if w is None and "w" not in kwargs:
            raise ValueError("A spatial weights object (w) must be provided.")

        x = df[cols].values
        return cls(x, w=w, **kwargs)