from .base import ODESolver
import torch
import numpy as np
from torch import nn
__all__ = ['Euler','RK2','RK4', 'H2', 'Verlet', 'H4', 'H3', 'H2explicit', 'Yoshida4', 'Yoshida6']

class FixedGridSolver(ODESolver):

    def __init__(self, func, t0, y0, t1=1.0, h=0.1, rtol=1e-3, atol=1e-6, neval_max=500000,
                 print_neval=False, print_direction=False, step_dif_ratio=1e-3, safety=0.9,
                 regenerate_graph=False, dense_output=True, interpolation_method = 'cubic', end_point_mode = False, print_time = False):
        super(FixedGridSolver, self).__init__(func=func, t0=t0, y0=y0, t1=t1, h=h, rtol=rtol,
                                              atol=atol, neval_max=neval_max,
                 print_neval=print_neval, print_direction=print_direction, step_dif_ratio=step_dif_ratio, safety=safety,
                 regenerate_graph=regenerate_graph, dense_output=dense_output,
                                              interpolation_method = interpolation_method,
                                              end_point_mode = end_point_mode, print_time = print_time)

        if h is None:
            print('Stepsize h is required for fixed grid solvers')

        if not isinstance(h, torch.Tensor):
            h = torch.tensor(h).to(y0[0].device).float()
        self.h = h

        self.Nt = round(abs(self.t1.item() - self.t0.item())/self.h.item())

    def step(self, *args, **kwargs):
        pass

    def integrate(self, y0, t0, predefine_steps=None, return_steps=False, t_eval=None):
        # determine integration steps
        if predefine_steps is None:  # use steps defined by h
            steps = [self.t0 + (n + 1) * torch.abs(self.h) * self.time_direction for n in range(self.Nt)]
            steps = torch.stack(steps).view(-1).float()
        else:
            steps = predefine_steps

        out = self.integrate_predefined_grids(y0, t0, predefine_steps=steps, t_eval=t_eval)

        if return_steps:
            return out, steps
        else:
            return out

class Euler(FixedGridSolver):
    order = 1
    def step(self, func, t, dt, y, return_variables=False):
        k1 = func(t,y)
        out = tuple( _y + dt * _k1 for _y, _k1 in zip(y, k1) )
        if return_variables:
            return out, None, k1
        else:
            return out, None,

class RK2(FixedGridSolver):
    order = 2
    def step(self, func, t, dt, y, return_variables=False):
        k1 = func(t, y)
        k2 = func(t + dt / 2.0, tuple( _y + 1.0 / 2.0 * dt *_k1 for _y, _k1 in zip(y, k1)) )
        out = tuple( _y + dt * _k2 for _y, _k2 in zip(y, k2) )
        if return_variables:
            return out, None, [k1, k2]
        else:
            return out, None

class RK4(FixedGridSolver):
    order = 4
    def step(self, func, t, dt, y, return_variables=False):
        k1 = func(t, y)
        k2 = func(t + dt / 2.0, tuple( _y + 1.0 / 2.0 * dt *_k1 for _y, _k1 in zip(y, k1)  )   )
        k3 = func(t + dt / 2.0, tuple( _y + 1.0 / 2.0 * dt *_k2 for _y, _k2 in zip(y, k2)  ) )
        k4 = func(t + dt,  tuple( _y + dt *_k3 for _y, _k3 in zip(y, k3)  )   )
        out = tuple( _y + 1.0 / 6.0 * dt * _k1 + 1.0 / 3.0 * dt * _k2 + 1.0 / 3.0 * dt * _k3 + 1.0 / 6.0 * dt * _k4
                     for _y, _k1, _k2, _k3, _k4 in zip(y, k1, k2, k3, k4))
        if return_variables:
            return out, None, [k1, k2, k3, k4]
        else:
            return out, None
class H2(FixedGridSolver):
    order = 1
    def step(self, func, t, dt, y, return_variables=False):
        
        _len = len(y)
        y0, v0 = y[0:_len // 2], y[_len // 2:]
        #y0, v0 = y[0], y[1]
        #_len2 = len(func)
        #f, g = funk[0:_len2 // 2], funk[_len2 // 2:]
        #y1,v1 = func(t,y)[0]
        
        
        #g = func.g
        k2 = (1+np.sin(t))
        y1 = tuple(_y0 + dt*k2*_v0 for _y0, _v0 in zip(y0, v0))
        k1 = func(y1,t)#(t+dt,y1)
        y2 = tuple(_v0 + dt*_k1 for _k1, _v0 in zip(k1, v0))

        out = tuple(list(y1) + list(y2))
        
        if return_variables:
            return out, None, k1
        else:
            return out, None
        
class Verlet(FixedGridSolver):
    order = 2
    def step(self, func, t, dt, y, return_variables=False):
        _len = len(y)
        y0, v0 = y[0:_len // 2], y[_len // 2:]
        
        
        
       
        
        dt2 = 1.0/2.0 * dt
        
        
        y1 = tuple(_y0 + dt2*_v0 for _y0, _v0 in zip(y0, v0))
        
        
        k1 = func(y1)#(t+dt,y1)
        v2 = tuple(_v0 + dt*_k1 for _k1, _v0 in zip(k1, v0))
        
        y2 = tuple(_y1 + dt2*_v2 for _y1, _v2 in zip(y1, v2))
        #g = func.g
    

        out = tuple(list(y2) + list(v2))
        
        if return_variables:
            return out, None, k1
        else:
            return out, None
        
class H4(FixedGridSolver):
    order = 4
    
    P = np.array([[7.0/48.0, 3.0/8.0, -1.0/48.0, -1.0/48.0, 3.0/8.0, 7.0/48.0],
                  [1.0/3.0, -1.0/3.0, 1.0, -1.0/3.0, 1.0/3.0, 0.0],
                  [0, 1.0/3.0, 0.0, 1.0, 2.0/3.0, 1.0],
                  [7.0/48.0, 25.0/48.0, 0.5, 23.0/48.0, 41.0/48.0, 1.0]])
    P = torch.from_numpy(P).float()
    def step(self, func, t, dt, y, return_variables=False):
        _len = len(y)
        y0, v0 = y[0:_len // 2], y[_len // 2:]
        
        
        
        for i in range(0, 5):
            k1 = func(y0,t+dt*self.P[2][i])#(t,y0)
            k2 = (1+np.sin(t+dt*self.P[3][i]))
            v1 = tuple(_v0 + dt*self.P[0][i]*_k1 for _v0, _k1 in zip(v0, k1))
            y1 = tuple(_y0 + dt*k2*self.P[1][i]*_v1 for _y0, _v1 in zip(y0, v1))
            y0 = y1
            v0 = v1
       
    

        out = tuple(list(y1) + list(v1))
        
        if return_variables:
            return out, None, k1
        else:
            return out, None
        
        
class H3(FixedGridSolver):
    order = 3
    
    P = np.array([[7.0/24.0, 3.0/4.0, -1.0/24.0],
                  [2.0/3.0, -2.0/3.0, 1.0],
                  [0.0, 2.0/3.0, 0.0],
                  [7.0/24.0, 25.0/24.0, 1.0]])
    P = torch.from_numpy(P).float()
    def step(self, func, t, dt, y, return_variables=False):
        _len = len(y)
        y0, v0 = y[0:_len // 2], y[_len // 2:]
        
        
        
        for i in range(0, 3):
            k1 = func(y0,t+dt*self.P[2][i])
            k2 = (1+np.sin(t+dt*self.P[3][i]))#(t,y0)
            v1 = tuple(_v0 + dt*self.P[0][i]*_k1 for _v0, _k1 in zip(v0, k1))
            y1 = tuple(_y0 + dt*k2*self.P[1][i]*_v1 for _y0, _v1 in zip(y0, v1))
            y0 = y1
            v0 = v1
       
    

        out = tuple(list(y1) + list(v1))
        
        if return_variables:
            return out, None, k1
        else:
            return out, None
        
class H2explicit(FixedGridSolver):
    order = 2
    
    P = np.array([[0.0, 1.0],
                  [1.0/2.0, 1.0/2.0],
                  [0.0, 1.0/2.0],
                  [0.0, 1.0]])
    P = torch.from_numpy(P).float()
    def step(self, func, t, dt, y, return_variables=False):
        _len = len(y)
        y0, v0 = y[0:_len // 2], y[_len // 2:]
        
        
        
        for i in range(0, 2):
            k1 = func(y0,t+dt*self.P[2][i])
            k2 = (1+np.sin(t+dt*self.P[3][i]))#(t,y0)
            v1 = tuple(_v0 + dt*self.P[0][i]*_k1 for _v0, _k1 in zip(v0, k1))
            y1 = tuple(_y0 + dt*k2*self.P[1][i]*_v1 for _y0, _v1 in zip(y0, v1))
            y0 = y1
            v0 = v1
       
    

        out = tuple(list(y1) + list(v1))
        
        if return_variables:
            return out, None, k1
        else:
            return out, None
                
        
        
class Yoshida4(FixedGridSolver):
    order = 4
    def step(self, f, t, dt, y, return_variables=False):
        
        #_len = len(y)
        #y0, v0 = y[0:_len // 2], y[_len // 2:]
        h = dt
        a = 1/(2 - 2**(1/3))
        
        b = 1 - 2*a

        u1 = Verlet.step(self,func = f, t=t, dt = a*h,y = y, return_variables= return_variables)#Verlet(f, t, (a*h),y)
        t1 = t + (a*h)
        

        u2 = Verlet.step(self,func = f, t=t1, dt = b*h,y = u1[0], return_variables= return_variables)#Verlet(f, t1, (b*h),u1[0])
        t2 = t1 + (b*h)
        
        

        out,error,variables = Verlet.step(self,func = f, t=t2, dt = a*h,y = u2[0], return_variables= return_variables)#Verlet(f, t2, (a*h),u2[0])
        t3 = t2 + (a*h)
        
    
        
        if return_variables:
            return out, error, variables
        else:
            return out, error
        
class Yoshida6(FixedGridSolver):
    order = 6
    P = np.array([[7.0/48.0, 3.0/8.0, -1.0/48.0, -1.0/48.0, 3.0/8.0, 7.0/48.0],
                  [1.0/3.0, -1.0/3.0, 1.0, -1.0/3.0, 1.0/3.0, 0.0],
                  [0.1458, 0.5208, 0.5, 0.4792, 0.8542, 1.0],
                  [0, 1.0/3.0, 0.0, 1.0, 2.0/3.0, 1.0]])
    P = torch.from_numpy(P).float()
    def step(self, f, t, dt, y, return_variables=False):
       h = dt
       a = 1/(2 - 2**(1/5))
       
       b = 1 - 2*a

       u1 = H4.step(self,func = f, t=t, dt = a*h,y = y, return_variables= return_variables)#Verlet(f, t, (a*h),y)
       t1 = t + (a*h)
       

       u2 = H4.step(self,func = f, t=t1, dt = b*h,y = u1[0], return_variables= return_variables)#Verlet(f, t1, (b*h),u1[0])
       t2 = t1 + (b*h)
       
       

       out,error,variables = H4.step(self,func = f, t=t2, dt = a*h,y = u2[0], return_variables= return_variables)#Verlet(f, t2, (a*h),u2[0])
       t3 = t2 + (a*h)
       
   
       
       if return_variables:
           return out, error, variables
       else:
           return out, error