# PDE for elastic pendulum

from pde.base import PDE
from integrator.base import Solver

import torch
from torch import sin, cos

class EP(PDE):
    # d/dt[theta, r, theta_t, r_t] = 
    #   [theta_t, r_t, 1/r(-g sin(theta) - theta_t r_t), r theta_t^2 - k/m (r - l0) + g cos(theta)]
    
    # only support 1D uniform grid
    def __init__(self, ic = None, g = 9.801, k_over_mg = 1, l0 = 2):
        super().__init__(None, ic, 0)
        self.g = g
        self.k2mg = k_over_mg
        self.l0 = l0
    
    
    def attach_solver(self, solver: Solver):
        def rhs(state, t):
            phi, r, phit, rt = state
            v1 = -1/r * (self.g * sin(phi) - phit * rt)
            v2 = r * phit**2 - self.g * (self.k2mg * self.g * (r - self.l0) - cos(phi))
            result = torch.cat( ( state[2:], torch.stack((v1, v2)) ) )
            return result
        solver.rhs = rhs

            
        
        