"""
before runing the code, use 'source /pathto/miniconda3/bin/activate; conda activate base' 
to set the juliacall environment
"""

import numpy as np
import time
import test_helper as helper
from juliacall import Main as jl

_GSWDESIGN_READY = False
_SAMPLER = None


def _ensure_gswdesign():
    global _GSWDESIGN_READY, _SAMPLER
    if not _GSWDESIGN_READY:
        jl.seval("using GSWDesign")
        _SAMPLER = jl.seval("(A, phi) -> Int.(sample_gs_walk(A, phi, 1)[1])")
        _GSWDESIGN_READY = True
    return _SAMPLER

def GSwalk_poly_low_rank_many(V, k, rank_k, phi, its, weights=None, balance=False, exp=False):
    """
    Args:
        V: array of shape (d, n), original feature vector
        rank_k: approximate the feature Gram using a rank-k PSD matrix
        k: balance degree-k polynomial terms of features
        ϕ: balance-robust trade-off, bi = [sqrt(ϕ) ei, sqrt(1-ϕ) vi]
        weights: weights for feature vector in V
        balance = True: weigth a_0 = 1/(e-1) for 1 to match the exponential kernel
        exp = True: use the exponential kernel
    """

    d,n = V.shape
    
    if exp:
        rank_k = d
    rank_k = min(rank_k, n)

    assignments: list[list[int]] = []
    rng = np.random.default_rng(12345)  # single seed for reproducibility

    # Normalize so the maximum column norm is at most 1 (if nonzero).
    col_norms = np.linalg.norm(V, axis=0)
    max_norm = col_norms.max()
    if max_norm > 0:
        V = V / max_norm

    try:
        sampler = _ensure_gswdesign()
    except Exception as exc:
        raise RuntimeError(
            "Failed to load GSWDesign in Julia. Install it with: "
            'using Pkg; Pkg.add("GSWDesign")'
        ) from exc

    for _ in range(its):
        S = rng.choice(n, size=rank_k, replace=False)
        
        gram_subset = V.T @ V[:, S]
        
        if exp:
            gram_subset_pow = np.exp(gram_subset) / np.exp(1)
        else:
            powers = np.arange(1, k + 1, dtype=float)
            if weights is None:
                weights = np.ones_like(powers)
            else:
                weights = np.asarray(weights, dtype=float)
                if weights.shape[0] != powers.shape[0]:
                    raise ValueError(f"weights must have length {k}")
            
            # Enforce nonnegative weights and normalize to sum to 1.
            weights = np.abs(weights)
            total = weights.sum()
            if total <= 0.0:
                raise ValueError("weights must have at least one positive entry")
            weights = weights / total

            gram_subset_pow = np.sum((gram_subset[..., None] ** powers) * weights, axis=2)
        
        C = gram_subset_pow
        W = gram_subset_pow[S, :]
        
        try:
            W_chol = np.linalg.cholesky(W)
        except np.linalg.LinAlgError:
            jitter = 1e-8 * np.eye(W.shape[0])
            W_chol = np.linalg.cholesky(W + jitter)
        
        W_chol_inv = np.linalg.inv(W_chol)
        A = C @ W_chol_inv.T  # dim: n-by-(rank_k)
        
        A_jl = jl.Array(np.ascontiguousarray(A, dtype=float))
        assignment_vec = sampler(A_jl, float(phi))
        assignment_arr = np.array(assignment_vec, dtype=np.int8)
        assignment = np.where(assignment_arr > 0, 1, -1).astype(int).tolist()
        assignments.append(assignment)

    return assignments

def _smoke_test() -> None:
    rng = np.random.default_rng(0)
    d, n = 5, 10
    V = rng.standard_normal((d, n))
    assignments = GSwalk_poly_low_rank_many(
        V,
        k=2,
        rank_k=3,
        phi=0.5,
        its=2,
        weights=np.array([0.5, 0.5]),
    )
    print(assignments)
    assert len(assignments) == 2
    assert all(len(a) == n for a in assignments)


if __name__ == "__main__":
    _smoke_test()


