import numpy as np
from scipy.optimize import linprog
from scipy.stats import poisson

# Below code converted from  Estimating the Unseen: improved Estimators for Entropy and other properties paper
def unseen(f):
    # Ensure input is a numpy row vector
    f = np.asarray(f).flatten()

    # Total sample size: k = sum(i * f(i))
    k = np.dot(f, np.arange(1, len(f) + 1))

    ######## Algorithm Parameters ########
    gridFactor = 1.1
    alpha = 0.5
    xLPmin = 1 / (k * max(10, k))
    maxLPIters = 1000
    ######################################

    # Split fingerprint into dense (LP) and sparse (empirical)
    x = []
    histx = []
    fLP = np.zeros_like(f)

    for i in range(len(f)):
        if f[i] > 0:
            wind_start = max(0, i - int(np.ceil(np.sqrt(i + 1))))  # +1 because Matlab is 1-based
            wind_end = min(i + int(np.ceil(np.sqrt(i + 1))) + 1, len(f))  # +1 for Python slicing
            if np.sum(f[wind_start:wind_end]) < 2 * np.sqrt(i + 1):
                x.append((i + 1) / k)
                histx.append(f[i])
                fLP[i] = 0
            else:
                fLP[i] = f[i]

    # If no LP portion exists, return empirical histogram
    if not np.any(fLP > 0):
        return np.array(histx), np.array(x)

    # Max index used in LP
    fmax = np.max(np.where(fLP > 0)[0]) if np.any(fLP > 0) else 0

    # LP mass: 1 - mass covered by empirical (sparse) part
    LPmass = 1 - np.dot(x, histx)

    # Extend fLP with zeros (padding to allow sqrt(fmax) coverage)
    fLP = np.concatenate([fLP[:fmax+1], np.zeros(int(np.ceil(np.sqrt(fmax + 1))))])
    szLPf = len(fLP)

    # Set up xLP grid: geometric range between min and max allowed probs
    xLPmax = (fmax + 1) / k
    powers = np.arange(0, int(np.ceil(np.log(xLPmax / xLPmin) / np.log(gridFactor))) + 1)
    xLP = xLPmin * gridFactor**powers
    szLPx = len(xLP)

    # LP1 objective vector: discrepancy terms (with slack variables)
    objf = np.zeros(szLPx + 2 * szLPf)
    objf[szLPx::2] = 1 / np.sqrt(fLP + 1)
    objf[szLPx + 1::2] = 1 / np.sqrt(fLP + 1)

    # LP1 inequality constraints: 2 per fingerprint index (upper/lower bounds)
    A = np.zeros((2 * szLPf, szLPx + 2 * szLPf))
    b = np.zeros(2 * szLPf)

    for i in range(szLPf):
        pois_probs = poisson.pmf(i + 1, k * xLP)  # +1 because i starts from 0
        A[2*i, :szLPx] = pois_probs
        A[2*i+1, :szLPx] = -pois_probs
        A[2*i, szLPx + 2*i] = -1
        A[2*i+1, szLPx + 2*i + 1] = -1
        b[2*i] = fLP[i]
        b[2*i+1] = -fLP[i]

    # LP1 equality constraint: xLP @ h = LPmass
    Aeq = np.zeros((1, szLPx + 2 * szLPf))
    Aeq[0, :szLPx] = xLP
    beq = np.array([LPmass])

    # Scale A and Aeq for better conditioning
    for i in range(szLPx):
        A[:, i] /= xLP[i]
        Aeq[0, i] /= xLP[i]

    # Solve LP1: minimize discrepancy
    bounds = [(0, None)] * (szLPx + 2 * szLPf)
    res1 = linprog(c=objf, A_ub=A, b_ub=b, A_eq=Aeq, b_eq=beq,
                   bounds=bounds, options={'maxiter': maxLPIters, 'disp': False})

    if not res1.success:
        print("Warning: LP1 did not converge. Proceeding to LP2.")

    vopt = res1.fun if res1.success else None

    # LP2: Minimize total number of bins, with constraint on objective degradation
    objf2 = np.zeros_like(objf)
    objf2[:szLPx] = 1

    # Add constraint to bound LP1 objective value (alpha worse)
    A2 = np.vstack([A, objf])
    b2 = np.concatenate([b, [vopt + alpha if vopt is not None else alpha]])

    # Rescale again for LP2
    for i in range(szLPx):
        objf2[i] /= xLP[i]

    # Solve LP2
    res2 = linprog(c=objf2, A_ub=A2, b_ub=b2, A_eq=Aeq, b_eq=beq,
                   bounds=bounds, options={'maxiter': maxLPIters, 'disp': False})

    if not res2.success:
        print("Warning: LP2 did not converge.")

    # Remove scaling and finalize output
    sol2 = res2.x[:szLPx] / xLP
    x = np.concatenate([x, xLP])
    histx = np.concatenate([histx, sol2])

    # Sort and clean up small values
    order = np.argsort(x)
    x = x[order]
    histx = histx[order]
    mask = histx > 0
    return histx[mask], x[mask]

def make_finger(v):
    # Count how many times each value occurs in the sample
    v = np.asarray(v)
    unique, counts = np.unique(v, return_counts=True)
    
    # Now count the frequency of those counts
    max_count = counts.max()
    freq_of_freqs = np.bincount(counts, minlength=max_count + 1)
    
    # Drop f(0), which is not used
    return freq_of_freqs[1:]