import torch
import math
def fractional_pow(base, exponent):
    eps = 1e-4
    return torch.exp(exponent * torch.log(base + eps))

def fractional_pow_taylor(base, exponent, n_terms=4):
    result = torch.ones_like(base)
    term = torch.ones_like(base)
    factorial = 1.0

    for n in range(1, n_terms + 1):
        term = term * torch.log(base) * exponent / n
        result = result + term
        factorial = factorial* n

    return result

def caputoEuler_graphcon(alpha, f, y0, tspan,device):
    """Use one-step Adams-Bashforth (Euler) method to integrate Caputo equation
    D^a y(t) = f(y,t)
    Args:
      a: fractional exponent in the range (0,1)
      f: callable(y,t) returning a numpy array of shape (d,)
         Vector-valued function to define the right hand side of the system
      y0: array of shape (d,) giving the initial state vector y(t==0)
      tspan (array): The sequence of time points for which to solve for y.
        These must be equally spaced, e.g. np.arange(0,10,0.005)
        tspan[0] is the intial time corresponding to the initial state y0.
    Returns:
      y: array, with shape (len(tspan), len(y0))
         With the initial value y0 in the first row
    Raises:
      FODEValueError
    See also:
      K. Diethelm et al. (2004) Detailed error analysis for a fractional Adams
         method
      C. Li and F. Zeng (2012) Finite Difference Methods for Fractional
         Differential Equations
    """


    N = len(tspan)
    h = (tspan[N - 1] - tspan[0]) / (N - 1)
    gamma_alpha = 1 / torch.exp(torch.lgamma(alpha))
    y0 = y0.to(device).to(torch.float32)
    y_shape = y0.size()
    fhistory = []
    yn = y0.clone()
    b_j_k_1=[]
    for k in range(N):
        tn = tspan[k].item()
        f_k = f(torch.tensor(tn, device=device), yn)
        fhistory.append(f_k)
        j_vals = torch.arange(0, k + 1, dtype=torch.float32).unsqueeze(1).to(device)

        b_j_k_1 = (fractional_pow(h, alpha) / alpha) * (fractional_pow(k + 1 - j_vals, alpha) - fractional_pow(k - j_vals, alpha)).to(device)
        # # b_j_k_1 = ((h**alpha / alpha) * ( (k + 1 - j_vals) **alpha) - (k - j_vals)** alpha).to(device)
        # print("h**alpha:",h**alpha)
        # b_j_k_1.append(h**alpha)

        temp_product = torch.stack([b_j_k_1[i] * fhistory[i] for i in range(k + 1)])

        b_all_k = torch.sum(temp_product, dim=0)

        yn = y0 + gamma_alpha * b_all_k

    return yn







def caputoEuler_corrector(alpha, f, y0, tspan,device):
    """Use one-step Adams-Bashforth (Euler) method to integrate Caputo equation
    D^a y(t) = f(y,t)
    Args:
      alpha: fractional exponent in the range (0,1)
      f: callable(y,t) returning a numpy array of shape (d,)
         Vector-valued function to define the right hand side of the system
      y0: array of shape (d,) giving the initial state vector y(t==0)
      tspan (array): The sequence of time points for which to solve for y.
        These must be equally spaced, e.g. np.arange(0,10,0.005)
        tspan[0] is the intial time corresponding to the initial state y0.
    Returns:
      y: array, with shape (len(tspan), len(y0))
         With the initial value y0 in the first row
    Raises:
      FODEValueError
    See also:
      K. Diethelm et al. (2004) Detailed error analysis for a fractional Adams
         method
      C. Li and F. Zeng (2012) Finite Difference Methods for Fractional
         Differential Equations
    """
    # (d, alpha, f, y0, tspan) = _check_args(alpha, f, y0, tspan)
    # tspan = torch.arange(0, 10, 1)
    N = len(tspan)
    h = (tspan[N-1] - tspan[0])/(N - 1)

    # c = torch.lgamma(alpha).exp() * torch.pow(h,alpha) / alpha
    # gamma_alpha = 1 / torch.lgamma(alpha).exp()
    gamma_alpha =1 / math.gamma(alpha)
    a_item = torch.pow(h, alpha) / (alpha *(alpha+1))
    # c = gamma_alpha * torch.pow(h,alpha) / alpha
    # w = c * torch.diff(torch.pow(torch.arange(N), alpha))
    if y0.size() == torch.Size([1]):
        fhistory = torch.zeros((N , len(y0)))
        fhistory_new = torch.zeros((N, len(y0)))

    else:
        fhistory = torch.zeros((N , y0.shape[0], y0.shape[1]))
        fhistory_new = torch.zeros((N, y0.shape[0], y0.shape[1]))
    # y = np.zeros((N, d), dtype=type(y0[0]))
    # y[0] = y0
    yn = y0
    # yn_new = y0
    yn_corrector = y0
    for k in range(0, N ):
        tn = tspan[k]
        # yn = y[n]
        fhistory[k] = f(tn, yn).to(device)

        # vectorized implementation for b_all_k
        j_vals = torch.arange(0, k + 1, dtype=torch.float32).unsqueeze(1).to(device)
        b_j_k_1 = ((torch.pow(h, alpha) / alpha) * (torch.pow(k + 1 - j_vals, alpha) - torch.pow(k - j_vals, alpha))).to(device)

        if b_j_k_1.shape != fhistory[:k + 1].shape:
            b_j_k_1 = b_j_k_1.unsqueeze(-1)
            # print("fhistory[:k + 1]: ", fhistory[:k + 1].shape)
            # print("b_j_k_1: ", b_j_k_1.shape)
        b_all_k = torch.sum(b_j_k_1 * fhistory[:k + 1].to(device), dim=0)

        yn = y0 + gamma_alpha * b_all_k

        # vectorized implementation for a_all
        # a_all = 0
        fhistory_new[k] = f(tn, yn_corrector)
        # j_vals = torch.arange(0, k + 1, dtype=torch.float32).unsqueeze(1).to(device)

        a_j_k_1 = torch.zeros((k + 1,1), dtype=torch.float32).to(device)
        a_j_k_1[0] = a_item * (torch.pow(k, alpha + 1) - (k - alpha) * torch.pow(k + 1, alpha))
        for j in range(1, k + 1):
            a_j_k_1[j] = a_item * (
                        torch.pow(k + 2 - j, alpha + 1) + torch.pow(k - j, alpha + 1) - 2 * torch.pow(k + 1 - j,
                                                                                                      alpha + 1))
        if a_j_k_1.shape != fhistory_new[:k + 1].shape:
            a_j_k_1 = a_j_k_1.unsqueeze(-1)
        a_j_k_all =a_j_k_1.to(device) * fhistory_new[:k + 1].to(device)
        a_all = torch.sum(a_j_k_all, dim=0)

        a_k_k = (a_item * fhistory[k]).to(device)
        yn_corrector = y0 + gamma_alpha * (a_all + a_k_k)




    #     # for _ in range(10):
    # # for k in range(0, N):
    #     a_all = 0
    #     # tn = tspan[k]
    #     fhistory_new[k] = f(tn, yn_corrector)
    #     for j in range(0,k+1):
    #         if j == 0:
    #             a_j_k_1 =a_item* (torch.pow(k,alpha+1) - (k-alpha)*torch.pow(k+1,alpha))
    #         elif j > 0 and j<=k:
    #             a_j_k_1 =a_item*( torch.pow(k+2-j,alpha+1) + torch.pow(k-j,alpha+1) - 2 * torch.pow(k+1-j,alpha+1) )
    #         a_j_k_all = a_j_k_1 * fhistory_new[j].to(device)
    #         a_all += a_j_k_all
    #     a_k_k = (a_item* 1 * fhistory[k]).to(device)
    #     yn_corrector = y0  + gamma_alpha*(a_all + a_k_k)
    #     # yn = yn_corrector
    #     # yn = y0 + torch.dot(w[0:n+1], fhistory[:,n::-1])
    return yn_corrector



def GL_method(alpha, f, y0, tspan,device):
    """Use one-step Adams-Bashforth (Euler) method to integrate Caputo equation
    D^a y(t) = f(y,t)
    Args:
      a: fractional exponent in the range (0,1)
      f: callable(y,t) returning a numpy array of shape (d,)
         Vector-valued function to define the right hand side of the system
      y0: array of shape (d,) giving the initial state vector y(t==0)
      tspan (array): The sequence of time points for which to solve for y.
        These must be equally spaced, e.g. np.arange(0,10,0.005)
        tspan[0] is the intial time corresponding to the initial state y0.
    Returns:
      y: array, with shape (len(tspan), len(y0))
         With the initial value y0 in the first row
    Raises:
      FODEValueError
    See also:
        ```tex
        @INPROCEEDINGS{8742063,
        author={Clemente-López, D. and Muñoz-Pacheco, J. M. and Félix-Beltrán, O. G. and Volos, C.},
        booktitle={2019 8th International Conference on Modern Circuits and Systems Technologies (MOCAST)},
        title={Efficient Computation of the Grünwald-Letnikov Method for ARM-Based Implementations of Fractional-Order Chaotic Systems},
        year={2019},
        doi={10.1109/MOCAST.2019.8742063}}
        ```
    """
    # (d, alpha, f, y0, tspan) = _check_args(alpha, f, y0, tspan)
    # tspan = torch.arange(0, 10, 1)
    N = len(tspan)
    # print("tsapn: ", tspan)
    # print("N: ", N)
    h = (tspan[N-1] - tspan[0])/(N - 1)

    c = torch.zeros(N+1, dtype=torch.float64).to(device)
    c[0] = 1
    for j in range(1, N+1):
        c[j] = (1 - (1+alpha)/j) * c[j-1]
    yn =y0
    # y_history = torch.zeros((N, y0.shape[0], y0.shape[1])).to(device)
    if y0.size() == torch.Size([1]):
        y_history = torch.zeros((N , len(y0))).to(device)

    else:
        y_history = torch.zeros((N , y0.shape[0], y0.shape[1])).to(device)
    y_history[0] = y0
    for k in range(1, N):
        tn = tspan[k]
        right = 0
        for j in range(1, k+1):
            right = (right + c[j] * y_history[k-j]).to(device)
        yn = f(tn,yn) * torch.pow(h, alpha)-right
        y_history[k] = yn
    # yn = y0 + yn
    return yn


def PIEX_method(alpha, f, y0, tspan,device):
    """Use one-step Adams-Bashforth (Euler) method to integrate Caputo equation
    D^a y(t) = f(y,t)
    Args:
      a: fractional exponent in the range (0,1)
      f: callable(y,t) returning a numpy array of shape (d,)
         Vector-valued function to define the right hand side of the system
      y0: array of shape (d,) giving the initial state vector y(t==0)
      tspan (array): The sequence of time points for which to solve for y.
        These must be equally spaced, e.g. np.arange(0,10,0.005)
        tspan[0] is the intial time corresponding to the initial state y0.
    Returns:
      y: array, with shape (len(tspan), len(y0))
         With the initial value y0 in the first row
    Raises:
      FODEValueError
    See also:
    ```tex
    @inproceedings{Garrappa2018NumericalSO,
      title={Numerical Solution of Fractional Differential Equations: A Survey and a Software Tutorial},
      author={Roberto Garrappa},
      year={2018}
    }
    ```
    """
    # (d, alpha, f, y0, tspan) = _check_args(alpha, f, y0, tspan)
    # tspan = torch.arange(0, 10, 1)
    N = len(tspan)
    # print("tsapn: ", tspan)
    # print("N: ", N)
    h = (tspan[N-1] - tspan[0])/(N - 1)

    c = torch.zeros(N+1, dtype=torch.float64).to(device)
    c[0] = 1
    yn = y0
    y_history = torch.zeros((N, y0.shape[0], y0.shape[1])).to(device)
    y_history[0] = y0
    for k in range(1, N):
        middle = 0
        # tn = tspan[k]
        for j in range(0, k):
            tj = tspan[j]
            middle = middle + bcoefficients(k-j-1, alpha) * f(tj, y_history[j])
        middel = middle * (h**alpha)
        yn = y0 + middel
        y_history[k] = yn
    return yn

def PIIM_method(alpha, f, y0, tspan,device):
    """Use one-step Adams-Bashforth (Euler) method to integrate Caputo equation
    D^a y(t) = f(y,t)
    Args:
      a: fractional exponent in the range (0,1)
      f: callable(y,t) returning a numpy array of shape (d,)
         Vector-valued function to define the right hand side of the system
      y0: array of shape (d,) giving the initial state vector y(t==0)
      tspan (array): The sequence of time points for which to solve for y.
        These must be equally spaced, e.g. np.arange(0,10,0.005)
        tspan[0] is the intial time corresponding to the initial state y0.
    Returns:
      y: array, with shape (len(tspan), len(y0))
         With the initial value y0 in the first row
    Raises:
      FODEValueError
    See also:
    ```tex
    @inproceedings{Garrappa2018NumericalSO,
      title={Numerical Solution of Fractional Differential Equations: A Survey and a Software Tutorial},
      author={Roberto Garrappa},
      year={2018}
    }
    ```
    """
    # (d, alpha, f, y0, tspan) = _check_args(alpha, f, y0, tspan)
    # tspan = torch.arange(0, 10, 1)
    N = len(tspan)
    # print("tsapn: ", tspan)
    # print("N: ", N)
    h = (tspan[N-1] - tspan[0])/(N - 1)

    c = torch.zeros(N+1, dtype=torch.float64).to(device)
    c[0] = 1
    yn = y0
    y_history = torch.zeros((N, y0.shape[0], y0.shape[1])).to(device)
    y_history[0] = y0
    for k in range(1, N):
        middle = 0
        # tn = tspan[k]
        for j in range(1, k+1):
            tj = tspan[j]
            middle = middle + bcoefficients(k-j, alpha) * f(tj, y_history[j])
        middel = middle * (h**alpha)
        yn = y0 + middel
        y_history[k] = yn
    return yn
def bcoefficients(n, α):
    return ((n+1)**α-n**α)/ (math.gamma(α+1))


def PIIM_trap_method(alpha, f, y0, tspan,device):
    """Use one-step Adams-Bashforth (Euler) method to integrate Caputo equation
    D^a y(t) = f(y,t)
    Args:
      a: fractional exponent in the range (0,1)
      f: callable(y,t) returning a numpy array of shape (d,)
         Vector-valued function to define the right hand side of the system
      y0: array of shape (d,) giving the initial state vector y(t==0)
      tspan (array): The sequence of time points for which to solve for y.
        These must be equally spaced, e.g. np.arange(0,10,0.005)
        tspan[0] is the intial time corresponding to the initial state y0.
    Returns:
      y: array, with shape (len(tspan), len(y0))
         With the initial value y0 in the first row
    Raises:
      FODEValueError
    See also:
    ```tex
    @inproceedings{Garrappa2018NumericalSO,
      title={Numerical Solution of Fractional Differential Equations: A Survey and a Software Tutorial},
      author={Roberto Garrappa},
      year={2018}
    }
    ```
    """
    # (d, alpha, f, y0, tspan) = _check_args(alpha, f, y0, tspan)
    # tspan = torch.arange(0, 10, 1)
    N = len(tspan)
    # print("tsapn: ", tspan)
    # print("N: ", N)
    h = (tspan[N-1] - tspan[0])/(N - 1)

    t0 = tspan[0]
    yn = y0
    y_history = torch.zeros((N, y0.shape[0], y0.shape[1])).to(device)
    y_history[0] = y0
    for k in range(1, N):
        middle = 0
        a_titlde_k = ((k-1)**alpha - (k**alpha) *(k-alpha-1))/math.gamma(alpha+2)
        # tn = tspan[k]
        for j in range(1, k+1):
            tj = tspan[j]
            middle = middle +a_k_alpha(k-j,alpha) * f(tj, y_history[j])
        middel = (a_titlde_k * f(t0,y0) + middle )* (h**alpha)
        yn = y0 + middel
        y_history[k] = yn
    return yn

def a_k_alpha(k, alpha):
    if k == 0:
        a = 1 / math.gamma(alpha + 2)
    else:
        a = (k-1)**(alpha+1) - 2 * (k ** (alpha+1)) + (k+1)**(alpha+1)
        a = a / (math.gamma(alpha+2))
    return a

if __name__ == '__main__':
    import matplotlib.pyplot as plt
    import numpy as np

    # Define the right hand side of the system
    # def f(y, t):
    #     return y
    #
    # # Define the initial condition
    # # y0 = torch.tensor([1.0])
    # y0 = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
    #
    # # Define the time span
    # tspan = torch.arange(0, 1, 0.1)
    #
    # alpha = torch.tensor(0.5)
    # # Solve the system
    # y = caputoEuler(alpha, f, y0, tspan)
    #
    # print(y)

    def f(t, y):

        # y = (40320 / math.gamma(9-0.5)) * (t ** (8-0.5)) - 3 * ((math.gamma(5+0.5)/math.gamma(5-0.5)) * (t ** (4-0.5/2))) + 9/4 * math.gamma(1+0.5) + ( (3/2 * (t ** 0.5/2) )- t **4) **3 -(y ** (3/2))
        y = (2 / math.gamma(3-0.5)) * (t**(2-0.5)) - (1 / math.gamma(2-0.5)) * (t**(1-0.5)) - y + t**2 -t
        return torch.tensor(y, dtype=torch.float32)
    y0 = torch.tensor([0])
    t=2
    tspan = torch.arange(0, t, t/160)
    alpha = torch.tensor(0.5)
    # print((40320 / math.gamma(9 - 0.5)))
    # print((math.gamma(5 + 0.5) / math.gamma(5 - 0.5)))
    # print(math.gamma(1+0.5) )
    # print((40320 / math.gamma(9 - 0.5)) - 3*(math.gamma(5 + 0.5) / math.gamma(5 - 0.5)) + 9/4 * math.gamma(1+0.5) + 0.5 **3 -1)
    print(y0.shape)
    print(len(y0.shape))
    y = caputoEuler(alpha, f, y0, tspan,device='cpu')
    y1 = caputoEuler_corrector(alpha, f, y0, tspan,device='cpu')
    y2 = GL_method(alpha, f, y0, tspan,device='cpu')
    print("Euler Predictor: ",y.item())
    print("Euler Predictor-corrector: ",y1.item())
    print('GL_method: ',y2.item())
    # y_gt = t**8 - 3*(t**(4+0.5/2)) + 9/4 * (t**(0.5))
    y_gt = t**2 - t
    print("ground truth: " , y_gt)

    # # Plot the solution
    # plt.plot(tspan, y)
    # plt.show()