import numpy as np
import math
import cvxpy as cp


class Environment:
    """Methods for the interaction between the Environment and the Interpolator"""
    def __init__(self,
                 eta,
                 rho):
        self.eta = eta
        self.rho = rho

    def cost(self,spline,data,avg=True):
        return self.loss(spline,data,avg=avg) + self.eta * self.roughness(spline,avg=avg)

    def loss(self,spline,data,avg=True):
        time_stamps = data['time stamps']
        signal_values = data['signal values']
        loss = sum( (spline.reconstruct_spline(time_stamps)-signal_values)**2 )
        if avg is True:
            loss = loss / len(spline.pieces)
        return loss

    def roughness(self,spline,avg=True):
        """Computes the spline roughness. 
        It computes the individual roughness of each PolynomialPiece (PP) and sum them up. 
        If avg is set to True computes the roughness per PP""" 
        roughness = 0
        for piece in spline.pieces:
            roughness += self.piece_roughness(piece)
        if avg is True:
            roughness = roughness / len(spline.pieces)
        return roughness        

    def piece_roughness(self,piece):
        """returns the roughness of the polynomial piece"""
        u = piece.end_point - piece.start_point 
        dim = len(piece.coefficients)
        m_roughness = Environment.piece_roughness_matrix(dim,u,self.rho)
        
        return piece.coefficients.T @ m_roughness @ piece.coefficients

    @staticmethod
    def piece_roughness_matrix(dim,u,rho):
        """Roughness matrix of a polynomial piece
        Args:
            dim (int): dimensionality (order of the polynomial + 1)
            u : time section length (x - x_)
            rho : derivative order
        Returns:
            (matrix): roughness matrix"""
        m_roughness = np.zeros(( dim,dim ))
        for row in range(dim):
            for column in range(dim):
                if row <= (rho-1) or column <= (rho-1):
                    m_roughness[row][column] = 0
                else:
                    c = 1
                    for i in range(rho):
                        c *= (row-i)*(column-i)
                    m_roughness[row][column] = c * u**(row+column-2*rho+1)/(row+column-2*rho+1)  
        return m_roughness
    
    @staticmethod
    def piece_roughness_tensor(dim,time_stamps,rho):
        """Computes the tensor containing the roughness matrix for each of the future pieces"""
        time_distances = time_stamps[1:] - time_stamps[:-1]
        t_roughness = np.zeros(( len(time_distances),dim,dim ))
        for i,u in enumerate(time_distances):
            t_roughness[i] = Environment.piece_roughness_matrix(dim,u,rho)
        return t_roughness
    
    
# PARENT CLASS
class Interpolator:
    def __init__(self,
                 order,
                 smooth,
                 eta,
                 rho,
                 origin={'x0':0,'y0':0}):
        
        self.order = order
        self.smooth = smooth
        self.eta = eta
        self.rho = rho
        self.origin = origin
        self.origin_flag = False
        self.spline = self.Spline() 
        
    def interpolate(self,time_stamps,signal_values):
        """Solves the interpolation optimization problem"""
        dim = self.order + 1
        phi = self.smooth + 1
        number_pieces = len(time_stamps)
        A = cp.Variable(( number_pieces,dim )) ### optimization variable
        if self.origin_flag is False:
            augmented_time_stamps = np.concatenate(( np.array(self.origin['x0']).reshape(1,),time_stamps ))
            constraints = [ A[0][0] == self.origin['y0'] ]
            self.origin_flag = True
        else:
            augmented_time_stamps = np.concatenate(( self.state.signal_state['x_'].reshape(1,),time_stamps ))
            constraints = [ A[0][:phi] == self.state.signal_state['cc'] ]
        t_roughness = Environment.piece_roughness_tensor(dim,augmented_time_stamps,self.rho)
        cost = 0
        for t in range(number_pieces):
            roughness = cp.quad_form( A[t],t_roughness[t] )
            pivot_point = augmented_time_stamps[t] 
            evaluation_point = augmented_time_stamps[t+1]
            basis = Interpolator.Spline.PolynomialPiece.piece_basis(dim,pivot_point,evaluation_point)
            loss = cp.square( A[t].T @ basis - signal_values[t] )
            cost += loss + self.eta * roughness
            if t>=1:
                pivot_point = augmented_time_stamps[t-1]
                evaluation_point = augmented_time_stamps[t]
                for k in range(phi):
                    basis = Interpolator.Spline.PolynomialPiece.piece_basis(dim,pivot_point,evaluation_point,k=k)
                    constraints += [ A[t][k] == A[t-1].T @ basis / math.factorial(k) ]
        problem = cp.Problem( cp.Minimize(cost),constraints )
        try:
            problem.solve(solver=cp.ECOS)  
        except cp.SolverError:
            problem.solve(solver=cp.SCS)
        m_coefficients = A.value
        return m_coefficients  
    
    # INNER CLASSES
    class Spline:
        """The Spline object is intended to be accessed/modified from the above class"""
        def __init__(self):
            self.reset_spline()

        def reset_spline(self):
            """the spline is constructed as a list of polynomial pieces"""
            self.pieces = []

        def append_piece(self,coefficients,start_point,end_point):
            """the polynomial piece is created and added by the interpolator"""
            piece = self.PolynomialPiece(coefficients,start_point,end_point)
            self.pieces.append( piece )
        
        def reconstruct_spline(self,x,k=0):
            """reconstructs the spline (or any of its derivatives)"""
            ### to fix dimensionality issues
            x = x.reshape((1,)) if x.size == 1 else x
            f = np.zeros(x.size)
            for piece in self.pieces:
                idxs = (x>=piece.start_point)&(x<=piece.end_point)
                f[idxs] = piece.reconstruct_piece(x[idxs],k=k)
            return f
        
        ## INNER (NESTED) CLASS
        class PolynomialPiece:
            """The PolynomialPiece object is intended to be accessed/modified from the above classes"""
            def __init__(self,
                         coefficients,
                         start_point,
                         end_point):
                self.coefficients = coefficients
                self.start_point = start_point
                self.end_point = end_point   

            def reconstruct_piece(self,x,k=0):
                """Evaluates the polynomial piece (or its k-th derivative) -- CAUTION: even outside of its domain --
                Args:
                    x (array): time instants where to evaluate
                    k (int): order of the derivative
                Returns:
                    (array) [len(x) -x- 1]: evaluation of the polynomial piece at time instats in x"""
                dim = len(self.coefficients) 
                m_basis = Interpolator.Spline.PolynomialPiece.piece_basis(dim,self.start_point,x,k=k)
                return self.coefficients.T @ m_basis

            @staticmethod
            def piece_basis(dim,x_,x,k=0):
                """Piece (polynomial) basis (for any kth derivative) array-evaluated
                Args:
                    dim (int): dimensionality (order of the polynomial + 1)
                    x_ : pivot point in the polynomial point value representation
                    x (array): time instants
                    k (int): order of the derivative
                Returns:
                    m_basis (matrix) [dim -x- len(x)]: basis matrix (1 array for each time instant)
                """
                m_basis = np.zeros(( dim, x.size ))
                for i in range(k,dim):
                    c = 1
                    for j in range(k):
                        c *= (i - j)
                    m_basis[i] = c * (x - x_) ** (i - k)
                return m_basis    

# CHILDREN CLASSES 
class OnlineInterpolator(Interpolator):

    def __init__(self,
                 order,
                 smooth,
                 rho, ## needed for the policy
                 eta,
                 origin={'x0':0,'y0':0}):
        super().__init__(order,smooth,eta,rho,origin)
        self.rho = rho
        self.eta = eta
        self.state = self.SignalState(smooth,origin)

    def forward(self,data):
        ## evaluate policy (retrieve the coefficients of the next piece)
        coefficients = self.policy(data)
        if coefficients is None:
            return False
        
        ## construct/update the corresponding spline piece and signal state
        start_point = self.state.signal_state['x_']
        end_point = data['time stamps'][0]
        self.spline.append_piece(coefficients,start_point,end_point)
        self.state.update_signal_state( self.spline.pieces[-1] )

        return True

    def policy(self,data):
        time_stamps = data['time stamps']
        signal_values = data['signal values']
        m_coefficients = self.interpolate(time_stamps,signal_values)
        if m_coefficients is None:
            return None
        else:
            return m_coefficients[0]

    ## INNER CLASS
    class SignalState:
        """The SignalState object is intended to be accessed/modified from the above class"""
        def __init__(self,
                     smooth,
                     origin):
            self.smooth = smooth
            self.reset_signal_state(origin)

        def reset_signal_state(self,origin):
            signal_state = {}
            signal_state['x_'] = np.array(origin['x0'])
            ## cc: continuity constraints
            signal_state['cc'] = np.concatenate(( np.array(origin['y0']).reshape(1,) , np.zeros(self.smooth) )) 
            self.signal_state = signal_state
        
        def update_signal_state(self,piece):
            dim = len(piece.coefficients)
            continuity_constraints = np.zeros(self.smooth + 1)
            for i in range(self.smooth + 1):
                basis = piece.piece_basis(dim,piece.start_point,piece.end_point,k=i)
                continuity_constraints[i] = piece.coefficients.T @ basis / math.factorial(i)
            self.signal_state['cc'] = continuity_constraints
            self.signal_state['x_'] = piece.end_point

class BatchInterpolator(Interpolator):
    def __init__(self,
                 order,
                 smooth,
                 rho,
                 eta,
                 origin={'x0':0,'y0':0}):
        super().__init__(order,smooth,eta,rho,origin)
        self.rho = rho
        self.eta = eta

    def solve(self,data):
        time_stamps = data['time stamps']
        signal_values = data['signal values']
        m_coefficients = self.interpolate(time_stamps,signal_values)
        self.instantiate(time_stamps,m_coefficients)
        
    def instantiate(self,time_stamps,m_coefficients):
        """Instantiates the spline from the time stamps and the coefficients"""
        self.spline.reset_spline()
        augmented_time_stamps = np.concatenate(( np.array(self.origin['x0']).reshape(1,),time_stamps ))
        for t,coefficients in enumerate(m_coefficients):
            self.spline.append_piece(coefficients,augmented_time_stamps[t],augmented_time_stamps[t+1])
