import os
import sys
import subprocess
from itertools import product
from scipy.optimize import minimize_scalar
import numpy as np
from scipy.optimize import fsolve

def solve_ab(n, sizes, k, snr, C):
    sizes = np.asarray(sizes, dtype=float)
    T = np.sum(sizes**2)      # sum n_r^2
    U = n**2 - T              # n^2 - sum n_r^2
    X = n * (n - 1) * C/k       # the Constant On TheRight

    def equations(vars):
        a, b = vars
        eq1 = a * T + b * U - X
        eq2 = (a - b)**2 / (k * (a + (k - 1) * b)) - snr
        return [eq1, eq2]

    # Give an initial value and guess any one (e.g. a=2C/k, b=C/k)
    init = [2*C/k, C/k]
    sol = fsolve(equations, init)
    a, b = sol
    return a, b

def ab_to_pq(a, b, n):
    """Convert a,b to probability p,q"""
    logn = np.log(max(n, 2))
    p = (a * logn) / n
    q = (b * logn) / n
    return p, q


def snr_objective(a, s, k, total_ab):
    b = (total_ab - a) / (k - 1)
    if a <= b or b <= 0:  # Force a > b and b > 0
        return 1e9
    snr = ((a - b) ** 2) / (k * (a + (k - 1) * b))
    return abs(snr - s)

def find_a_given_snr(s, k, total_ab):
    # Set a in the range (total_ab / k + ε, total_ab - ε), guaranteeing a > b and a < total_ab
    eps = 1e-4
    lower_bound = total_ab / k + eps
    upper_bound = total_ab - eps

    if lower_bound >= upper_bound:
        raise ValueError("No feasible solution: cannot satisfy a > b under given total_ab and k")

    res = minimize_scalar(
        snr_objective, args=(s, k, total_ab),
        bounds=(lower_bound, upper_bound), method='bounded'
    )
    a = res.x
    b = (total_ab - a) / (k - 1)
    return a, b