import torch


class OnlineApprox:
    # This class describe a general method to compute an online approximation of a function f
    # as a projection on a space spanned by some basis function, when the coefficients of
    # such an approximations have an explicit computable rule for updating the coefficients.
    # In general the options that needs to be passed to such a method are
    # 1. The value of the function that should be tracked, we call it 'target'
    # 2. the number of basis functions that we want to use in order to compute the approximation;
    #     we call it 'order'
    # 3. Addtional options that we collect in opt (e.g. the method we want to use for numerical integration)
    def __init__(self, order, opt, device='cpu'):
        self.device = device
        self.opt = opt
        self.N = order
        self.A = self.computeA()
        self.B = self.computeB()
        self.t = 1.
        self.tau = opt['tau']
        self.coeff = None

    def computeA(self):
        raise NotImplementedError("Implement this in a son class!")

    def computeB(self):
        raise NotImplementedError("Implement this in a son class!")

    def computeA_mod(self):
        raise NotImplementedError("Implement this in a son class!")

    def computeB_mod(self):
        raise NotImplementedError("Implement this in a son class!")

    def euler_step(self, variable, variable_dot):
        _var = variable + self.tau * variable_dot
        return _var

    def update_coefficients(self, x):
        if self.coeff == None:  # Initialize to the appropriate dimensions
                                # First dimensions are the dimensions of the input
                                # then for each of those I add an extra dimension to
                                # contain the coefficients for that
            self.coeff = torch.zeros(x.shape+torch.Size([self.N])).to(device=self.device)
        #Some preliminar definitions to perform the tensor operations needed below

        einsum_str_A = '...ij,...j->...i'
        B_reshaped = self.B.reshape(1, self.N)
        x_expand = x.unsqueeze(-1)

        #Now I compute the update rule
        coeff_dot_first_term = -self.computeA_mod()*torch.einsum(einsum_str_A, self.A, self.coeff)
        coeff_dot_second_term = self.computeB_mod()*(B_reshaped*x_expand)
        coeff_dot = coeff_dot_first_term + coeff_dot_second_term

        # and I update the coefficients
        self.coeff = self.euler_step(self.coeff, coeff_dot)

        # go on
        self.t += 1

    def compute_approximation(self, s):
        #  This method compute the online approximation of the function at time instant s
        assert 0 <= s <= self.tau*self.t  # The independent variable must lie in the domain of the shifted polynomial
        approx = torch.matmul(self.coeff, self.base_fun(s))
        return approx

    def base_fun(self, s):
        raise NotImplementedError("Implement this in a son class!")



class HippoApprox(OnlineApprox):

    def __init__(self, order, opt, device):
        super(HippoApprox, self).__init__(order, opt, device)

    def computeA(self):
        row_indices = torch.arange(self.N).view(-1, 1).to(device=self.device)
        col_indices = torch.arange(self.N).view(1, -1).to(device=self.device)
        # Compute elements for the upper triangular part where i > j
        upper_triangular = torch.sqrt(2 * row_indices + 1) * torch.sqrt(2 * col_indices + 1)
        # Set diagonal elements
        diagonal_elements = torch.arange(1, self.N + 1).float().to(device=self.device)
        # Set lower triangular elements to 0
        lower_triangular = torch.zeros(self.N, self.N).to(device=self.device)
        # Combine the upper triangular, diagonal, and lower triangular parts
        _A = torch.where(row_indices > col_indices, upper_triangular, lower_triangular)
        _A.diagonal().copy_(diagonal_elements)

        return _A

    def computeB(self):
        _B = torch.sqrt(2 * torch.arange(self.N).float().to(device=self.device) + 1)
        return _B

    def computeA_mod(self):
        return 1./(self.t*self.tau)

    def computeB_mod(self):
        return 1./(self.t*self.tau)

    def base_fun(self, s):
        assert 0 <= s <= self.tau * self.t
        norm = torch.sqrt(2 * torch.arange(0, self.N).to(device=self.device) + 1)
        val = norm*self.LegP(((2 * s) / (self.t*self.tau)) - 1)
        return val


    def LegP(self, x):
        order = torch.arange(self.N).to(device=self.device)
        return torch.special.legendre_polynomial_p(x, order)





