import torch
import math
# def fractional_pow(base, exponent):
#     eps = 1e-4
#     return torch.exp(exponent * torch.log(base + eps))

def fractional_pow_taylor(base, exponent, n_terms=4):
    result = torch.ones_like(base)
    term = torch.ones_like(base)
    factorial = 1.0

    for n in range(1, n_terms + 1):
        term = term * torch.log(base) * exponent / n
        result = result + term
        factorial = factorial* n

    return result

def fractional_pow(base, exponent):
    eps = 1e-4
    return torch.pow(base, exponent)

def caputoEuler(alpha, f, y0, tspan,device):
    """Use one-step Adams-Bashforth (Euler) method to integrate Caputo equation
    D^a y(t) = f(y,t)
    Args:
      a: fractional exponent in the range (0,1)
      f: callable(y,t) returning a numpy array of shape (d,)
         Vector-valued function to define the right hand side of the system
      y0: array of shape (d,) giving the initial state vector y(t==0)
      tspan (array): The sequence of time points for which to solve for y.
        These must be equally spaced, e.g. np.arange(0,10,0.005)
        tspan[0] is the intial time corresponding to the initial state y0.
    Returns:
      y: array, with shape (len(tspan), len(y0))
         With the initial value y0 in the first row
    Raises:
      FODEValueError
    See also:
      K. Diethelm et al. (2004) Detailed error analysis for a fractional Adams
         method
      C. Li and F. Zeng (2012) Finite Difference Methods for Fractional
         Differential Equations
    """

    N = len(tspan)
    h = (tspan[N - 1] - tspan[0]) / (N - 1)
    gamma_alpha = 1 / torch.exp(torch.lgamma(alpha))
    y0 = y0.to(device).to(torch.float32)
    y_shape = y0.size()
    fhistory = []
    yn = y0.clone()
    b_j_k_1=[]
    for k in range(N):
        tn = tspan[k].item()
        f_k = f(torch.tensor(tn, device=device), yn)
        fhistory.append(f_k)
        j_vals = torch.arange(0, k + 1, dtype=torch.float32).unsqueeze(1).to(device)

        b_j_k_1 = (fractional_pow(h, alpha) / alpha) * (fractional_pow(k + 1 - j_vals, alpha) - fractional_pow(k - j_vals, alpha)).to(device)
        # # b_j_k_1 = ((h**alpha / alpha) * ( (k + 1 - j_vals) **alpha) - (k - j_vals)** alpha).to(device)
        # print("h**alpha:",h**alpha)
        # b_j_k_1.append(h**alpha)

        temp_product = torch.stack([b_j_k_1[i] * fhistory[i] for i in range(k + 1)])

        b_all_k = torch.sum(temp_product, dim=0)

        yn = y0 + gamma_alpha * b_all_k
    del b_j_k_1
    del fhistory
    del f_k

    return yn


def caputoEuler_memory(alpha, f, y0, tspan,device,memory_k=3):
    """Use one-step Adams-Bashforth (Euler) method to integrate Caputo equation
    D^a y(t) = f(y,t)
    Args:
      a: fractional exponent in the range (0,1)
      f: callable(y,t) returning a numpy array of shape (d,)
         Vector-valued function to define the right hand side of the system
      y0: array of shape (d,) giving the initial state vector y(t==0)
      tspan (array): The sequence of time points for which to solve for y.
        These must be equally spaced, e.g. np.arange(0,10,0.005)
        tspan[0] is the intial time corresponding to the initial state y0.
    Returns:
      y: array, with shape (len(tspan), len(y0))
         With the initial value y0 in the first row
    Raises:
      FODEValueError
    See also:
      K. Diethelm et al. (2004) Detailed error analysis for a fractional Adams
         method
      C. Li and F. Zeng (2012) Finite Difference Methods for Fractional
         Differential Equations
    """
    # (d, alpha, f, y0, tspan) = _check_args(alpha, f, y0, tspan)
    # tspan = torch.arange(0, 10, 1)
    N = len(tspan)
    # print("tsapn: ", tspan)
    # print("N: ", N)
    h = (tspan[N-1] - tspan[0])/(N - 1)
    # h = (tspan[N] - tspan[0]) / (N)
    # print("h: ", h)
    # c = torch.lgamma(alpha).exp() * torch.pow(h,alpha) / alpha
    # w = c * torch.diff(torch.pow(torch.arange(N), alpha))
    # gamma_alpha = 1 / torch.lgamma(alpha).exp()
    gamma_alpha = 1/ math.gamma(alpha)
    # c = gamma_alpha * torch.pow(h, alpha) / alpha
    if y0.size() == torch.Size([1]):
        fhistory = torch.zeros((N , len(y0)))

    else:
        #fhistory = torch.zeros((N , y0.shape[0], y0.shape[1]))
        fhistory =[]
    # fhistory = torch.zeros((N, len(y0)))
    # y = np.zeros((N, d), dtype=type(y0[0]))
    # y[0] = y0
    yn = y0
    for k in range(0, N):
        tn = tspan[k]
        # yn = y[n]
        # fhistory[k] = f(tn, yn)
        fhistory.append(f(tn, yn))
        b_all_k = 0
        memory = max(0,k-memory_k)
        for j in range(memory, k+1):
            b_j_k_1 =(torch.pow(h, alpha) /alpha) * (torch.pow(k+1-j, alpha) - torch.pow(k-j, alpha)).to(device)
            b_all_k += b_j_k_1 * fhistory[j].to(device)
        yn = y0 +gamma_alpha* b_all_k
        # yn = y0 + b_all_k
        # yn = y0 + torch.dot(w[0:n+1], fhistory[:,n::-1])
    return yn





def caputoEuler_corrector(alpha, f, y0, tspan,device):
    """Use one-step Adams-Bashforth (Euler) method to integrate Caputo equation
    D^a y(t) = f(y,t)
    Args:
      alpha: fractional exponent in the range (0,1)
      f: callable(y,t) returning a numpy array of shape (d,)
         Vector-valued function to define the right hand side of the system
      y0: array of shape (d,) giving the initial state vector y(t==0)
      tspan (array): The sequence of time points for which to solve for y.
        These must be equally spaced, e.g. np.arange(0,10,0.005)
        tspan[0] is the intial time corresponding to the initial state y0.
    Returns:
      y: array, with shape (len(tspan), len(y0))
         With the initial value y0 in the first row
    Raises:
      FODEValueError
    See also:
      K. Diethelm et al. (2004) Detailed error analysis for a fractional Adams
         method
      C. Li and F. Zeng (2012) Finite Difference Methods for Fractional
         Differential Equations
    """
    N = len(tspan)
    h = (tspan[N - 1] - tspan[0]) / (N - 1)
    gamma_alpha = 1 / torch.exp(torch.lgamma(alpha))
    y0 = y0.to(device).to(torch.float32)
    a_item = torch.pow(h, alpha) / (alpha * (alpha + 1))
    y_shape = y0.size()
    fhistory = []
    yn = y0.clone()
    b_j_k_1 = []
    yn_corrector = y0.clone()
    yn_all = []
    if y0.size() == torch.Size([1]):
        fhistory_new = torch.zeros((N, len(y0)),device=device)

    else:
        fhistory_new = torch.zeros((N, y0.shape[0], y0.shape[1]),device=device)

    for k in range(N):
        tn = tspan[k].item()
        f_k = f(torch.tensor(tn, device=device), yn)
        fhistory.append(f_k)
        j_vals = torch.arange(0, k + 1, dtype=torch.float32,device=device).unsqueeze(1)

        b_j_k_1 = (fractional_pow(h, alpha) / alpha) * (
                    fractional_pow(k + 1 - j_vals, alpha) - fractional_pow(k - j_vals, alpha))
        # # b_j_k_1 = ((h**alpha / alpha) * ( (k + 1 - j_vals) **alpha) - (k - j_vals)** alpha).to(device)
        # print("h**alpha:",h**alpha)
        # b_j_k_1.append(h**alpha)

        temp_product = torch.stack([b_j_k_1[i] * fhistory[i] for i in range(k + 1)])

        b_all_k = torch.sum(temp_product, dim=0)

        yn = y0 + gamma_alpha * b_all_k
        yn_all.append(yn)

    for _ in range(1):
        yn_corrector = y0.clone()
        for k in range(N):
            tn = tspan[k].item()
            # f_k = f(torch.tensor(tn, device=device), yn_all[k])
            fhistory_new[k] = f(tn, yn_corrector)

            a_j_k_1 = torch.zeros((k + 1, 1), dtype=torch.float32,device=device)
            a_j_k_1[0] = a_item * (torch.pow(k, alpha + 1) - (k - alpha) * torch.pow(k + 1, alpha))
            for j in range(1, k + 1):
                a_j_k_1[j] = a_item * (
                        torch.pow(k + 2 - j, alpha + 1) + torch.pow(k - j, alpha + 1) - 2 * torch.pow(k + 1 - j,
                                                                                                      alpha + 1))
            if a_j_k_1.shape != fhistory_new[:k + 1].shape:
                a_j_k_1 = a_j_k_1.unsqueeze(-1)
            a_j_k_all = a_j_k_1* fhistory_new[:k+1]
            a_all = torch.sum(a_j_k_all, dim=0)
            a_k_k = (a_item * fhistory[k])
            yn_corrector = y0 + gamma_alpha * (a_all + a_k_k)
        fhistory = fhistory_new


    return yn_corrector



def GL_method(alpha, f, y0, tspan,device):

    # (d, alpha, f, y0, tspan) = _check_args(alpha, f, y0, tspan)
    # tspan = torch.arange(0, 10, 1)
    N = len(tspan)
    # print("tsapn: ", tspan)
    # print("N: ", N)
    h = (tspan[N-1] - tspan[0])/(N - 1)

    c = torch.zeros(N+1, dtype=torch.float64).to(device)
    c[0] = 1
    for j in range(1, N+1):
        c[j] = (1 - (1+alpha)/j) * c[j-1]
    yn =y0
    # y_history = torch.zeros((N, y0.shape[0], y0.shape[1])).to(device)
    # if y0.size() == torch.Size([1]):
    #     y_history = torch.zeros((N , len(y0))).to(device)
    #
    # else:
    #     y_history = torch.zeros((N , y0.shape[0], y0.shape[1])).to(device)
    y_history = []
    y_history.append(y0)
    for k in range(1, N):
        tn = tspan[k]
        right = 0
        for j in range(1, k+1):
            right = (right + c[j] * y_history[k-j]).to(device)
        yn = f(tn,yn) * torch.pow(h, alpha)-right
        y_history.append(yn)
    # yn = y0 + yn
    return yn
import scipy.special as sp
# Helper function to compute binomial coefficients
def binom(alpha, k):
    return sp.gamma(alpha + 1) / (sp.gamma(k + 1) * sp.gamma(alpha - k + 1))
def GL_order_n(alpha, coefficient, f, y0, tspan, device):
    N = len(tspan)
    h = (tspan[-1] - tspan[0]) / (N - 1)
    alpha_tensor = torch.tensor(alpha, dtype=torch.float32, device=device,requires_grad=False)
    # coeff_tensor = torch.tensor(coefficient, dtype=torch.float32, device=device,requires_grad=False)
    coeff_tensor = torch.stack([p for p in coefficient])

    # Precompute coeff_alpha
    coeff_alpha = []
    for a in alpha:
        coeff_alpha.append(torch.tensor([(-1) ** k * binom(a, k) for k in range(N+1)], dtype=torch.float32, device=device))
    coeff_alpha = torch.stack(coeff_alpha)

    y_history = [y0]
    yn = y0

    res = torch.sum( (1/h) ** alpha_tensor)
    res = 1 / res


    # solution = torch.empty(len(tspan), *y0.shape, dtype=y0.dtype, device=y0.device)
    for k in range(1, N):
        tn = tspan[k]
        right = torch.zeros(len(alpha), *y0.shape, dtype=torch.float32, device=device)

        # if k > 0:
        # Vectorized computation of right
        y_history_tensor = torch.stack(y_history)

        # print("y_history_tensor[:k]:", y_history_tensor[:k].shape)
        # print("y_history_tensor[:k].flip(dims=[0]):", y_history_tensor[:k].flip(dims=[0]).shape)
        for i in range(len(alpha)):
            # print("coeff_alpha[i, 1:k+1]:", coeff_alpha[i, 1:k+1].shape)
            # print("coeff_alpha[i, 1:k+1].unsqueeze(1)", coeff_alpha[i, 1:k+1].unsqueeze(1).shape)
            right[i] = (coeff_alpha[i, 1:k+1].view(-1, 1, 1) * y_history_tensor[:k].flip(dims=[0])).sum(dim=0)

        # print("right:", right.shape)
        # print("coeff_tensor.view(-1, 1, 1)", coeff_tensor.view(-1, 1, 1))
        # print("(h ** alpha_tensor).view(-1, 1, 1) ", (h ** alpha_tensor).view(-1, 1, 1) )
        # print("coeff_tensor.view(-1, 1, 1)", coeff_tensor.view(-1, 1, 1).shape)
        right = right / ((h ** alpha_tensor).view(-1, 1, 1) )
        total = (coeff_tensor.view(-1, 1, 1) * right).sum(dim=0)
        # print("total:", total.shape)
        # print("f(tn, yn):", f(tn, yn).shape)



        yn = f(tn, yn)  - total

        yn = yn * res

        # solution[k] = yn
        y_history.append(yn)

    return yn

def implicit_l1(alpha, f, y0, tspan,device):

    # (d, alpha, f, y0, tspan) = _check_args(alpha, f, y0, tspan)
    # tspan = torch.arange(0, 10, 1)
    N = len(tspan)
    h = (tspan[N-1] - tspan[0])/(N - 1)
    yn = y0
    # # yn_new = y0
    # yn_corrector = y0

    yn_all=[]
    u_h = (torch.pow(h, alpha) * math.gamma(2 - alpha))
    yn_all.append(yn)
    for k in range(1, N ):
        tn = tspan[k]
        # yn = y[n]
        fhistory_k = f(tn, yn).to(device)
        y_sum =0
        for j in range(0, k-2):
            R_k_j = torch.pow(k-j,1-alpha) - torch.pow(k-j-1,1-alpha)
            y_sum = y_sum + R_k_j*(yn_all[j+1]-yn_all[j])


        yn = yn + u_h * fhistory_k - y_sum
        yn_all.append(yn)

    for _ in range(3):
        yn_corrector = y0
        yn_all_corrector = []
        yn_all_corrector.append(yn_corrector)
        for k in range(1,N):
            f_predictor = f(tspan[k], yn_all[k]).to(device)
            y_sum = 0
            for j in range(0, k-2):
                R_k_j = torch.pow(k-j,1-alpha) - torch.pow(k-j-1,1-alpha)
                y_sum = y_sum+ R_k_j * (yn_all_corrector[j+1] - yn_all_corrector[j])

            yn_corrector = yn_corrector + u_h * f_predictor- y_sum
            yn_all_corrector.append(yn_corrector)
        yn_all = yn_all_corrector





    return yn_corrector

def RLcoeffs(index_k, index_j, alpha):
    """Calculates coefficients for the RL differintegral operator.

    see Baleanu, D., Diethelm, K., Scalas, E., and Trujillo, J.J. (2012). Fractional
        Calculus: Models and Numerical Methods. World Scientific.
    """

    if index_j == 0:
        return ((index_k - 1) ** (1 - alpha) - (index_k + alpha - 1) * index_k ** -alpha)
    elif index_j == index_k:
        return 1
    else:
        return ((index_k - index_j + 1) ** (1 - alpha) + (index_k - index_j - 1) ** (1 - alpha) - 2 * (
                    index_k - index_j) ** (1 - alpha))

def product_trap(alpha, f, y0, tspan,device):

    # (d, alpha, f, y0, tspan) = _check_args(alpha, f, y0, tspan)
    # tspan = torch.arange(0, 10, 1)
    N = len(tspan)
    # print("tsapn: ", tspan)
    # print("N: ", N)
    h = (tspan[N-1] - tspan[0])/(N - 1)

    c = torch.zeros(N+1, dtype=torch.float64).to(device)
    c[0] = 1
    for j in range(1, N+1):
        c[j] = (1 - (1+alpha)/j) * c[j-1]
    yn =y0
    y_history = []
    y_history.append(y0)
    for k in range(1, N):
        tn = tspan[k]
        right = 0
        for j in range(0, k):
            right = (right + RLcoeffs(k, j, alpha)  * y_history[j]).to(device)
        yn = math.gamma(2 - alpha) * f(tn,yn) * torch.pow(h, alpha)- right
        y_history.append(yn)

        # yn_new =math.gamma(2 - alpha) * f(tn,yn) * torch.pow(h, alpha)-  right
        # y_history.append(yn_new)
    # yn = y0 + yn
    return yn

def product_trap_corrector(alpha, f, y0, tspan,device):

    # (d, alpha, f, y0, tspan) = _check_args(alpha, f, y0, tspan)
    # tspan = torch.arange(0, 10, 1)
    N = len(tspan)
    # print("tsapn: ", tspan)
    # print("N: ", N)
    h = (tspan[N-1] - tspan[0])/(N - 1)

    c = torch.zeros(N+1, dtype=torch.float64).to(device)
    c[0] = 1
    for j in range(1, N+1):
        c[j] = (1 - (1+alpha)/j) * c[j-1]
    yn =y0
    y_history = []
    y_history_corrector = []
    y_history.append(y0)
    y_history_corrector.append(y0)
    yn_new = y0
    for k in range(1, N):
        tn = tspan[k]
        right = 0
        for j in range(0, k):
            right = (right + RLcoeffs(k, j, alpha)  * y_history[j]).to(device)
        yn = math.gamma(2 - alpha) * f(tn,yn) * torch.pow(h, alpha)- right
        yn_new =math.gamma(2 - alpha) * f(tn,yn) * torch.pow(h, alpha)-  right
        y_history.append(yn_new)
    # yn = y0 + yn
    return yn_new


# def product_trap(alpha, f, y0, tspan,device):
#
#     # (d, alpha, f, y0, tspan) = _check_args(alpha, f, y0, tspan)
#     # tspan = torch.arange(0, 10, 1)
#     N = len(tspan)
#     # print("tsapn: ", tspan)
#     # print("N: ", N)
#     h = (tspan[N-1] - tspan[0])/(N - 1)
#
#     c = torch.zeros(N+1, dtype=torch.float64).to(device)
#     c[0] = 1
#     for j in range(1, N+1):
#         c[j] = (1 - (1+alpha)/j) * c[j-1]
#     yn =y0
#     y_history = []
#     y_history.append(y0)
#     for k in range(1, N):
#         tn = tspan[k]
#         right = 0
#         for j in range(0, k):
#             right = (right + RLcoeffs(k, j, alpha) / math.gamma(2- alpha) * y_history[j]).to(device)
#         yn = math.gamma(2 + alpha) * f(tn,yn) * torch.pow(h, alpha)-math.gamma(2 + alpha) * right
#         # y_history.append(yn)
#
#         yn_new =math.gamma(2 + alpha) * f(tn,yn) * torch.pow(h, alpha)- math.gamma(2 + alpha) * right
#         y_history.append(yn_new)
#     # yn = y0 + yn
#     return yn_new




def PIEX_method(alpha, f, y0, tspan,device):
    """Use one-step Adams-Bashforth (Euler) method to integrate Caputo equation
    D^a y(t) = f(y,t)
    Args:
      a: fractional exponent in the range (0,1)
      f: callable(y,t) returning a numpy array of shape (d,)
         Vector-valued function to define the right hand side of the system
      y0: array of shape (d,) giving the initial state vector y(t==0)
      tspan (array): The sequence of time points for which to solve for y.
        These must be equally spaced, e.g. np.arange(0,10,0.005)
        tspan[0] is the intial time corresponding to the initial state y0.
    Returns:
      y: array, with shape (len(tspan), len(y0))
         With the initial value y0 in the first row
    Raises:
      FODEValueError
    See also:
    ```tex
    @inproceedings{Garrappa2018NumericalSO,
      title={Numerical Solution of Fractional Differential Equations: A Survey and a Software Tutorial},
      author={Roberto Garrappa},
      year={2018}
    }
    ```
    """
    # (d, alpha, f, y0, tspan) = _check_args(alpha, f, y0, tspan)
    # tspan = torch.arange(0, 10, 1)
    N = len(tspan)
    # print("tsapn: ", tspan)
    # print("N: ", N)
    h = (tspan[N-1] - tspan[0])/(N - 1)

    c = torch.zeros(N+1, dtype=torch.float64).to(device)
    c[0] = 1
    yn = y0
    y_history = torch.zeros((N, y0.shape[0], y0.shape[1])).to(device)
    y_history[0] = y0
    for k in range(1, N):
        middle = 0
        # tn = tspan[k]
        for j in range(0, k):
            tj = tspan[j]
            middle = middle + bcoefficients(k-j-1, alpha) * f(tj, y_history[j])
        middel = middle * (h**alpha)
        yn = y0 + middel
        y_history[k] = yn
    return yn

def PIIM_method(alpha, f, y0, tspan,device):
    """Use one-step Adams-Bashforth (Euler) method to integrate Caputo equation
    D^a y(t) = f(y,t)
    Args:
      a: fractional exponent in the range (0,1)
      f: callable(y,t) returning a numpy array of shape (d,)
         Vector-valued function to define the right hand side of the system
      y0: array of shape (d,) giving the initial state vector y(t==0)
      tspan (array): The sequence of time points for which to solve for y.
        These must be equally spaced, e.g. np.arange(0,10,0.005)
        tspan[0] is the intial time corresponding to the initial state y0.
    Returns:
      y: array, with shape (len(tspan), len(y0))
         With the initial value y0 in the first row
    Raises:
      FODEValueError
    See also:
    ```tex
    @inproceedings{Garrappa2018NumericalSO,
      title={Numerical Solution of Fractional Differential Equations: A Survey and a Software Tutorial},
      author={Roberto Garrappa},
      year={2018}
    }
    ```
    """
    # (d, alpha, f, y0, tspan) = _check_args(alpha, f, y0, tspan)
    # tspan = torch.arange(0, 10, 1)
    N = len(tspan)
    # print("tsapn: ", tspan)
    # print("N: ", N)
    h = (tspan[N-1] - tspan[0])/(N - 1)

    c = torch.zeros(N+1, dtype=torch.float64).to(device)
    c[0] = 1
    yn = y0
    y_history = torch.zeros((N, y0.shape[0], y0.shape[1])).to(device)
    y_history[0] = y0
    for k in range(1, N):
        middle = 0
        # tn = tspan[k]
        for j in range(1, k+1):
            tj = tspan[j]
            middle = middle + bcoefficients(k-j, alpha) * f(tj, y_history[j])
        middel = middle * (h**alpha)
        yn = y0 + middel
        y_history[k] = yn
    return yn
def bcoefficients(n, α):
    return ((n+1)**α-n**α)/ (math.gamma(α+1))


def PIIM_trap_method(alpha, f, y0, tspan,device):
    """Use one-step Adams-Bashforth (Euler) method to integrate Caputo equation
    D^a y(t) = f(y,t)
    Args:
      a: fractional exponent in the range (0,1)
      f: callable(y,t) returning a numpy array of shape (d,)
         Vector-valued function to define the right hand side of the system
      y0: array of shape (d,) giving the initial state vector y(t==0)
      tspan (array): The sequence of time points for which to solve for y.
        These must be equally spaced, e.g. np.arange(0,10,0.005)
        tspan[0] is the intial time corresponding to the initial state y0.
    Returns:
      y: array, with shape (len(tspan), len(y0))
         With the initial value y0 in the first row
    Raises:
      FODEValueError
    See also:
    ```tex
    @inproceedings{Garrappa2018NumericalSO,
      title={Numerical Solution of Fractional Differential Equations: A Survey and a Software Tutorial},
      author={Roberto Garrappa},
      year={2018}
    }
    ```
    """
    # (d, alpha, f, y0, tspan) = _check_args(alpha, f, y0, tspan)
    # tspan = torch.arange(0, 10, 1)
    N = len(tspan)
    # print("tsapn: ", tspan)
    # print("N: ", N)
    h = (tspan[N-1] - tspan[0])/(N - 1)

    t0 = tspan[0]
    yn = y0
    y_history = torch.zeros((N, y0.shape[0], y0.shape[1])).to(device)
    y_history[0] = y0
    for k in range(1, N):
        middle = 0
        a_titlde_k = ((k-1)**alpha - (k**alpha) *(k-alpha-1))/math.gamma(alpha+2)
        # tn = tspan[k]
        for j in range(1, k+1):
            tj = tspan[j]
            middle = middle +a_k_alpha(k-j,alpha) * f(tj, y_history[j])
        middel = (a_titlde_k * f(t0,y0) + middle )* (h**alpha)
        yn = y0 + middel
        y_history[k] = yn
    return yn

def a_k_alpha(k, alpha):
    if k == 0:
        a = 1 / math.gamma(alpha + 2)
    else:
        a = (k-1)**(alpha+1) - 2 * (k ** (alpha+1)) + (k+1)**(alpha+1)
        a = a / (math.gamma(alpha+2))
    return a

if __name__ == '__main__':
    import matplotlib.pyplot as plt
    import numpy as np

    # Define the right hand side of the system
    # def f(y, t):
    #     return y
    #
    # # Define the initial condition
    # # y0 = torch.tensor([1.0])
    # y0 = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
    #
    # # Define the time span
    # tspan = torch.arange(0, 1, 0.1)
    #
    # alpha = torch.tensor(0.5)
    # # Solve the system
    # y = caputoEuler(alpha, f, y0, tspan)
    #
    # print(y)

    def f(t, y):

        # y = (40320 / math.gamma(9-0.5)) * (t ** (8-0.5)) - 3 * ((math.gamma(5+0.5)/math.gamma(5-0.5)) * (t ** (4-0.5/2))) + 9/4 * math.gamma(1+0.5) + ( (3/2 * (t ** 0.5/2) )- t **4) **3 -(y ** (3/2))
        y = (2 / math.gamma(3-0.5)) * (t**(2-0.5)) - (1 / math.gamma(2-0.5)) * (t**(1-0.5)) - y + t**2 -t
        return torch.tensor(y, dtype=torch.float32)
    y0 = torch.tensor([0])
    t=4
    tspan = torch.arange(0, t, t/160)
    alpha = torch.tensor(0.5)
    # print((40320 / math.gamma(9 - 0.5)))
    # print((math.gamma(5 + 0.5) / math.gamma(5 - 0.5)))
    # print(math.gamma(1+0.5) )
    # print((40320 / math.gamma(9 - 0.5)) - 3*(math.gamma(5 + 0.5) / math.gamma(5 - 0.5)) + 9/4 * math.gamma(1+0.5) + 0.5 **3 -1)
    print(y0.shape)
    print(len(y0.shape))
    y = caputoEuler(alpha, f, y0, tspan,device='cpu')
    y1 = caputoEuler_corrector(alpha, f, y0, tspan,device='cpu')
    y2 = GL_method(alpha, f, y0, tspan,device='cpu')
    y3 = implicit_l1(alpha, f, y0, tspan,device='cpu')
    y4 = product_trap(alpha, f, y0, tspan,device='cpu')
    y5 = product_trap_corrector(alpha, f, y0, tspan,device='cpu')

    print("Euler Predictor: ",y.item())
    print("Euler Predictor-corrector: ",y1.item())
    print('GL_method: ',y2.item())
    print('implicit_l1: ',y3.item())
    # y_gt = t**8 - 3*(t**(4+0.5/2)) + 9/4 * (t**(0.5))
    print('product_trap: ',y4.item())
    print('product_trap_corrector: ',y5.item())
    y_gt = t**2 - t
    print("ground truth: " , y_gt)

    # # Plot the solution
    # plt.plot(tspan, y)
    # plt.show()