from .adaptive_grid_solver import AdaptiveGridSolver
from .fixed_grid_solver import FixedGridSolver
import torch
import torch.nn as nn

__all__ = ['Sym12Async', 'FixedStep_Sym12Async','Yoshida_ALF2','FixedStep_Yoshida_ALF2', 'Suzuki_ALF2', 'FixedStep_Suzuki_ALF2']
Lambda = 1.0




def sym12async_forward(func, t, dt, y, adaptive = False):
    _len = len(y)
    y0, v0 = y[0:_len // 2], y[_len // 2:]
    #print('step ALF',dt)
    y1 = tuple(_y0 + 0.5 * dt * _v0 for _y0, _v0 in zip(y0, v0))
    vt1 = func(t + 0.5 * dt, y1)

    v1 = tuple(2 * Lambda * (_vt1 - _v0) + _v0 for _vt1, _v0 in zip(vt1, v0))
    y2 = tuple(_y1 + 0.5 * dt * _v1 for _y1, _v1 in zip(y1, v1))

    out = tuple(list(y2) + list(v1))

    if adaptive:
        error = tuple(_v1 * dt - _v0 * dt for _v1, _v0 in zip(vt1, v0))
        #print('Error',error)
        return out, error, [vt1, y1]
    else:
        return out, None, [v1, y1]

def sym12async_inverse(func, t1, dt, y):
    t1 = t1.to(y[0].device)
    dt = dt.to(y[0].device)
    t = t1 - dt  # initial time

    _len = len(y)
    y2, v1 = y[0:_len // 2], y[_len // 2:]
    y1 = tuple(_y2 - 0.5 * dt * _v1 for _y2, _v1 in zip(y2, v1))

    vt1 = func(t + 0.5 * dt, y1)
    v0 = tuple((2 * Lambda * _vt1 - _v1) / (2.0 * Lambda - 1.0) for _vt1, _v1 in zip(vt1, v1))

    y0 = tuple(_y1 - 0.5 * dt * _v0 for _y1, _v0 in zip(y1, v0))

    out = tuple(list(y0) + list(v0))

    return out, [y1, vt1]


def Yoshida1(f, t, h, u, adaptive = False):
    a = 1/(2 - 2**(1/3))
    
    b = 1 - 2*a

    u1 = sym12async_forward(f, t, (a*h)/2,u)
    t1 = t + (a*h)/2
    
    u2 = sym12async_forward(f, t1, (a*h)/2,u1[0])
    t2 = t1 + (a*h)/2

    u3 = sym12async_forward(f, t2, (b*h)/2,u2[0])
    t3 = t2 + (b*h)/2
    
    u4  = sym12async_forward(f, t3, (b*h)/2,u3[0])
    t4 = t3 + (b*h)/2

    u5 = sym12async_forward(f, t4, (a*h)/2,u4[0])
    t5 = t4 + (a*h)/2
    
    u6 = sym12async_forward(f, t5, (a*h)/2,u5[0])
    t6 = t5 + (a*h)/2

    #print('Time Y t6',t6)
    
    
    if adaptive:
        return u6
    else:
        return u6
    
    
def Yoshida2(f, t, h, u, adaptive = False):
    _len = len(u)
    u0, v0 = u[0:_len // 2], u[_len // 2:]
    #print('step Y',h)
    a = 1/(2 - 2**(1/3))
    
    b = 1 - 2*a
    
    u1 = sym12async_forward(f, t, (a*h)/2,u)
    t1 = t + (a*h)/2
    
    u2 = sym12async_forward(f, t1, (a*h)/2,u1[0])
    t2 = t1 + (a*h)/2

    u3 = sym12async_forward(f, t2, (b*h)/2,u2[0])
    t3 = t2 + (b*h)/2
    
    u4  = sym12async_forward(f, t3, (b*h)/2,u3[0])
    t4 = t3 + (b*h)/2

    u5 = sym12async_forward(f, t4, (a*h)/2,u4[0])
    t5 = t4 + (a*h)/2
    
    out, error1, [vt1, y1] = sym12async_forward(f, t5, (a*h)/2,u5[0])
    t6 = t3 + (a*h)/2
    
    
    ###########################################################
    
    
    
    if adaptive:
        
        u6 = out[0]
        
        dt = h
        k1 = f(t, u0)
        k2 = f(t + dt / 5, tuple( _y + 1 / 5 * dt * _k1 for _y, _k1 in zip(u0, k1))   )
        k3 = f(t + dt * 3 / 10,  tuple( _y + 3 / 40 * dt * _k1 + 9.0 / 40.0 * dt * _k2 for
                                           _y, _k1, _k2 in zip(u0, k1, k2)) )
        k4 = f(t + dt * 4. / 5., tuple( _y + 44. / 45. * dt * _k1 - 56. / 15. * dt * _k2 + 32. / 9. * dt * _k3 for
                                           _y, _k1, _k2, _k3 in zip(u0, k1, k2, k3)))
        k5 = f(t + dt * 8. / 9.,
                       tuple( _y + 19372. / 6561. * dt * _k1 - 25360. / 2187. *dt * _k2 + \
                              64448. / 6561. * dt * _k3 - 212. / 729. * dt * _k4 for
                              _y, _k1, _k2, _k3, _k4 in zip(u0, k1, k2, k3, k4) ))

        k6 = f(t + dt,
                       tuple( _y + 9017. / 3168.*dt * _k1 - 355. / 33. * dt * _k2 + 46732. / 5247. * dt * _k3 + \
                              49. / 176. * dt * _k4 - 5103. / 18656. * dt * _k5 for
                        _y, _k1, _k2, _k3, _k4, _k5 in zip(u0, k1, k2, k3, k4, k5)) )



        out2 = tuple( _y + 35. / 384. * dt * _k1 + 0 * dt * _k2 + 500. / 1113. *dt * _k3 +
                      125. / 192. * dt * _k4 - 2187. / 6784. * dt * _k5 + 11. / 84. *dt * _k6 for
                      _y, _k1, _k2, _k3, _k4, _k5, _k6 in zip(u0, k1, k2, k3, k4, k5, k6))
        error = tuple( _y1 - _y2 for _y1, _y2 in zip((u6,),out2))
        
        return out, error, [vt1, y1]
    else:
        return out, None, [vt1, y1] 
    
    
 
def Yoshida_inverse(f, t, h, u):
     a = 1/(2 - 2**(1/3))
     
     b = 1 - 2*a

     u1 = sym12async_inverse(f, t, (a*h)/2,u)
     t1 = t - (a*h)/2
     
     u2 = sym12async_inverse(f, t1, (a*h)/2,u1[0])
     t2 = t1 - (a*h)/2

     u3 = sym12async_inverse(f, t2, (b*h)/2,u2[0])
     t3 = t2 - (b*h)/2
     
     u4 = sym12async_inverse(f, t3, (b*h)/2,u3[0])
     t4 = t3 - (b*h)/2

     u5 = sym12async_inverse(f, t4, (a*h)/2,u4[0])
     t5 = t4 - (a*h)/2
     
     u6 = sym12async_inverse(f, t5, (a*h)/2,u5[0])
     t6 = t5 - (a*h)/2
 
     return u6

class Sym12Async(AdaptiveGridSolver):
    order = 1
    def step(self, func, t, dt, y, return_variables=False):
        out, error, variables = sym12async_forward(func, t, dt, y, adaptive=True)
        if return_variables:
            return out, error, variables
        else:
            return out, error

    def inverse_async(self, func, t1, dt, y):
        return sym12async_inverse(func, t1, dt, y)

class FixedStep_Sym12Async(FixedGridSolver):
    order = 1
    def step(self, func, t, dt, y, return_variables=False):
        out, error, variables = sym12async_forward(func, t, dt, y, adaptive=False)
        if return_variables:
            return out, error, variables
        else:
            return out, error

    def inverse_async(self, func, t1, dt, y):
        return sym12async_inverse(func, t1, dt, y)
    
    
class Yoshida_ALF2(AdaptiveGridSolver):
    order = 4  
    def step(self, func, t, dt, y, return_variables=False):
        out, error, variables = Yoshida2(func, t, dt, y, adaptive=True)
        
        if return_variables:
            return out, error, variables
        else:
            return out, error

    def inverse_async(self, func, t1, dt, y):
        return Yoshida_inverse(func, t1, dt, y)

class FixedStep_Yoshida_ALF2(FixedGridSolver):
    order = 4
   
    def step(self, func, t, dt, y, return_variables=False):
        out, error, variables = Yoshida2(func, t, dt, y, adaptive=False)
        if return_variables:
            return out, error, variables
        else:
            return out, error

    def inverse_async(self, func, t1, dt, y):
        return Yoshida_inverse(func, t1, dt, y)
    
    

def Suzuki1(f, t, h, u, adaptive = False):
    a = 1/(4-4**(1/3))
    
    b = 1 - 4*a

    u1 = sym12async_forward(f, t, (a*h)/2,u)
    t1 = t +(a*h)/2
    u2 = sym12async_forward(f, t1, (a*h)/2,u1[0])
    t2 = t1 + (a*h)/2
    
    u3 = sym12async_forward(f, t2, (a*h)/2,u2[0])
    t3 = t2 +(a*h)/2
    u4 = sym12async_forward(f, t3, (a*h)/2,u3[0])
    t4 = t3 +(a*h)/2

    u5 = sym12async_forward(f, t4, (b*h)/2,u4[0])
    t5 = t4 + (b*h)/2
    u6 = sym12async_forward(f, t5, (b*h)/2,u5[0])
    t6 = t5 + (b*h)/2
    
    u7  = sym12async_forward(f, t6,(a*h)/2,u6[0])
    t7 = t6 + (a*h)/2
    u8  = sym12async_forward(f, t7,(a*h)/2,u7[0])
    t8 = t7 + (a*h)/2

    u9 = sym12async_forward(f, t8, (a*h)/2,u8[0])
    t9 = t8 + (a*h)/2
    u10 = sym12async_forward(f, t9, (a*h)/2,u9[0])
    t10 = t9 + (a*h)/2
    
  
    
    
    if adaptive:
        return u10
    else:
        return u10
def Suzuki2(f, t, h, u, adaptive = False):
    _len = len(u)
    u0, v0 = u[0:_len // 2], u[_len // 2:]
    #print('Time Y',h)
    a = 1/(4-4**1/3)
    
    b = 1 - 4*a
    
    
    
    
    
    u1 = sym12async_forward(f, t, (a*h)/2,u)
    t1 = t + (a*h)/2
    u2 = sym12async_forward(f, t1, (a*h)/2,u1[0])
    t2 = t1 + (a*h)/2
    
    u3 = sym12async_forward(f, t2, (a*h)/2,u2[0])
    t3 = t2 + (a*h)/2
    u4 = sym12async_forward(f, t3, (a*h)/2,u3[0])
    t4 = t3 + (a*h)/2

    u5 = sym12async_forward(f, t4, (b*h)/2,u4[0])
    t5 = t4 + (b*h)/2
    u6 = sym12async_forward(f, t5, (b*h)/2,u5[0])
    t6 = t5 + (b*h)/2
    
    u7  = sym12async_forward(f, t6, (a*h)/2,u6[0])
    t7 = t6 + (a*h)/2
    u8  = sym12async_forward(f, t7, (a*h)/2,u7[0])
    t8 = t7 + (a*h)/2

    u9 = sym12async_forward(f, t8, (a*h)/2,u8[0])
    t9 = t8 + (a*h)/2
    out, error1, [vt1, y1] = sym12async_forward(f, t9, (a*h)/2,u9[0])
    t10 = t9 + (a*h)/2
    
    
    
    if adaptive:
        u10 = out[0]
        ###########################################################
        
        dt = h
        k1 = f(t, u0)
        k2 = f(t + dt / 5, tuple( _y + 1 / 5 * dt * _k1 for _y, _k1 in zip(u0, k1))   )
        k3 = f(t + dt * 3 / 10,  tuple( _y + 3 / 40 * dt * _k1 + 9.0 / 40.0 * dt * _k2 for
                                           _y, _k1, _k2 in zip(u0, k1, k2)) )
        k4 = f(t + dt * 4. / 5., tuple( _y + 44. / 45. * dt * _k1 - 56. / 15. * dt * _k2 + 32. / 9. * dt * _k3 for
                                           _y, _k1, _k2, _k3 in zip(u0, k1, k2, k3)))
        k5 = f(t + dt * 8. / 9.,
                       tuple( _y + 19372. / 6561. * dt * _k1 - 25360. / 2187. *dt * _k2 + \
                              64448. / 6561. * dt * _k3 - 212. / 729. * dt * _k4 for
                              _y, _k1, _k2, _k3, _k4 in zip(u0, k1, k2, k3, k4) ))

        k6 = f(t + dt,
                       tuple( _y + 9017. / 3168.*dt * _k1 - 355. / 33. * dt * _k2 + 46732. / 5247. * dt * _k3 + \
                              49. / 176. * dt * _k4 - 5103. / 18656. * dt * _k5 for
                        _y, _k1, _k2, _k3, _k4, _k5 in zip(u0, k1, k2, k3, k4, k5)) )



        out2 = tuple( _y + 35. / 384. * dt * _k1 + 0 * dt * _k2 + 500. / 1113. *dt * _k3 +
                      125. / 192. * dt * _k4 - 2187. / 6784. * dt * _k5 + 11. / 84. *dt * _k6 for
                      _y, _k1, _k2, _k3, _k4, _k5, _k6 in zip(u0, k1, k2, k3, k4, k5, k6))
        error = tuple( _y1 - _y2 for
                           _y1, _y2 in zip(u10,out2))
        return out, error, [vt1, y1]
    else:
        return out, None, [vt1, y1] 
    
def Suzuki_inverse(f, t, h, u):
     a = 1/(4-4**1/3)
     
     b = 1 - 4*a

     u1 = sym12async_inverse(f, t, (a*h)/2,u)
     t1 = t - (a*h)/2
     u2 = sym12async_inverse(f, t1, (a*h)/2,u1[0])
     t2 = t1 - (a*h)/2
     
     u3 = sym12async_inverse(f, t2, (a*h)/2,u2[0])
     t3 = t2 - (a*h)/2
     u4 = sym12async_inverse(f, t3, (a*h)/2,u3[0])
     t4 = t3 - (a*h)/2

     u5 = sym12async_inverse(f, t4, (a*h)/2,u4[0])
     t5 = t4 - (a*h)/2
     u6 = sym12async_inverse(f, t5, (a*h)/2,u5[0])
     t6 = t5 - (a*h)/2
     
     u7 = sym12async_inverse(f, t6, (a*h)/2,u6[0])
     t7 = t6 - (a*h)/2
     u8 = sym12async_inverse(f, t7, (a*h)/2,u7[0])
     t8 = t7 - (a*h)/2

     u9 = sym12async_inverse(f, t8, (a*h)/2,u8[0])
     t9 = t8 - (a*h)/2
     u10 = sym12async_inverse(f, t9, (a*h)/2,u9[0])
     t10 = t9 - (a*h)/2
     
   
    
 
    
 
    
     return u10
    
class Suzuki_ALF2(AdaptiveGridSolver):
    order = 4   
    def step(self, func, t, dt, y, return_variables=False):
        out, error, variables = Suzuki2(func, t, dt, y, adaptive=True)
        
        if return_variables:
            return out, error, variables
        else:
            return out, error

    def inverse_async(self, func, t1, dt, y):
        return Suzuki_inverse(func, t1, dt, y)

class FixedStep_Suzuki_ALF2(FixedGridSolver):
    order = 4
   
    def step(self, func, t, dt, y, return_variables=False):
        out, error, variables = Suzuki1(func, t, dt, y, adaptive=False)
        if return_variables:
            return out, error, variables
        else:
            return out, error

    def inverse_async(self, func, t1, dt, y):
        return Suzuki_inverse(func, t1, dt, y)