import numpy as np
from numba import jit
from numba.typed import List as TypedList
from numba.core import types


@jit(nopython=True)
def _expit(x):
    return 1 / (1 + np.exp(-x))

@jit(nopython=True)
def _draw_random_index(
    C_shape: tuple[int, int],
    p_cumsum: np.ndarray
):
    rows, cols = C_shape
    index = np.searchsorted(p_cumsum, np.random.rand())
    row_index = index // cols
    col_index = index % cols

    return row_index, col_index

@jit(nopython=True)
def _fit_loop(
    X, Y, p_cumsum, col_sums, neg_weight, eta, t_iter, C_shape, ones, VVt, pot, pot_community, verbose
):
    for t in range(t_iter):
        if verbose and (t > 0) and (t % 1_000_000 == 0):
            print("Iteration:", t, "/", t_iter)

        # 1. Draw indicies
        i, j = _draw_random_index(C_shape, p_cumsum)

        # 2. Compute only the necessary row (Row i)
        x_i = X[i, :] # Shape (d,)
        dots = Y @ x_i # Shape (n,)
        q_row = _expit(dots)

        # 3. Construct the H vector for row i
        h_vec = -(neg_weight * col_sums * q_row)
        h_vec[j] += (1.0 - q_row[j])

        # 4. Compute Gradients
        grad_X_i = h_vec @ Y # Shape (d,)
        grad_Y = h_vec.reshape(-1, 1) * x_i.reshape(1, -1)

        # 5. Apply updates
        X[i] += eta * grad_X_i
        Y += eta * grad_Y

        # Only calculate potential every 5000 steps or so
        if t % 5000 == 0 and t > 0:
            norm_X = np.linalg.norm(X)
            d = X.shape[1]
            pot.append((( (X.T @ ones)/(norm_X))**2)[0][0])
            pot_community.append((( np.linalg.norm(VVt @ X)/(norm_X))**2))

    return X, Y, pot, pot_community


class SkipgramNegativeSampling:
    """
    Implements Skig-gram with negative sampling.

    Parameters
    ----------
    eta: float, default = 0.1
        The learning rate.

    t_iter: int, default = 10
        The number of iterations to perform.

    s_n: int
        The negative sampling parameter.

    r: float
        The radius of the embeddings at initialization.
    
    d: int
        The embedding dimension. 

    random_state: int
        Seed for the random number generator. 
    """

    def __init__(
        self,
        eta: float = 0.1,
        t_iter: int = 10,
        s_n: int = 1,
        r: float = 0.1,
        d: int = 1,
        random_state: int = 1234,
        verbose: bool = False,
    ):
        self.eta = eta
        self.t_iter = t_iter
        self.s_n = s_n
        self.r = r
        self.d = d
        self.random_state = random_state
        self.verbose = verbose


    def fit(
        self, 
        C: np.ndarray,
        K: int = None,
        warm_start: bool = False
    ): 
    
        n = C.shape[0]
        d = self.d
        m = C.sum()
        flattened_probabilities = C.flatten() / C.sum()
        p_cumsum = np.cumsum(flattened_probabilities)

        self.rng = np.random.default_rng(self.random_state)
        np.random.seed(self.random_state)

        # Initialize embeddings
        if warm_start:
            X = self.X
            Y = self.Y
            pot = self.pot_
            pot_community = self.pot_community_
        else:
            X, Y = self._initialize_embeddings(n, d, self.r)

        # PRE-COMPUTATION
        col_sums = C.sum(axis=0) # Shape (n,)
        neg_weight = self.s_n / m

        # Initialize Potential Function
        ones = np.ones((n, 1))
        n_ones = np.linalg.norm(ones)
        ones = ones/n_ones
        
        U_community = np.zeros((n, K-1))
        for k in range(K-1):
            block_start1 = int(k * (n/K))
            block_end1 = int((k+1)*(n/K))
            block_start2 = int((k+1) * (n/K))
            block_end2 = int((k+2)*(n/K))
            U_community[block_start1:block_end1, k] = 1
            U_community[block_start2:block_end2, k] = -1
            U_community[:, k] = U_community[:, k]/np.linalg.norm(U_community[:,  k])
            
        VVt = U_community @ np.linalg.inv(U_community.T @ U_community) @ U_community.T
        
        if not warm_start:
            norm_X = np.linalg.norm(X)
            pot_val = (( (X.T @ ones)/(norm_X))**2)[0][0]
            pot_community_val = (( np.linalg.norm(VVt @ X)/(norm_X))**2)
            
            pot = TypedList([pot_val])
            pot_community = TypedList([pot_community_val])
        
        # OPTIMIZATION
        print(f"Starting iterations (Total: {self.t_iter})...")
        
        X, Y, pot, pot_community = _fit_loop(
            X, Y, p_cumsum, col_sums, neg_weight, self.eta, self.t_iter, C.shape, ones, VVt, pot, pot_community, self.verbose
        )
        
        print("Iterations finished.")
        
        self.X = X
        self.Y = Y
        self.pot_ = list(pot)
        self.pot_community_ = list(pot_community)

        return self
    

    def _initialize_embeddings(
        self,
        n,
        d,
        r
    ):

        X = self.rng.normal(0, r, (n, d))
        Y = self.rng.normal(0, r, (n, d))

        return X, Y
    

    


    
