# -*- coding: utf-8 -*-
"""
Simulations for the paper *Sharp Matrix Empirical Bernstein Inequalities*
Hongjian Wang and Aaditya Ramdas
NeurIPS 2025 Poster
"""

import numpy as np

def maurer_pontil_radius(v, n, alpha=.05):
  # Maurer-Pontil empirical Bernstein CI for [0,1]-bounded RVs
  log2a = np.log(2/alpha)
  return np.sqrt(2*v*log2a/n) + 7*log2a/(3*(n-1))


def maurer_pontil_radius_sharp(v, n, alpha=.05):
  # Maurer-Pontil empirical Bernstein CI for [0,1]-bounded RVs
  # sharpened by a different alpha splitting scheme
  log_term_1 = np.log(n/((n-1)*alpha))
  log_term_2 = np.log(n/alpha)
  return log_term_1/(3*n) + np.sqrt(2*v*log_term_1/n) + 2*np.sqrt( log_term_1*log_term_2 / (n*(n-1)) )


def bernstein_radius(n, alpha=.05, variance=1/4):
  # Bernstein CI for [0,1]-bounded RVs
  log1a = np.log(1/alpha)
  return log1a/(3*n) + np.sqrt(2*variance*log1a/n)


def wsr_eb_radius(s_lambda, s_v_psi_lambda, alpha=.05):
  log1a = np.log(1/alpha)
  return (log1a + s_v_psi_lambda) / s_lambda

class seq_radius:
  def __init__(self, alpha=.05):
    self.alpha = alpha
    self.n = 0
  def observe(self, x):
    pass
  def radius(self):
    pass

class VarianceBasedRadius(seq_radius):
  def __init__(self, alpha=.05):
    super().__init__(alpha)
    self.S = 0
    self.V = 0
  def observe(self, x):
    self.n += 1
    self.S += x
    self.V += x*x
  def v(self):
    # Bessel-corrected variance
    return (self.V - self.S*self.S/self.n) / (self.n-1)


class MaurerPontilRadius(VarianceBasedRadius):
  def radius(self):
    return maurer_pontil_radius(self.v(), self.n, self.alpha)


class MaurerPontilRadiusSharp(VarianceBasedRadius):
  def radius(self):
    return maurer_pontil_radius_sharp(self.v(), self.n, self.alpha)


class BernsteinRadius(seq_radius):
  def __init__(self, alpha=.05, variance=1/4):
    super().__init__(alpha)
    self.variance = variance
  def observe(self, x):
    self.n += 1
  def radius(self):
    return bernstein_radius(self.n, self.alpha, self.variance)


class WSREbRadius(seq_radius):
  def __init__(self, alpha=.05, final_N=1000):
    super().__init__(alpha)
    self.s_lambda = 0
    self.s_v_psi_lambda = 0
    self.S = 0
    self.prev_muhat = .5
    self.prev_sigmasqhat = .25
    self.log1a = np.log(1/alpha)
    self.s_xi_minus_muhat_sq = 0
    self.final_N = final_N
  def psiE(self, u):
    return (-np.log(1-u) - u)/4
  def observe(self, x):
    self.n += 1
    self.S += x
    muhat = (self.S + .5)/(self.n + 1)
    self.s_xi_minus_muhat_sq += np.square(x - muhat)
    sigmasqhat = (.25 + self.s_xi_minus_muhat_sq)/(self.n + 1)
    lamb = np.min([np.sqrt(2*self.log1a/(self.final_N*self.prev_sigmasqhat)), .75])
    self.s_lambda += lamb
    self.s_v_psi_lambda += 4*np.square(x-self.prev_muhat)*self.psiE(lamb)
    self.prev_muhat = muhat
    self.prev_sigmasqhat = sigmasqhat
  def radius(self):
    return wsr_eb_radius(self.s_lambda, self.s_v_psi_lambda, self.alpha)


def get_wsr_radius(X: np.array, alpha=.05):
  w = WSREbRadius(alpha, len(X))
  for x in X:
    w.observe(x)
  return w.radius()

mp1 = MaurerPontilRadius()
mp2 = MaurerPontilRadiusSharp()
b = BernsteinRadius(variance=1/12)
sample_sizes = [ 100, 1000, 10000, 100000, 1000000 ]

X = np.array([])
for n in range(1,sample_sizes[-1]+1):
  x = np.random.rand()
  mp1.observe(x)
  mp2.observe(x)
  b.observe(x)
  X = np.append(X, x)
  if n in sample_sizes:
    print(f"n={n}, MP-to-Bern={mp1.radius()/b.radius()}, Sharp-MP-to-Bern={mp2.radius()/b.radius()}, WSR-to-Bern={get_wsr_radius(X)/b.radius()}")

# now matrices
import numpy as np

def opnorm(M: np.matrix):
  # M is always real symmetric
  return np.linalg.norm(M, ord=2)


def mat_maurer_pontil_radius_sharp(v: np.matrix, n, alpha=.05):
  # Sharp matrix Maurer-Pontil empirical Bernstein CI for [0,1]-bounded RMs
  # v: empirical variance matrix
  vop = opnorm(v)
  d = v.shape[0]
  log_term_1 = np.log(n*d/((n-1)*alpha))
  log_term_2 = np.log(2*n*d/alpha)
  return log_term_1/(3*n) + np.sqrt(2*log_term_1/n) * (
      np.sqrt(vop) + np.min([
          np.sqrt(log_term_2/(2*n*vop)),
          np.power(2*log_term_2/n, 1/4)
      ])
  )


def mat_maurer_pontil_radius_sharp_var(v: np.matrix, n, alpha=.05):
  # Sharp matrix Maurer-Pontil empirical Bernstein CI for [0,1]-bounded RMs
  # variant: second order term being n^{-1} instead of n^{-3/4}
  # v: "paired" variance estimator
  vop = opnorm(v)
  d = v.shape[0]
  log_term_1 = np.log(n*d/((n-1)*alpha))
  log_term_2 = np.log(2*n*d/alpha)
  return log_term_1/(3*n) + np.sqrt(2*log_term_1*vop/n) + (np.sqrt(5/3) + 1)*np.sqrt(log_term_1*log_term_2)/n


def mat_bernstein_radius(n, variance: np.matrix, alpha=.05):
  # Bernstein CI for [0,1]-bounded RVs
  logda = np.log(variance.shape[0]/alpha)
  varop = opnorm(variance)
  return logda/(3*n) + np.sqrt(2*varop*logda/n)


def mat_wsr_eb_radius(s_lambda, s_v_psi_lambda: np.matrix, alpha=.05):
  d = s_v_psi_lambda.shape[0]
  logda = np.log(d/alpha)
  return (logda + opnorm(s_v_psi_lambda)) / s_lambda

class SharpMatrixMaurerPontilRadius(seq_radius):
  def __init__(self, dim, alpha=.05):
    super().__init__(alpha)
    self.S = np.asmatrix(np.zeros((dim,dim)))
    self.V = np.asmatrix(np.zeros((dim,dim)))
  def observe(self, x: np.matrix):
    self.n += 1
    self.S += x
    self.V += np.matmul(x, x)
  def v(self):
    # Bessel-corrected variance
    return (self.V - np.matmul(self.S, self.S)/self.n) / (self.n-1)
  def radius(self):
    return mat_maurer_pontil_radius_sharp(self.v(), self.n, self.alpha)


class SharpMatrixMaurerPontilRadiusVar(seq_radius):
  def __init__(self, dim, alpha=.05):
    super().__init__(alpha)
    self.S = np.asmatrix(np.zeros((dim,dim)))
    self.V = np.asmatrix(np.zeros((dim,dim)))
  def observe(self, x: np.matrix):
    self.n += 1
    self.S += x
    if self.n % 2 == 1:
      self.oddx = x
    else:
      self.V += np.matmul(x-self.oddx, x-self.oddx)
  def v(self):
    # paired sample variance
    return self.V / self.n
  def radius(self):
    return mat_maurer_pontil_radius_sharp_var(self.v(), self.n, self.alpha)

class MatrixBernsteinRadius(seq_radius):
  def __init__(self, variance_matrix, alpha=.05):
    super().__init__(alpha)
    self.variance_matrix = variance_matrix
  def observe(self, x):
    self.n += 1
  def radius(self):
    return mat_bernstein_radius(self.n, self.variance_matrix, self.alpha)


class MatrixWSREbRadius(seq_radius):
  def __init__(self, dim, alpha=.05, final_N=1000):
    super().__init__(alpha)
    self.dim = dim
    self.s_lambda = 0
    self.s_v_psi_lambda = np.asmatrix(np.zeros((dim,dim)))
    self.S = np.asmatrix(np.zeros((dim,dim)))
    self.V = np.asmatrix(np.zeros((dim,dim)))
    self.prev_mubar = np.asmatrix(np.zeros((dim,dim)))
    self.prev_vbar = .25
    self.logda = np.log(dim/alpha)
    self.final_N = final_N
  def psiE(self, u):
    return - np.log(1-u) - u
  def observe(self, x):
    self.n += 1
    self.S += x
    self.V += np.matmul(x, x)
    Vbar = (self.V - np.matmul(self.S, self.S)/self.n) / self.n
    vbar = np.max([ opnorm(Vbar), 5*self.logda/self.n ])
    mubar = self.S/self.n

    lamb = np.sqrt(
        2*self.logda/(self.final_N*self.prev_vbar)
    )

    self.s_lambda += lamb
    self.s_v_psi_lambda += np.matmul(x-self.prev_mubar, x-self.prev_mubar)*self.psiE(lamb)
    self.prev_mubar = mubar
    self.prev_vbar = vbar
  def radius(self):
    return mat_wsr_eb_radius(self.s_lambda, self.s_v_psi_lambda, self.alpha)


def get_Mwsr_radius(X: list, dim, alpha=.05):
  w = MatrixWSREbRadius(dim, alpha, len(X))
  for x in X:
    w.observe(x)
  return w.radius()

sample_sizes = [ 100, 1000, 10000, 100000, 1000000 ]
X = []

# get 3 orthonormal vectors of 3 dimension
u1 = np.array([ [1/np.sqrt(2)], [1/np.sqrt(2)], [0] ])
u2 = np.array([ [0], [0], [1] ])
u3 = np.array([ [1/np.sqrt(2)], [-1/np.sqrt(2)], [0] ])
variance_matrix = (1/12)* (np.matmul(u1, u1.T) + np.matmul(u2, u2.T) + np.matmul(u3, u3.T))

mp = SharpMatrixMaurerPontilRadius(dim=3)
mpv = SharpMatrixMaurerPontilRadiusVar(dim=3)
b = MatrixBernsteinRadius(variance_matrix=variance_matrix)

for n in range(1,sample_sizes[-1]+1):
  x = np.matmul(u1, u1.T)*np.random.rand() + np.matmul(u2, u2.T)*np.random.rand() + np.matmul(u3, u3.T)*np.random.rand()
  mp.observe(x)
  mpv.observe(x)
  b.observe(x)
  X.append(x)
  if n in sample_sizes:
    print(f"n={n}, MP-to-Bern={mp.radius()/b.radius()}, MP-var-to-Bern={mpv.radius()/b.radius()}, WSR-to-Bern={get_Mwsr_radius(X, dim=3)/b.radius()}")
