import sys
import numpy as np

"""
Computes the Bayes risk, for the population loss.
nu : float ; signal strength.
Lmax : int ; sequence length.
Lrandom : bool ; if true the sequence lengths are random uniform in [1, Lmax] ; if false the sequence lengths are Lmax.
model : in ["spiked", "max"] ; choice of the model between the spiked-SLR or the max-SLR.
"""

nu = float(sys.argv[1])
Lmax = int(sys.argv[2])
Lrandom = int(sys.argv[3])
model = sys.argv[4]
assert model in ["spiked", "max"]

nMC = int(3e6)

def _err(mv, mk, nu, l):
    chis = np.random.normal(0, 1, (l, nMC))
    if model=="spiked":
        err = np.mean(1/(1+np.sum(np.exp(np.sqrt(nu*mk)*(chis[1:,:]-chis[0,:]-np.sqrt(nu*mk))), axis=0)))
    elif model=="max":
        err = l*np.mean(1/(1+np.sum(np.exp(nu*mk*(chis[1:,:]-chis[0,:])), axis=0))**2)
    return 1-mv*err

mv, mk = 1, 1
err = 0

if Lrandom:
    for l in range(1,Lmax+1):
        err += _err(mv, mk, nu, l)
    err /= Lmax
else:
    err = _err(mv, mk, nu, Lmax)

print(err)
