import torch
from scipy.integrate import solve_ivp
import numpy as np

def simple_cosine(t, frequency, phase=0., amplitude=1.):
    r"""
    :math:`f(x) = A cos(2\pi f t + \phi)`

    where :math:`f` : frequency in (0, 1) [because it's discrete time]
    """
    return amplitude * torch.cos(2 * torch.pi * frequency * t + phase)


def rectified_cosine(t, frequency, phase=0., amplitude=1.):
    y = amplitude * torch.cos(2 * torch.pi * frequency * t + phase)
    return y*(y>0)


def fun_roessler(t, s, a, b, c):
    x, y, z = s
    return [-y-z, x+a*y, b + z*(x-c)]


def roessler_system(total_time, time_step, ic=torch.zeros(3), roessler_params=(0.2, 0.2, 5.7)):
    n_pts = int((total_time) / time_step) + 1
    t_eval = np.linspace(0, total_time, n_pts)
    sol = solve_ivp(fun_roessler, [-time_step, total_time+time_step], ic.numpy(),
                    t_eval=t_eval, args=roessler_params)
    return sol.t, torch.from_numpy(sol.y).float()/10  # shape = (3, n_pts) | division by 10 to adjust range


def fun_vanderpol(t, s, mu):
    x, v = s
    return [v, mu*(1 - x**2)*v - x]


def vanderpol_system(total_time, time_step, ic=torch.tensor([0, 1]), vanderPol_params=(2,)):
    n_pts = int((total_time) / time_step) + 1
    t_eval = np.linspace(0, total_time, n_pts)
    sol = solve_ivp(fun_vanderpol, [-time_step, total_time + time_step], ic.numpy(),
                    t_eval=t_eval, args=vanderPol_params)
    return sol.t, torch.from_numpy(sol.y).float() / 10  # shape = (2, n_pts)