# File bernstein_comparision.py

"""
We have a neural network that is supposedly a sufficient condition for
p(x,y) >= 0 (which is true when the output is psd)

If p = sum alpha_i * B_i in the Bernstein basis has nonnegative alpha_i,
then p is nonnegative in [0,1]. 

"""

import sympy
import torch
import scipy.interpolate
import numpy as np
from scipy.special import binom

# THIS FUNCTION IS COPIED DIRECTLY FROM
# https://github.com/caslabuiowa/BeBOT/blob/10fe7a8d21e75ecf20d42d58ed95462446aa9863/polynomial/bernstein.py
# That codebase is licensed under GNU GENERAL PUBLIC LICENSE  Version 3, 29 June 2007
def elevMatrix(N, R=1):
    """Creates an elevation matrix for a Bezier curve.
    Creates a matrix to elevate a Bezier curve of degree N to degree N+R.
    The elevation is performed as such:
        B_(N)*T = B_(N+1) where * is the dot product.
    :param N: Degree of the Bezier curve being elevated
    :type N: int
    :param R: Number of degrees to elevate the Bezier curve
    :type R: int
    :return: Elevation matrix to raise a Bezier curve of degree N by R degrees
    :rtype: numpy.ndarray
    """
    T = np.zeros((N+1, N+R+1))
    for i in range(N+R+1):
        den = binom(N+R, i)
        for j in range(N+1):
            T[j, i] = binom(N, j) * binom(R, i-j) / den
    return T


def solve_bernstein(p):
    """
    Input is an array of polynomial coefficients as
    coefficients of monomials y^d, y^(d-1) * x, ..., x^d

    This polynomial is converted to univariate polynomial on the 
    interval [0,1] with the same sign. This is then written in the Bernste basis
    with scipy.interpolate. We return true if and only if all of the Bernstein
    coefficients are nonnegative after elevating degree by maxR
    """
    x, y = sympy.symbols("x y")
    n = len(p) - 1
    ppoly = sum([p[i] * x**i * y**(n-i) for i in range(n+1)])
    ptransform = sympy.Poly(ppoly.subs(x, 2*x-1).subs(y, x*(1-x))) # to interval poly
    # because if 0 < x < 1, then (2x-1)/(x(1-x)) goes from -Inf to Inf.
    # On the interval [0,1] that denominator is nonnegative so we multipy through.
    p2coeffs = ptransform.all_coeffs()
    p2coeffs = np.array([p2coeffs]).reshape(-1,1)
    powerpoly = scipy.interpolate.PPoly(p2coeffs, [0,1])
    bernsteinpoly = scipy.interpolate.BPoly.from_power_basis(powerpoly)
    maxR = 4 * n
    return np.all(np.transpose(bernsteinpoly.c) @ elevMatrix(2*n,maxR) >= 0)

def search_smallest_deg_elev(bernsteinc):
    """
    Find the smallest R such that 
    """

def ispsd(M, cutoff=0):
    """
    Checks if a Symmetric(!) matrix is positive semidefinite
    """
    return all(torch.linalg.eigh(M)[0] >= cutoff)
