import numpy as np


#######################
# data structure for quadrature
# nodes and weights
class QUAD:
    def __init__(self, length: int, X=1, w=1):
        if (type(X) != int):  # (X,w) provided
            if ((X.size != w.size) or (w.size != length)):
                raise Exception("QUAD: dimension mismatch", X.size, w.size, length)
            self.X = X.flatten()
            self.w = w.flatten()

        else:  # ignore default argument
            self.set_length(length)

    # return number of points
    def length(self):
        return self.X.size

    # constructor
    def set_length(self, length):
        # print("Change length to ", length)
        self.X = np.empty(length)
        self.w = np.empty(length)

    def deepcopy(self):
        return QUAD(self.length(), self.X, self.w)

    # output to screen
    def print(self):
        print("Quadrature rule with ", self.length(), " points.")
        print("Weights=", self.w)
        print("Nodes=", self.X)

    # evaluate quadrature for a test function
    def do_quad(self, test):
        # evaluate test fn wrt to quad
        sum = 0.0
        for i in range(self.length()):
            sum += test(self.X[i]) * self.w[i]
        return sum


###########################
def simplify(quad: QUAD, m: int, cutoff=0):
    # Return reduced quadrature
    # If provided, group points outside interval [cutoff[0],cutoff[1]]
    # Apply m-point GQ on interior
    # No action if input quad less than m
    ind = np.empty([])
    if (type(cutoff) == int):  # no cut-off
        if (quad.length() > m):  # only if quad long enough
            ind = np.arange(0, quad.length())
    else:  # use cut-off
        ind = np.where(quad.X <= cutoff[0])  # small nodes
        m1 = np.size(ind)
        w1 = sum(quad.w[ind])
        ind = np.where(quad.X >= cutoff[1])  # big nodes
        m2 = np.size(ind)
        w2 = sum(quad.w[ind])
        if (quad.length()-m1-m2 > m):  # only if quad long enough
            ind = np.where(np.all(
                [quad.X > cutoff[0],
                 quad.X < cutoff[1]],
                axis=0))  # interior nodes
    if (np.size(ind) > m-2):  # compute GQ if needed
        ab = lanczos(m, np.array([quad.X[ind], quad.w[ind]]).T)
        xw = gauss(m, ab)

        # save QG to quad
        if (type(cutoff) == int):  # no cut-off
            quad.set_length(m)
            quad.X[0:m] = xw[:, 0]
            quad.w[0:m] = xw[:, 1]
        else:  # with cut-off
            quad.set_length(m + 2)
            #
            quad.X[0] = cutoff[0]
            quad.w[0] = w1
            #
            quad.X[m + 1] = cutoff[1]
            quad.w[m + 1] = w2
            #
            quad.X[1:m + 1] = xw[:, 0]
            quad.w[1:m + 1] = xw[:, 1]


##############################################
def lanczos(N, xw):
    # Adapted from Gautschi's OPQ Matlab routines
    Ncap = np.size(xw, 0)  # number of rows
    # xw has two columns (nodes, weights)
    p0 = xw[:, 0]  # first column is nodes
    p1 = np.zeros(Ncap)
    p1[0] = xw[0, 1]  # first weight
    for n in range(Ncap-1):
        pn = xw[n + 1, 1]
        gam = 1
        sig = 0
        t = 0
        xlam = xw[n + 1, 0]
        for k in range(0, min(n+2,N)):
            rho = p1[k]+pn
            tmp = gam*rho
            tsig = sig

            if (rho) <= 0:
                gam = 1
                sig = 0
            else:
                gam = p1[k]/rho
                sig = pn/rho
            tk = sig*(p0[k]-xlam)-gam*t
            p0[k] = p0[k]-(tk-t)
            t = tk
            if (sig) <= 0:
                pn = tsig*p1[k]
            else:
                pn = (t**2)/sig
            tsig = sig
            p1[k] = tmp
    ab = np.array([p0[:N], p1[:N]]).T
    return ab


######################################
######################################
def gauss(N, ab):
    # Adapted from Gautschi's OPQ Matlab routines
    # ab has two columns consisting of weights and nodes
    # N is the target length for the GQ rule
    N0 = np.size(ab, 0)  # number of rows
    if N0 < N:
        raise Exception("gauss: input array ab too short")
    J = np.zeros([N, N])
    for n in range(0, N):
        J[n, n] = ab[n, 0]  # nth node
    for n in range(1, N):
        J[n, n-1] = np.sqrt(ab[n, 1])  # from nth weight
        J[n-1, n] = J[n, n-1]
    D, V = np.linalg.eigh(J)
    xw = np.array([D, ab[0, 1]*V[0, :]**2]).T
    return xw


########################### 
def compress(y, w, m):
    # Compress(y, w) to length m (if longer than m)
    # by computing m-point GQ for rule with nodes y
    # and weights w. Return (yc,qc) where yc
    # are new nodes and qc are new exponents.
    #
    # construct QUAD data structure
    quad = QUAD(y.size, y, w)
    print("compress to ", m, "points with GQ")
    # call simplify
    simplify(quad, m)


    thresh=1e-1000
    # convert back to yc, qc co-ordinates
    yc = quad.X.reshape(-1, 1)
    qc = quad.w.reshape(-1, 1)
    #print('min qc',np.min(qc))
    #print(qc.shape)
    index=(qc>thresh)
    yc=yc[index]
    qc=qc[index]
    #print(index)
    n=yc.size
    yc=yc.reshape((n,1))
    qc=qc.reshape((n,1))
    #print(qc.shape)
    
#    qc = np.log(qc)
    return yc, qc
