import argparse

from torch.nn import Softplus

from pina.model import FeedForward
from pina import Condition, CartesianDomain, Plotter, PINN


class FirstOrderODE(SpatialProblem):

    x_rng = [0, 5]
    output_variables = ['y']
    spatial_domain = CartesianDomain({'x': x_rng})

    def ode(input_, output_):
        y = output_
        x = input_
        return grad(y, x) + y - x

    def fixed(input_, output_):
        exp_value = 1.
        return output_ - exp_value

    def solution(self, input_):
        x = input_
        return x - 1.0 + 2*torch.exp(-x)

    conditions = {
        'bc': Condition(CartesianDomain({'x': x_rng[0]}), fixed),
        'dd': Condition(CartesianDomain({'x': x_rng}), ode),
    }
    truth_solution = solution


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="Run PINA")
    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument("-s", "-save", action="store_true")
    group.add_argument("-l", "-load", action="store_true")
    parser.add_argument("id_run", help="number of run", type=int)
    args = parser.parse_args()

    # define Problem + Model + PINN
    problem = FirstOrderODE()
    model = FeedForward(
        layers=[4]*2,
        output_variables=problem.output_variables,
        input_variables=problem.input_variables,
        func=Softplus,
    )
    pinn = PINN(problem, model, lr=0.03, error_norm='mse', regularizer=0)

    if args.s:

        pinn.span_pts(
            {'variables': ['x'], 'mode': 'grid', 'n': 1}, locations=['bc'])
        pinn.span_pts(
            {'variables': ['x'], 'mode': 'grid', 'n': 30}, locations=['dd'])
        Plotter().plot_samples(pinn, ['x'])
        pinn.train(1200, 50)
        pinn.save_state('pina.ode')

    else:
        pinn.load_state('pina.ode')
        plotter = Plotter()
        plotter.plot(pinn, components=['y'])
